In [1]:
# Imports
import math
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from modules.d3pm.dit import DDiT_Llama

# Model configuration
N = 27  # text8 has 26 letters + 1 for padding/special token (vocab size)
dim = 256
n_layers = 6
n_heads = 8
multiple_of = 256
ffn_dim_multiplier = None
norm_eps = 1e-5
learn_gating = False

# Instantiate model
model = DDiT_Llama(
    N=N,
    dim=dim,
    n_layers=n_layers,
    n_heads=n_heads,
    multiple_of=multiple_of,
    ffn_dim_multiplier=ffn_dim_multiplier,
    norm_eps=norm_eps,
    learn_gating=learn_gating
)

In [3]:
# Print number of parameters
num_params = sum(p.numel() for p in model.parameters())
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {num_params:,}")
print(f"Trainable parameters: {num_trainable_params:,}")

Total parameters: 7,769,627
Trainable parameters: 7,769,627


In [None]:

# Create dummy inputs for forward pass
batch_size = 4
seq_len = 128
x = torch.randint(0, N, (batch_size, seq_len))  # Input token indices
t = torch.randint(0, 1000, (batch_size,))  # Diffusion timesteps

# Character mapping for text8 (a-z = 0-25, space = 26)
chars = list("abcdefghijklmnopqrstuvwxyz ")
idx_to_char = {i: c for i, c in enumerate(chars)}

# Sample from dataset and mask random positions
# Note: Using a simple placeholder dataset for now (will use text8_indices after loading)
# For demonstration, create a simple dataset from repeated alphabet
sample_text = "the quick brown fox jumps over the lazy dog " * 20
sample_indices = torch.tensor([chars.index(c) for c in sample_text[:seq_len * batch_size]], dtype=torch.long)
sample_indices = sample_indices.view(batch_size, seq_len)

# Mask a percentage of random positions
mask_percent = 0.15  # 15% of positions will be masked
mask_token = N - 1  # Use last token index as mask token (or could use a special value)

# Create random mask
mask = torch.rand(batch_size, seq_len) < mask_percent
x = sample_indices.clone()
x[mask] = mask_token  # Replace masked positions with mask token

# Run forward pass
with torch.no_grad():
    output = model(x, t)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    
    # Convert original indices to text
    original_text = [''.join([idx_to_char.get(idx.item(), '?') for idx in seq]) for seq in sample_indices]
    print(f"\nOriginal (as text):")
    for i, txt in enumerate(original_text):
        print(f"  Sample {i}: '{txt[:50]}...'")
    
    # Convert masked input indices to text (show masked positions as '_')
    input_text = [''.join(['_' if idx.item() == mask_token else idx_to_char.get(idx.item(), '?') for idx in seq]) for seq in x]
    print(f"\nMasked Input (as text, '_' = masked):")
    for i, txt in enumerate(input_text):
        print(f"  Sample {i}: '{txt[:50]}...'")
    
    # Convert output logits to text (argmax to get predicted tokens)
    output_indices = output.argmax(dim=-1)
    output_text = [''.join([idx_to_char.get(idx.item(), '?') for idx in seq]) for seq in output_indices]
    print(f"\nOutput (as text, argmax of logits):")
    for i, txt in enumerate(output_text):
        print(f"  Sample {i}: '{txt[:50]}...'")

Input shape: torch.Size([4, 128])
Output shape: torch.Size([4, 128, 27])

Original (as text):
  Sample 0: 'the quick brown fox jumps over the lazy dog the qu...'
  Sample 1: 'dog the quick brown fox jumps over the lazy dog th...'
  Sample 2: 'azy dog the quick brown fox jumps over the lazy do...'
  Sample 3: 'he lazy dog the quick brown fox jumps over the laz...'

Masked Input (as text, '_' = masked):
  Sample 0: 'the_quic__b_own_fox_jumps___er__he_laz___og_t___qu...'
  Sample 1: 'dog_the_quic__brown_fox_jumps_over_the_laz___og_th...'
  Sample 2: 'az__dog_the_quick_br____fox_jumps_over__he_lazy_do...'
  Sample 3: 'he_laz__dog_th__q_ick__ro_n_fox_jum_s_ov_r_the_laz...'

