### Sampling from Diffusion Model using DDIM

In [1]:
import torch
from pathlib import Path

from utils import sample, ddim_step, latents_to_img
from models import UNet2DModel, UNet2DConditionModel
from display import plot_sample, plot_sample_one
from IPython.display import HTML

device = "cuda" if torch.cuda.is_available() else "cpu"

#### Fashion MNIST

In [None]:
save_dir = Path("./weights")

model = UNet2DConditionModel(n_classes=10, in_channels=1, out_channels=1, nfs=(32,64,128,256), num_layers=2).to("cuda")
model.load_state_dict(torch.load(save_dir / "fashion_mnist_cond_model_24_bs_512.pth"))
model.eval()
print("loaded in model with context")

# generate sample with condition as the label
sz = (16,1,32,32)
cids = 0 # T-Shirts
# or below to generate random samples
# cids = torch.randint(0, 10, (sz[0],), dtype=torch.int32)

samples, intermediates = sample(ddim_step, model, sz, c=cids, steps=20)
anim = plot_sample(intermediates, n_sample=16, nrows=2, save_as="fashion_mnist_cond.gif")
HTML(anim.to_jshtml())

#### Bored Ape Yacht Club NFT

In [2]:
save_dir = Path("./weights")

# load the pre-trained model (see bayc_training.ipynb notebook for detail)
model = UNet2DModel(in_channels=4, out_channels=4, nfs=(32,64,128,256), num_layers=2).to("cuda")
model.load_state_dict(torch.load("bayc_model_16_bs_16.pth"))
model.eval()
print("loaded in model with context")

# generate sample
sz = (16,4,64,64)
random_select = torch.randint(0,16,(1,))

# load VAE model
from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)

# generate samples using DDIM with 20 steps 
samples, intermediates = sample(ddim_step, model, sz, steps=20)
intermediates = torch.tensor(intermediates[:,random_select]).view(-1,4,64,64).to("cuda")
intermediates_decoded = latents_to_img(vae, intermediates)
anim = plot_sample_one(intermediates_decoded, save_as="BAYC_uncond.gif")
HTML(anim.to_jshtml())

loaded in model with context
saved gif at /home/ubuntu/animation/BAYC_uncond.gif
