In [None]:
import time
from pathlib import Path
import shutil

import cv2
import diffusers
import matplotlib.pyplot as plt
import numpy as np
import torch
from diffusers import LMSDiscreteScheduler, StableDiffusionPipeline
from PIL import Image
from skimage.exposure import match_histograms
from torch import autocast
from torchvision import transforms as tfms
from tqdm.auto import tqdm

from img2img import StableDiffusionImg2ImgPipeline

In [None]:
STABLE_DIFFUSION_MODEL_PATH = Path.home() / "Desktop/stable-diffusion-v1-4"

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Load the txt_to_img pipeline
txt_to_img = StableDiffusionPipeline.from_pretrained(
    str(STABLE_DIFFUSION_MODEL_PATH), revision="fp16", torch_dtype=torch.float16
)
# Turn off safety_checker to avoid false positives
txt_to_img.safety_checker = lambda images, **kwargs: (images, False)
# txt_to_img.enable_attention_slicing()  # use less vram
txt_to_img = txt_to_img.to(device)

# Load the img2img pipeline, using the models
# from the txt_to_img pipeline, to not waste vram.
im2im = StableDiffusionImg2ImgPipeline(
    vae=txt_to_img.vae,
    text_encoder=txt_to_img.text_encoder,
    tokenizer=txt_to_img.tokenizer,
    unet=txt_to_img.unet,
    scheduler=LMSDiscreteScheduler(
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        num_train_timesteps=1000,
    ),
)
# im2im.enable_attention_slicing()
im2im.progress_bar = lambda iterable: iterable  # turn off progress bar
im2im.to(device)

In [None]:
# Helpers


def timestamp():
    return time.strftime("%Y%m%d-%H%M%S")


def maintain_colors(prev_img, color_match_sample, mode):
    # source: https://colab.research.google.com/github/deforum/stable-diffusion/blob/main/Deforum_Stable_Diffusion.ipynb#scrollTo=2g-f7cQmf2Nt
    if mode == "Match Frame 0 RGB":
        return match_histograms(prev_img, color_match_sample, multichannel=True)
    elif mode == "Match Frame 0 HSV":
        prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)
        color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)
        matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)
        return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB)
    else:  # Match Frame 0 LAB
        prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB)
        color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB)
        matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True)
        return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB)

In [None]:
OUTPUT_DIR = Path(f"images/{timestamp()}")
PROMPT_A = "A photo of a bowl of fruit"
PROMPT_B = "A photo of an acrobat"
GUIDANCE_SCALE = 7.5
IMG2IMG_STRENGTH = 0.45
NUM_IMG2IMG_STEPS = 100
SEED = 0
WIDTH = 512
HEIGHT = 512
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
# Use same random seed for everything
generator = torch.Generator("cuda").manual_seed(SEED)
# To know where things were
notebook_path = Path(__vsc_ipynb_file__)  # vscode only
print(f"{notebook_path.resolve() = }")
print(f"{OUTPUT_DIR.resolve() = }")

In [None]:
# save a copy of this nb in OUTPUT_DIR, for reproducibility
shutil.copy(src=notebook_path, dst=OUTPUT_DIR / f"{notebook_path.stem}_{timestamp()}.ipynb")

In [None]:
# Generate the initial image
with autocast("cuda"), torch.no_grad():
    init_image = txt_to_img(
        [PROMPT_A], width=WIDTH, height=HEIGHT, generator=generator
    )["sample"][0]
init_image.save(OUTPUT_DIR / f"{PROMPT_A}_{PROMPT_B}_{0:04d}.jpg")

In [None]:
image = init_image
# Generate the rest of the images
for i in tqdm(range(1, NUM_IMG2IMG_STEPS)):
    # Try to prevent colours from going red
    image = maintain_colors(np.array(image), np.array(init_image), "Match Frame 0 RGB")
    image = Image.fromarray(image)
    generator = torch.Generator("cuda").manual_seed(i)
    with autocast("cuda"), torch.no_grad():
        image = im2im(
            PROMPT_B,
            image,
            strength=IMG2IMG_STRENGTH,
            guidance_scale=GUIDANCE_SCALE,
            generator=generator,
        )["sample"][0]
    image.save(OUTPUT_DIR / f"{PROMPT_A}_{PROMPT_B}_{i:04d}.jpg")