In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [2]:
# We offload training and encoding to the Rust implementation
# Here we only need it for loading and decoding tokens

import json

class SimpleBPE:

    def __init__(self):
        self.vocab = {}           # token_id -> token_string
        self.special_tokens = {}  # optional: useful for reference

    def decode(self, token_ids):
        """Decode token IDs back to text."""
        return ''.join(self.vocab[i] for i in token_ids)

    def load(self, path):
        """Load tokenizer from file."""
        with open(path) as f:
            data = json.load(f)
        self.vocab = {int(k): v for k, v in data['vocab'].items()}
        self.special_tokens = {k: int(v) for k, v in data.get('special_tokens', {}).items()}
        print(f"Tokenizer loaded from {path} (vocab size: {len(self.vocab)})")
        return self

In [3]:
tokenizer = SimpleBPE().load('./data/tokenizer_bpe.json')

# use tokenizer
def decode(ids): return tokenizer.decode(ids)

Tokenizer loaded from ./data/tokenizer_bpe.json (vocab size: 8192)


In [4]:
# Hyperparameters

block_size = 1024
final_vocab_size = len(tokenizer.vocab)
n_embd = 768
n_head = 12
n_blocks = 12
MASK_TOKEN = tokenizer.special_tokens.get('<MASK>', 'Not found')

In [5]:
# time embedding

class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2

        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t.unsqueeze(1) * emb.unsqueeze(0) * 1000 # [B, 1] * [1, half_dim] = [B, half_dim]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # [B, dim]

        return emb

In [6]:
# MHA
from torchtune.modules import RotaryPositionalEmbeddings
# torchtune RopE implementation expects input shape [B, T, n_head, head_dim], matches the code below

class MHA(nn.Module):
    def __init__(self, n_embd, n_head, dropout=0.1):
        super().__init__()
        assert n_embd % n_head == 0

        self.n_head = n_head
        self.n_embd = n_embd
        self.head_dim = n_embd // n_head

        self.c_attn = nn.Linear(n_embd, 3*n_embd)
        self.c_proj = nn.Linear(n_embd, n_embd)

        self.rope = RotaryPositionalEmbeddings(dim=n_embd // n_head)

        self.attn_dropout = nn.Dropout(dropout)
        self.residual_dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        # atten_c: [n_embd, 3*n_embd];
        # x: [B, T, n_embd];
        # attn_c(x): [B, T, n_embd]@[n_embd, 3*n_embd]=[B, T, 3*n_embd]
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # [B, T, 3*n_embd] ==> 3 * [B, T, n_embd]

        q = q.view(B, T, self.n_head, self.head_dim) # [B, T, n_embd] = [B, T, n_head*head_dim] ==> [B, T, n_head, head_dim]
        # apply RoPE before transpose
        q = self.rope(q)
        q = q.transpose(1, 2) # [B, n_head, T, head_dim]
        k = k.view(B, T, self.n_head, self.head_dim)
        k = self.rope(k)
        k = k.transpose(1, 2) # [B, n_head, T, head_dim]
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        # k.transpose(-2, -1).shape: [B, n_head, head_dim, T]
        # (q @ k.transpose(-2, -1)).shape: [B, n_head, T, head_dim]@[B, n_head, head_dim, T] = [B, n_head, T, T]
        attn = (q @ k.transpose(-2, -1)*(1.0 / math.sqrt(self.head_dim))) # [B, n_head, T, T]

        # No causal mask

        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)

        out = attn @ v # [B, n_head, T, T]@[B, n_head, T, head_dim] = [B, n_head, T, head_dim]
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        # drop out on the residual stream
        out = self.residual_dropout(self.c_proj(out)) # [B, T, C]@[C, C]=[B, T, C]

        return out


In [7]:
class SwiGLU(nn.Module):
    """
    SwiGLU activation function.
    
    This effectively implements:
    SwiGLU(x) = (xW + b) * SiLU(xV + c)
    
    Where the input is split into two parts: one for the 'value' path
    and one for the 'gate' path.
    """
    def forward(self, x):
        # Split the input tensor into two halves along the last dimension
        x, gate = x.chunk(2, dim=-1)
        # Apply SiLU (Swish) to the gate and multiply with the value
        return x * F.silu(gate)

In [8]:
# FFN

