# Assignment 4: Advanced Image Generation
## Diffusion Models and Energy-Based Models

**Student**: my2878  
**Course**: Generative AI  
**Date**: November 2025

---

## Overview

This notebook implements two advanced generative models:

1. **Diffusion Model (DDPM)** - Denoising Diffusion Probabilistic Model
2. **Energy-Based Model (EBM)** - Using Langevin Dynamics

Both models are trained on **CIFAR-10** dataset and integrated into the FastAPI.

---

## Table of Contents

1. [Setup and Imports](#setup)
2. [Part 1: Diffusion Model Implementation](#diffusion)
3. [Part 2: Energy-Based Model Implementation](#energy)
4. [Part 3: Training on CIFAR-10](#training)
5. [Part 4: Theory Questions](#theory)
6. [Part 5: Results and Visualization](#results)
7. [Part 6: API Integration](#api)


---

## 1. Setup and Imports


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
from typing import Optional, Tuple
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 
                     'mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Using device: {device}')

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


---

## 2. Part 1: Diffusion Model Implementation

### 2.1 Sinusoidal Time Embedding

The sinusoidal time embedding provides a continuous representation of timesteps using sine and cosine functions.

**Mathematical Formula:**

For timestep $t$ and embedding dimension $d$, the $i$-th dimension is:

$$
\text{embedding}[2i] = \sin\left(\frac{t}{10000^{2i/d}}\right)
$$

$$
\text{embedding}[2i+1] = \cos\left(\frac{t}{10000^{2i/d}}\right)
$$


In [None]:
class SinusoidalTimeEmbedding(nn.Module):
    """
    Sinusoidal Time Embedding for diffusion timesteps.
    
    This provides a continuous, deterministic embedding for each timestep
    using sine and cosine functions at different frequencies.
    """
    
    def __init__(self, embedding_dim=128, max_period=10000):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.max_period = max_period
    
    def forward(self, timesteps):
        """
        Args:
            timesteps: (batch_size,) tensor of timestep indices
        
        Returns:
            (batch_size, embedding_dim) tensor of embeddings
        """
        device = timesteps.device
        half_dim = self.embedding_dim // 2
        
        # Calculate frequency scaling
        frequencies = torch.exp(
            -math.log(self.max_period) * torch.arange(half_dim, device=device) / half_dim
        )
        
        # Compute arguments: t * frequency
        args = timesteps[:, None].float() * frequencies[None, :]
        
        # Concatenate sin and cos components
        embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        
        return embedding

# Test the time embedding
time_embed = SinusoidalTimeEmbedding(embedding_dim=8, max_period=10000)
t = torch.tensor([1])
embedding = time_embed(t)
print("Time Embedding for t=1, d=8:")
print(embedding.numpy()[0])
print("\nThis matches our theoretical calculation!")


In [None]:
class ResidualBlock(nn.Module):
    """Residual block with time embedding injection."""
    
    def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.dropout = nn.Dropout(dropout)
        
        if in_channels != out_channels:
            self.residual_conv = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.residual_conv = nn.Identity()
    
    def forward(self, x, time_emb):
        residual = x
        x = self.conv1(x)
        x = self.norm1(x)
        x = F.relu(x)
        
        # Inject time embedding
        time_emb = self.time_mlp(F.relu(time_emb))
        x = x + time_emb[:, :, None, None]
        
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.dropout(x)
        x = F.relu(x)
        
        return x + self.residual_conv(residual)


class AttentionBlock(nn.Module):
    """Self-attention block for spatial features."""
    
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.GroupNorm(8, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)
    
    def forward(self, x):
        B, C, H, W = x.shape
        residual = x
        
        x = self.norm(x)
        qkv = self.qkv(x)
        q, k, v = torch.chunk(qkv, 3, dim=1)
        
        q = q.reshape(B, C, H * W).permute(0, 2, 1)
        k = k.reshape(B, C, H * W)
        v = v.reshape(B, C, H * W).permute(0, 2, 1)
        
        attn = torch.bmm(q, k) / math.sqrt(C)
        attn = F.softmax(attn, dim=-1)
        out = torch.bmm(attn, v)
        out = out.permute(0, 2, 1).reshape(B, C, H, W)
        out = self.proj(out)
        
        return out + residual

print("✓ UNet building blocks defined")


### 2.3 Complete UNet Architecture

The UNet predicts the noise $\epsilon$ that was added to create the noisy image at timestep $t$.


In [None]:
# Simplified UNet for CIFAR-10 (32x32 images)
class SimpleUNet(nn.Module):
    """Simplified UNet for CIFAR-10 diffusion model."""
    
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=128):
        super().__init__()
        
        # Time embedding
        self.time_embed = nn.Sequential(
            SinusoidalTimeEmbedding(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim * 4),
            nn.ReLU(),
            nn.Linear(time_emb_dim * 4, time_emb_dim * 4)
        )
        
        # Encoder
        self.init_conv = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.down1 = ResidualBlock(64, 128, time_emb_dim * 4)
        self.down2 = ResidualBlock(128, 256, time_emb_dim * 4)
        self.downsample1 = nn.Conv2d(128, 128, 3, stride=2, padding=1)
        self.downsample2 = nn.Conv2d(256, 256, 3, stride=2, padding=1)
        
        # Bottleneck
        self.mid1 = ResidualBlock(256, 256, time_emb_dim * 4)
        self.mid_attn = AttentionBlock(256)
        self.mid2 = ResidualBlock(256, 256, time_emb_dim * 4)
        
        # Decoder
        self.up1 = ResidualBlock(512, 128, time_emb_dim * 4)  # 256 + 256
        self.up2 = ResidualBlock(256, 64, time_emb_dim * 4)   # 128 + 128
        self.upsample1 = nn.ConvTranspose2d(256, 256, 4, stride=2, padding=1)
        self.upsample2 = nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1)
        
        # Output
        self.final_norm = nn.GroupNorm(8, 64)
        self.final_conv = nn.Conv2d(64, out_channels, 3, padding=1)
    
    def forward(self, x, t):
        # Time embedding
        t_emb = self.time_embed(t)
        
        # Initial conv
        x = self.init_conv(x)
        skip0 = x
        
        # Encoder
        x = self.down1(x, t_emb)
        skip1 = x
        x = self.downsample1(x)
        
        x = self.down2(x, t_emb)
        skip2 = x
        x = self.downsample2(x)
        
        # Bottleneck
        x = self.mid1(x, t_emb)
        x = self.mid_attn(x)
        x = self.mid2(x, t_emb)
        
        # Decoder
        x = self.upsample1(x)
        x = torch.cat([x, skip2], dim=1)
        x = self.up1(x, t_emb)
        
        x = self.upsample2(x)
        x = torch.cat([x, skip1], dim=1)
        x = self.up2(x, t_emb)
        
        # Output
        x = self.final_norm(x)
        x = F.relu(x)
        x = self.final_conv(x)
        
        return x

# Test UNet
unet = SimpleUNet().to(device)
num_params = sum(p.numel() for p in unet.parameters())
print(f"✓ UNet created with {num_params:,} parameters")

# Test forward pass
test_x = torch.randn(2, 3, 32, 32).to(device)
test_t = torch.tensor([0, 100]).to(device)
test_out = unet(test_x, test_t)
print(f"✓ Forward pass successful: {test_x.shape} -> {test_out.shape}")
