In [None]:
import math
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms, utils
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### Baseline: A simple diffusion model on MNIST

In [None]:
# This UNET-style prediction model was originally included as part of the Score-based generative modelling tutorial
# by Yang Song et al: https://colab.research.google.com/drive/120kYYBOVa1i0TD85RjlEkFjaWDxSFUx3?usp=sharing


class GaussianFourierProjection(nn.Module):
    """Gaussian random features for encoding time steps."""

    def __init__(self, embed_dim, scale=30.0):
        super().__init__()
        # Randomly sample weights during initialization. These weights are fixed
        # during optimization and are not trainable.
        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)

    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class Dense(nn.Module):
    """A fully connected layer that reshapes outputs to feature maps."""

    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.dense(x)[..., None, None]


class ScoreNet(nn.Module):
    """A time-dependent score-based model built upon U-Net architecture."""

    def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
        """Initialize a time-dependent score-based network.

        Args:
          marginal_prob_std: A function that takes time t and gives the standard
            deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
          channels: The number of channels for feature maps of each resolution.
          embed_dim: The dimensionality of Gaussian random feature embeddings.
        """
        super().__init__()
        # Gaussian random feature embedding layer for time
        self.embed = nn.Sequential(
            GaussianFourierProjection(embed_dim=embed_dim),
            nn.Linear(embed_dim, embed_dim),
        )
        # Encoding layers where the resolution decreases
        self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
        self.dense1 = Dense(embed_dim, channels[0])
        self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
        self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
        self.dense2 = Dense(embed_dim, channels[1])
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
        self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
        self.dense3 = Dense(embed_dim, channels[2])
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
        self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
        self.dense4 = Dense(embed_dim, channels[3])
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])

        # Decoding layers where the resolution increases
        self.tconv4 = nn.ConvTranspose2d(
            channels[3], channels[2], 3, stride=2, bias=False
        )
        self.dense5 = Dense(embed_dim, channels[2])
        self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
        self.tconv3 = nn.ConvTranspose2d(
            channels[2] + channels[2],
            channels[1],
            3,
            stride=2,
            bias=False,
            output_padding=1,
        )
        self.dense6 = Dense(embed_dim, channels[1])
        self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
        self.tconv2 = nn.ConvTranspose2d(
            channels[1] + channels[1],
            channels[0],
            3,
            stride=2,
            bias=False,
            output_padding=1,
        )
        self.dense7 = Dense(embed_dim, channels[0])
        self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
        self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)

        # The swish activation function
        self.act = lambda x: x * torch.sigmoid(x)
        self.marginal_prob_std = marginal_prob_std

    def forward(self, x, t):
        # Obtain the Gaussian random feature embedding for t
        embed = self.act(self.embed(t))
        # Encoding path
        h1 = self.conv1(x)
        ## Incorporate information from t
        h1 += self.dense1(embed)
        ## Group normalization
        h1 = self.gnorm1(h1)
        h1 = self.act(h1)
        h2 = self.conv2(h1)
        h2 += self.dense2(embed)
        h2 = self.gnorm2(h2)
        h2 = self.act(h2)
        h3 = self.conv3(h2)
        h3 += self.dense3(embed)
        h3 = self.gnorm3(h3)
        h3 = self.act(h3)
        h4 = self.conv4(h3)
        h4 += self.dense4(embed)
        h4 = self.gnorm4(h4)
        h4 = self.act(h4)

        # Decoding path
        h = self.tconv4(h4)
        ## Skip connection from the encoding path
        h += self.dense5(embed)
        h = self.tgnorm4(h)
        h = self.act(h)
        h = self.tconv3(torch.cat([h, h3], dim=1))
        h += self.dense6(embed)
        h = self.tgnorm3(h)
        h = self.act(h)
        h = self.tconv2(torch.cat([h, h2], dim=1))
        h += self.dense7(embed)
        h = self.tgnorm2(h)
        h = self.act(h)
        h = self.tconv1(torch.cat([h, h1], dim=1))

        # Normalize output
        h = h / self.marginal_prob_std(t)[:, None, None, None]
        return h

In [None]:
# ExponentialMovingAverage implementation as used in pytorch vision
# https://github.com/pytorch/vision/blob/main/references/classification/utils.py#L159

# BSD 3-Clause License

# Copyright (c) Soumith Chintala 2016,
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.

# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.

# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
    """Maintains moving averages of model parameters using an exponential decay.
    ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
    `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
    is used to compute the EMA.
    """

    def __init__(self, model, decay, device="cpu"):
        def ema_avg(avg_model_param, model_param, num_averaged):
            return decay * avg_model_param + (1 - decay) * model_param

        super().__init__(model, device, ema_avg, use_buffers=True)

In [None]:
class DDPM(nn.Module):

    def __init__(self, network, T=100, beta_1=1e-4, beta_T=2e-2):
        """
        Initialize Denoising Diffusion Probabilistic Model

        Parameters
        ----------
        network: nn.Module
            The inner neural network used by the diffusion process. Typically a Unet.
        beta_1: float
            beta_t value at t=1
        beta_T: [float]
            beta_t value at t=T (last step)
        T: int
            The number of diffusion steps.
        """

        super(DDPM, self).__init__()

        # Normalize time input before evaluating neural network
        # Reshape input into image format and normalize time value before sending it to network model
        self._network = network
        self.network = lambda x, t: (
            self._network(x.reshape(-1, 1, 28, 28), (t.squeeze() / T))
        ).reshape(-1, 28 * 28)

        # Total number of time steps
        self.T = T

        # Registering as buffers to ensure they get transferred to the GPU automatically
        self.register_buffer("beta", torch.linspace(beta_1, beta_T, T + 1))
        self.register_buffer("alpha", 1 - self.beta)
        self.register_buffer("alpha_bar", self.alpha.cumprod(dim=0))

    def forward_diffusion(self, x0, t, epsilon):
        """
        q(x_t | x_0)
        Forward diffusion from an input datapoint x0 to an xt at timestep t, provided a N(0,1) noise sample epsilon.
        Note that we can do this operation in a single step

        Parameters
        ----------
        x0: torch.tensor
            x value at t=0 (an input image)
        t: int
            step index
        epsilon:
            noise sample

        Returns
        -------
        torch.tensor
            image at timestep t
        """

        mean = torch.sqrt(self.alpha_bar[t]) * x0
        std = torch.sqrt(1 - self.alpha_bar[t])

        return mean + std * epsilon

    def reverse_diffusion(self, xt, t, epsilon):
        """
        p(x_{t-1} | x_t)
        Single step in the reverse direction, from x_t (at timestep t) to x_{t-1}, provided a N(0,1) noise sample epsilon.

        Parameters
        ----------
        xt: torch.tensor
            x value at step t
        t: int
            step index
        epsilon:
            noise sample

        Returns
        -------
        torch.tensor
            image at timestep t-1
        """

        mean = (
            1.0
            / torch.sqrt(self.alpha[t])
            * (
                xt
                - (self.beta[t])
                / torch.sqrt(1 - self.alpha_bar[t])
                * self.network(xt, t)
            )
        )
        std = torch.where(
            t > 0,
            torch.sqrt(
                ((1 - self.alpha_bar[t - 1]) / (1 - self.alpha_bar[t])) * self.beta[t]
            ),
            0,
        )

        return mean + std * epsilon

    @torch.no_grad()
    def sample(self, shape):
        """
        Sample from diffusion model (Algorithm 2 in Ho et al, 2020)

        Parameters
        ----------
        shape: tuple
            Specify shape of sampled output. For MNIST: (nsamples, 28*28)

        Returns
        -------
        torch.tensor
            sampled image
        """

        # Sample xT: Gaussian noise
        xT = torch.randn(shape).to(self.beta.device)

        xt = xT
        for t in range(self.T, 0, -1):
            noise = torch.randn_like(xT) if t > 1 else 0
            t = torch.tensor(t).expand(xt.shape[0], 1).to(self.beta.device)
            xt = self.reverse_diffusion(xt, t, noise)

        return xt

    def elbo_simple(self, x0):
        """
        ELBO training objective (Algorithm 1 in Ho et al, 2020)

        Parameters
        ----------
        x0: torch.tensor
            Input image

        Returns
        -------
        float
            ELBO value
        """

        # Sample time step t
        t = torch.randint(1, self.T, (x0.shape[0], 1)).to(x0.device)

        # Sample noise
        epsilon = torch.randn_like(x0)

        # TODO: Forward diffusion to produce image at step t
        xt = self.forward_diffusion(x0, t, epsilon)

        return -nn.MSELoss(reduction="mean")(epsilon, self.network(xt, t))

    def loss(self, x0):
        """
        Loss function. Just the negative of the ELBO.
        """
        return -self.elbo_simple(x0).mean()


def train(
    model,
    optimizer,
    scheduler,
    dataloader,
    epochs,
    device,
    ema=True,
    per_epoch_callback=None,
):
    """
    Training loop

    Parameters
    ----------
    model: nn.Module
        Pytorch model
    optimizer: optim.Optimizer
        Pytorch optimizer to be used for training
    scheduler: optim.LRScheduler
        Pytorch learning rate scheduler
    dataloader: utils.DataLoader
        Pytorch dataloader
    epochs: int
        Number of epochs to train
    device: torch.device
        Pytorch device specification
    ema: Boolean
        Whether to activate Exponential Model Averaging
    per_epoch_callback: function
        Called at the end of every epoch
    """

    # Setup progress bar
    total_steps = len(dataloader) * epochs
    progress_bar = tqdm(range(total_steps), desc="Training")

    loss_history = []

    if ema:
        ema_global_step_counter = 0
        ema_steps = 10
        ema_adjust = dataloader.batch_size * ema_steps / epochs
        ema_decay = 1.0 - 0.995
        ema_alpha = min(1.0, (1.0 - ema_decay) * ema_adjust)
        ema_model = ExponentialMovingAverage(
            model, device=device, decay=1.0 - ema_alpha
        )

    for epoch in range(epochs):

        # Switch to train mode
        model.train()

        global_step_counter = 0
        for i, (x, _) in enumerate(dataloader):
            x = x.to(device)
            optimizer.zero_grad()
            loss = model.loss(x)
            loss.backward()
            optimizer.step()
            scheduler.step()

            loss_history.append(loss.item())

            # Update progress bar
            progress_bar.set_postfix(
                loss=f"⠀{loss.item():12.4f}",
                epoch=f"{epoch+1}/{epochs}",
                lr=f"{scheduler.get_last_lr()[0]:.2E}",
            )
            progress_bar.update()

            if ema:
                ema_global_step_counter += 1
                if ema_global_step_counter % ema_steps == 0:
                    ema_model.update_parameters(model)

        if per_epoch_callback:
            per_epoch_callback(ema_model.module if ema else model)

    return loss_history


