<a href="https://colab.research.google.com/github/chwon9-jpg/Diffusion_models/blob/main/CIFAR_10_Diffusion_model_2_reproducible.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ----------------------------
# Required command
# ----------------------------


!pip install deepinv

In [None]:
# ----------------------------
# Required libraries
# ----------------------------

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import os
import time
import math

import torchvision.transforms as transforms, torchvision, matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split

In [None]:
# ----------------------------
# U-Net architecture
# ----------------------------

# ----------------------------
# Sinusoidal Positional Embedding used in Transformers (used for timestep embedding)
# ----------------------------
class SinusoidalPositionalEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None].float() * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


# ----------------------------
# Self-Attention Block
# ----------------------------
class SelfAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.proj_query = nn.Conv2d(channels, channels // 8, kernel_size=1)
        self.proj_key = nn.Conv2d(channels, channels // 8, kernel_size=1)
        self.proj_value = nn.Conv2d(channels, channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        # 1x1 convolutional layers act as linear transformations
        # Weights are randomly initialized just like the convolution layers used in the encoder/decoder pathway

    def forward(self, x):
        B, C, H, W = x.size()  # Batch_size, Channels, Height, Width

        query = self.proj_query(x).view(B, -1, H * W).permute(0, 2, 1)  # B, HW, C/8
        # HW flattens the spatial dimensions into a single dimension where each HW in the feature map
        # corresponds to an index in this flattened dimension space
        # Each image in the batch, we have HW query vectors, each of dimension C/8

        key = self.proj_key(x).view(B, -1, H * W)                      # B, C/8, HW
        # Each image in the batch, we have C/8 feature maps of dimension HW

        """
        query groups features by spatial location WHEREAS key groups features by channel
        """

        energy = torch.bmm(query, key)  # Transposed B, HW,
        # batch matrix multiplication, energy = attention scores,
        # indicating how much each position should attend to every other position

        attention = F.softmax(energy, dim=-1)
        # transforms a vector of numbers into a probability distribution
        # where each value represents the likelihood of a particular class.

        value = self.proj_value(x).view(B, -1, H * W)                   # B, C, HW

        out = torch.bmm(value, attention.permute(0, 2, 1)) # B, C, HW
        out = out.view(B, C, H, W)

        out = self.gamma * out + x # + x is the residual connection (skip-connection)
        # helps prevent vanishing gradient problem
        return out


# ----------------------------
# Simple UNet for MNIST Diffusion
# ----------------------------
class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, hidden_dim=32):
        super().__init__()

        # Encoder
        self.enc1 = nn.Conv2d(in_channels, hidden_dim, 3, padding=1)
        self.norm_enc1 = nn.GroupNorm(8, hidden_dim)

        self.enc2 = nn.Conv2d(hidden_dim, hidden_dim * 2, 3, padding=1)
        self.norm_enc2 = nn.GroupNorm(8, hidden_dim * 2)
        self.attn0_5 = SelfAttention(hidden_dim * 2)

        self.enc3 = nn.Conv2d(hidden_dim * 2, hidden_dim * 4, 3, padding=1)
        self.norm_enc3 = nn.GroupNorm(8, hidden_dim * 4)
        self.attn0 = SelfAttention(hidden_dim * 4)

        self.enc4 = nn.Conv2d(hidden_dim * 4, hidden_dim * 8, 3, padding=1)
        self.norm_enc4 = nn.GroupNorm(8, hidden_dim * 8)


        # Attention after deepest encoder
        self.attn1 = SelfAttention(hidden_dim * 8)

        # Bottleneck
        self.bottleneck = nn.Conv2d(hidden_dim * 8, hidden_dim * 8, 3, padding=1)
        self.norm_bottleneck = nn.GroupNorm(8, hidden_dim * 8)

        # Attention after bottleneck
        self.attn2 = SelfAttention(hidden_dim * 8)


        # Decoder
        self.dec4 = nn.Conv2d(hidden_dim * 8 + hidden_dim * 4, hidden_dim * 4, 3, padding=1)
        self.norm_dec4 = nn.GroupNorm(8, hidden_dim * 4)

        self.dec3 = nn.Conv2d(hidden_dim * 4 + hidden_dim * 2, hidden_dim * 2, 3, padding=1)
        self.norm_dec3 = nn.GroupNorm(8, hidden_dim * 2)

        self.dec2 = nn.Conv2d(hidden_dim * 2 + hidden_dim, hidden_dim, 3, padding=1)
        self.norm_dec2 = nn.GroupNorm(8, hidden_dim)

        self.attn3 = SelfAttention(hidden_dim)

        self.dec1 = nn.Conv2d(hidden_dim + in_channels, out_channels, 3, padding=1)




        # Timestep embedding: Using Sinusoidal Positional Embedding + MLP
        time_embedding_dim = hidden_dim * 4 # Intermediate dim for sinusoidal embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionalEmbedding(dim=time_embedding_dim),
            nn.Linear(time_embedding_dim, hidden_dim * 8), # Project to final bottleneck dim
            nn.SiLU(),
            nn.Linear(hidden_dim * 8, hidden_dim * 8)
        )

    def forward(self, x, t, type_t="timestep"):
        # Process timestep with positional embedding and MLP
        t_emb = self.time_mlp(t) # Output: [B, hidden_dim * 8]
        # Unsqueeze to match the spatial dimensions of the feature maps
        t_emb = t_emb.unsqueeze(-1).unsqueeze(-1)  # [B, hidden_dim * 8, 1, 1]

        # Encoder (Downsampling)
        e1 = F.silu(self.norm_enc1(self.enc1(x)))                     # Conv -> Norm -> SiLU
        p1 = F.max_pool2d(e1, 2) # both kernel filter size and stride are 2

        e2_block_out = F.silu(self.norm_enc2(self.enc2(p1)))                   # Conv -> Norm -> SiLU

        e2_attn_out = self.attn0_5(e2_block_out)          # Self-Attention -1

        p2 = F.max_pool2d(e2_block_out, 2) # Pool applied before attention output, as in original

        e3_block_out = F.silu(self.norm_enc3(self.enc3(p2)))                   # Conv -> Norm -> SiLU

        e3_attn_out = self.attn0(e3_block_out)                      # Self-Attention 0

        p3 = F.max_pool2d(e3_attn_out, 2)

        e4_block_out = F.silu(self.norm_enc4(self.enc4(p3)))                   # Conv -> Norm -> SiLU
        # Self - Attention 1
        e4_attn_out = self.attn1(e4_block_out)

        # Bottleneck + timestep (MidBlock)
        b_block_out = F.silu(self.norm_bottleneck(self.bottleneck(e4_attn_out) + t_emb))       # Conv -> Add T_emb -> Norm -> SiLU
        # 4 x 4 spatial dimension

        # Self - Attention 2
        b_attn_out = self.attn2(b_block_out)

        # Decoder (Upsampling)
        u4 = F.interpolate(b_attn_out, scale_factor=2, mode="nearest") # Spatial dimensions are upscaled by a factor of 2
        # Skip-connection uses e3_block_out (output of enc3 before its attention or pooling)
        d4 = F.silu(self.norm_dec4(self.dec4(torch.cat([u4, e3_block_out], dim=1)))) # Skip-connection -> Conv -> Norm -> SiLU

        u3 = F.interpolate(d4, scale_factor=2, mode="nearest")
        # Skip-connection uses e2_block_out (output of enc2 before its attention or pooling)
        d3 = F.silu(self.norm_dec3(self.dec3(torch.cat([u3, e2_block_out], dim=1)))) # Skip-connection -> Conv -> Norm -> SiLU

        u2 = F.interpolate(d3, scale_factor=2, mode="nearest")
        # Skip-connection uses e1 (output of enc1 before pooling)
        d2_block_out = F.silu(self.norm_dec2(self.dec2(torch.cat([u2, e1], dim=1)))) # Skip-connection -> Conv -> Norm -> SiLU

        d2_attn_out = self.attn3(d2_block_out)


        # Final output (no activation or normalization after the last layer)
        out = self.dec1(torch.cat([d2_attn_out, x], dim=1))


        return out

In [None]:
class EarlyStopping:
    def __init__(self, patience, delta):
        self.patience = patience
        self.delta = delta
        self.best_score = None # Track the best validation score
        self.early_stop = False
        self.counter = 0
        self.best_model_state = None # Track the best model state

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        model.load_state_dict(self.best_model_state)

In [None]:
# ----------------------------
# Setup
# ----------------------------

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}\n")

