In [None]:
# =========================================================
# DDIM ONLY DIFFUSION MODEL â€” MNIST
# Proper U-Net + Residual Blocks + Correct DDIM Sampling
# =========================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import math
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------------------------------------
# 1. SINUSOIDAL TIME EMBEDDING
# ---------------------------------------------------------

def time_embedding(t, dim):
    half = dim // 2
    freqs = torch.exp(
        -math.log(10000) * torch.arange(half, device=t.device) / half
    )
    args = t[:, None] * freqs[None]
    return torch.cat([torch.sin(args), torch.cos(args)], dim=1)

# ---------------------------------------------------------
# 2. RESIDUAL BLOCK
# ---------------------------------------------------------

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()

        self.time_mlp = nn.Linear(time_dim, out_ch)

        self.block1 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.SiLU()
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.SiLU()
        )

        self.shortcut = (
            nn.Conv2d(in_ch, out_ch, 1)
            if in_ch != out_ch else nn.Identity()
        )

    def forward(self, x, t):
        h = self.block1(x)
        h = h + self.time_mlp(t)[:, :, None, None]
        h = self.block2(h)
        return h + self.shortcut(x)

# ---------------------------------------------------------
# 3. U-NET (FIXED)
# ---------------------------------------------------------

class UNet(nn.Module):
    def __init__(self, time_dim=128):
        super().__init__()

        self.time_mlp = nn.Sequential(
            nn.Linear(time_dim, time_dim * 4),
            nn.SiLU(),
            nn.Linear(time_dim * 4, time_dim)
        )

        self.in_conv = nn.Conv2d(1, 64, 3, padding=1)

        self.down1 = ResBlock(64, 128, time_dim)
        self.pool1 = nn.MaxPool2d(2)

        self.down2 = ResBlock(128, 256, time_dim)
        self.pool2 = nn.MaxPool2d(2)

        self.mid = ResBlock(256, 256, time_dim)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.up_block2 = ResBlock(384, 128, time_dim)  # FIXED

        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.up_block1 = ResBlock(192, 64, time_dim)   # FIXED

        self.out = nn.Conv2d(64, 1, 1)

    def forward(self, x, t):
        t = self.time_mlp(time_embedding(t, 128))

        x1 = self.in_conv(x)
        d1 = self.down1(x1, t)
        p1 = self.pool1(d1)

        d2 = self.down2(p1, t)
        p2 = self.pool2(d2)

        mid = self.mid(p2, t)

        u2 = self.up2(mid)
        u2 = self.up_block2(torch.cat([u2, d2], dim=1), t)

        u1 = self.up1(u2)
        u1 = self.up_block1(torch.cat([u1, d1], dim=1), t)

        return self.out(u1)

# ---------------------------------------------------------
# 4. DIFFUSION (DDIM ONLY)
# ---------------------------------------------------------

class Diffusion:
    def __init__(self, T=1000):
        self.T = T
        self.beta = torch.linspace(1e-4, 0.02, T).to(device)
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

    def add_noise(self, x0, t):
        noise = torch.randn_like(x0)
        a_bar = self.alpha_bar[t][:, None, None, None]
        return torch.sqrt(a_bar) * x0 + torch.sqrt(1 - a_bar) * noise, noise

    @torch.no_grad()
    def ddim_sample(self, model, steps=50, n=64):
        times = np.linspace(0, self.T - 1, steps).astype(int)
        x = torch.randn(n, 1, 28, 28).to(device)

        for i in reversed(range(1, steps)):
            t = torch.full((n,), times[i], device=device)
            t_prev = torch.full((n,), times[i - 1], device=device)

            eps = model(x, t)

            a = self.alpha_bar[t][:, None, None, None]
            a_prev = self.alpha_bar[t_prev][:, None, None, None]

            x0 = (x - torch.sqrt(1 - a) * eps) / torch.sqrt(a)
            x = torch.sqrt(a_prev) * x0 + torch.sqrt(1 - a_prev) * eps

        return torch.clamp(x, -1, 1)

# ---------------------------------------------------------
# 5. DATA
# ---------------------------------------------------------

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
loader = torch.utils.data.DataLoader(
    dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True
)

# ---------------------------------------------------------
# 6. TRAINING
# ---------------------------------------------------------

model = UNet().to(device)
diffusion = Diffusion()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

epochs = 20

for epoch in range(epochs):
    total_loss = 0
    for x, _ in loader:
        x = x.to(device)
        t = torch.randint(0, diffusion.T, (x.size(0),), device=device)

        x_noisy, noise = diffusion.add_noise(x, t)
        pred_noise = model(x_noisy, t)

        loss = F.mse_loss(pred_noise, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Loss: {total_loss / len(loader):.5f}")

# ---------------------------------------------------------
# 7. DDIM SAMPLING
# ---------------------------------------------------------

samples = diffusion.ddim_sample(model, steps=50, n=64)

plt.figure(figsize=(6,6))
for i in range(36):
    plt.subplot(6,6,i+1)
    plt.imshow(samples[i][0].cpu(), cmap="gray")
    plt.axis("off")
plt.show()