# Parameters
T = 1000
learning_rate = 1e-3
epochs = 100
batch_size = 256


# Rather than treating MNIST images as discrete objects, as done in Ho et al 2020,
# we here treat them as continuous input data, by dequantizing the pixel values (adding noise to the input data)
# Also note that we map the 0..255 pixel values to [-1, 1], and that we process the 28x28 pixel values as a flattened 784 tensor.
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Lambda(
            lambda x: x + torch.rand(x.shape) / 255
        ),  # Dequantize pixel values
        transforms.Lambda(lambda x: (x - 0.5) * 2.0),  # Map from [0,1] -> [-1, -1]
        transforms.Lambda(lambda x: x.flatten()),
    ]
)

# Download and transform train dataset
dataloader_train = torch.utils.data.DataLoader(
    datasets.MNIST("./mnist_data", download=True, train=True, transform=transform),
    batch_size=batch_size,
    shuffle=True,
)

# Select device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Construct Unet
# The original ScoreNet expects a function with std for all the
# different noise levels, such that the output can be rescaled.
# Since we are predicting the noise (rather than the score), we
# ignore this rescaling and just set std=1 for all t.
mnist_unet = ScoreNet((lambda t: torch.ones(1).to(device)))

# Construct model
model = DDPM(mnist_unet, T=T).to(device)

# Construct optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Setup simple scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)


def reporter(model):
    """Callback function used for plotting images during training"""

    # Switch to eval mode
    model.eval()

    with torch.no_grad():
        nsamples = 10
        samples = model.sample((nsamples, 28 * 28)).cpu()

        # Map pixel values back from [-1,1] to [0,1]
        samples = (samples + 1) / 2
        samples = samples.clamp(0.0, 1.0)

        # Plot in grid
        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=nsamples)
        plt.gca().set_axis_off()
        plt.imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        plt.show()


# Call training loop
train(
    model,
    optimizer,
    scheduler,
    dataloader_train,
    epochs=epochs,
    device=device,
    ema=True,
    per_epoch_callback=reporter,
)

### Direction 1: Converting the MNIST DDPM template to Flow Matching

#### Implementation

In [None]:
class FlowMatchingModel(nn.Module):
    """
    Simple Flow Matching model using the conditional OT / linear Gaussian path.
    """
    def __init__(self, network):
        super().__init__()
        self.network = network  
        
    def velocity(self, x_flat, t):
        if t.ndim == 2:
            t_in = t.squeeze(1)
        else:
            t_in = t
        v = self.network(x_flat.reshape(-1, 1, 28, 28), t_in).reshape(-1, 28*28)
        return v

    def loss(self, x1_flat):
        B = x1_flat.shape[0]
        device = x1_flat.device

        x0 = torch.randn_like(x1_flat)
        t = torch.rand(B, 1, device=device)
        xt = (1.0 - t) * x0 + t * x1_flat
        v_target = x1_flat - x0
        v_pred = self.velocity(xt, t)

        return F.mse_loss(v_pred, v_target, reduction="mean")

    @torch.no_grad()
    def sample(self, nsamples, n_steps=50):
        device = next(self.parameters()).device
        x = torch.randn(nsamples, 28*28, device=device)

        t_grid = torch.linspace(0.0, 1.0, n_steps + 1, device=device)
        for k in range(n_steps):
            t_k = t_grid[k].expand(nsamples, 1)
            dt = (t_grid[k+1] - t_grid[k]).item()
            x = x + dt * self.velocity(x, t_k)

        return x

fm_unet = ScoreNet((lambda t: torch.ones(1).to(device))).to(device)

fm_model = FlowMatchingModel(fm_unet).to(device)

fm_optimizer = torch.optim.Adam(fm_model.parameters(), lr=learning_rate)
fm_scheduler = torch.optim.lr_scheduler.ExponentialLR(fm_optimizer, 0.9999)

def train_flow_matching(model, optimizer, scheduler, dataloader, epochs, device, per_epoch_callback=None):
    total_steps = len(dataloader) * epochs
    progress_bar = tqdm(range(total_steps), desc="Training (Flow Matching)")

    loss_history = []

    for epoch in range(epochs):
        model.train()
        for x1, _ in dataloader:
            x1 = x1.to(device)
            optimizer.zero_grad()
            loss = model.loss(x1)
            loss.backward()
            optimizer.step()
            scheduler.step()

            loss_history.append(loss.item())

            progress_bar.set_postfix(loss=f"⠀{loss.item():12.4f}",
                                     epoch=f"{epoch+1}/{epochs}",
                                     lr=f"{scheduler.get_last_lr()[0]:.2E}")
            progress_bar.update()

        if per_epoch_callback:
            per_epoch_callback(model)

    return loss_history

def reporter_flow_matching(model):
    model.eval()
    with torch.no_grad():
        nsamples = 10
        samples = model.sample(nsamples=nsamples, n_steps=50).cpu()
        samples = (samples + 1) / 2
        samples = samples.clamp(0.0, 1.0)

        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=nsamples)
        plt.gca().set_axis_off()
        plt.imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        plt.show()

#### Comparison with DDPM

In [None]:
torch.manual_seed(0)
epochs_full = epochs
batch_size_full = batch_size

dataloader_full = torch.utils.data.DataLoader(
    datasets.MNIST("./mnist_data", download=True, train=True, transform=transform),
    batch_size=batch_size_full,
    shuffle=True,
)

# DDPM training
ddpm_full = DDPM(ScoreNet(lambda t: torch.ones(1).to(device)), T=T).to(device)
opt_d_full = torch.optim.Adam(ddpm_full.parameters(), lr=learning_rate)
sch_d_full = torch.optim.lr_scheduler.ExponentialLR(opt_d_full, 0.9999)

loss_ddpm_full = train(
    ddpm_full,
    opt_d_full,
    sch_d_full,
    dataloader_full,
    epochs=epochs_full,
    device=device,
    ema=True,
    per_epoch_callback=None,
)

# Flow Matching training
fm_full = FlowMatchingModel(ScoreNet(lambda t: torch.ones(1).to(device)).to(device)).to(device)
opt_f_full = torch.optim.Adam(fm_full.parameters(), lr=learning_rate)
sch_f_full = torch.optim.lr_scheduler.ExponentialLR(opt_f_full, 0.9999)

loss_fm_full = train_flow_matching(
    fm_full,
    opt_f_full,
    sch_f_full,
    dataloader_full,
    epochs=epochs_full,
    device=device,
    per_epoch_callback=None,
)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Plot 1: Training loss comparison with smoothing
window = 100
ddpm_smooth = np.convolve(loss_ddpm_full, np.ones(window)/window, mode='valid')
fm_smooth = np.convolve(loss_fm_full, np.ones(window)/window, mode='valid')

axes[0].plot(ddpm_smooth, label="DDPM", linewidth=2, alpha=0.8)
axes[0].plot(fm_smooth, label="Flow Matching", linewidth=2, alpha=0.8)
axes[0].set_title("Training Loss Comparison", fontsize=11)
axes[0].set_xlabel("Training Step", fontsize=10)
axes[0].set_ylabel("Loss", fontsize=10)
axes[0].legend(fontsize=9)
axes[0].grid(True, alpha=0.3)

# Plot 2: DDPM samples
ddpm_full.eval()
with torch.no_grad():
    ddpm_samples = ddpm_full.sample((16, 28 * 28)).cpu()
    ddpm_samples = ((ddpm_samples + 1) / 2).clamp(0, 1)
    ddpm_grid_plot2 = utils.make_grid(ddpm_samples.reshape(-1, 1, 28, 28), nrow=4, padding=1, pad_value=1.0)

axes[1].imshow(transforms.functional.to_pil_image(ddpm_grid_plot2), cmap="gray")
axes[1].axis("off")
axes[1].set_title("DDPM Samples", fontsize=11)

# Plot 3: Flow Matching samples
fm_full.eval()
with torch.no_grad():
    fm_samples = fm_full.sample(nsamples=16, n_steps=100).cpu()
    fm_samples = ((fm_samples + 1) / 2).clamp(0, 1)
    fm_grid_plot3 = utils.make_grid(fm_samples.reshape(-1, 1, 28, 28), nrow=4, padding=1, pad_value=1.0)

axes[2].imshow(transforms.functional.to_pil_image(fm_grid_plot3), cmap="gray")
axes[2].axis("off")
axes[2].set_title("Flow Matching Samples", fontsize=11)

plt.tight_layout()
plt.show()

#### Failure Mode Analysis

In [None]:
# 1. Too few ODE steps

torch.manual_seed(42)

# Test different numbers of ODE integration steps
step_counts = [1, 5, 10, 20, 200, 1000]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

fm_full.eval()
with torch.no_grad():
    for i, n_steps in enumerate(step_counts):
        samples = fm_full.sample(nsamples=16, n_steps=n_steps).cpu()
        samples = ((samples + 1) / 2).clamp(0, 1)
        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=4)
        
        axes[i].imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        axes[i].axis("off")
        axes[i].set_title(f"FM with {n_steps} ODE steps")