Output (as text, argmax of logits):
  Sample 0: 'the quic  b own fox jumps   er  he laz   og t   qu...'
  Sample 1: 'dog the quic  brown fox jumps over the laz   og th...'
  Sample 2: 'az  dog the quick br    fox jumps over  he lazy do...'
  Sample 3: 'he laz  dog th  q ick  ro n fox jum s ov r the laz...'


In [None]:
import os
import zipfile
import urllib.request

# Download text8 dataset
data_dir = "data"
os.makedirs(data_dir, exist_ok=True)

text8_path = os.path.join(data_dir, "text8")
if not os.path.exists(text8_path):
    zip_path = os.path.join(data_dir, "text8.zip")
    if not os.path.exists(zip_path):
        print("Downloading text8 dataset...")
        urllib.request.urlretrieve("http://mattmahoney.net/dc/text8.zip", zip_path)
    print("Extracting...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)
    os.remove(zip_path)
    print("Done!")
else:
    print("Text8 dataset already exists, skipping download.")

# Load text8 data
with open(text8_path, 'r') as f:
    text8_data = f.read()

print(f"Text8 dataset size: {len(text8_data)} characters")
print(f"Sample: {text8_data[:100]}")

# Create character to index mapping (a-z = 0-25, space = 26)
chars = list("abcdefghijklmnopqrstuvwxyz ")
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for c, i in char_to_idx.items()}

# Convert text to indices
text8_indices = torch.tensor([char_to_idx[c] for c in text8_data], dtype=torch.long)
print(f"Text8 indices shape: {text8_indices.shape}")

Text8 dataset size: 100000000 characters
Sample:  anarchism originated as a term of abuse first used against early working class radicals including t
Text8 indices shape: torch.Size([100000000])


In [4]:
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

# Create train/test splits
# We'll split the text8 data into sequences of length seq_len
# Use 90% for training, 10% for testing

# First, we need to wait for text8_indices to be loaded (cell 4)
# For now, define the split function that will be used after loading

seq_len = 256  # Length of each sequence

# Reshape text8_indices into sequences
num_sequences = len(text8_indices) // seq_len
text8_sequences = text8_indices[:num_sequences * seq_len].view(num_sequences, seq_len)

# Split into train/test (90/10)
train_size = int(0.9 * num_sequences)
train_sequences = text8_sequences[:train_size]
test_sequences = text8_sequences[train_size:]

print(f"Total sequences: {num_sequences}")
print(f"Train sequences: {len(train_sequences)}")
print(f"Test sequences: {len(test_sequences)}")

# Create datasets and dataloaders
train_dataset = TensorDataset(train_sequences)
test_dataset = TensorDataset(test_sequences)

Total sequences: 390625
Train sequences: 351562
Test sequences: 39063


 ---
 **Cosine schedule:**
 
 Cosine-style keep mass:
 
 $$ \bar{\alpha}_t = \varepsilon + (1 - \varepsilon) \cos^2 \left( \frac{\pi}{2} \frac{t}{T} \right) $$
 
 Then again:
 
 $$ \beta_t = 1 - \frac{\bar{\alpha}_t}{\bar{\alpha}_{t-1}} $$

In [5]:
import numpy as np

def cosine_schedule(T, eps=1e-4):
    """
    Cosine noise schedule for diffusion models.
    
    Computes:
        alpha_bar_t = eps + (1 - eps) * cos^2(pi/2 * t/T)
        beta_t = 1 - alpha_bar_t / alpha_bar_{t-1}
        alpha_t = 1 - beta_t
    
    Args:
        T: number of timesteps
        eps: small constant to prevent alpha_bar from reaching 0
    
    Returns:
        betas: tensor of shape (T,) with beta values
        alphas: tensor of shape (T,) with alpha values
        alpha_bars: tensor of shape (T,) with cumulative alpha values
    """
    # TODO: Create timestep indices from 0 to T-1
    t = np.arange(0, T)
    
    # TODO: Compute alpha_bar_t = eps + (1 - eps) * cos^2(pi/2 * t/T)
    alpha_bar_t = eps + (1 - eps) * np.cos(np.pi / 2 * t / T) ** 2
    
    # TODO: Compute beta_t = 1 - alpha_bar_t / alpha_bar_{t-1}
    # Hint: Handle t=0 case separately (beta_0 = 1 - alpha_bar_0)
    alpha_bar_t_prev = np.roll(alpha_bar_t, 1)
    beta_t = 1 - alpha_bar_t / alpha_bar_t_prev
    beta_t[0] = 1 - alpha_bar_t[0]
    
    # TODO: Clamp betas to valid range [0, 1)
    beta_t = np.clip(beta_t, 0, 1)
    
    # Compute alpha_t = 1 - beta_t
    alpha_t = 1 - beta_t
    
    # TODO: Return betas, alphas, and alpha_bars
    return beta_t, alpha_t, alpha_bar_t

 ---

 **Absorbing transition kernel**:

 For $k \neq MASK$:
 $$ Q_t[k,j] = (1-\beta_t) \mathbf{1}[j=k] + \beta_t \cdot \mathbf{1}[j=\text{MASK}] $$

 For $k = MASK$:
 $$ Q_t[MASK, j] = 1[j=MASK] $$

 *Stationary distribution:*
  $$ \pi = \delta_{MASK} \quad (\pi_{MASK}=1,\; \pi_k=0\;\text{for } k\neq MASK). $$
 
  > **Note:** $\delta_{MASK}$ is the **Dirac delta distribution** centered at the MASK token—a "one-hot" vector where all probability mass is on the MASK state. This makes sense because the absorbing kernel always transitions tokens toward MASK and never leaves MASK once reached, so at equilibrium, everything ends up in the MASK state.


