In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
def linear_beta_schedule(timesteps):
    return torch.linspace(1e-4, 0.02, timesteps)

T = 1000
betas = linear_beta_schedule(T)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, base_channels=64):
        super().__init__()
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels + 1, base_channels, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1), nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2)

        self.mid = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels * 4, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1), nn.ReLU()
        )

        self.up1 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(base_channels * 4, base_channels * 2, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1), nn.ReLU()
        )

        self.up2 = nn.ConvTranspose2d(base_channels * 2, base_channels, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.ReLU()
        )

        self.out = nn.Conv2d(base_channels, in_channels, 1)

    def forward(self, x, t):
        t = t[:, None, None, None].float() / T
        t = t.expand(-1, 1, x.size(2), x.size(3))
        x = torch.cat([x, t], dim=1)
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        m = self.mid(self.pool2(e2))
        d1 = self.up1(m)
        d1 = self.dec1(torch.cat([d1, e2], dim=1))
        d2 = self.up2(d1)
        d2 = self.dec2(torch.cat([d2, e1], dim=1))
        return self.out(d2)

In [None]:
def q_sample(x_0, t, noise):
    return sqrt_alphas_cumprod[t][:, None, None, None] * x_0 + sqrt_one_minus_alphas_cumprod[t][:, None, None, None] * noise

def train_ddim(model, loader, optimizer, epochs=10, device='cuda'):
    model.train()
    train_losses = []
    for epoch in range(epochs):
        total_loss = 0
        for x, _ in tqdm(loader, desc=f"Epoch {epoch+1}"):
            x = x.to(device)
            t = torch.randint(0, T, (x.size(0),), device=device).long()
            noise = torch.randn_like(x)
            x_t = q_sample(x, t, noise)
            pred = model(x_t, t)
            loss = F.mse_loss(pred, noise)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
        avg_loss = total_loss / len(loader.dataset)
        train_losses.append(avg_loss)
        print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}")
    return train_losses


In [None]:
@torch.no_grad()
def ddim_sample(model, steps=50, eta=0.0):
    device = next(model.parameters()).device
    x = torch.randn((16, 3, 32, 32)).to(device)
    trajectory = []

    times = torch.linspace(T-1, 0, steps, dtype=torch.long)
    for i in range(steps):
        t = times[i].expand(x.size(0)).to(device)
        noise_pred = model(x, t)

        alpha_bar = alphas_cumprod[t].to(device)[:, None, None, None]
        sqrt_ab = torch.sqrt(alpha_bar)
        sqrt_one_minus_ab = torch.sqrt(1 - alpha_bar)

        x0_pred = (x - sqrt_one_minus_ab * noise_pred) / sqrt_ab
        if i < steps - 1:
            t_next = times[i+1].expand(x.size(0)).to(device)
            alpha_bar_next = alphas_cumprod[t_next].to(device)[:, None, None, None]
            sigma = eta * torch.sqrt((1 - alpha_bar / alpha_bar_next) * (1 - alpha_bar_next) / (1 - alpha_bar))
            noise = torch.randn_like(x) if eta > 0 else 0
            x = torch.sqrt(alpha_bar_next) * x0_pred + torch.sqrt(1 - alpha_bar_next - sigma**2) * noise_pred + sigma * noise
        else:
            x = x0_pred
        if i % 10 == 0 or i == steps - 1:
            trajectory.append(x.cpu().clone())
    return x.clamp(-1, 1).cpu(), trajectory


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x * 2. - 1.)
])
dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=128, shuffle=True)

In [None]:
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

In [None]:
losses = train_ddim(model, loader, optimizer, epochs=10, device=device)

In [None]:
plt.figure(figsize=(6, 4))
plt.plot(losses, marker='o')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss (MSE)")
plt.grid(True)
plt.show()

In [None]:
samples, traj = ddim_sample(model, steps=50, eta=0.0)
grid = make_grid(samples, nrow=4, normalize=True)
plt.figure(figsize=(6, 6))
plt.imshow(grid.permute(1, 2, 0))
plt.axis('off')
plt.title("DDIM Samples (eta=0)")
plt.show()

In [None]:
# Visualize trajectory of the first sample
plt.figure(figsize=(15, 3))
for i, img in enumerate(traj):
    plt.subplot(1, len(traj), i + 1)
    plt.imshow((img[0].permute(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.title(f"t={50 - i*10}")
plt.suptitle("Trajectory of 1 Sample Through DDIM")
plt.show()