plt.suptitle("Failure Mode: Effect of ODE Integration Steps on Sample Quality", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# 2. Inadequate training (early stopping)

torch.manual_seed(42)

# Train Flow Matching models with different numbers of epochs
epoch_counts = [1, 5, 10, 20, 50, 100]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, n_epochs in enumerate(epoch_counts):
    fm_test = FlowMatchingModel(ScoreNet(lambda t: torch.ones(1).to(device)).to(device)).to(device)
    opt_test = torch.optim.Adam(fm_test.parameters(), lr=learning_rate)
    sch_test = torch.optim.lr_scheduler.ExponentialLR(opt_test, 0.9999)
    
    torch.manual_seed(42)
    dataloader_test = torch.utils.data.DataLoader(
        datasets.MNIST("./mnist_data", download=True, train=True, transform=transform),
        batch_size=batch_size,
        shuffle=True,
    )
    
    _ = train_flow_matching(
        fm_test,
        opt_test,
        sch_test,
        dataloader_test,
        epochs=n_epochs,
        device=device,
        per_epoch_callback=None,
    )
    
    fm_test.eval()
    with torch.no_grad():
        samples = fm_test.sample(nsamples=16, n_steps=50).cpu()
        samples = ((samples + 1) / 2).clamp(0, 1)
        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=4)
        
        axes[i].imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        axes[i].axis("off")
        axes[i].set_title(f"FM trained for {n_epochs} epochs")

plt.suptitle("Failure Mode: Effect of Training Duration on Sample Quality", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# 3. High learning rate instability

torch.manual_seed(42)

# Train Flow Matching models with different learning rates
learning_rates = [1e-2, 5e-3, 1e-3, 5e-4, 1e-4, 1e-5]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

training_epochs = 30

for i, lr in enumerate(learning_rates):
    fm_test = FlowMatchingModel(ScoreNet(lambda t: torch.ones(1).to(device)).to(device)).to(device)
    opt_test = torch.optim.Adam(fm_test.parameters(), lr=lr)
    sch_test = torch.optim.lr_scheduler.ExponentialLR(opt_test, 0.9999)
    
    torch.manual_seed(42)
    dataloader_test = torch.utils.data.DataLoader(
        datasets.MNIST("./mnist_data", download=True, train=True, transform=transform),
        batch_size=batch_size,
        shuffle=True,
    )
    
    _ = train_flow_matching(
        fm_test,
        opt_test,
        sch_test,
        dataloader_test,
        epochs=training_epochs,
        device=device,
        per_epoch_callback=None,
    )
    
    fm_test.eval()
    with torch.no_grad():
        samples = fm_test.sample(nsamples=16, n_steps=50).cpu()
        samples = ((samples + 1) / 2).clamp(0, 1)
        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=4)
        
        axes[i].imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        axes[i].axis("off")
        axes[i].set_title(f"FM with LR={lr:.0e}")

plt.suptitle("Failure Mode: Effect of Learning Rate on Sample Quality", fontsize=14)
plt.tight_layout()
plt.savefig("fm_failure_learning_rate.png", dpi=150, bbox_inches='tight')
plt.show()

### Direction 2: Latent Diffusion Model

#### Implementation

In [None]:
class Encoder(nn.Module):
    """Convolutional encoder that compresses 28x28 images to a latent vector."""
    def __init__(self, latent_dim=32):
        super().__init__()
        self.latent_dim = latent_dim
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, latent_dim),
        )
    
    def forward(self, x):
        return self.conv_layers(x)


class Decoder(nn.Module):
    """Convolutional decoder that reconstructs images from latent vectors."""
    def __init__(self, latent_dim=32):
        super().__init__()
        self.latent_dim = latent_dim
        
        self.fc = nn.Linear(latent_dim, 128 * 4 * 4)
        
        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=0),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Tanh(),
        )
    
    def forward(self, z):
        h = self.fc(z)
        h = h.view(-1, 128, 4, 4)
        return self.deconv_layers(h)


class Autoencoder(nn.Module):
    """Standard Autoencoder for learning latent representations."""
    def __init__(self, latent_dim=32):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)
    
    def encode(self, x):
        """Encode image to latent space."""
        if x.ndim == 2:
            x = x.reshape(-1, 1, 28, 28)
        return self.encoder(x)
    
    def decode(self, z):
        """Decode latent vector to image."""
        return self.decoder(z)
    
    def forward(self, x):
        """Full forward pass: encode then decode."""
        if x.ndim == 2:
            x = x.reshape(-1, 1, 28, 28)
        z = self.encode(x)
        x_recon = self.decode(z)
        return x_recon, z
    
    def loss(self, x):
        """Reconstruction loss (MSE)."""
        if x.ndim == 2:
            x = x.reshape(-1, 1, 28, 28)
        x_recon, _ = self.forward(x)
        return F.mse_loss(x_recon, x, reduction="mean")


def train_autoencoder(model, optimizer, scheduler, dataloader, epochs, device, per_epoch_callback=None):
    """Training loop for the autoencoder."""
    total_steps = len(dataloader) * epochs
    progress_bar = tqdm(range(total_steps), desc="Training Autoencoder")
    
    loss_history = []
    
    for epoch in range(epochs):
        model.train()
        for x, _ in dataloader:
            x = x.to(device)
            if x.ndim == 2:
                x = x.reshape(-1, 1, 28, 28)
            
            optimizer.zero_grad()
            loss = model.loss(x)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            loss_history.append(loss.item())
            
            progress_bar.set_postfix(
                loss=f"⠀{loss.item():12.6f}",
                epoch=f"{epoch+1}/{epochs}",
                lr=f"{scheduler.get_last_lr()[0]:.2E}"
            )
            progress_bar.update()
        
        if per_epoch_callback:
            per_epoch_callback(model)
    
    return loss_history