In [6]:
def absorbing_kernel(beta_t, K, mask_token=None):
    """
    Compute the absorbing transition matrix Q_t for D3PM.

    The absorbing kernel transitions tokens to a special MASK token with probability beta_t,
    and stays in place with probability (1 - beta_t). Once in the MASK state, it stays there.

    For k != MASK:
        Q_t[k, j] = (1 - beta_t) * 1[j=k] + beta_t * 1[j=MASK]

    For k = MASK:
        Q_t[MASK, j] = 1[j=MASK]

    Args:
        beta_t: Noise level at timestep t (scalar, 0 < beta_t < 1)
        K: Number of discrete states/tokens (vocabulary size, including MASK token)
        mask_token: Index of the MASK token (default: K-1, the last token)

    Returns:
        Q_t: Transition matrix of shape [K, K] where Q_t[k, j] = P(x_t = j | x_{t-1} = k)
              Each row sums to 1.
        pi: Stationary distribution of shape [K] (delta at MASK)
    """
    # Set default mask_token to K-1 if not provided
    if mask_token is None:
        mask_token = K - 1

    Q_t = np.zeros((K, K))

    # TODO: : Implement transition rule
    # [1-β   0    0   ...  β ]
    # [ 0   1-β   0   ...  β ]
    # [ 0    0   1-β  ...  β ]
    # ...
    # [ 0    0    0   ...  1 ]

    # - For non-MASK tokens stay in place with probability (1 - beta_t): Q_t[k, k] = (1 - beta_t)
    idx = np.arange(K)
    non_mask = idx != mask_token
    Q_t[idx[non_mask], idx[non_mask]] = 1 - beta_t

    # - For MASK tokens transition to MASK with probability beta_t: Q_t[k, MASK] = beta_t
    Q_t[idx[non_mask], idx[mask_token]] = beta_t

    # - For MASK token stays in MASK state with probability 1 (0 otherwise)
    Q_t[mask_token, idx[non_mask]] = 0
    Q_t[mask_token, idx[mask_token]] = 1

    # TODO: Create stationary distribution pi = delta_{MASK} of being in each state K
    pi = np.zeros(K)
    pi[mask_token] = 1

    # TODO: Return the transition matrix Q_t and stationary distribution pi
    return Q_t, pi

  ---

  **Training Steps for D3PM:**
  1. Sample a batch of clean data $x_0$ from the dataset
  2. Sample random timesteps $t \sim \text{Uniform}(1, T)$ for each sample in the batch
  3. Corrupt $x_0$ to $x_t$ using the cumulative transition matrix $\bar{Q}_t$: $p(x_t | x_0) = x_0 \cdot \bar{Q}_t$
  4. Pass $x_t$ and $t$ to the model to predict $p(x_0 | x_t)$
  5. Compute cross-entropy loss between predicted $x_0$ distribution and true $x_0$
  6. Backpropagate and update model parameters

