# Sampling from a diffusion model
In this notebooks we will sampled from the previously trained diffusion model.
- We are going to compare the samples from DDPM and DDIM samplers
- Visualize mixing samples with conditional diffusion models

In [None]:
from types import SimpleNamespace
from pathlib import Path
import torch
import torch.nn.functional as F
import numpy as np
from utilities import *

import wandb

In [None]:
# wandb.login(relogin=True) # uncomment if you want to login to wandb

# Setting Things Up

In [None]:
# Wandb Params
PROJECT = "sprite_diffusion"
MODEL_ARTIFACT = "deeplearning-ai-temp/model-registry/SpriteGen:latest" 

# ddpm sampler hyperparameters
timesteps = 500
beta1 = 1e-4
beta2 = 0.02
num_samples = 30
height = 16
ddim_n = 25

device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))

# we are storing the parameters in a dictionary to be logged to wandb
config = SimpleNamespace(
    timesteps=timesteps,
    beta1=beta1,
    beta2=beta2,
    num_samples=num_samples,
    height=height,
    ddim_n=ddim_n,
    device=device,
)

We will load the model from a wandb.Artifact and set up the sampling loop.

In [None]:
def load_model(model_artifact_name):
    "Load the model from wandb artifacts"
    api = wandb.Api()
    artifact = api.artifact(model_artifact_name, type="model")
    model_path = Path(artifact.download())

    # recover model info from the registry
    producer_run = artifact.logged_by()

    # load the weights dictionary
    model_weights = torch.load(model_path/"context_model.pth", 
                               map_location="cpu")

    # create the model
    model = ContextUnet(in_channels=3, 
                        n_feat=producer_run.config["n_feat"], 
                        n_cfeat=producer_run.config["n_cfeat"], 
                        height=producer_run.config["height"])
    
    # load the weights into the model
    model.load_state_dict(model_weights)

    # set the model to eval mode
    model.eval()
    return model.to(device)

In [None]:
nn_model = load_model(MODEL_ARTIFACT)

## Sampling

We will sample and log the generated samples to wandb.

In [None]:
# construct DDPM noise schedule
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()    
ab_t[0] = 1

In [None]:
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
def denoise_add_noise(x, t, pred_noise, z=None):
    if z is None:
        z = torch.randn_like(x)
    noise = b_t.sqrt()[t] * z
    mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
    return mean + noise

sample with context using standard algorithm
we make a change to the original algorithm to allow for context 
and pass a fixed noise tensor (samples)

In [None]:
@torch.no_grad()
def sample_ddpm_context(samples, context, save_rate=20):
    # array to keep track of generated steps for plotting
    intermediate = [] 
    for i in range(timesteps, 0, -1):
        # reshape time tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        # sample some random noise to inject back in. For i = 1, don't add back in noise
        z = torch.randn_like(samples) if i > 1 else 0

        eps = nn_model(samples, t, c=context)    # predict noise
        samples = denoise_add_noise(samples, i, eps, z)
        if i % save_rate==0 or i==timesteps or i<8:
            print(f'sampling timestep {i:3d}', end='\r')
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples.clip(-1, 1), intermediate

Let's define a set of noises and a context vector to condition on.

In [None]:
# Noise vector
# x_T ~ N(0, 1), sample initial noise
noises = torch.randn(num_samples, 3, height, height).to(device)  

# A fixed context vector to sample from
ctx_vector = F.one_hot(torch.tensor([0,0,0,0,0,0,
                                     1,1,1,1,1,1,
                                     2,2,2,2,2,2,
                                     3,3,3,3,3,3,
                                     4,4,4,4,4,4]), 
                       5).to(device=device).float()

Let's bring that faster DDIM sampler from the diffusion course.

In [None]:
# define sampling function for DDIM   
# removes the noise using ddim
def denoise_ddim(x, t, t_prev, pred_noise):
    ab = ab_t[t]
    ab_prev = ab_t[t_prev]
    
    x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
    dir_xt = (1 - ab_prev).sqrt() * pred_noise

    return x0_pred + dir_xt

In [None]:
# fast sampling algorithm with context
@torch.no_grad()
def sample_ddim_context(samples, context, n=25): 
    # array to keep track of generated steps for plotting
    intermediate = [] 
    step_size = timesteps // n
    for i in range(timesteps, 0, -step_size):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        eps = nn_model(samples, t, c=context)    # predict noise e_(x_t,t)
        samples = denoise_ddim(samples, i, i - step_size, eps)
        intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples.clip(-1, 1), intermediate

Let's create a `wandb.Table` to store our generations

In [None]:
table = wandb.Table(columns=["input_noise", "ddpm", "ddim", "class"])

let's compute ddpm samples as before

In [None]:
ddpm_samples, _ = sample_ddpm_context(noises, ctx_vector)

For DDIM we can control the step size by the `n` param:

In [None]:
ddim_samples, _ = sample_ddim_context(noises, ctx_vector, n=ddim_n)

In [None]:
def ctx_to_classes(ctx_vector):
    classes = "hero,non-hero,food,spell,side-facing".split(",")
    return [classes[i] for i in ctx_vector.argmax(dim=1)]

Let's keep track of the sampling params on a dictionary

We can add the rows to the table one by one, we also cast images to `wandb.Image` so we can render them correctly in the UI

In [None]:
for noise, ddpm_sample, ddim_sample, c in zip(noises, ddpm_samples, ddim_samples, ctx_to_classes(ctx_vector)):
    table.add_data(wandb.Image(noise), 
                   wandb.Image(ddpm_sample), 
                   wandb.Image(ddim_sample),
                   c)

we log the table to W&B, we can also use `wandb.init` as a context manager, this way we ensure that the run is finished when exiting the manager.

In [None]:
with wandb.init(project=PROJECT, job_type="samplers_battle", anonymous="allow", config=config):
    wandb.log({"samplers_table":table})