In [None]:
class LatentUNet(nn.Module):
    """
    A smaller UNet-style network for operating in the latent space.
    Since the latent space is much smaller (32 dims vs 784), we use a simpler architecture.
    """
    def __init__(self, latent_dim=32, hidden_dim=256, embed_dim=128):
        super().__init__()
        self.latent_dim = latent_dim
        
        self.time_embed = nn.Sequential(
            GaussianFourierProjection(embed_dim=embed_dim),
            nn.Linear(embed_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        
        self.input_proj = nn.Linear(latent_dim, hidden_dim)
        
        self.block1 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.norm1 = nn.LayerNorm(hidden_dim)
        
        self.block2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.norm2 = nn.LayerNorm(hidden_dim)
        
        self.block3 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.norm3 = nn.LayerNorm(hidden_dim)
        
        self.output_proj = nn.Linear(hidden_dim, latent_dim)
    
    def forward(self, z, t):
        t_emb = self.time_embed(t)
        
        h = self.input_proj(z)
        
        h = self.norm1(h + self.block1(h + t_emb))
        h = self.norm2(h + self.block2(h + t_emb))
        h = self.norm3(h + self.block3(h + t_emb))
        
        return self.output_proj(h)


class LatentDDPM(nn.Module):
    """
    DDPM operating in the latent space of a pre-trained autoencoder.
    """
    def __init__(self, autoencoder, network, T=1000, beta_1=1e-4, beta_T=2e-2):
        super().__init__()
        
        self.autoencoder = autoencoder
        for param in self.autoencoder.parameters():
            param.requires_grad = False
        
        self.network = network
        self.T = T
        
        self.register_buffer("beta", torch.linspace(beta_1, beta_T, T + 1))
        self.register_buffer("alpha", 1 - self.beta)
        self.register_buffer("alpha_bar", self.alpha.cumprod(dim=0))
    
    def forward_diffusion(self, z0, t, epsilon):
        """Forward diffusion in latent space: q(z_t | z_0)"""
        mean = torch.sqrt(self.alpha_bar[t]) * z0
        std = torch.sqrt(1 - self.alpha_bar[t])
        return mean + std * epsilon
    
    def reverse_diffusion(self, zt, t, epsilon):
        """Reverse diffusion step in latent space: p(z_{t-1} | z_t)"""
        t_normalized = t.squeeze().float() / self.T
        pred_noise = self.network(zt, t_normalized)
        
        mean = (
            1.0 / torch.sqrt(self.alpha[t]) * (
                zt - (self.beta[t]) / torch.sqrt(1 - self.alpha_bar[t]) * pred_noise
            )
        )
        std = torch.where(
            t > 0,
            torch.sqrt(((1 - self.alpha_bar[t - 1]) / (1 - self.alpha_bar[t])) * self.beta[t]),
            torch.zeros_like(t, dtype=torch.float32),
        )
        
        return mean + std * epsilon
    
    @torch.no_grad()
    def sample(self, nsamples):
        """
        Sample from latent DDPM and decode to image space.
        """
        device = self.beta.device
        latent_dim = self.autoencoder.latent_dim
        
        zt = torch.randn(nsamples, latent_dim, device=device)
        
        for t in range(self.T, 0, -1):
            noise = torch.randn_like(zt) if t > 1 else torch.zeros_like(zt)
            t_tensor = torch.tensor(t, device=device).expand(nsamples, 1)
            zt = self.reverse_diffusion(zt, t_tensor, noise)
        
        self.autoencoder.eval()
        x_decoded = self.autoencoder.decode(zt)
        
        return x_decoded.reshape(nsamples, -1)
    
    def loss(self, x):
        """
        ELBO training loss in latent space.
        """
        self.autoencoder.eval()
        with torch.no_grad():
            z0 = self.autoencoder.encode(x.reshape(-1, 1, 28, 28))
        
        t = torch.randint(1, self.T, (z0.shape[0], 1), device=x.device)
        
        epsilon = torch.randn_like(z0)
        
        zt = self.forward_diffusion(z0, t, epsilon)
        
        t_normalized = t.squeeze().float() / self.T
        epsilon_pred = self.network(zt, t_normalized)
        
        return F.mse_loss(epsilon_pred, epsilon, reduction="mean")


def train_latent_ddpm(model, optimizer, scheduler, dataloader, epochs, device, per_epoch_callback=None):
    """Training loop for Latent DDPM."""
    total_steps = len(dataloader) * epochs
    progress_bar = tqdm(range(total_steps), desc="Training Latent DDPM")
    
    loss_history = []
    
    for epoch in range(epochs):
        model.train()
        for x, _ in dataloader:
            x = x.to(device)
            
            optimizer.zero_grad()
            loss = model.loss(x)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            loss_history.append(loss.item())
            
            progress_bar.set_postfix(
                loss=f"⠀{loss.item():12.6f}",
                epoch=f"{epoch+1}/{epochs}",
                lr=f"{scheduler.get_last_lr()[0]:.2E}"
            )
            progress_bar.update()
        
        if per_epoch_callback:
            per_epoch_callback(model)
    
    return loss_history

#### Comparison with DDPM

In [None]:
torch.manual_seed(0)

# Shared parameters
comparison_epochs = 50
comparison_batch_size = 256
ae_latent_dim = 32
ae_epochs = 30

comparison_dataloader = torch.utils.data.DataLoader(
    datasets.MNIST("./mnist_data", download=True, train=True, transform=transform),
    batch_size=comparison_batch_size,
    shuffle=True,
)

# Train Standard DDPM (pixel space)

standard_ddpm = DDPM(
    ScoreNet(lambda t: torch.ones(1).to(device)), 
    T=T
).to(device)

standard_opt = torch.optim.Adam(standard_ddpm.parameters(), lr=learning_rate)
standard_sch = torch.optim.lr_scheduler.ExponentialLR(standard_opt, 0.9999)

standard_loss_history = train(
    standard_ddpm,
    standard_opt,
    standard_sch,
    comparison_dataloader,
    epochs=comparison_epochs,
    device=device,
    ema=True,
    per_epoch_callback=None,
)

In [None]:
# Train Latent DDPM (latent space)

fresh_autoencoder = Autoencoder(latent_dim=ae_latent_dim).to(device)
fresh_ae_opt = torch.optim.Adam(fresh_autoencoder.parameters(), lr=learning_rate)
fresh_ae_sch = torch.optim.lr_scheduler.ExponentialLR(fresh_ae_opt, 0.9999)

fresh_ae_loss = train_autoencoder(
    fresh_autoencoder,
    fresh_ae_opt,
    fresh_ae_sch,
    comparison_dataloader,
    epochs=ae_epochs,
    device=device,
    per_epoch_callback=None,
)

# Then train latent DDPM
fresh_latent_unet = LatentUNet(latent_dim=ae_latent_dim, hidden_dim=256, embed_dim=128).to(device)
fresh_latent_ddpm = LatentDDPM(fresh_autoencoder, fresh_latent_unet, T=T).to(device)

fresh_latent_opt = torch.optim.Adam(fresh_latent_unet.parameters(), lr=learning_rate)
fresh_latent_sch = torch.optim.lr_scheduler.ExponentialLR(fresh_latent_opt, 0.9999)

fresh_latent_loss = train_latent_ddpm(
    fresh_latent_ddpm,
    fresh_latent_opt,
    fresh_latent_sch,
    comparison_dataloader,
    epochs=comparison_epochs,
    device=device,
    per_epoch_callback=None,
)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Plot 1: Training loss comparison with smoothing
window = 100
standard_smooth = np.convolve(standard_loss_history, np.ones(window)/window, mode='valid')
latent_smooth = np.convolve(fresh_latent_loss, np.ones(window)/window, mode='valid')

axes[0].plot(standard_smooth, label="Standard DDPM", linewidth=2, alpha=0.8)
axes[0].plot(latent_smooth, label="Latent DDPM", linewidth=2, alpha=0.8)
axes[0].set_title("Training Loss Comparison", fontsize=11)
axes[0].set_xlabel("Training Step", fontsize=10)
axes[0].set_ylabel("Loss", fontsize=10)
axes[0].legend(fontsize=9)
axes[0].grid(True, alpha=0.3)

# Plot 2: Standard DDPM samples
standard_ddpm.eval()
with torch.no_grad():
    standard_samples = standard_ddpm.sample((16, 28 * 28)).cpu()
    standard_samples = ((standard_samples + 1) / 2).clamp(0, 1)
    standard_grid = utils.make_grid(standard_samples.reshape(-1, 1, 28, 28), nrow=4, padding=1, pad_value=1.0)

axes[1].imshow(transforms.functional.to_pil_image(standard_grid), cmap="gray")
axes[1].axis("off")
axes[1].set_title("Standard DDPM Samples", fontsize=11)

# Plot 3: Latent DDPM samples
fresh_latent_ddpm.eval()
with torch.no_grad():
    latent_samples = fresh_latent_ddpm.sample(16).cpu()
    latent_samples = ((latent_samples + 1) / 2).clamp(0, 1)
    latent_grid = utils.make_grid(latent_samples.reshape(-1, 1, 28, 28), nrow=4, padding=1, pad_value=1.0)

axes[2].imshow(transforms.functional.to_pil_image(latent_grid), cmap="gray")
axes[2].axis("off")
axes[2].set_title("Latent DDPM Samples", fontsize=11)

plt.tight_layout()
plt.show()

#### Failure Mode Analysis

In [None]:
# 1. Too few diffusion steps during sampling

torch.manual_seed(42)

step_counts = [10, 50, 100, 250, 500, 1000]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

fresh_latent_ddpm.eval()
for i, n_steps in enumerate(step_counts):
    with torch.no_grad():
        device_ddpm = fresh_latent_ddpm.beta.device
        latent_dim_ddpm = fresh_latent_ddpm.autoencoder.latent_dim
        nsamples = 16
        
        zt = torch.randn(nsamples, latent_dim_ddpm, device=device_ddpm)
        
        step_size = fresh_latent_ddpm.T // n_steps
        for step_idx in range(n_steps):
            t = max(1, fresh_latent_ddpm.T - step_idx * step_size)
            noise = torch.randn_like(zt) if t > 1 else torch.zeros_like(zt)
            t_tensor = torch.tensor(t, device=device_ddpm).expand(nsamples, 1)
            zt = fresh_latent_ddpm.reverse_diffusion(zt, t_tensor, noise)
        
        x_decoded = fresh_latent_ddpm.autoencoder.decode(zt)
        samples = x_decoded.reshape(nsamples, -1).cpu()
        samples = ((samples + 1) / 2).clamp(0, 1)
        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=4)
        
        axes[i].imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        axes[i].axis("off")
        axes[i].set_title(f"Latent DDPM with {n_steps} steps")

plt.suptitle("Failure Mode: Effect of Diffusion Steps on Sample Quality", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# 2. Wrong initial noise scale in latent space

torch.manual_seed(42)

# Test different noise scales for initial latent sampling
noise_scales = [0.1, 0.5, 1.0, 2.0, 5.0, 10.0]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

fresh_latent_ddpm.eval()
for i, scale in enumerate(noise_scales):
    with torch.no_grad():
        device_ddpm = fresh_latent_ddpm.beta.device
        latent_dim_ddpm = fresh_latent_ddpm.autoencoder.latent_dim
        nsamples = 16
        
        zt = torch.randn(nsamples, latent_dim_ddpm, device=device_ddpm) * scale
        
        for t in range(fresh_latent_ddpm.T, 0, -1):
            noise = torch.randn_like(zt) if t > 1 else torch.zeros_like(zt)
            t_tensor = torch.tensor(t, device=device_ddpm).expand(nsamples, 1)
            zt = fresh_latent_ddpm.reverse_diffusion(zt, t_tensor, noise)
        
        x_decoded = fresh_latent_ddpm.autoencoder.decode(zt)
        samples = x_decoded.reshape(nsamples, -1).cpu()
        samples = ((samples + 1) / 2).clamp(0, 1)
        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=4)
        
        axes[i].imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        axes[i].axis("off")
        axes[i].set_title(f"Initial noise scale={scale:.1f}")

plt.suptitle("Failure Mode: Effect of Initial Noise Scale in Latent Space", fontsize=14)
plt.tight_layout()
plt.show()

### Direction 3: Continuous-Time SDE Diffusion Model

#### Implementation

In [None]:
class VPSDE(nn.Module):
    """
    Variance Preserving SDE (continuous-time limit of DDPM).
    """
    def __init__(self, network, beta_min=0.1, beta_max=20.0, T=1.0):
        super().__init__()
        
        self._network = network
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.T = T
    
    def beta(self, t):
        """Linear noise schedule: beta(t) = beta_min + t * (beta_max - beta_min)"""
        return self.beta_min + t * (self.beta_max - self.beta_min)
    
    def integral_beta(self, t):
        """Integral of beta from 0 to t: integral_0^t beta(s) ds"""
        return self.beta_min * t + 0.5 * t**2 * (self.beta_max - self.beta_min)
    
    def marginal_prob_mean_coeff(self, t):
        """Mean coefficient for p_{0t}(x(t)|x(0)): exp(-0.5 * integral_0^t beta(s) ds)"""
        return torch.exp(-0.5 * self.integral_beta(t))
    
    def marginal_prob_std(self, t):
        """Standard deviation for p_{0t}(x(t)|x(0)): sqrt(1 - exp(-integral_0^t beta(s) ds))"""
        return torch.sqrt(1.0 - torch.exp(-self.integral_beta(t)))
    
    def score(self, x, t):
        """
        Score network output, properly normalized.
        """
        if t.ndim == 2:
            t = t.squeeze(1)
        
        t_normalized = t / self.T
        x_img = x.reshape(-1, 1, 28, 28)
        
        output = self._network(x_img, t_normalized).reshape(-1, 28 * 28)
        
        std = self.marginal_prob_std(t)[:, None]
        return -output / std
    
    def forward_sample(self, x0, t, epsilon):
        """
        Sample from p_{0t}(x(t) | x(0)) using the closed-form marginal.
                """
        mean_coeff = self.marginal_prob_mean_coeff(t)
        std = self.marginal_prob_std(t)
        
        if mean_coeff.ndim == 1:
            mean_coeff = mean_coeff[:, None]
        if std.ndim == 1:
            std = std[:, None]
        
        return mean_coeff * x0 + std * epsilon
    
    def loss(self, x0):
        """
        Denoising score matching loss (continuous-time version).
        """
        batch_size = x0.shape[0]
        device = x0.device
        
        eps = 1e-5
        t = torch.rand(batch_size, device=device) * (self.T - eps) + eps
        
        epsilon = torch.randn_like(x0)
        
        xt = self.forward_sample(x0, t, epsilon)
        
        score_pred = self.score(xt, t)
        
        std = self.marginal_prob_std(t)[:, None]
        score_true = -epsilon / std
        
        loss = torch.mean(torch.sum((score_pred - score_true)**2 * std**2, dim=1))
        
        return loss
    
    @torch.no_grad()
    def sample_sde(self, shape, n_steps=1000):
        """
        Sample using the reverse-time SDE (Euler-Maruyama).
        """
        device = next(self.parameters()).device
        
        x = torch.randn(shape, device=device)
        
        dt = -self.T / n_steps
        
        for i in range(n_steps):
            t = self.T - i * self.T / n_steps
            t_tensor = torch.full((shape[0],), t, device=device)
            
            beta_t = self.beta(t)
            f = -0.5 * beta_t * x
            g = torch.sqrt(torch.tensor(beta_t, device=device))
            
            score = self.score(x, t_tensor)
            
            drift = f - g**2 * score
            
            noise = torch.randn_like(x) if i < n_steps - 1 else 0
            x = x + drift * dt + g * torch.sqrt(torch.abs(torch.tensor(dt))) * noise
        
        return x
    
    @torch.no_grad()
    def sample_ode(self, shape, n_steps=1000):
        """
        Sample using the Probability Flow ODE (deterministic).
        """
        device = next(self.parameters()).device
        
        x = torch.randn(shape, device=device)
        
        dt = -self.T / n_steps
        
        for i in range(n_steps):
            t = self.T - i * self.T / n_steps
            t_tensor = torch.full((shape[0],), t, device=device)
            
            beta_t = self.beta(t)
            f = -0.5 * beta_t * x
            g_squared = beta_t
            
            score = self.score(x, t_tensor)
            
            drift = f - 0.5 * g_squared * score
            
            x = x + drift * dt
        
        return x
    
    @torch.no_grad()
    def sample(self, shape, method='sde', n_steps=1000):
        """Unified sampling interface."""
        if method == 'sde':
            return self.sample_sde(shape, n_steps)
        elif method == 'ode':
            return self.sample_ode(shape, n_steps)
        else:
            raise ValueError(f"Unknown sampling method: {method}")


def train_sde(model, optimizer, scheduler, dataloader, epochs, device, per_epoch_callback=None):
    """Training loop for continuous-time SDE model."""
    total_steps = len(dataloader) * epochs
    progress_bar = tqdm(range(total_steps), desc="Training SDE Model")
    
    loss_history = []
    
    for epoch in range(epochs):
        model.train()
        for x, _ in dataloader:
            x = x.to(device)
            
            optimizer.zero_grad()
            loss = model.loss(x)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            scheduler.step()
            
            loss_history.append(loss.item())
            
            progress_bar.set_postfix(
                loss=f"⠀{loss.item():12.4f}",
                epoch=f"{epoch+1}/{epochs}",
                lr=f"{scheduler.get_last_lr()[0]:.2E}"
            )
            progress_bar.update()
        
        if per_epoch_callback:
            per_epoch_callback(model)
    
    return loss_history

In [None]:
vpsde_unet = ScoreNet(lambda t: torch.ones(1).to(device)).to(device)
vpsde_model = VPSDE(vpsde_unet, beta_min=0.1, beta_max=20.0, T=1.0).to(device)

vpsde_optimizer = torch.optim.Adam(vpsde_model.parameters(), lr=learning_rate)
vpsde_scheduler = torch.optim.lr_scheduler.ExponentialLR(vpsde_optimizer, 0.9999)

def vpsde_reporter(model):
    """Callback to visualize generated samples during training."""
    model.eval()
    with torch.no_grad():
        nsamples = 10
        
        samples = model.sample((nsamples, 28*28), method='ode', n_steps=200).cpu()
        samples = (samples + 1) / 2
        samples = samples.clamp(0.0, 1.0)
        
        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=nsamples)
        plt.figure(figsize=(15, 2))
        plt.gca().set_axis_off()
        plt.imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        plt.title("VP-SDE Samples (ODE)")
        plt.show()

