# Stable Diffusion to generate CIFAR-10 images

In [1]:
import os

from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import models, transforms

from torchvision.datasets import CIFAR10
import matplotlib.pyplot as plt

import numpy as np
from IPython.display import HTML

from dataclasses import dataclass

from models.context_unet import ContextUnet
from utils import *
import utils

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters

In [2]:
@dataclass
class hyperparameters:
    # data hyperparams
    num_classes: int = 10

    # diffusion hyperparams
    timesteps: int = 500

    # ddpm hyperparams
    beta1: float = 1e-4
    beta2: float = 0.02

    # model hyperparams
    n_feat: int = 64 # 64 hidden dimension feature
    n_cfeat: int = 10 # context vector is of size 10
    height: int = 28 # 28x28 image
    n_channels: int = 1
    save_dir: str = './weights/ddpm/'

    # training hyperparams
    n_epochs: int = 32
    batch_size: int = 100
    learning_rate: float = 1e-3

hyperparams = hyperparameters()

# Denoising parameters schedule

DDPM proposes a forward process whose variance $\beta_t$ follows a schedule along the timesteps. So, to denoise the samples, the parameters also should be scheduled accordingly: 

* $ \beta_t = \lfloor \left(\frac{\Beta_2 - \Beta_1}{N}\right)t \rfloor + \Beta_1 $
* $ \alpha_t = 1 - \beta_t $
* $ \bar{\alpha}_t = \prod_{s=0}^t \alpha_s $
* $ \bar{\alpha}_0 = 1 $

Where $\Beta_1$ and $\Beta_2$ are, respectively, the forward process variance upper and lower bounds and N is the number of timesteps.

In [3]:
# construct DDPM noise schedule
b_t = (hyperparams.beta2 - hyperparams.beta1) * torch.linspace(0, 1, hyperparams.timesteps + 1, device=DEVICE) + hyperparams.beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()    
ab_t[0] = 1

# Remove predicted noise

In DDIM generative process, we use an estimated prediction $f_\theta (x_t)$ of the denoised observation $x_0$ which can be calculated by

$$ f_\theta (x_t) = \frac{1}{\sqrt{\alpha_t}} (x_t - \sqrt{1 - \alpha_t}\cdot \epsilon_\theta(x_t)) $$

Then we can generate a sample $x_{t-1}$ from a sample $x_t$ by "moving" from $x_0$ along the direction of $x_t$:

$$ x_{t-1} = \sqrt{\alpha_{t-1}} \cdot f_\theta (x_t) + \sqrt{1-\alpha_{t-1}} \cdot \epsilon_\theta(x_t) $$

In [4]:
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
def denoise_ddim(x, t, t_prev, pred_noise):
    ab = ab_t[t]
    ab_prev = ab_t[t_prev]
    
    x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
    dir_xt = (1 - ab_prev).sqrt() * pred_noise

    return x0_pred + dir_xt

# Creating the model

In [5]:
# construct model
nn_model = ContextUnet(in_channels=1, n_feat=hyperparams.n_feat, n_cfeat=hyperparams.n_cfeat, height=hyperparams.height).to(DEVICE)
optim = torch.optim.Adam(nn_model.parameters(), lr=hyperparams.learning_rate)

# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{hyperparams.save_dir}/../ddpm.pth", map_location=DEVICE))
nn_model.eval()
print("Loaded in Model")

Loaded in Model


  nn_model.load_state_dict(torch.load(f"{hyperparams.save_dir}/../ddpm.pth", map_location=DEVICE))


# Sampling intermediate latents along the Markov chain 

In [6]:
# sample with context using standard algorithm
@torch.no_grad()
def sample_ddim(n_sample, context, n=20):
    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_sample, hyperparams.n_channels, hyperparams.height, hyperparams.height).to(DEVICE)  

    # array to keep track of generated steps for plotting
    intermediate = [] 
    step_size = hyperparams.timesteps // n
    for i in range(hyperparams.timesteps, 0, -step_size):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / hyperparams.timesteps])[:, None, None, None].to(DEVICE)

        eps = nn_model(samples, t, c=context)    # predict noise e_(x_t,t)
        samples = denoise_ddim(samples, i, i - step_size, eps)
        intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate

In [7]:
# helper function: perturbs an image to a specified noise level
def perturb_input(x, t, noise):
    return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise

# Loading data

In [8]:
dataset = CIFAR10(root='./data', download=True, transform=
        transforms.Compose([
        transforms.ToTensor(),                # from [0,255] to range [0.0,1.0]
        transforms.Normalize((0.5,), (0.5,))  # range [-1,1]
]))

dataloader = DataLoader(dataset, batch_size=hyperparams.batch_size, shuffle=True, num_workers=1)

# Training loop

Training in DDIM is the same as in DDPM

In [9]:
# # training without context code

# # set into train mode
# nn_model.train()

# for ep in range(hyperparams.n_epochs):
#     print(f'epoch {ep}')
    
#     # linearly decay learning rate
#     optim.param_groups[0]['lr'] = hyperparams.learning_rate*(1-ep/hyperparams.n_epochs)
    
#     pbar = tqdm(dataloader, mininterval=2 )
#     for x, c in pbar:   # x: images  c: context
#         optim.zero_grad()
#         x = x.to(DEVICE)
#         c = F.one_hot(c, num_classes=10).float().to(DEVICE)

#         # randomly mask out c
#         context_mask = torch.bernoulli(torch.zeros(c.shape[0]) + 0.9).to(DEVICE)
#         c = c * context_mask.unsqueeze(-1)
        
#         # perturb data
#         noise = torch.randn_like(x)
#         t = torch.randint(1, hyperparams.timesteps + 1, (x.shape[0],)).to(DEVICE) 
#         x_pert = perturb_input(x, t, noise)
        
#         # use network to recover noise
#         pred_noise = nn_model(x_pert, t / hyperparams.timesteps, c=c)
        
#         # loss is mean squared error between the predicted and true noise
#         loss = F.mse_loss(pred_noise, noise)
#         loss.backward()
        
#         optim.step()

#     # save model periodically
#     if ep%4==0 or ep == int(hyperparams.n_epochs-1):
#         if not os.path.exists(hyperparams.save_dir):
#             os.mkdir(hyperparams.save_dir)
#         torch.save(nn_model.state_dict(), hyperparams.save_dir + f"context_model_{ep}.pth")
#         print('saved model at ' + hyperparams.save_dir + f"context_model_{ep}.pth")

# Visualizing results

In [11]:
# visualize samples with randomly selected context
plt.clf()
ctx = F.one_hot(torch.randint(0, 10, (32,)), 10).to(DEVICE).float()
samples, intermediate = sample_ddim(32, ctx)
animation_ddpm_context = utils.plot_sample(intermediate.squeeze(),32,4,hyperparams.save_dir, "ani_run", None, save=False)
HTML(animation_ddpm_context.to_jshtml())

gif animating frame 19 of 20

<Figure size 640x480 with 0 Axes>

In [13]:
#########################################################################################

plt.clf()
ctx = F.one_hot(torch.arange(1, 10), 10).to(DEVICE).float()
samples, intermediate = sample_ddim(9, ctx)
animation_ddpm_context = utils.plot_sample(intermediate.squeeze(axis=2), 9, 3, hyperparams.save_dir, "ani_run", None, save=True)
HTML(animation_ddpm_context.to_jshtml())

saved gif at ./weights/ddpm/ani_run_wNone.gif
gif animating frame 19 of 20

<Figure size 640x480 with 0 Axes>