In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Set device and hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
num_epochs = 50
learning_rate = 0.0002
beta_start = 0.0001
beta_end = 0.02
timesteps = 1000
base_channels = 64
ddim_steps = [10, 50, 100]

In [None]:
# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, transform=transform, download=True
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Define Unconditional U-Net for CIFAR-10
class UNet(nn.Module):
    def __init__(self, base_channels=64):
        super(UNet, self).__init__()
        
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, base_channels, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2)
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1), nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2)
        
        self.middle = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels * 4, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1), nn.ReLU()
        )
        
        self.up1 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(base_channels * 4, base_channels * 2, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1), nn.ReLU()
        )
        
        self.up2 = nn.ConvTranspose2d(base_channels * 2, base_channels, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.ReLU()
        )
        
        self.final = nn.Conv2d(base_channels, 3, 1)
        
    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        
        m = self.middle(p2)
        
        u1 = self.up1(m)
        d1_input = torch.cat([u1, e2], dim=1)
        d1 = self.dec1(d1_input)
        
        u2 = self.up2(d1)
        d2_input = torch.cat([u2, e1], dim=1)
        d2 = self.dec2(d2_input)
        
        return torch.tanh(self.final(d2))

In [None]:
# Diffusion schedule
def get_betas(timesteps):
    betas = torch.linspace(beta_start, beta_end, timesteps)
    return betas.to(device)

def get_alphas(betas):
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return alphas, alphas_cumprod

def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    
    betas = get_betas(timesteps)
    alphas, alphas_cumprod = get_alphas(betas)
    
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod[t]).view(-1, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod[t]).view(-1, 1, 1, 1)
    
    sqrt_alphas_cumprod = sqrt_alphas_cumprod.expand_as(x_start)
    sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.expand_as(x_start)
    
    return (sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise)

In [None]:
# Training setup
model = UNet(base_channels).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

betas = get_betas(timesteps)
alphas, alphas_cumprod = get_alphas(betas)

In [None]:
# Training loop with loss tracking
def train_model():
    losses = []
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for i, (images, _) in enumerate(train_loader):
            images = images.to(device)
            batch_size = images.shape[0]
            
            t = torch.randint(0, timesteps, (batch_size,), device=device)
            
            noise = torch.randn_like(images)
            x_noisy = q_sample(images, t, noise)
            predicted_noise = model(x_noisy)
            
            loss = criterion(predicted_noise, noise)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        print(f'Epoch [{epoch + 1}/{num_epochs}], Average Loss: {avg_loss:.4f}')
    
    return losses

In [None]:
# DDIM sampling with trajectory
@torch.no_grad()
def sample_ddim(model, num_samples=1, ddim_steps=100, eta=0.0, save_trajectory=False):
    model.eval()
    x = torch.randn(num_samples, 3, 32, 32).to(device)
    
    step_size = timesteps // ddim_steps
    timesteps_subset = list(range(timesteps - 1, -1, -step_size))
    if timesteps_subset[-1] != 0:
        timesteps_subset[-1] = 0
    
    trajectory = [x.cpu().numpy()] if save_trajectory else None
    
    for i in range(len(timesteps_subset) - 1):
        t = timesteps_subset[i]
        t_next = timesteps_subset[i + 1]
        
        t_tensor = torch.full((num_samples,), t, device=device)
        
        predicted_noise = model(x)
        
        alpha_cumprod_t = alphas_cumprod[t]
        alpha_cumprod_t_next = alphas_cumprod[t_next]
        
        sigma_t = eta * torch.sqrt((1 - alpha_cumprod_t_next) / (1 - alpha_cumprod_t) * (1 - alpha_cumprod_t / alpha_cumprod_t_next))
        
        x0_t = (x - torch.sqrt(1 - alpha_cumprod_t) * predicted_noise) / torch.sqrt(alpha_cumprod_t)
        direction_pointing_to_xt = torch.sqrt(1 - alpha_cumprod_t_next - sigma_t**2) * predicted_noise
        
        x = torch.sqrt(alpha_cumprod_t_next) * x0_t + direction_pointing_to_xt
        
        if sigma_t > 0:
            x = x + sigma_t * torch.randn_like(x)
        
        if save_trajectory and (t % 100 == 0 or t == 0):
            trajectory.append(x.cpu().numpy())
    
    return x, trajectory

In [None]:
losses = train_model()

In [None]:
# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.title('Training Loss Over Time')
plt.grid(True)
plt.legend()
plt.show()

In [None]:
# Samples with different DDIM steps
fig, axes = plt.subplots(1, len(ddim_steps), figsize=(15, 5))
for i, steps in enumerate(ddim_steps):
    samples, _ = sample_ddim(model, num_samples=1, ddim_steps=steps)
    samples = samples.cpu().numpy()
    ax = axes[i]
    ax.imshow(np.transpose(samples[0], (1, 2, 0)) * 0.5 + 0.5)
    ax.set_title(f'DDIM Steps: {steps}')
    ax.axis('off')
plt.suptitle("Samples with Different DDIM Steps")
plt.tight_layout()
plt.show()

In [None]:
# Trajectory with DDIM (using 100 steps)
samples, trajectory = sample_ddim(model, num_samples=1, ddim_steps=100, save_trajectory=True)
trajectory = np.array(trajectory)

fig, axes = plt.subplots(1, len(trajectory), figsize=(15, 3))
for i in range(len(trajectory)):
    img = trajectory[i][0]
    img = np.transpose(img, (1, 2, 0)) * 0.5 + 0.5
    axes[i].imshow(img)
    axes[i].set_title(f'Step {timesteps - (i * 100)}')
    axes[i].axis('off')
plt.suptitle("Trajectory with DDIM (100 Steps)")
plt.tight_layout()
plt.show()