#### Comparison with DDPM

In [None]:
torch.manual_seed(42)

comparison_epochs_d3 = 50
comparison_batch_size_d3 = 256

comparison_dataloader_d3 = torch.utils.data.DataLoader(
    datasets.MNIST("./mnist_data", download=True, train=True, transform=transform),
    batch_size=comparison_batch_size_d3,
    shuffle=True,
)

ddpm_d3 = DDPM(
    ScoreNet(lambda t: torch.ones(1).to(device)), 
    T=T
).to(device)

ddpm_d3_opt = torch.optim.Adam(ddpm_d3.parameters(), lr=learning_rate)
ddpm_d3_sch = torch.optim.lr_scheduler.ExponentialLR(ddpm_d3_opt, 0.9999)

loss_ddpm_d3 = train(
    ddpm_d3,
    ddpm_d3_opt,
    ddpm_d3_sch,
    comparison_dataloader_d3,
    epochs=comparison_epochs_d3,
    device=device,
    ema=True,
    per_epoch_callback=None,
)

vpsde_d3 = VPSDE(
    ScoreNet(lambda t: torch.ones(1).to(device)).to(device),
    beta_min=0.1,
    beta_max=20.0,
    T=1.0
).to(device)

vpsde_d3_opt = torch.optim.Adam(vpsde_d3.parameters(), lr=learning_rate)
vpsde_d3_sch = torch.optim.lr_scheduler.ExponentialLR(vpsde_d3_opt, 0.9999)

loss_vpsde_d3 = train_sde(
    vpsde_d3,
    vpsde_d3_opt,
    vpsde_d3_sch,
    comparison_dataloader_d3,
    epochs=comparison_epochs_d3,
    device=device,
    per_epoch_callback=None,
)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Plot 1: Training loss comparison with smoothing
window = 100
ddpm_d3_smooth = np.convolve(loss_ddpm_d3, np.ones(window)/window, mode='valid')
vpsde_d3_smooth = np.convolve(loss_vpsde_d3, np.ones(window)/window, mode='valid')

axes[0].plot(ddpm_d3_smooth, label="DDPM", linewidth=2, alpha=0.8)
axes[0].plot(vpsde_d3_smooth, label="Continuous-Time SDE", linewidth=2, alpha=0.8)
axes[0].set_title("Training Loss Comparison", fontsize=11)
axes[0].set_xlabel("Training Step", fontsize=10)
axes[0].set_ylabel("Loss", fontsize=10)
axes[0].legend(fontsize=9)
axes[0].grid(True, alpha=0.3)

# Plot 2: DDPM samples
ddpm_d3.eval()
with torch.no_grad():
    ddpm_d3_samples = ddpm_d3.sample((16, 28 * 28)).cpu()
    ddpm_d3_samples = ((ddpm_d3_samples + 1) / 2).clamp(0, 1)
    ddpm_d3_grid = utils.make_grid(ddpm_d3_samples.reshape(-1, 1, 28, 28), nrow=4, padding=1, pad_value=1.0)

axes[1].imshow(transforms.functional.to_pil_image(ddpm_d3_grid), cmap="gray")
axes[1].axis("off")
axes[1].set_title("DDPM Samples", fontsize=11)

# Plot 3: Continuous-Time SDE samples (using ODE for deterministic generation)
vpsde_d3.eval()
with torch.no_grad():
    vpsde_d3_samples = vpsde_d3.sample((16, 28 * 28), method='ode', n_steps=500).cpu()
    vpsde_d3_samples = ((vpsde_d3_samples + 1) / 2).clamp(0, 1)
    vpsde_d3_grid = utils.make_grid(vpsde_d3_samples.reshape(-1, 1, 28, 28), nrow=4, padding=1, pad_value=1.0)

axes[2].imshow(transforms.functional.to_pil_image(vpsde_d3_grid), cmap="gray")
axes[2].axis("off")
axes[2].set_title("Continuous-Time SDE Samples", fontsize=11)

plt.tight_layout()
plt.show()

#### Failure Mode Analysis

In [None]:
# 1. Numerical instability when beta(t) grows too quickly

torch.manual_seed(42)

# Test different beta schedules (varying max beta)
beta_configs = [
    (0.1, 5.0, "Stable (0.1, 5.0)"),
    (0.1, 20.0, "Standard (0.1, 20.0)"),
    (0.1, 50.0, "High (0.1, 50.0)"),
    (0.1, 100.0, "Very High (0.1, 100.0)"),
    (0.1, 200.0, "Extreme (0.1, 200.0)"),
    (0.1, 500.0, "Unstable (0.1, 500.0)"),
]

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, (beta_min_test, beta_max_test, label) in enumerate(beta_configs):
    test_model = VPSDE(
        ScoreNet(lambda t: torch.ones(1).to(device)).to(device),
        beta_min=beta_min_test,
        beta_max=beta_max_test,
        T=1.0
    ).to(device)
    
    test_model._network.load_state_dict(vpsde_d3._network.state_dict())
    
    test_model.eval()
    with torch.no_grad():
        try:
            samples = test_model.sample((16, 28 * 28), method='ode', n_steps=500).cpu()
            samples = ((samples + 1) / 2).clamp(0, 1)
            grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=4)
            
            axes[i].imshow(transforms.functional.to_pil_image(grid), cmap="gray")
            axes[i].axis("off")
            axes[i].set_title(label)
        except Exception as e:
            axes[i].text(0.5, 0.5, f"Failed:\n{str(e)[:50]}", 
                        ha='center', va='center', fontsize=8)
            axes[i].axis("off")
            axes[i].set_title(label + " (Failed)")

plt.suptitle("Failure Mode: Effect of Beta Schedule Growth Rate on Numerical Stability", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# 2. Miscalibrated noise leading to over-smooth or noisy samples

torch.manual_seed(42)

# Test different initial noise scales
noise_scales = [0.1, 0.5, 1.0, 2.0, 5.0, 10.0]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

vpsde_d3.eval()
for i, scale in enumerate(noise_scales):
    with torch.no_grad():
        device_vp = next(vpsde_d3.parameters()).device
        
        x = torch.randn((16, 28 * 28), device=device_vp) * scale
        
        dt = -vpsde_d3.T / 500
        
        for step in range(500):
            t = vpsde_d3.T - step * vpsde_d3.T / 500
            t_tensor = torch.full((16,), t, device=device_vp)
            
            beta_t = vpsde_d3.beta(t)
            f = -0.5 * beta_t * x
            g_squared = beta_t
            
            score = vpsde_d3.score(x, t_tensor)
            
            drift = f - 0.5 * g_squared * score
            
            x = x + drift * dt
        
        samples = x.cpu()
        samples = ((samples + 1) / 2).clamp(0, 1)
        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=4)
        
        axes[i].imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        axes[i].axis("off")
        axes[i].set_title(f"Initial noise scale={scale:.1f}")

plt.suptitle("Failure Mode: Effect of Miscalibrated Initial Noise on Sample Quality", fontsize=14)
plt.tight_layout()
plt.show()