In [7]:
def train_step(model, x_0, optimizer, time_steps, Q, Q_bar, device):
    """
    Single training step for D3PM (Discrete Denoising Diffusion Probabilistic Model).

    In D3PM, we work with discrete tokens instead of continuous values.
    The forward process corrupts discrete states using transition matrices Q.

    Args:
        model: The D3PM model that predicts p(x_0 | x_t)
        x_0: Clean discrete tokens from dataset [batch_size, seq_len] (integer values 0 to K-1)
        optimizer: The optimizer
        time_steps: Total number of diffusion steps T
        Q: Transition matrices for each timestep [T, K, K] where Q[t, i, j] = p(x_t=j | x_{t-1}=i)
        Q_bar: Cumulative transition matrices [T, K, K] where Q_bar[t] = Q_1 @ Q_2 @ ... @ Q
        device: Device to run on

    Returns:
        loss value
    """
    optimizer.zero_grad()

    batch_size = x_0.shape[0]
    K = Q_bar.shape[-1]

    # TODO: Sample a batch of random timesteps t ~ Uniform(1, T)
    t = torch.randint(1, time_steps, (batch_size,), device=x_0.device)

    # TODO: Compute corrupted tokens x_t by sampling from categorical distribution
    # p(x_t | x_0) = x_0 @ Q_bar_t (one-hot x_0 times transition matrix) [B, S, K]
    # Hint 1: Convert x_0 to one-hot probability distribution
    # Hint 2.1: To multiply by Q_bar_t use torch.einsum for efficent BMM
    # Hint 2.2: we multiply prob distribution x_0 by transition matrix Q_bar to get a prob distribution over x_t
    # Hint 3: Sample x_t from p(x_t | x_0)
    x0_onehot = F.one_hot(x_0, num_classes=K).float()  # [B, S, K]
    p_xt = torch.einsum('bsk,bkj->bsj', x0_onehot, Q_bar[t])
    xt = torch.distributions.Categorical(probs=p_xt).sample()

    # TODO: Predict logits of p(x_0 | x_t) using model
    # Model outputs logits for each position predicting the original clean token
    p_x0 = model(xt, t).transpose(1,2) # [B, S, K] -> [B, K, S]

    # TODO: Compute cross-entropy loss between predicted x_0 and true x_0
    loss = F.cross_entropy(p_x0, x_0)

    loss.backward()
    optimizer.step()

    return loss.item()

In [8]:
def eval_step(model, x_0, time_steps, Q, Q_bar, device):
    """
    Evaluate D3PM model on a single batch and return the loss.

    Args:
        model: The D3PM model that predicts p(x_0 | x_t)
        x_0: Clean discrete tokens from dataset [batch_size, seq_len] (integer values 0 to K-1)
        time_steps: Total number of diffusion steps T
        Q: Transition matrices for each timestep [T, K, K]
        Q_bar: Cumulative transition matrices [T, K, K]
        device: Device to run on

    Returns:
        loss value
    """
    model.eval()

    batch_size = x_0.shape[0]
    K = Q_bar.shape[-1]

    # Sample a batch of random timesteps t ~ Uniform(1, T)
    t = torch.randint(1, time_steps, (batch_size,), device=device)

    # Compute corrupted tokens x_t by sampling from categorical distribution
    # p(x_t | x_0) = x_0 @ Q_bar_t (one-hot x_0 times transition matrix)
    x0_onehot = F.one_hot(x_0, num_classes=K).float()  # [B, S, K]
    Q_bar_t = Q_bar[t]  # [B, K, K]
    p_x_t = torch.einsum('bsk,bkj->bsj', x0_onehot, Q_bar_t)  # [B, S, K]
    x_t = torch.distributions.Categorical(probs=p_x_t).sample()

    # Predict p(x_0 | x_t) using model
    with torch.no_grad():
        logits_p_x_0 = model(x_t, t)

    # Compute cross-entropy loss between predicted x_0 and true x_0
    loss = F.cross_entropy(logits_p_x_0.transpose(1, 2), x_0)

    return loss.item()

  **Inference Steps for D3PM:**
  
  1. Start with $x_{t_{start}}$:
     - If starting from pure noise ($t_{start} = T$): $x_T \sim \text{Uniform}(1, K)$
     - If starting from partially noised data: provide $x_{t_{start}}$ directly
  2. For $t = t_{start}, t_{start}-1, \ldots, 1$:
     - Predict $\pi_i = p_\theta(x_0 = i | x_t, t)$ using the neural network
     - Compute conditional posterior: $q(x_{t-1}=k | x_t=j, x_0=i) = \frac{Q_t[k,j] \cdot \bar{Q}_{t-1}[i,k]}{\bar{Q}_t[i,j]}$
     - Marginalize: $p_\theta(x_{t-1}=k | x_t=j) = \sum_{i=1}^K \pi_i \cdot q(x_{t-1}=k | x_t=j, x_0=i)$
     - Sample $x_{t-1} \sim p_\theta(x_{t-1} | x_t)$
  3. Return $x_0$

