In [6]:
from diffusers import UNet2DModel, DDIMScheduler, VQModel
import torch
import PIL.Image
import numpy as np
import tqdm

torch.__version__

'1.13.1+cu117'

In [7]:
# load all models
unet = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet")
vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler")

# set to cuda
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

unet.to(torch_device)
vqvae.to(torch_device);

The config attributes {'timestep_values': None, 'timesteps': 1000} were passed to DDIMScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.


In [8]:
seed = 3

# generate gaussian noise to be decoded
generator = torch.manual_seed(seed)
noise = torch.randn(
    (1, unet.in_channels, unet.sample_size, unet.sample_size),
    generator=generator,
).to(torch_device)

# set inference steps for DDIM
scheduler.set_timesteps(num_inference_steps=200)

  (1, unet.in_channels, unet.sample_size, unet.sample_size),


In [9]:
image = noise
frames_unsampled = []
for t in tqdm.tqdm(scheduler.timesteps):
    # predict noise residual of previous image
    with torch.no_grad():
        residual = unet(image, t)["sample"]

    # compute previous image x_t according to DDIM formula
    prev_image = scheduler.step(residual, t, image, eta=0.0)["prev_sample"]

    # x_t-1 -> x_t
    image = prev_image

    # decode image with vae
    with torch.no_grad():
        image_frame = vqvae.decode(image)
        frames_unsampled.append(image_frame)

100%|██████████| 200/200 [00:14<00:00, 13.67it/s]


In [11]:
from PIL import Image, ImageSequence

frames = []

for frame in frames_unsampled:
    # process image
    image_processed = frame.sample.cpu().permute(0, 2, 3, 1)
    image_processed = (image_processed + 1.0) * 127.5
    image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8)
    image_pil = PIL.Image.fromarray(image_processed[0])
    frames.append(image_pil)

frames[-1].info["duration"] = 100000
frames[0].save(f"figures/output_{seed}.gif", format="GIF", append_images=frames[1:], save_all=True, duration=0.5, loop=0)