### Direction 4: Classifier-Free Guidance

#### Implementation

In [None]:
class ConditionalScoreNet(nn.Module):
    """
    A conditional version of ScoreNet that can take class labels as input.
    Supports classifier-free guidance by conditioning on labels or null token.
    """
    def __init__(self, marginal_prob_std, num_classes=10, channels=[32, 64, 128, 256], embed_dim=256):
        super().__init__()
        
        self.time_embed = nn.Sequential(
            GaussianFourierProjection(embed_dim=embed_dim),
            nn.Linear(embed_dim, embed_dim),
        )
        
        self.class_embed = nn.Embedding(num_classes + 1, embed_dim)
        
        # Encoding layers
        self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
        self.dense1 = Dense(embed_dim * 2, channels[0]) 
        self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
        
        self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
        self.dense2 = Dense(embed_dim * 2, channels[1])
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
        
        self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
        self.dense3 = Dense(embed_dim * 2, channels[2])
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
        
        self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
        self.dense4 = Dense(embed_dim * 2, channels[3])
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
        
        # Decoding layers
        self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
        self.dense5 = Dense(embed_dim * 2, channels[2])
        self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
        
        self.tconv3 = nn.ConvTranspose2d(
            channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1
        )
        self.dense6 = Dense(embed_dim * 2, channels[1])
        self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
        
        self.tconv2 = nn.ConvTranspose2d(
            channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1
        )
        self.dense7 = Dense(embed_dim * 2, channels[0])
        self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
        
        self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)
        
        self.act = lambda x: x * torch.sigmoid(x)
        self.marginal_prob_std = marginal_prob_std
    
    def forward(self, x, t, y):
        """
        Forward pass with conditioning.
        """
        t_embed = self.act(self.time_embed(t))
        y_embed = self.class_embed(y)
        
        embed = torch.cat([t_embed, y_embed], dim=-1)
        
        # Encoding path
        h1 = self.conv1(x)
        h1 += self.dense1(embed)
        h1 = self.gnorm1(h1)
        h1 = self.act(h1)
        
        h2 = self.conv2(h1)
        h2 += self.dense2(embed)
        h2 = self.gnorm2(h2)
        h2 = self.act(h2)
        
        h3 = self.conv3(h2)
        h3 += self.dense3(embed)
        h3 = self.gnorm3(h3)
        h3 = self.act(h3)
        
        h4 = self.conv4(h3)
        h4 += self.dense4(embed)
        h4 = self.gnorm4(h4)
        h4 = self.act(h4)
        
        # Decoding path
        h = self.tconv4(h4)
        h += self.dense5(embed)
        h = self.tgnorm4(h)
        h = self.act(h)
        
        h = self.tconv3(torch.cat([h, h3], dim=1))
        h += self.dense6(embed)
        h = self.tgnorm3(h)
        h = self.act(h)
        
        h = self.tconv2(torch.cat([h, h2], dim=1))
        h += self.dense7(embed)
        h = self.tgnorm2(h)
        h = self.act(h)
        
        h = self.tconv1(torch.cat([h, h1], dim=1))
        
        h = h / self.marginal_prob_std(t)[:, None, None, None]
        return h


class ClassifierFreeDDPM(nn.Module):
    """
    DDPM with Classifier-Free Guidance.
    """
    def __init__(self, network, num_classes=10, T=1000, beta_1=1e-4, beta_T=2e-2, 
                 p_uncond=0.1, guidance_scale=1.0):
        super().__init__()
        
        self._network = network
        self.num_classes = num_classes
        self.null_class = num_classes
        self.p_uncond = p_uncond 
        self.guidance_scale = guidance_scale
        
        self.network = lambda x, t, y: (
            self._network(x.reshape(-1, 1, 28, 28), (t.squeeze() / T), y)
        ).reshape(-1, 28 * 28)
        
        self.T = T
        
        self.register_buffer("beta", torch.linspace(beta_1, beta_T, T + 1))
        self.register_buffer("alpha", 1 - self.beta)
        self.register_buffer("alpha_bar", self.alpha.cumprod(dim=0))
    
    def forward_diffusion(self, x0, t, epsilon):
        """Forward diffusion: q(x_t | x_0)"""
        mean = torch.sqrt(self.alpha_bar[t]) * x0
        std = torch.sqrt(1 - self.alpha_bar[t])
        return mean + std * epsilon
    
    def reverse_diffusion(self, xt, t, y, epsilon):
        """
        Reverse diffusion with classifier-free guidance.
        """
        eps_cond = self.network(xt, t, y)
        
        y_uncond = torch.full_like(y, self.null_class)
        eps_uncond = self.network(xt, t, y_uncond)
        
        eps_guided = eps_uncond + self.guidance_scale * (eps_cond - eps_uncond)
        
        mean = (
            1.0 / torch.sqrt(self.alpha[t]) * (
                xt - (self.beta[t]) / torch.sqrt(1 - self.alpha_bar[t]) * eps_guided
            )
        )
        std = torch.where(
            t > 0,
            torch.sqrt(((1 - self.alpha_bar[t - 1]) / (1 - self.alpha_bar[t])) * self.beta[t]),
            0,
        )
        
        return mean + std * epsilon
    
    @torch.no_grad()
    def sample(self, shape, y=None, guidance_scale=None):
        """
        Sample from the model with classifier-free guidance.
        """
        device = self.beta.device
        nsamples = shape[0]
        
        if y is None:
            y = torch.randint(0, self.num_classes, (nsamples,), device=device)
        
        old_guidance_scale = self.guidance_scale
        if guidance_scale is not None:
            self.guidance_scale = guidance_scale
        
        xT = torch.randn(shape).to(device)
        xt = xT
        
        for t in range(self.T, 0, -1):
            noise = torch.randn_like(xT) if t > 1 else 0
            t_tensor = torch.tensor(t).expand(xt.shape[0], 1).to(device)
            xt = self.reverse_diffusion(xt, t_tensor, y, noise)
        
        self.guidance_scale = old_guidance_scale
        
        return xt
    
    def loss(self, x0, y):
        """
        Training loss with classifier-free guidance.
        Randomly drops conditioning during training.
        """
        batch_size = x0.shape[0]
        device = x0.device
        
        t = torch.randint(1, self.T, (batch_size, 1)).to(device)
        
        epsilon = torch.randn_like(x0)
        
        xt = self.forward_diffusion(x0, t, epsilon)
        
        mask = torch.rand(batch_size, device=device) < self.p_uncond
        y_train = y.clone()
        y_train[mask] = self.null_class
        
        epsilon_pred = self.network(xt, t, y_train)
        
        return F.mse_loss(epsilon_pred, epsilon, reduction="mean")


def train_cfg(model, optimizer, scheduler, dataloader, epochs, device, per_epoch_callback=None):
    """Training loop for classifier-free guidance model."""
    total_steps = len(dataloader) * epochs
    progress_bar = tqdm(range(total_steps), desc="Training CFG Model")
    
    loss_history = []
    
    for epoch in range(epochs):
        model.train()
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
            loss = model.loss(x, y)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            loss_history.append(loss.item())
            
            progress_bar.set_postfix(
                loss=f"⠀{loss.item():12.4f}",
                epoch=f"{epoch+1}/{epochs}",
                lr=f"{scheduler.get_last_lr()[0]:.2E}"
            )
            progress_bar.update()
        
        if per_epoch_callback:
            per_epoch_callback(model)
    
    return loss_history

#### Comparison with DDPM

In [None]:
torch.manual_seed(42)

comparison_epochs_d4 = 50
comparison_batch_size_d4 = 256

comparison_dataloader_d4 = torch.utils.data.DataLoader(
    datasets.MNIST("./mnist_data", download=True, train=True, transform=transform),
    batch_size=comparison_batch_size_d4,
    shuffle=True,
)

# Train standard DDPM (unconditional)
ddpm_d4 = DDPM(
    ScoreNet(lambda t: torch.ones(1).to(device)), 
    T=T
).to(device)

ddpm_d4_opt = torch.optim.Adam(ddpm_d4.parameters(), lr=learning_rate)
ddpm_d4_sch = torch.optim.lr_scheduler.ExponentialLR(ddpm_d4_opt, 0.9999)

loss_ddpm_d4 = train(
    ddpm_d4,
    ddpm_d4_opt,
    ddpm_d4_sch,
    comparison_dataloader_d4,
    epochs=comparison_epochs_d4,
    device=device,
    ema=True,
    per_epoch_callback=None,
)

# Train Classifier-Free Guidance DDPM
cfg_network = ConditionalScoreNet(
    lambda t: torch.ones(1).to(device),
    num_classes=10
).to(device)

cfg_model = ClassifierFreeDDPM(
    cfg_network,
    num_classes=10,
    T=T,
    p_uncond=0.1,  # 10% unconditional training
    guidance_scale=2.0
).to(device)

cfg_opt = torch.optim.Adam(cfg_model.parameters(), lr=learning_rate)
cfg_sch = torch.optim.lr_scheduler.ExponentialLR(cfg_opt, 0.9999)

loss_cfg_d4 = train_cfg(
    cfg_model,
    cfg_opt,
    cfg_sch,
    comparison_dataloader_d4,
    epochs=comparison_epochs_d4,
    device=device,
    per_epoch_callback=None,
)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Plot 1: Training loss comparison with smoothing
window = 100
ddpm_d4_smooth = np.convolve(loss_ddpm_d4, np.ones(window)/window, mode='valid')
cfg_d4_smooth = np.convolve(loss_cfg_d4, np.ones(window)/window, mode='valid')

axes[0].plot(ddpm_d4_smooth, label="DDPM", linewidth=2, alpha=0.8)
axes[0].plot(cfg_d4_smooth, label="Classifier-Free Guidance", linewidth=2, alpha=0.8)
axes[0].set_title("Training Loss Comparison", fontsize=11)
axes[0].set_xlabel("Training Step", fontsize=10)
axes[0].set_ylabel("Loss", fontsize=10)
axes[0].legend(fontsize=9)
axes[0].grid(True, alpha=0.3)