In [9]:
@torch.no_grad()
def sample_d3pm(
    model,
    Q,
    Q_bar,
    device,
    init_probs,               # [K] initial probability distribution (defined upstream)
    *,
    x_t=None,                 # [B,S] int
    t_start=None,             # int
    shape=None,               # (B,S) if x_t is None
    return_intermediates=False,
    intermediate_steps=None,
    eps=1e-10,
):
    was_training = model.training
    model.eval()

    Q = Q.to(device=device)
    Q_bar = Q_bar.to(device=device)
    init_probs = init_probs.to(device=device)

    T = int(Q.shape[0])
    K = int(Q.shape[-1])

    if x_t is None:
        if shape is None:
            raise ValueError("Provide shape=(B,S) when x_t is None.")
        t_start = T - 1 if t_start is None else int(t_start)
        x_t = torch.distributions.Categorical(probs=init_probs).sample(shape)
    else:
        x_t = x_t.to(device=device)
        if t_start is None:
            raise ValueError("Provide t_start when x_t is given.")
        t_start = int(t_start)

    if not (0 <= t_start < T):
        raise ValueError(f"t_start must be in [0, {T-1}], got {t_start}.")

    B = int(x_t.shape[0])

    if return_intermediates:
        steps = [] if intermediate_steps is None else list(intermediate_steps)
        step_set = set(steps)
        x0_by_step = {}
        x_tm1_by_step = {}

    for t in reversed(range(1, t_start + 1)):
        t_batch = torch.full((B,), t, dtype=torch.long, device=device)

        # TODO: Get model prediction for p(x_0 | x_t)
        logits_p_x_0 = model(x_t, t_batch) # [B, S, K]
        pi = F.softmax(logits_p_x_0, dim=-1)                # [B,S,K] softmax over logits

        # TODO: Get transition matrices for current timestep
        # Q_t = ...      # [K,K]  (k -> j)
        Q_t = Q[t]
        # Qbar_t = ...   # [K,K]  (i -> j)
        Q_bar_t = Q_bar[t]
        # Qbar_tm1 = ... # [K,K]  (i -> k)
        Q_bar_tm1 = Q_bar[t - 1]

        # TODO: Compute posterior q(x_{t-1}=k | x_t=j, x_0=i)
        # for b in range(B):
        #     for s in range(S):
        #         j = x_t[b,s] # x_t
        #         for i in range(K): # x_0
        #             for k in range(K): # x_{t-1}
        #                 posterior[b,s,i,k] = Q_t[k,j] * Q_bar_tm1[i,k] / Q_bar_t[i,j]

        # Vectorized computation of posterior q(x_{t-1}=k | x_t=j, x_0=i)
        # x_t: [B, S] -> j indices
        # Q_t[k, j]: transition from k to j at time t
        # Q_bar_tm1[i, k]: cumulative transition from i to k at time t-1
        # Q_bar_t[i, j]: cumulative transition from i to j at time t
        
        j = x_t  # [B, S]
        Q_t_kj = Q_t[:, j]  # [K, B, S] -> Q_t[k, x_t[b,s]]
        Q_bar_t_ij = Q_bar_t[:, j]  # [K, B, S] -> Q_bar_t[i, x_t[b,s]]
        
        # posterior[b,s,i,k] = Q_t[k, j] * Q_bar_tm1[i, k] / Q_bar_t[i, j]
        # Q_t_kj: [K, B, S] -> need [B, S, 1, K]
        # Q_bar_tm1: [K, K] (i, k) -> need [1, 1, K, K]
        # Q_bar_t_ij: [K, B, S] -> need [B, S, K, 1]
        
        Q_t_kj = Q_t_kj.permute(1, 2, 0).unsqueeze(2)  # [B, S, 1, K]
        Q_bar_t_ij = Q_bar_t_ij.permute(1, 2, 0).unsqueeze(-1)  # [B, S, K, 1]
        Q_bar_tm1_ik = Q_bar_tm1.unsqueeze(0).unsqueeze(0)  # [1, 1, K, K]
        
        posterior = Q_t_kj * Q_bar_tm1_ik / (Q_bar_t_ij + eps)  # [B, S, K_x0, K_xtm1]

        # TODO: Compute p(x_{t-1} | x_t) by marginalizing over x_0
        # full_posterior = torch.empty(B,S,K)
        # for b in range(B):
        #     for s in range(S):
        #         for k in range(K): # x_{t-1}
        #             for i in range(K):
        #                 full_posterior[b,s,k] = (pi[b,s,i] * posterior[b,s,i,k]).sum()

        # Vectorized: pi[b,s,i] * posterior[b,s,i,k] summed over i
        # pi: [B, S, K] -> [B, S, K, 1]
        # posterior: [B, S, K_x0, K_xtm1] = [B, S, K, K]
        full_posterior = (pi.unsqueeze(-1) * posterior).sum(dim=2)  # [B, S, K]

        # TODO: Sample x_{t-1} from categorical distribution
        x_tm1 = torch.distributions.Categorical(probs=full_posterior).sample()

        if return_intermediates and t in step_set:
            x0_by_step[t] = pi.argmax(dim=-1).detach().clone()
            x_tm1_by_step[t] = x_tm1.detach().clone()

        x_t = x_tm1

    if return_intermediates and 0 in step_set:
        x0_by_step[0] = x_t.detach().clone()
        x_tm1_by_step[0] = x_t.detach().clone()

    if was_training:
        model.train()

    if return_intermediates:
        x0s = [x0_by_step[t] for t in steps]
        x_tm1s = [x_tm1_by_step[t] for t in steps]
        return x_t, x0s, x_tm1s

    return x_t

