In [None]:

import torch 
import torch.nn as nn 
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt 
import numpy as np 
from tqdm import tqdm 


In [None]:

transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Lambda(lambda x: x * 2. - 1.)
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
class UNet(nn.Module): 
    def __init__(self, in_channels=3, base_channels=64, num_classes=10, cond_dropout=0.1): 
        super().__init__()
        self.cond_dropout = cond_dropout
        self.label_emb = nn.Embedding(num_classes, 1)

        # Downsampling path 
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels + 2, base_channels, 3, padding=1), nn.ReLU(), 
            nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1), nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2)

        # Bottleneck 
        self.middle = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels * 4, 3, padding=1), nn.ReLU(), 
            nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1), nn.ReLU()
        )

        # Upsampling path 
        self.up1 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(base_channels * 4, base_channels * 2, 3, padding=1), nn.ReLU(), 
            nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1), nn.ReLU()
        )

        self.up2 = nn.ConvTranspose2d(base_channels * 2, base_channels, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels, 3, padding=1), nn.ReLU(), 
            nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.ReLU()
        )

        self.final = nn.Conv2d(base_channels, in_channels, 1)
        
    def forward(self, x, t, y=None): 
        # Time embedding
        t = t[:, None, None, None].repeat(1, 1, x.shape[2], x.shape[3])
        
        # Classifier-free dropout
        if y is not None and torch.rand(1).item() < self.cond_dropout:
            y = None
        y_emb = self.label_emb(y) if y is not None else torch.zeros(x.size(0), 1, device=x.device)
        y_emb = y_emb[:, :, None, None].repeat(1, 1, x.shape[2], x.shape[3])

        x = torch.cat([x, t, y_emb], dim=1)

        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        mid = self.middle(self.pool2(e2))

        d1 = self.up1(mid)
        d1 = self.dec1(torch.cat([d1, e2], dim=1))

        d2 = self.up2(d1)
        d2 = self.dec2(torch.cat([d2, e1], dim=1))

        return self.final(d2)

In [None]:

def linear_beta_schedule(timesteps):
    beta_start = 1e-4
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

T = 1000
betas = linear_beta_schedule(T).to(device)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

def q_sample(x_start, t, noise):
    return sqrt_alphas_cumprod[t][:, None, None, None] * x_start +            sqrt_one_minus_alphas_cumprod[t][:, None, None, None] * noise


In [None]:

model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

for epoch in range(5):  # reduce for demonstration
    model.train()
    for x, y in tqdm(train_loader):
        x = x.to(device)
        y = y.to(device)
        t = torch.randint(0, T, (x.size(0),), device=device).long()
        noise = torch.randn_like(x)
        x_t = q_sample(x, t, noise)
        pred = model(x_t, t, y)
        loss = F.mse_loss(pred, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1} loss: {loss.item():.4f}")


In [None]:

@torch.no_grad()
def p_sample(model, x, t, y, guidance_scale=5.0):
    eps_cond = model(x, t, y)
    eps_uncond = model(x, t, None)
    eps = (1 + guidance_scale) * eps_cond - guidance_scale * eps_uncond
    alpha_t = alphas[t][:, None, None, None]
    alpha_bar_t = alphas_cumprod[t][:, None, None, None]
    beta_t = betas[t][:, None, None, None]
    sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
    sqrt_recip_alpha_bar_t = (1. / sqrt_alphas_cumprod[t])[:, None, None, None]

    x0_pred = (x - sqrt_one_minus_alpha_bar_t * eps) / sqrt_alphas_cumprod[t][:, None, None, None]
    mean = (1 / torch.sqrt(alpha_t)) * (x - beta_t * eps / sqrt_one_minus_alpha_bar_t)
    if t[0] == 0:
        return x0_pred
    noise = torch.randn_like(x)
    return mean + torch.sqrt(beta_t) * noise


In [None]:

@torch.no_grad()
def sample_ddpm(model, label, guidance_scale=5.0):
    model.eval()
    x = torch.randn((16, 3, 32, 32)).to(device)
    imgs = []
    for t_val in reversed(range(T)):
        t = torch.tensor([t_val] * x.size(0)).to(device)
        x = p_sample(model, x, t, torch.tensor([label] * x.size(0)).to(device), guidance_scale)
        if t_val % 100 == 0:
            imgs.append(x.clamp(-1, 1).cpu())
    return imgs

# Generate and visualize with different guidance scales
scales = [0.0, 2.5, 5.0]
fig, axs = plt.subplots(len(scales), 5, figsize=(15, 9))

for i, scale in enumerate(scales):
    imgs = sample_ddpm(model, label=3, guidance_scale=scale)
    for j in range(5):
        grid = make_grid(imgs[j], nrow=4, normalize=True, value_range=(-1, 1))
        axs[i, j].imshow(grid.permute(1, 2, 0))
        axs[i, j].set_title(f"t={T - j * 100}")
        axs[i, j].axis('off')
    axs[i, 0].set_ylabel(f"Scale={scale}", fontsize=14)

plt.tight_layout()
plt.show()


In [None]:

@torch.no_grad()
def sample_multiple_classes(model, labels, guidance_scale=5.0):
    model.eval()
    n = len(labels)
    x = torch.randn((n, 3, 32, 32)).to(device)
    for t_val in reversed(range(T)):
        t = torch.tensor([t_val] * n).to(device)
        x = p_sample(model, x, t, torch.tensor(labels).to(device), guidance_scale)
    return x.clamp(-1, 1).cpu()

# Visualize samples for all 10 CIFAR-10 classes
label_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
labels = list(range(10))
samples = sample_multiple_classes(model, labels, guidance_scale=5.0)

fig, axs = plt.subplots(1, 10, figsize=(15, 2))
for i, img in enumerate(samples):
    axs[i].imshow((img.permute(1, 2, 0) + 1) / 2)
    axs[i].set_title(label_names[i], fontsize=9)
    axs[i].axis('off')
plt.suptitle("Classifier-Free Guidance Samples (1 per class)")
plt.show()


In [None]:

@torch.no_grad()
def visualize_trajectory(model, label=3, guidance_scale=5.0):
    x = torch.randn((1, 3, 32, 32)).to(device)
    trajectory = []

    for t_val in reversed(range(T)):
        t = torch.tensor([t_val]).to(device)
        x = p_sample(model, x, t, torch.tensor([label]).to(device), guidance_scale)
        if t_val % 100 == 0 or t_val in [0, T-1]:
            trajectory.append(x.squeeze().clamp(-1, 1).cpu())

    # Plot
    fig, axs = plt.subplots(1, len(trajectory), figsize=(18, 2))
    for i, img in enumerate(trajectory):
        axs[i].imshow((img.permute(1, 2, 0) + 1) / 2)
        axs[i].set_title(f"t={T - i*100}")
        axs[i].axis("off")
    plt.suptitle("Sampling Trajectory for One Image")
    plt.show()

# Visualize
visualize_trajectory(model, label=3, guidance_scale=5.0)