# Plot 2: DDPM samples (unconditional)
ddpm_d4.eval()
with torch.no_grad():
    ddpm_d4_samples = ddpm_d4.sample((16, 28 * 28)).cpu()
    ddpm_d4_samples = ((ddpm_d4_samples + 1) / 2).clamp(0, 1)
    ddpm_d4_grid = utils.make_grid(ddpm_d4_samples.reshape(-1, 1, 28, 28), nrow=4, padding=1, pad_value=1.0)

axes[1].imshow(transforms.functional.to_pil_image(ddpm_d4_grid), cmap="gray")
axes[1].axis("off")
axes[1].set_title("DDPM Samples", fontsize=11)

# Plot 3: CFG samples (conditional - each row is a different class)
cfg_model.eval()
with torch.no_grad():
    # Generate samples for classes 0-3 (4 samples each)
    y_cfg = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], device=device)
    cfg_d4_samples = cfg_model.sample((16, 28 * 28), y=y_cfg, guidance_scale=2.0).cpu()
    cfg_d4_samples = ((cfg_d4_samples + 1) / 2).clamp(0, 1)
    cfg_d4_grid = utils.make_grid(cfg_d4_samples.reshape(-1, 1, 28, 28), nrow=4, padding=1, pad_value=1.0)

axes[2].imshow(transforms.functional.to_pil_image(cfg_d4_grid), cmap="gray")
axes[2].axis("off")
axes[2].set_title("CFG Samples (classes 0-3)", fontsize=11)

plt.tight_layout()
plt.show()

#### Failure Mode Analysis

In [None]:
# 1. Effect of guidance scale on sample quality and diversity

torch.manual_seed(42)

# Test different guidance scales
guidance_scales = [0.0, 0.5, 1.0, 2.0, 5.0, 10.0]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

cfg_model.eval()
for i, scale in enumerate(guidance_scales):
    with torch.no_grad():
        # Generate samples for class 3 (easier to see quality differences)
        y_test = torch.full((16,), 3, dtype=torch.long, device=device)
        samples = cfg_model.sample((16, 28 * 28), y=y_test, guidance_scale=scale).cpu()
        samples = ((samples + 1) / 2).clamp(0, 1)
        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=4)
        
        axes[i].imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        axes[i].axis("off")
        axes[i].set_title(f"Guidance scale={scale:.1f}")

plt.suptitle("Failure Mode: Effect of Guidance Scale on Sample Quality", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# 2. Insufficient unconditional training (p_uncond too low)

torch.manual_seed(42)

# Test different unconditional training probabilities
p_uncond_values = [0.0, 0.05, 0.1, 0.2, 0.5, 1.0]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

training_epochs_cfg = 20

for i, p_uncond_val in enumerate(p_uncond_values):
    # Train a new model with different p_uncond
    cfg_net_test = ConditionalScoreNet(
        lambda t: torch.ones(1).to(device),
        num_classes=10
    ).to(device)
    
    cfg_test = ClassifierFreeDDPM(
        cfg_net_test,
        num_classes=10,
        T=T,
        p_uncond=p_uncond_val,
        guidance_scale=2.0
    ).to(device)
    
    cfg_opt_test = torch.optim.Adam(cfg_test.parameters(), lr=learning_rate)
    cfg_sch_test = torch.optim.lr_scheduler.ExponentialLR(cfg_opt_test, 0.9999)
    
    torch.manual_seed(42)
    dataloader_cfg_test = torch.utils.data.DataLoader(
        datasets.MNIST("./mnist_data", download=True, train=True, transform=transform),
        batch_size=comparison_batch_size_d4,
        shuffle=True,
    )
    
    _ = train_cfg(
        cfg_test,
        cfg_opt_test,
        cfg_sch_test,
        dataloader_cfg_test,
        epochs=training_epochs_cfg,
        device=device,
        per_epoch_callback=None,
    )
    
    cfg_test.eval()
    with torch.no_grad():
        y_test = torch.full((16,), 7, dtype=torch.long, device=device)
        samples = cfg_test.sample((16, 28 * 28), y=y_test, guidance_scale=2.0).cpu()
        samples = ((samples + 1) / 2).clamp(0, 1)
        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=4)
        
        axes[i].imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        axes[i].axis("off")
        axes[i].set_title(f"p_uncond={p_uncond_val:.2f}")

plt.suptitle("Failure Mode: Effect of Unconditional Training Probability", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# 3. Class-specific quality variation (some classes harder to generate)

torch.manual_seed(42)

# Generate samples for all 10 classes
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

cfg_model.eval()
for class_idx in range(10):
    with torch.no_grad():
        y_class = torch.full((16,), class_idx, dtype=torch.long, device=device)
        samples = cfg_model.sample((16, 28 * 28), y=y_class, guidance_scale=2.0).cpu()
        samples = ((samples + 1) / 2).clamp(0, 1)
        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=4)
        
        axes[class_idx].imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        axes[class_idx].axis("off")
        axes[class_idx].set_title(f"Class {class_idx}")

plt.suptitle("Failure Mode: Class-Specific Quality Variation", fontsize=14)
plt.tight_layout()
plt.show()

### Model Comparison: Baseline vs. Four Directions

In [None]:
import torch.nn.functional as F
from scipy.linalg import sqrtm
from scipy.stats import entropy
import warnings
warnings.filterwarnings('ignore')

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def get_features(self, x):
        """Extract features before final classification layer"""
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        return x

def train_classifier(epochs=5):
    classifier = MNISTClassifier().to(device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
    
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST("./mnist_data", download=True, train=True, 
                      transform=transforms.Compose([
                          transforms.ToTensor(),
                          transforms.Normalize((0.5,), (0.5,))
                      ])),
        batch_size=128,
        shuffle=True,
    )
    
    classifier.train()
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = classifier(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
        print(f"Classifier training epoch {epoch+1}/{epochs} complete")
    
    return classifier

def calculate_inception_score(samples, classifier, splits=10):
    """
    Calculate Inception Score for generated samples.
    Higher is better. Measures quality and diversity.
    """
    classifier.eval()
    with torch.no_grad():
        preds = F.softmax(classifier(samples), dim=1).cpu().numpy()
        
        split_scores = []
        for k in range(splits):
            part = preds[k * (len(preds) // splits): (k + 1) * (len(preds) // splits), :]
            py = np.mean(part, axis=0)
            scores = []
            for i in range(part.shape[0]):
                pyx = part[i, :]
                scores.append(entropy(pyx, py))
            split_scores.append(np.exp(np.mean(scores)))
        
        return np.mean(split_scores), np.std(split_scores)

def calculate_fid(real_features, fake_features):
    """
    Calculate Fréchet Inception Distance.
    Lower is better. Measures similarity to real distribution.
    """
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
    
    ssdiff = np.sum((mu1 - mu2) ** 2)
    covmean = sqrtm(sigma1.dot(sigma2))
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid

def calculate_quality_score(samples, classifier):
    """
    Measures average confidence of classifier on generated samples.
    Higher is better.
    """
    classifier.eval()
    with torch.no_grad():
        preds = F.softmax(classifier(samples), dim=1)
        max_probs = preds.max(dim=1)[0]
        return max_probs.mean().item(), max_probs.std().item()

def calculate_diversity_score(samples, classifier):
    """
    Measures diversity by entropy of predicted class distribution.
    Higher is better (max ~2.3 for uniform over 10 classes).
    """
    classifier.eval()
    with torch.no_grad():
        preds = F.softmax(classifier(samples), dim=1)
        avg_pred = preds.mean(dim=0).cpu().numpy()
        return entropy(avg_pred)

#### Train Classifier for Metrics

In [None]:
mnist_classifier = train_classifier(epochs=5)

real_loader = torch.utils.data.DataLoader(
    datasets.MNIST("./mnist_data", download=True, train=True, 
                  transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.5,), (0.5,))
                  ])),
    batch_size=1000,
    shuffle=True,
)

real_images = next(iter(real_loader))[0].to(device)
with torch.no_grad():
    real_features = mnist_classifier.get_features(real_images).cpu().numpy()

print(f"Classifier trained. Real features shape: {real_features.shape}")

#### Recreate and Evaluate All Five Models

In [None]:
comparison_T = 1000
comparison_lr = 1e-3
comparison_epochs = 50
comparison_batch_size = 256
num_eval_samples = 1000

comparison_dataloader = torch.utils.data.DataLoader(
    datasets.MNIST("./mnist_data", download=True, train=True, transform=transform),
    batch_size=comparison_batch_size,
    shuffle=True,
)

results = {}

In [None]:
# MODEL 1: Baseline DDPM
print("\n[1/5] Training Baseline DDPM...")
torch.manual_seed(42)

baseline_model = DDPM(
    ScoreNet(lambda t: torch.ones(1).to(device)), 
    T=comparison_T
).to(device)

baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=comparison_lr)
baseline_sch = torch.optim.lr_scheduler.ExponentialLR(baseline_opt, 0.9999)

baseline_loss = train(
    baseline_model,
    baseline_opt,
    baseline_sch,
    comparison_dataloader,
    epochs=comparison_epochs,
    device=device,
    ema=True,
    per_epoch_callback=None,
)

baseline_model.eval()
with torch.no_grad():
    baseline_samples = baseline_model.sample((num_eval_samples, 28 * 28)).cpu()
    baseline_samples = ((baseline_samples + 1) / 2).clamp(0, 1)
    baseline_samples_img = baseline_samples.reshape(-1, 1, 28, 28)

    baseline_samples_norm = (baseline_samples_img - 0.5) / 0.5
    baseline_samples_norm = baseline_samples_norm.to(device)
    
    is_mean, is_std = calculate_inception_score(baseline_samples_norm, mnist_classifier)
    fake_features = mnist_classifier.get_features(baseline_samples_norm).cpu().numpy()
    fid = calculate_fid(real_features, fake_features)
    quality_mean, quality_std = calculate_quality_score(baseline_samples_norm, mnist_classifier)
    diversity = calculate_diversity_score(baseline_samples_norm, mnist_classifier)
    
    results['Baseline DDPM'] = {
        'IS': f"{is_mean:.3f} ± {is_std:.3f}",
        'FID': f"{fid:.2f}",
        'Quality': f"{quality_mean:.3f} ± {quality_std:.3f}",
        'Diversity': f"{diversity:.3f}",
        'Final Loss': f"{baseline_loss[-1]:.4f}"
    }