In [10]:
from tqdm import tqdm
from IPython.display import clear_output
import numpy as np
import os
import time

_chars = list("abcdefghijklmnopqrstuvwxyz ")
_idx_to_char = {i: c for i, c in enumerate(_chars)}


def _tokens_to_str(tokens):
    return ''.join(_idx_to_char.get(int(t), '?') for t in tokens.cpu())


def visualize_generation_samples(model, sample_fn, time_steps, epoch, n=5, seed=None):
    """Print n texts generated from pure noise."""
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    with torch.no_grad():
        samples = sample_fn(model, n, time_steps)
    print(f"=== Generated Samples | Epoch {epoch} ===")
    for i, seq in enumerate(samples):
        print(f"  [{i+1}] {_tokens_to_str(seq)}")


def visualize_generation_timeline(model, sample_fn, time_steps, epoch, k=8, seed=None):
    """Print pred_x0 and x_{t-1} at k evenly-spaced denoising steps (T -> 0)."""
    if seed is not None:
        torch.manual_seed(seed)
    steps = sorted(set(np.linspace(0, time_steps - 1, k, dtype=int).tolist()), reverse=True)
    model.eval()
    with torch.no_grad():
        x0s, x_tm1s = sample_fn(
            model, 1, time_steps,
            return_intermediates=True,
            intermediate_steps=steps,
        )
    pad = len(str(time_steps - 1))
    print(f"=== Generation Timeline | Epoch {epoch} ===")
    for t, x0, x_tm1 in zip(steps, x0s, x_tm1s):
        print(f"  t={t:{pad}d}/{time_steps-1} | pred_x0: {_tokens_to_str(x0[0])}")
        print(f"  {' ' * (pad + 2 + len(str(time_steps-1)))}  | x_{{t-1}}: {_tokens_to_str(x_tm1[0])}")


def visualize_reconstruction_samples(
    model, sample_fn, Q_bar, time_steps, epoch, k=5, seed=None, dataset=None, t_noise=125
):
    """Print (original, noised, denoised) triplets for k dataset samples."""
    if seed is not None:
        torch.manual_seed(seed)
    indices = torch.randperm(len(dataset))[:k]
    x0 = torch.stack([dataset[i][0] for i in indices]).long()
    device = Q_bar.device
    K = Q_bar.shape[-1]
    x0 = x0.to(device)
    if x0.ndim > 2:
        x0 = x0.view(k, -1)
    t = min(t_noise, time_steps - 1)
    x0_onehot = F.one_hot(x0, num_classes=K).float()
    p_xt = torch.einsum('bsk,bkj->bsj', x0_onehot, Q_bar[t].unsqueeze(0).expand(k, -1, -1))
    x_noised = torch.distributions.Categorical(probs=p_xt).sample()
    model.eval()
    with torch.no_grad():
        denoised = sample_fn(model, k, t + 1, x_init=x_noised)
    print(f"=== Reconstruction Samples | Epoch {epoch} | t_noise={t} ===")
    for i in range(k):
        print(f"  [{i+1}] original : {_tokens_to_str(x0[i])}")
        print(f"  [{i+1}] noised   : {_tokens_to_str(x_noised[i])}")
        print(f"  [{i+1}] denoised : {_tokens_to_str(denoised[i])}")
        if i < k - 1:
            print()