# Hyperparameters
batch_size = 128
image_size = 32
lr = 1e-3
epochs = 60
timesteps = 1000
beta_start = 1e-4
beta_end = 0.02

# Data
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)), # Normalize to [-1, 1]
])

print("Loading CIFAR-10 dataset...")
full_trainset = torchvision.datasets.CIFAR10(root='./data',
                                        train=True,
                                        download=True,
                                        transform=transform)

# Split into training and validation sets
train_size = int(0.8 * len(full_trainset)) # 80% for training
val_size = len(full_trainset) - train_size  # Remaining for validation
trainset, valset = random_split(full_trainset, [train_size, val_size])

trainloader = DataLoader(trainset, batch_size, shuffle=True)
valloader = DataLoader(valset, batch_size, shuffle=False) # Use a separate dataloader for validation dataset

print(f"Dataset loaded. Training batches: {len(trainloader)}, Validation batches: {len(valloader)}\n")

# Model, optimizer, loss
model = SimpleUNet(in_channels=3, out_channels=3, hidden_dim=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
mse = nn.MSELoss()
early_stopping = EarlyStopping(patience=10, delta=0.001) # Added delta for a small change threshold
# Early stopping is triggered if for 5 consecutive epochs, the improvement is not at least 0.001
# compared to the best validation loss recorded so far.

# Precompute noise schedule
betas = torch.linspace(beta_start, beta_end, timesteps, device=device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)


# Print model size
num_params = sum(p.numel() for p in model.parameters())
print(f"Model: SimpleUNet | Parameters: {num_params:,}\n")

In [None]:
# ----------------------------
# Training Loop
# ----------------------------
print("Starting training...\n")
total_start_time = time.time()

for epoch in range(epochs):
    epoch_start_time = time.time()
    model.train()
    total_train_loss = 0.0

    for data, _ in trainloader:
        imgs = data.to(device)
        noise = torch.randn_like(imgs)
        t = torch.randint(0, timesteps, (imgs.size(0),), device=device)

        # Add noise
        noised_imgs = (
            sqrt_alphas_cumprod[t, None, None, None] * imgs
            + sqrt_one_minus_alphas_cumprod[t, None, None, None] * noise
        )

        optimizer.zero_grad() # Avoid Gradient Accumulation
        predicted_noise = model(noised_imgs, t) # Predict noise using U-net Model as mentioned before
        loss = mse(predicted_noise, noise)
        # Backprop + Update model params
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(trainloader)

    # Validation loop
    model.eval() # Set model to evaluation mode
    total_val_loss = 0.0
    with torch.no_grad(): # Disable gradient calculations
        for data, _ in valloader:
            imgs = data.to(device)
            noise = torch.randn_like(imgs)
            t = torch.randint(0, timesteps, (imgs.size(0),), device=device)

            noised_imgs = (
                sqrt_alphas_cumprod[t, None, None, None] * imgs
                + sqrt_one_minus_alphas_cumprod[t, None, None, None] * noise
            )
            predicted_noise = model(noised_imgs, t)
            val_loss = mse(predicted_noise, noise)
            total_val_loss += val_loss.item()

    avg_val_loss = total_val_loss / len(valloader)

    epoch_time = time.time() - epoch_start_time
    total_elapsed = time.time() - total_start_time

    print(f"Epoch {epoch + 1}/{epochs} | "
          f"Avg Train Loss: {avg_train_loss:.6f} | "
          f"Avg Val Loss: {avg_val_loss:.6f} | "
          f"Epoch Time: {epoch_time:.2f}s | "
          f"Total Time: {total_elapsed:.2f}s")

    early_stopping(avg_val_loss, model) # Pass the validation loss
    if early_stopping.early_stop:
        print("Early stopping triggered!")
        break


early_stopping.load_best_model(model) # Load the best model weights found during training

# ----------------------------
# Save Model
# ----------------------------

os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), "models/simple_diffusion_model.pth")
print("\nTraining finished!")
print("Model saved to models/simple_diffusion_model.pth")

