## **Conditional DDPM on CIFAR-10 dataset** 

This notebook trains a classifier-free conditional DDPM on the CIFAR-10 dataset. 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Set device and hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
num_epochs = 50
learning_rate = 0.0002
beta_start = 0.0001
beta_end = 0.02
timesteps = 1000
num_classes = 10
dropout_prob = 0.1
base_channels = 64

In [None]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, transform=transform, download=True
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Define Complex Conditional U-Net
class ConditionalUNet(nn.Module):
    def __init__(self, num_classes, base_channels=64):
        super(ConditionalUNet, self).__init__()
        
        # Class embedding
        self.class_emb = nn.Embedding(num_classes + 1, base_channels)
        
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, base_channels, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2)
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1), nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2)
        
        # Middle
        self.middle = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels * 4, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1), nn.ReLU()
        )
        
        # Decoder
        self.up1 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(base_channels * 4 + base_channels, base_channels * 2, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1), nn.ReLU()
        )
        
        self.up2 = nn.ConvTranspose2d(base_channels * 2, base_channels, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(base_channels * 2 + base_channels, base_channels, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.ReLU()
        )
        
        self.final = nn.Conv2d(base_channels, 1, 1)
        
    def forward(self, x, class_labels):
        class_emb = self.class_emb(class_labels).unsqueeze(-1).unsqueeze(-1)
        
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        
        m = self.middle(p2)
        
        u1 = self.up1(m)
        class_emb_resized = class_emb.expand(-1, -1, u1.shape[2], u1.shape[3])
        d1_input = torch.cat([u1, e2, class_emb_resized], dim=1)
        d1 = self.dec1(d1_input)
        
        u2 = self.up2(d1)
        class_emb_resized = class_emb.expand(-1, -1, u2.shape[2], u2.shape[3])
        d2_input = torch.cat([u2, e1, class_emb_resized], dim=1)
        d2 = self.dec2(d2_input)
        
        return torch.tanh(self.final(d2))

In [None]:
# Diffusion schedule
def get_betas(timesteps):
    betas = torch.linspace(beta_start, beta_end, timesteps)
    return betas.to(device)

def get_alphas(betas):
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return alphas, alphas_cumprod

def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    
    betas = get_betas(timesteps)
    alphas, alphas_cumprod = get_alphas(betas)
    
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod[t])
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod[t])
    
    return (sqrt_alphas_cumprod * x_start + 
            sqrt_one_minus_alphas_cumprod * noise)

In [None]:
# Training setup
model = ConditionalUNet(num_classes, base_channels).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

betas = get_betas(timesteps)
alphas, alphas_cumprod = get_alphas(betas)

In [None]:
def train_model():
    losses = []
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            batch_size = images.shape[0]
            
            t = torch.randint(0, timesteps, (batch_size,), device=device)
            mask = (torch.rand(batch_size, device=device) > dropout_prob).long()
            conditioned_labels = labels * mask + (1 - mask) * num_classes
            
            noise = torch.randn_like(images)
            x_noisy = q_sample(images, t, noise)
            predicted_noise = model(x_noisy, conditioned_labels)
            
            loss = criterion(predicted_noise, noise)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
            if i % 100 == 0:
                print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(train_loader)}], '
                      f'Loss: {loss.item():.4f}')
        
        # Average loss for the epoch
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
    
    return losses

In [None]:
losses = train_model()

In [None]:
# Part 1: Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.title('Training Loss Over Time')
plt.grid(True)
plt.legend()
plt.show()

In [None]:
# Conditional sampling with trajectory
@torch.no_grad()
def sample_conditional(model, num_samples=1, class_label=None, guidance_scale=3.0, save_trajectory=False):
    model.eval()
    x = torch.randn(num_samples, 1, 28, 28).to(device)
    
    if class_label is None:
        labels = torch.arange(num_samples) % num_classes
    else:
        labels = torch.full((num_samples,), class_label, dtype=torch.long)
    labels = labels.to(device)
    null_labels = torch.full((num_samples,), num_classes, dtype=torch.long).to(device)
    
    trajectory = [x.cpu().numpy()] if save_trajectory else None
    
    for t in reversed(range(timesteps)):
        t_tensor = torch.full((num_samples,), t, device=device)
        pred_noise_cond = model(x, labels)
        pred_noise_uncond = model(x, null_labels)
        predicted_noise = pred_noise_uncond + guidance_scale * (pred_noise_cond - pred_noise_uncond)
        
        betas_t = betas[t]
        alphas_t = alphas[t]
        alphas_cumprod_t = alphas_cumprod[t]
        
        x = (1 / torch.sqrt(alphas_t)) * (
            x - ((1 - alphas_t) / torch.sqrt(1 - alphas_cumprod_t)) * predicted_noise
        )
        
        if t > 0:
            x = x + torch.sqrt(betas_t) * torch.randn_like(x)
        
        if save_trajectory and t % 100 == 0:
            trajectory.append(x.cpu().numpy())
    
    return x, trajectory

In [None]:
# Part 2: Samples for different class conditions
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(10):
    samples, _ = sample_conditional(model, num_samples=1, class_label=i)
    samples = samples.cpu().numpy()
    ax = axes[i // 5, i % 5]
    ax.imshow(samples[0, 0], cmap='gray')
    ax.set_title(f'Digit {i}')
    ax.axis('off')
plt.suptitle("Samples for Each Class")
plt.tight_layout()
plt.show()

In [None]:
# Part 3: Trajectory for digit 5
samples, trajectory = sample_conditional(model, num_samples=1, class_label=5, save_trajectory=True)
trajectory = np.concatenate(trajectory, axis=0)

fig, axes = plt.subplots(1, len(trajectory), figsize=(15, 3))
for i, (img, ax) in enumerate(zip(trajectory, axes)):
    ax.imshow(img[0, 0], cmap='gray')
    ax.set_title(f'Step {timesteps - (i * 100)}')
    ax.axis('off')
plt.suptitle("Trajectory from Noise to Digit 5")
plt.tight_layout()
plt.show()

In [None]:
# Part 4: Samples with different guidance scales for digit 5
guidance_scales = [0.0, 1.0, 2.0, 3.0, 5.0]
fig, axes = plt.subplots(1, len(guidance_scales), figsize=(15, 3))

for i, scale in enumerate(guidance_scales):
    samples, _ = sample_conditional(model, num_samples=1, class_label=5, guidance_scale=scale)
    samples = samples.cpu().numpy()
    ax = axes[i]
    ax.imshow(samples[0, 0], cmap='gray')
    ax.set_title(f'Guidance Scale: {scale}')
    ax.axis('off')
plt.suptitle("Effect of Guidance Scale on Digit 5")
plt.tight_layout()
plt.show()