In [15]:
def get_device():
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    return "cpu"

def train(
    model,
    train_dataloader,
    test_dataloader,
    num_epochs,
    learning_rate,
    time_steps,
    schedule_fn,
    transition_kernel_fn,
    *,
    recon_t_noise=125,
    refresh_every_k_epochs=5,
):

    device = get_device()
    model.to(device)

    betas, _, _ = schedule_fn(time_steps)
    betas = torch.as_tensor(betas, device=device, dtype=torch.float32)

    x0_example, = next(iter(train_dataloader))
    x0_example = x0_example.to(device)
    S = int(x0_example.shape[1]) if x0_example.ndim == 2 else int(np.prod(x0_example.shape[1:]))
    K = getattr(train_dataloader.dataset, "K", None)
    if K is None:
        K = int(x0_example.max().item()) + 1
    K = int(K)

    I = torch.eye(K, device=device, dtype=torch.float32)
    Q = torch.empty((time_steps, K, K), device=device, dtype=torch.float32)
    Q_bar = torch.empty_like(Q)
    Q[0] = I
    Q_bar[0] = I

    init_probs = None
    for t in range(1, time_steps):
        Q_t, pi = transition_kernel_fn(float(betas[t].item()), K)
        Q_t = torch.as_tensor(Q_t, device=device, dtype=torch.float32)
        Q[t] = Q_t
        Q_bar[t] = Q_bar[t - 1] @ Q_t
        if init_probs is None:
            init_probs = torch.as_tensor(pi, device=device, dtype=torch.float32)

    if init_probs is None:
        init_probs = torch.full((K,), 1.0 / K, device=device, dtype=torch.float32)

    def sample_fn(model, n_samples, num_steps, *, x_init=None, return_intermediates=False, intermediate_steps=None):
        if x_init is None:
            out = sample_d3pm(
                model, Q, Q_bar, device, init_probs,
                shape=(n_samples, S),
                t_start=num_steps - 1,
                return_intermediates=return_intermediates,
                intermediate_steps=intermediate_steps,
            )
        else:
            x_t = x_init.view(-1, S).long().to(device)
            out = sample_d3pm(
                model, Q, Q_bar, device, init_probs,
                x_t=x_t,
                t_start=num_steps - 1,
                return_intermediates=return_intermediates,
                intermediate_steps=intermediate_steps,
            )
        if return_intermediates:
            _, x0s, x_tm1s = out
            return x0s, x_tm1s
        return out

    num_params = sum(p.numel() for p in model.parameters())
    num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model_class_name = model.__class__.__name__
    model_dir = f"data/experiments/text8/{model_class_name}"
    os.makedirs(model_dir, exist_ok=True)

    print(f"Model: {model_class_name} | Params: {num_params:,} (trainable: {num_trainable_params:,})")
    model.eval()
    visualize_generation_samples(model, sample_fn, time_steps, epoch=-1, seed=42)
    visualize_generation_timeline(model, sample_fn, time_steps, epoch=-1, seed=42)
    visualize_reconstruction_samples(
        model, sample_fn, Q_bar, time_steps, epoch=-1, k=5, seed=42,
        dataset=train_dataloader.dataset,
        t_noise=min(int(recon_t_noise), time_steps - 1),
    )

    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        epoch_start_time = time.time()
        total_loss = 0
        num_batches = 0

        model.train()
        for x_0, in tqdm(train_dataloader, desc=f"Epoch {epoch} Train", leave=False):
            x_0 = x_0.to(device)
            if x_0.ndim > 2:
                x_0 = x_0.view(x_0.shape[0], -1)
            loss = train_step(model, x_0.long(), optimizer, time_steps, Q, Q_bar, device)
            total_loss += loss
            num_batches += 1

        avg_train_loss = total_loss / max(num_batches, 1)

        model.eval()
        total_test_loss = 0
        num_test_batches = 0
        with torch.no_grad():
            for x_0, in tqdm(test_dataloader, desc=f"Epoch {epoch} Test", leave=False):
                x_0 = x_0.to(device)
                if x_0.ndim > 2:
                    x_0 = x_0.view(x_0.shape[0], -1)
                test_loss = eval_step(model, x_0.long(), time_steps, Q, Q_bar, device)
                total_test_loss += test_loss
                num_test_batches += 1

        avg_test_loss = total_test_loss / max(num_test_batches, 1)
        epoch_train_time = time.time() - epoch_start_time

        torch.save(model.state_dict(), f"{model_dir}/model_latest.pth")

        should_refresh = (
            refresh_every_k_epochs > 0
            and ((epoch + 1) % refresh_every_k_epochs == 0 or epoch == num_epochs - 1)
        )

        if should_refresh:
            clear_output(wait=True)
        print(f"Model: {model_class_name} | Params: {num_params:,} (trainable: {num_trainable_params:,})")
        print(
            f"Epoch {epoch}/{num_epochs-1} | "
            f"Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f} | "
            f"Time: {epoch_train_time:.2f}s"
        )

        if should_refresh:
            visualize_generation_samples(model, sample_fn, time_steps, epoch, seed=42)
            visualize_generation_timeline(model, sample_fn, time_steps, epoch, k=8, seed=42)
            visualize_reconstruction_samples(
                model, sample_fn, Q_bar, time_steps, epoch, k=5, seed=42,
                dataset=train_dataloader.dataset,
                t_noise=min(int(recon_t_noise), time_steps - 1),
            )