class FFN(nn.Module):
    def __init__(self, n_embd, dropout=0.1):
        super().__init__()
        self.c_fc = nn.Linear(n_embd, 8*n_embd)
        self.swiglu = SwiGLU()
        self.c_proj = nn.Linear(4*n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.c_fc(x) # [B, T, C]@[C, 4*C]=[B, T, 4*C]
        x = self.swiglu(x)
        x = self.c_proj(x) # [B, T, 4*C]@[4*C, C]=[B, T, C]
        x = self.dropout(x) # [B, T, C]

        return x

In [9]:
# Block

class Block(nn.Module):
    def __init__(self, n_embd, n_head, dropout=0.1, use_time=True):
        super().__init__()
        self.use_time = use_time
        self.rms_norm_1 = nn.RMSNorm(n_embd)
        self.rms_norm_2 = nn.RMSNorm(n_embd)
        self.attn = MHA(n_embd, n_head, dropout)
        self.ffn = FFN(n_embd, dropout)

        # RMSNorm is designed to be "shift-invariant" (it centers data around 0), in Diffusion models, injecting the shift (beta) back in after normalization is a powerful way to tell the network about the noise level
        if use_time:
            self.time_ffn = nn.Sequential(
                nn.Linear(n_embd, 2 * n_embd),
                SwiGLU(), # SwiGLU will half the feature dimension
                nn.Linear(n_embd, 4*n_embd)
            )
    
    def forward(self, x, time_emb=None):
        if self.use_time and time_emb is not None:
            time_params = self.time_ffn(time_emb) # [B, e_embd]@[n_embd, 4*n_embd]=[B, 4*n_embd]
            shift1, scale1, shift2, scale2 = time_params.chunk(4, dim=-1) # [B, n_embd]

            h = self.rms_norm_1(x) * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1) # [B, T, C]*[B, 1, C]+[B, 1, C]=[B, T, C]
            x = x + self.attn(h)
            h = self.rms_norm_2(x) * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1) # [B, T, C]*[B, 1, C]+[B, 1, C]=[B, T, C]
            x = x + self.ffn(h) # [B, T, C]
        else:
            x = x + self.attn(self.rms_norm_1(x))
            x = x + self.ffn(self.rms_norm_2(x))

        return x


In [10]:
# Full MDLM

class MDLM(nn.Module):
    def __init__(
            self,
            vocab_size,
            n_embd,
            n_head,
            n_block,
            block_size,
            dropout=0.1,
            use_time = True
    ):
        super().__init__()

        self.block_size = block_size
        self.use_time = use_time
        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        if use_time:
            self.time_emb = SinusoidalTimeEmbedding(n_embd)
            self.time_proj = nn.Sequential(
                nn.Linear(n_embd, 2*n_embd),
                SwiGLU(),
                nn.Linear(n_embd, n_embd)
            )
        self.blocks = nn.ModuleList([
            Block(n_embd, n_head, dropout, use_time) for _ in range(n_block)
        ])

        self.rms_norm_final = nn.RMSNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

        # tie input and output embedding weights
        self.lm_head.weight = self.tok_emb.weight

        self.dropout = nn.Dropout(dropout)
        self.apply(self._init_weights)

        n_params = sum(p.numel() for p in self.parameters())
        print(f'Model has {n_params/1e6:.2f}M parameters.')

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.RMSNorm):
            torch.nn.init.ones_(module.weight)

    def forward(self, x, t=None):
        B, T = x.shape

        tok_emb = self.tok_emb(x) # [B, T, n_embd]
        h = self.dropout(tok_emb)

        if self.use_time and t is not None:
            t_emb = self.time_emb(t) # [B, n_embd]
            t_emb = self.time_proj(t_emb) # [B, n_embd]
        else:
            t_emb = None

        for block in self.blocks:
            h = block(h, t_emb)

        h = self.rms_norm_final(h)
        logits = self.lm_head(h) # [B, T, V]

        # Make sure model doesn't learn the behavior of predicting a mask token. Aligns with inference
        logits[:, :, MASK_TOKEN] = float('-inf')

        return logits

In [11]:
# inference

import os

# 1. Define the path (must match your training loop)
CKPT_PATH = './ckpt/latest.pt'

# 2. Re-initialize the model architecture 
model = MDLM(
    vocab_size=final_vocab_size,
    n_embd=n_embd,
    n_head=n_head,
    n_block=n_blocks,
    block_size=block_size,
).to(device)

# 3. Load the checkpoint
if os.path.exists(CKPT_PATH):
    print(f"Loading model weights from {CKPT_PATH}...")
    # map_location ensures that if you trained on GPU but load on CPU, it works
    checkpoint = torch.load(CKPT_PATH, map_location=device, weights_only=False)
    state_dict = checkpoint['model_state_dict']

    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('_orig_mod.'):
            # Remove the prefix
            new_key = key[10:]  # len('_orig_mod.') == 10
            new_state_dict[new_key] = value
        else:
            new_state_dict[key] = value
    
    model.load_state_dict(new_state_dict)
    print(f"Model loaded successfully (from Epoch {checkpoint['epoch']+1})")