In [None]:
# ----------------------------
# Sampling Generation
# ----------------------------


model.eval()  # Set to evaluation mode
num_samples = 32
img_size = 32

# Start with pure noise
x_t = torch.randn(num_samples, 3, img_size, img_size, device=device) # Changed to 3 channels

with torch.no_grad():
    for t in reversed(range(timesteps)):
        t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long)

        # Predict noise ε_θ(x_t, t)
        predicted_noise = model(x_t, t_batch)

        # Compute mean and variance for reverse step
        alpha_t = alphas[t]
        alpha_bar_t = alphas_cumprod[t]
        alpha_bar_t_prev = alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0, device=device)
        beta_t = betas[t]

        # Denoise: x_{t-1} = 1/sqrt(alpha_t) * (x_t - beta_t/sqrt(1-alpha_bar_t) * predicted_noise) + sigma_t * z
        x_0_pred = (x_t - sqrt_one_minus_alphas_cumprod[t] * predicted_noise) / sqrt_alphas_cumprod[t]
        x_0_pred = torch.clamp(x_0_pred, -1, 1)  # Clip to [-1,1] for stability

        mean = x_t - ((beta_t * predicted_noise) / (sqrt_one_minus_alphas_cumprod[t]))
        mean = mean / torch.sqrt(alpha_t)

        if t == 0:
            # final step: don't add noise, just use the predicted x_0
            x_t = x_0_pred
        else:
            variance = (1 - alpha_bar_t_prev) / (1 - alpha_bar_t)
            variance = variance * beta_t  # still unused, but kept to match your original
            sigma_t = variance ** 0.5
            z = torch.randn_like(x_t)  # same shape & device as x_t
            x_t = mean + sigma_t * z   # only a single tensor, no tuple