print(f"Baseline DDPM - IS: {is_mean:.3f}, FID: {fid:.2f}")

In [None]:
# MODEL 2: Direction 1 - Flow Matching
print("\n[2/5] Training Flow Matching Model...")
torch.manual_seed(42)

fm_model = FlowMatchingModel(
    ScoreNet(lambda t: torch.ones(1).to(device)).to(device)
).to(device)

fm_opt = torch.optim.Adam(fm_model.parameters(), lr=comparison_lr)
fm_sch = torch.optim.lr_scheduler.ExponentialLR(fm_opt, 0.9999)

fm_loss = train_flow_matching(
    fm_model,
    fm_opt,
    fm_sch,
    comparison_dataloader,
    epochs=comparison_epochs,
    device=device,
    per_epoch_callback=None,
)

fm_model.eval()
with torch.no_grad():
    fm_samples = fm_model.sample(nsamples=num_eval_samples, n_steps=50).cpu()
    fm_samples = ((fm_samples + 1) / 2).clamp(0, 1)
    fm_samples_img = fm_samples.reshape(-1, 1, 28, 28)
    fm_samples_norm = (fm_samples_img - 0.5) / 0.5
    fm_samples_norm = fm_samples_norm.to(device)
    
    is_mean, is_std = calculate_inception_score(fm_samples_norm, mnist_classifier)
    fake_features = mnist_classifier.get_features(fm_samples_norm).cpu().numpy()
    fid = calculate_fid(real_features, fake_features)
    quality_mean, quality_std = calculate_quality_score(fm_samples_norm, mnist_classifier)
    diversity = calculate_diversity_score(fm_samples_norm, mnist_classifier)
    
    results['Direction 1: Flow Matching'] = {
        'IS': f"{is_mean:.3f} ± {is_std:.3f}",
        'FID': f"{fid:.2f}",
        'Quality': f"{quality_mean:.3f} ± {quality_std:.3f}",
        'Diversity': f"{diversity:.3f}",
        'Final Loss': f"{fm_loss[-1]:.4f}"
    }

print(f"Flow Matching - IS: {is_mean:.3f}, FID: {fid:.2f}")

In [None]:
# MODEL 3: Direction 2 - Latent Diffusion Model
print("\n[3/5] Training Latent Diffusion Model...")
torch.manual_seed(42)

ae_latent_dim = 32
ae_epochs = 30

autoencoder = Autoencoder(latent_dim=ae_latent_dim).to(device)
ae_opt = torch.optim.Adam(autoencoder.parameters(), lr=comparison_lr)

for epoch in range(ae_epochs):
    autoencoder.train()
    total_loss = 0
    for batch_idx, (data, _) in enumerate(comparison_dataloader):
        data = data.to(device)
        ae_opt.zero_grad()
        reconstructed, _ = autoencoder(data)
        data_img = data.reshape(-1, 1, 28, 28)
        loss = F.mse_loss(reconstructed, data_img)
        loss.backward()
        ae_opt.step()
        total_loss += loss.item()
    if (epoch + 1) % 10 == 0:
        print(f"  Autoencoder epoch {epoch+1}/{ae_epochs}, Loss: {total_loss/len(comparison_dataloader):.4f}")

In [None]:
latent_ddpm = LatentDDPM(
    autoencoder,
    LatentUNet(latent_dim=ae_latent_dim).to(device),
    T=comparison_T
).to(device)

latent_opt = torch.optim.Adam(latent_ddpm.network.parameters(), lr=comparison_lr)
latent_sch = torch.optim.lr_scheduler.ExponentialLR(latent_opt, 0.9999)

latent_loss = train_latent_ddpm(
    latent_ddpm,
    latent_opt,
    latent_sch,
    comparison_dataloader,
    epochs=comparison_epochs,
    device=device,
    per_epoch_callback=None,
)

latent_ddpm.eval()
with torch.no_grad():
    latent_samples = latent_ddpm.sample(num_eval_samples).cpu()
    latent_samples = ((latent_samples + 1) / 2).clamp(0, 1)
    latent_samples_img = latent_samples.reshape(-1, 1, 28, 28)
    latent_samples_norm = (latent_samples_img - 0.5) / 0.5
    latent_samples_norm = latent_samples_norm.to(device)
    
    is_mean, is_std = calculate_inception_score(latent_samples_norm, mnist_classifier)
    fake_features = mnist_classifier.get_features(latent_samples_norm).cpu().numpy()
    fid = calculate_fid(real_features, fake_features)
    quality_mean, quality_std = calculate_quality_score(latent_samples_norm, mnist_classifier)
    diversity = calculate_diversity_score(latent_samples_norm, mnist_classifier)
    
    results['Direction 2: Latent Diffusion'] = {
        'IS': f"{is_mean:.3f} ± {is_std:.3f}",
        'FID': f"{fid:.2f}",
        'Quality': f"{quality_mean:.3f} ± {quality_std:.3f}",
        'Diversity': f"{diversity:.3f}",
        'Final Loss': f"{latent_loss[-1]:.4f}"
    }

print(f"Latent Diffusion - IS: {is_mean:.3f}, FID: {fid:.2f}")

In [None]:
# MODEL 4: Direction 3 - Continuous-Time SDE (VP-SDE)
print("\n[4/5] Training VP-SDE Model...")
torch.manual_seed(42)

vpsde_model = VPSDE(
    ScoreNet(lambda t: torch.ones(1).to(device)),
    beta_min=0.1,
    beta_max=20.0,
    T=1.0
).to(device)

vpsde_opt = torch.optim.Adam(vpsde_model.parameters(), lr=comparison_lr)
vpsde_sch = torch.optim.lr_scheduler.ExponentialLR(vpsde_opt, 0.9999)

vpsde_loss = train_sde(
    vpsde_model,
    vpsde_opt,
    vpsde_sch,
    comparison_dataloader,
    epochs=comparison_epochs,
    device=device,
    per_epoch_callback=None,
)

vpsde_model.eval()
with torch.no_grad():
    vpsde_samples = vpsde_model.sample((num_eval_samples, 28 * 28), n_steps=100).cpu()
    vpsde_samples = ((vpsde_samples + 1) / 2).clamp(0, 1)
    vpsde_samples_img = vpsde_samples.reshape(-1, 1, 28, 28)
    vpsde_samples_norm = (vpsde_samples_img - 0.5) / 0.5
    vpsde_samples_norm = vpsde_samples_norm.to(device)
    
    is_mean, is_std = calculate_inception_score(vpsde_samples_norm, mnist_classifier)
    fake_features = mnist_classifier.get_features(vpsde_samples_norm).cpu().numpy()
    fid = calculate_fid(real_features, fake_features)
    quality_mean, quality_std = calculate_quality_score(vpsde_samples_norm, mnist_classifier)
    diversity = calculate_diversity_score(vpsde_samples_norm, mnist_classifier)
    
    results['Direction 3: VP-SDE'] = {
        'IS': f"{is_mean:.3f} ± {is_std:.3f}",
        'FID': f"{fid:.2f}",
        'Quality': f"{quality_mean:.3f} ± {quality_std:.3f}",
        'Diversity': f"{diversity:.3f}",
        'Final Loss': f"{vpsde_loss[-1]:.4f}"
    }

print(f"VP-SDE - IS: {is_mean:.3f}, FID: {fid:.2f}")

In [None]:
# MODEL 5: Direction 4 - Classifier-Free Guidance
print("\n[5/5] Training Classifier-Free Guidance Model...")
torch.manual_seed(42)

labeled_dataloader = torch.utils.data.DataLoader(
    datasets.MNIST("./mnist_data", download=True, train=True, transform=transform),
    batch_size=comparison_batch_size,
    shuffle=True,
)

cfg_network = ConditionalScoreNet(
    marginal_prob_std=lambda t: torch.ones(1).to(device),
    num_classes=10
).to(device)

cfg_model = ClassifierFreeDDPM(
    cfg_network,
    num_classes=10,
    T=comparison_T,
    p_uncond=0.1,
    guidance_scale=2.0
).to(device)

cfg_opt = torch.optim.Adam(cfg_model.parameters(), lr=comparison_lr)
cfg_sch = torch.optim.lr_scheduler.ExponentialLR(cfg_opt, 0.9999)

cfg_loss = train_cfg(
    cfg_model,
    cfg_opt,
    cfg_sch,
    labeled_dataloader,
    epochs=comparison_epochs,
    device=device,
    per_epoch_callback=None,
)

cfg_model.eval()
with torch.no_grad():
    samples_per_class = num_eval_samples // 10
    all_samples = []
    
    for class_idx in range(10):
        class_labels = torch.full((samples_per_class,), class_idx, dtype=torch.long).to(device)
        class_samples = cfg_model.sample(
            (samples_per_class, 28 * 28),
            y=class_labels,
            guidance_scale=2.0
        ).cpu()
        all_samples.append(class_samples)
    
    cfg_samples = torch.cat(all_samples, dim=0)
    cfg_samples = ((cfg_samples + 1) / 2).clamp(0, 1)
    cfg_samples_img = cfg_samples.reshape(-1, 1, 28, 28)
    cfg_samples_norm = (cfg_samples_img - 0.5) / 0.5
    cfg_samples_norm = cfg_samples_norm.to(device)
    
    is_mean, is_std = calculate_inception_score(cfg_samples_norm, mnist_classifier)
    fake_features = mnist_classifier.get_features(cfg_samples_norm).cpu().numpy()
    fid = calculate_fid(real_features, fake_features)
    quality_mean, quality_std = calculate_quality_score(cfg_samples_norm, mnist_classifier)
    diversity = calculate_diversity_score(cfg_samples_norm, mnist_classifier)
    
    results['Direction 4: Classifier-Free Guidance'] = {
        'IS': f"{is_mean:.3f} ± {is_std:.3f}",
        'FID': f"{fid:.2f}",
        'Quality': f"{quality_mean:.3f} ± {quality_std:.3f}",
        'Diversity': f"{diversity:.3f}",
        'Final Loss': f"{cfg_loss[-1]:.4f}"
    }

print(f"Classifier-Free Guidance - IS: {is_mean:.3f}, FID: {fid:.2f}")

In [None]:
df_results = pd.DataFrame(results).T
df_results