else:
    print(f"Warning: No checkpoint found at {CKPT_PATH}. Using random initialization.")

@torch.no_grad()
def sample(model, seq_len, num_steps=100, temperature=0.7, top_p=0.9):
    """
    Generate text using the reverse diffusion process.
    
    Args:
        model: Trained MDLM
        seq_len: Length of sequence to generate
        num_steps: Number of denoising steps
        temperature: Sampling temperature
    
    Returns:
        Generated token sequence
    """
    model.eval()
    
    # Start with all masks
    x = torch.full((1, seq_len), MASK_TOKEN, dtype=torch.long, device=device)
    
    # Linearly spaced time steps from 1 to 0
    timesteps = torch.linspace(1, 0, num_steps + 1, device=device)
    
    for i in range(num_steps):
        t_current = timesteps[i]
        t_next = timesteps[i + 1]
        
        # Get model predictions
        t_batch = torch.tensor([t_current], device=device)
        logits = model(x, t_batch)  # [1, seq_len, vocab_size]
        # Model is predicting all positions at each forward pass

        if top_p < 1.0:
            # Sort logits in descending order
            sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
            
            # Calculate cumulative probabilities
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            
            # Remove tokens with cumulative probability above top_p
            sorted_indices_to_remove = cumulative_probs > top_p
            
            # handle edge case where top_p is too small
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            
            # 4. Scatter mask back to original indices
            indices_to_remove = sorted_indices_to_remove.scatter(
                dim=-1, index=sorted_indices, src=sorted_indices_to_remove
            )
            
            # 5. Set filtered logits to -inf
            logits[indices_to_remove] = float('-inf')
        
        # Don't predict [MASK] token during sampling
        logits[:, :, MASK_TOKEN] = float('-inf')
        
        # Apply temperature
        logits = logits / temperature
        
        # Convert to probabilities
        probs = F.softmax(logits, dim=-1)  # [1, seq_len, vocab_size]
        
        # Sample predictions for all positions
        # (we'll only use some based on remasking strategy)
        pred_tokens = torch.multinomial(
            probs.view(-1, probs.size(-1)), 
            num_samples=1
        ).view(1, seq_len)
        
        # Determine which positions to unmask this step (Ancestral Sampling)
        # Instead of forcing a specific count (top-k), we roll a probability die for each token
        
        # 1. Calculate the probability of unmasking at this step
        if t_current > 0:
            p_unmask = (t_current - t_next) / t_current
        else:
            p_unmask = 1.0  # Force finish at the very end
            
        # 2. Roll the dice for every position
        # Generate random values [0, 1]. If value < p_unmask, we reveal the token.
        random_values = torch.rand_like(probs[:, :, 0]) # Shape: [1, seq_len]
        should_unmask = random_values < p_unmask
        
        # 3. Apply updates
        # Only update if it is CURRENTLY a mask AND the dice roll said "Unmask"
        is_mask = (x == MASK_TOKEN)
        update_mask = is_mask & should_unmask
        
        x[update_mask] = pred_tokens[update_mask]
    
    # Final cleanup: unmask any remaining masks
    is_mask = (x == MASK_TOKEN)
    if is_mask.any():
        logits = model(x, torch.tensor([0.0], device=device))
        logits[:, :, MASK_TOKEN] = float('-inf')
        probs = F.softmax(logits / temperature, dim=-1)
        final_preds = probs.argmax(dim=-1)
        x[is_mask] = final_preds[is_mask]
    
    return x[0].tolist()

# Generate some samples!
print("Generated samples:")
print("=" * 60)

tokens = sample(model, seq_len=1024, num_steps=100, temperature=1.0, top_p=0.9)
text = decode(tokens)
print(''.join(text).replace('<|endoftext|>', '\n\n'))

Model has 163.97M parameters.
Loading model weights from ./ckpt/latest.pt...
Model loaded successfully (from Epoch 7)
Generated samples:
One sunny day, Lily and her mom went to the park. It was a big festival. All many children went there and had lots of fun. Lily was very scared that something bad might happen there.
Lily's mom loved going to the festival because there was a little boy there. But her mom told her to be kind to everyone when it was a reason. Suddenly, Lily noticed that the boy was hiding behind a tree. Lily was not sure what to do, so she went back to hide in her room and waited. She heard her mom calling her name.
Lily went to the door and saw the little girl looking for her. Her mom had gone to the festival, and it was time to go home. Lily was so happy that she gave the little girl a big hug. She knew it was important because she helped the girl who would have a chance to arrive.
When they arrived at the park, everyone danced and clapped their hands and games. They 