# Post-process: convert to [0,1] and detach
generated_images = (x_t.clamp(-1, 1) + 1) / 2.0  # From [-1,1] to [0,1]

# Plot generated images
import matplotlib.pyplot as plt

# fig, axes = plt.subplots(4, 4, figsize=(8, 8))
fig, axes = plt.subplots(4, 8, figsize=(12, 12))  # 4 rows, 8 columns
axes = axes.flatten()

for i in range(num_samples):
    ax = axes[i]
    ax.imshow(generated_images[i].cpu().permute(1, 2, 0)) # Permute for matplotlib, removed cmap='gray'
    ax.axis('off')

plt.tight_layout()
plt.show()

print("Generated 32 new CIFAR-10-like images!")

In [None]:
# ----------------------
# Real vs Generated samples comparison
# ----------------------


# Show real CIFAR-10 for reference
real_batch = next(iter(trainloader))[0][:32].cpu() # Changed train_loader to trainloader
real_batch = (real_batch + 1) / 2.0  # Only if you used Normalize((0.5,), (0.5,))

# Plot real vs generated side-by-side
fig, axes = plt.subplots(8, 8, figsize=(12, 12))
axes = axes.flatten()

for i in range(32):
    # Real
    ax = axes[i*2]
    ax.imshow(real_batch[i].permute(1, 2, 0)) # Permute for matplotlib, removed cmap='gray'
    ax.axis('off')
    # Generated
    ax = axes[i*2 + 1]
    ax.imshow(generated_images[i].cpu().permute(1, 2, 0)) # Permute for matplotlib, removed cmap='gray'
    ax.axis('off')

plt.tight_layout()
plt.show()