In [None]:
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Train batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

NUM_EPOCHS = 100
LEARNING_RATE = 1e-4
TIME_STEPS = 256
SCHEDULE_FN = cosine_schedule
TRANSITION_KERNEL_FN = absorbing_kernel
RECON_T_NOISE = 125
REFRESH_EVERY_K_EPOCHS = 5

train(
    model,
    train_dataloader,
    test_dataloader,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    time_steps=TIME_STEPS,
    schedule_fn=SCHEDULE_FN,
    transition_kernel_fn=TRANSITION_KERNEL_FN,
    recon_t_noise=RECON_T_NOISE,
    refresh_every_k_epochs=REFRESH_EVERY_K_EPOCHS,
)

Train batches: 5494
Test batches: 611
Model: DDiT_Llama | Params: 7,769,627 (trainable: 7,769,627)
=== Generated Samples | Epoch -1 ===
  [1] soujbhelureuyalhpllzsjnpinaqdbjpwdtjwsftxsbasrjjsedvavoshjxbdjvynucuwlafwihcbfjtrlakqmwmyfghweoyoxxgwahxoqhcswljgomjujbslmfrvrmdzklwrmvglqrvuiqvaxaqxgvpaqwqoypfutgiemccisgxsdpfvxrrrwysmchvofxtesqeslktvyigsdwzpaemyykmgdmfojvxibbogbwfsglgkxcqvufpkjqjinsejnqk
  [2] zeoagiwcghgxhhsstfuivanhgcnszkynghkjyamaewbcgevkaezsedbzwagtcbsjxnmxzdkanyjtorgaucfxhwhjpynmjggwxgnaozwdrzlmtkdxskbhhckgbllihovnmfrieepqypxwefdsguyoufwsqjnxjinucqirsynobzuyxecnpuzvzzttbmgmatrpdojbbajygaiytoaddtazsrvmlzluoecvaqyoeisbgqfugufzcgdsntitvfrbpmpd
  [3] ufyxwwnygjwilddujasjqnzqryauxdfkjbvifrzgtkzfpexjruayikmpcrqfzzqtqksxuogsnznfmrvjbusmfdoaqzsjjinpecmnycloamevzealzsapwqhaplxazfherqfypblsiwqqzivofemvtfyaawjubmeptwmlildullddtyamomuuffqcbegcnftqkypjaldpzbcohzcfmfklebqslhtzgmhfovuyxqmobabfdnrhflgoxtfmarhhfgzm
  [4] bvqnmyfvuxqpypvxsxjoafvdggauathyhqtdecfmfgcnmairtrenurihzdfcuiknureiv

Epochs:   0%|          | 0/100 [01:52<?, ?it/s]


KeyboardInterrupt: 