In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [2]:
import re
from collections import Counter
import json

class SimpleBPE:
    """
    Simple Byte Pair Encoding tokenizer.
    Trains custom vocabulary on your text data.
    """
    
    def __init__(self):
        self.vocab = {}           # token_id -> token_string
        self.vocab_inv = {}       # token_string -> token_id
        self.merges = []          # list of merge rules
        self.special_tokens = {}
        self.max_token_len = 1    # for fast encoding
        
    def _get_stats(self, token_seqs):
        """Count frequency of adjacent pairs."""
        pairs = Counter()
        for seq in token_seqs:
            for i in range(len(seq) - 1):
                pairs[(seq[i], seq[i + 1])] += 1
        return pairs
    
    def _merge_pair(self, token_seqs, pair, new_token):
        """Merge all occurrences of pair into new_token."""
        new_seqs = []
        for seq in token_seqs:
            new_seq = []
            i = 0
            while i < len(seq):
                if i < len(seq) - 1 and seq[i] == pair[0] and seq[i + 1] == pair[1]:
                    new_seq.append(new_token)
                    i += 2
                else:
                    new_seq.append(seq[i])
                    i += 1
            new_seqs.append(new_seq)
        return new_seqs
    
    def train(self, text, vocab_size=1000, verbose=True):
        """Train BPE on the given text."""
        if verbose:
            print(f"Training BPE tokenizer on {len(text):,} characters...")
        
        # Split into words (keep whitespace attached)
        words = re.findall(r'\S+|\s+', text)
        token_seqs = [[c for c in word] for word in words]
        
        # Initial vocabulary = unique characters
        chars = sorted(set(text))
        self.vocab = {i: c for i, c in enumerate(chars)}
        self.vocab_inv = {c: i for i, c in enumerate(chars)}
        next_id = len(chars)
        
        if verbose:
            print(f"Base vocabulary: {len(chars)} characters")
        
        # Convert to token IDs
        token_seqs = [[self.vocab_inv[c] for c in seq] for seq in token_seqs]
        
        # Iteratively merge most frequent pairs
        self.merges = []
        num_merges = vocab_size - len(chars)
        
        for i in range(num_merges):
            stats = self._get_stats(token_seqs)
            if not stats:
                break
                
            best_pair = max(stats, key=stats.get)
            new_token_str = self.vocab[best_pair[0]] + self.vocab[best_pair[1]]
            
            self.vocab[next_id] = new_token_str
            self.vocab_inv[new_token_str] = next_id
            self.merges.append(best_pair)
            
            token_seqs = self._merge_pair(token_seqs, best_pair, next_id)
            
            if verbose and (i + 1) % 200 == 0:
                print(f"  {i+1}/{num_merges} merges completed...")
            
            next_id += 1
        
        # Set max token length for fast encoding
        self.max_token_len = max(len(t) for t in self.vocab.values())
        
        if verbose:
            print(f"Training complete! Vocabulary size: {len(self.vocab)}")
            
        return self
    
    def add_special_token(self, token_str):
        """Add a special token like <MASK>."""
        token_id = len(self.vocab)
        self.vocab[token_id] = token_str
        self.vocab_inv[token_str] = token_id
        self.special_tokens[token_str] = token_id
        self.max_token_len = max(self.max_token_len, len(token_str))
        return token_id
    
    def encode(self, text):
        """
        Fast encoding using greedy longest-match.
        O(max_token_len * text_len) instead of O(num_merges * text_len)
        """
        tokens = []
        i = 0
        n = len(text)
        
        while i < n:
            # Try longest match first, then shorter
            for length in range(min(self.max_token_len, n - i), 0, -1):
                substr = text[i:i + length]
                if substr in self.vocab_inv:
                    tokens.append(self.vocab_inv[substr])
                    i += length
                    break
            else:
                raise ValueError(f"Unknown character at position {i}: {repr(text[i])}")
        
        return tokens
    
    def decode(self, token_ids):
        """Decode token IDs back to text."""
        return ''.join(self.vocab[i] for i in token_ids)
    
    def save(self, path):
        """Save tokenizer to file."""
        with open(path, 'w') as f:
            json.dump({
                'vocab': {str(k): v for k, v in self.vocab.items()},
                'merges': self.merges,
                'special_tokens': self.special_tokens,
                'max_token_len': self.max_token_len
            }, f)
        print(f"Tokenizer saved to {path}")
    
    def load(self, path):
        """Load tokenizer from file."""
        with open(path) as f:
            data = json.load(f)
        self.vocab = {int(k): v for k, v in data['vocab'].items()}
        self.vocab_inv = {v: k for k, v in self.vocab.items()}
        self.merges = [tuple(m) for m in data['merges']]
        self.special_tokens = data.get('special_tokens', {})
        self.max_token_len = data.get('max_token_len', max(len(t) for t in self.vocab.values()))
        print(f"Tokenizer loaded from {path} (vocab size: {len(self.vocab)})")
        return self

In [3]:
tokenizer = SimpleBPE().load('tokenizer_bpe.json')

# use tokenizer
def decode(ids): return tokenizer.decode(ids)

Tokenizer loaded from tokenizer_bpe.json (vocab size: 1001)


In [4]:
# Hyperparameters

block_size = 512
final_vocab_size = len(tokenizer.vocab)
batch_size = 16        # Fits in GPU memory (Physical Batch)
target_batch_size = 64 # What we want mathematically (Effective Batch)
grad_accum_steps = target_batch_size // batch_size # 64 // 16 = 4 steps
n_embd = 512
n_head = 4
n_blocks = 6
MASK_TOKEN = tokenizer.special_tokens.get('<MASK>', 'Not found')

In [5]:
# time embedding

class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2

        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t.unsqueeze(1) * emb.unsqueeze(0) * 1000 # [B, 1] * [1, half_dim] = [B, half_dim]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # [B, dim]

        return emb

In [6]:
# MHA
from torchtune.modules import RotaryPositionalEmbeddings

class MHA(nn.Module):
    def __init__(self, n_embd, n_head, dropout=0.1):
        super().__init__()
        assert n_embd % n_head == 0

        self.n_head = n_head
        self.n_embd = n_embd
        self.head_dim = n_embd // n_head

        self.c_attn = nn.Linear(n_embd, 3*n_embd)
        self.c_proj = nn.Linear(n_embd, n_embd)

        self.rope = RotaryPositionalEmbeddings(dim=n_embd // n_head)

        self.attn_dropout = nn.Dropout(dropout)
        self.residual_dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        # atten_c: [n_embd, 3*n_embd]; 
        # x: [B, T, n_embd]; 
        # attn_c(x): [B, T, n_embd]@[n_embd, 3*n_embd]=[B, T, 3*n_embd]
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # [B, T, 3*n_embd] ==> 3 * [B, T, n_embd]

        q = q.view(B, T, self.n_head, self.head_dim) # [B, T, n_embd] = [B, T, n_head*head_dim] ==> [B, T, n_head, head_dim]
        # apply RoPE before transpose
        q = self.rope(q)
        q = q.transpose(1, 2) # [B, n_head, T, head_dim]
        k = k.view(B, T, self.n_head, self.head_dim)
        k = self.rope(k)
        k = k.transpose(1, 2) # [B, n_head, T, head_dim]
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        # k.transpose(-2, -1).shape: [B, n_head, head_dim, T]
        # (q @ k.transpose(-2, -1)).shape: [B, n_head, T, head_dim]@[B, n_head, head_dim, T] = [B, n_head, T, T]
        attn = (q @ k.transpose(-2, -1)*(1.0 / math.sqrt(self.head_dim))) # [B, n_head, T, T]
        
        # No causal mask

        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)

        out = attn @ v # [B, n_head, T, T]@[B, n_head, T, head_dim] = [B, n_head, T, head_dim]
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        # drop out on the residual stream
        out = self.residual_dropout(self.c_proj(out)) # [B, T, C]@[C, C]=[B, T, C]

        return out


In [7]:
class SwiGLU(nn.Module):
    """
    SwiGLU activation function.
    
    This effectively implements:
    SwiGLU(x) = (xW + b) * SiLU(xV + c)
    
    Where the input is split into two parts: one for the 'value' path
    and one for the 'gate' path.
    """
    def forward(self, x):
        # Split the input tensor into two halves along the last dimension
        x, gate = x.chunk(2, dim=-1)
        # Apply SiLU (Swish) to the gate and multiply with the value
        return x * F.silu(gate)

In [8]:
# FFN

class FFN(nn.Module):
    def __init__(self, n_embd, dropout=0.1):
        super().__init__()
        self.c_fc = nn.Linear(n_embd, 8*n_embd)
        self.swiglu = SwiGLU()
        self.c_proj = nn.Linear(4*n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.c_fc(x) # [B, T, C]@[C, 4*C]=[B, T, 4*C]
        x = self.swiglu(x)
        x = self.c_proj(x) # [B, T, 4*C]@[4*C, C]=[B, T, C]
        x = self.dropout(x) # [B, T, C]

        return x

In [9]:
# Block

class Block(nn.Module):
    def __init__(self, n_embd, n_head, dropout=0.1, use_time=True):
        super().__init__()
        self.use_time = use_time
        self.rms_norm_1 = nn.RMSNorm(n_embd)
        self.rms_norm_2 = nn.RMSNorm(n_embd)
        self.attn = MHA(n_embd, n_head, dropout)
        self.ffn = FFN(n_embd, dropout)

        # RMSNorm is designed to be "shift-invariant" (it centers data around 0), in Diffusion models, injecting the shift (beta) back in after normalization is a powerful way to tell the network about the noise level
        if use_time:
            self.time_ffn = nn.Sequential(
                nn.Linear(n_embd, 2 * n_embd),
                SwiGLU(), # SwiGLU will half the feature dimension
                nn.Linear(n_embd, 4*n_embd)
            )
    
    def forward(self, x, time_emb=None):
        if self.use_time and time_emb is not None:
            time_params = self.time_ffn(time_emb) # [B, e_embd]@[n_embd, 4*n_embd]=[B, 4*n_embd]
            shift1, scale1, shift2, scale2 = time_params.chunk(4, dim=-1) # [B, n_embd]

            h = self.rms_norm_1(x) * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1) # [B, T, C]*[B, 1, C]+[B, 1, C]=[B, T, C]
            x = x + self.attn(h)
            h = self.rms_norm_2(x) * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1) # [B, T, C]*[B, 1, C]+[B, 1, C]=[B, T, C]
            x = x + self.ffn(h) # [B, T, C]
        else:
            x = x + self.attn(self.rms_norm_1(x))
            x = x + self.ffn(self.rms_norm_2(x))

        return x


In [10]:
# Full MDLM

class MDLM(nn.Module):
    def __init__(
            self,
            vocab_size,
            n_embd,
            n_head,
            n_block,
            block_size,
            dropout=0.1,
            use_time = True
    ):
        super().__init__()

        self.block_size = block_size
        self.use_time = use_time
        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        if use_time:
            self.time_emb = SinusoidalTimeEmbedding(n_embd)
            self.time_proj = nn.Sequential(
                nn.Linear(n_embd, 2*n_embd),
                SwiGLU(),
                nn.Linear(n_embd, n_embd)
            )
        self.blocks = nn.ModuleList([
            Block(n_embd, n_head, dropout, use_time) for _ in range(n_block)
        ])

        self.rms_norm_final = nn.RMSNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

        # tie input and output embedding weights
        self.lm_head.weight = self.tok_emb.weight

        self.dropout = nn.Dropout(dropout)
        self.apply(self._init_weights)

        n_params = sum(p.numel() for p in self.parameters())
        print(f'Model has {n_params/1e6:.2f}M parameters.')

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.RMSNorm):
            torch.nn.init.ones_(module.weight)

    def forward(self, x, t=None):
        B, T = x.shape

        tok_emb = self.tok_emb(x) # [B, T, n_embd]
        h = self.dropout(tok_emb)

        if self.use_time and t is not None:
            t_emb = self.time_emb(t) # [B, n_embd]
            t_emb = self.time_proj(t_emb) # [B, n_embd]
        else:
            t_emb = None
        
        for block in self.blocks:
            h = block(h, t_emb)

        h = self.rms_norm_final(h)
        logits = self.lm_head(h) # [B, T, V]

        return logits

In [15]:
# inference

import os

# 1. Define the path (must match your training loop)
CKPT_PATH = './ckpt/latest.pt'

# 2. Re-initialize the model architecture 
model = MDLM(
    vocab_size=final_vocab_size,
    n_embd=n_embd,
    n_head=n_head,
    n_block=n_blocks,
    block_size=block_size,
).to(device)

# 3. Load the checkpoint
if os.path.exists(CKPT_PATH):
    print(f"Loading model weights from {CKPT_PATH}...")
    # map_location ensures that if you trained on GPU but load on CPU, it works
    checkpoint = torch.load(CKPT_PATH, map_location=device, weights_only=False)
    
    # We only need the 'model_state_dict' for inference. 
    # The optimizer states are only needed if we plan to resume training.
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Model loaded successfully (from Epoch {checkpoint['epoch']+1})")
else:
    print(f"Warning: No checkpoint found at {CKPT_PATH}. Using random initialization.")

@torch.no_grad()
def sample(model, seq_len, num_steps=100, temperature=0.5):
    """
    Generate text using the reverse diffusion process.
    
    Args:
        model: Trained MDLM
        seq_len: Length of sequence to generate
        num_steps: Number of denoising steps
        temperature: Sampling temperature
    
    Returns:
        Generated token sequence
    """
    model.eval()
    
    # Start with all masks
    x = torch.full((1, seq_len), MASK_TOKEN, dtype=torch.long, device=device)
    
    # Linearly spaced time steps from 1 to 0
    timesteps = torch.linspace(1, 0, num_steps + 1, device=device)
    
    for i in range(num_steps):
        t_current = timesteps[i]
        t_next = timesteps[i + 1]
        
        # Get model predictions
        t_batch = torch.tensor([t_current], device=device)
        logits = model(x, t_batch)  # [1, seq_len, vocab_size]
        # Model is predicting all positions at each forward pass
        
        # Don't predict [MASK] token during sampling
        logits[:, :, MASK_TOKEN] = float('-inf')
        
        # Apply temperature
        logits = logits / temperature
        
        # Convert to probabilities
        probs = F.softmax(logits, dim=-1)  # [1, seq_len, vocab_size]
        
        # Sample predictions for all positions
        # (we'll only use some based on remasking strategy)
        pred_tokens = torch.multinomial(
            probs.view(-1, probs.size(-1)), 
            num_samples=1
        ).view(1, seq_len)
        
        # Determine which positions to unmask this step (Ancestral Sampling)
        # Instead of forcing a specific count (top-k), we roll a probability die for each token
        
        # 1. Calculate the probability of unmasking at this step
        if t_current > 0:
            p_unmask = (t_current - t_next) / t_current
        else:
            p_unmask = 1.0  # Force finish at the very end
            
        # 2. Roll the dice for every position
        # Generate random values [0, 1]. If value < p_unmask, we reveal the token.
        random_values = torch.rand_like(probs[:, :, 0]) # Shape: [1, seq_len]
        should_unmask = random_values < p_unmask
        
        # 3. Apply updates
        # Only update if it is CURRENTLY a mask AND the dice roll said "Unmask"
        is_mask = (x == MASK_TOKEN)
        update_mask = is_mask & should_unmask
        
        x[update_mask] = pred_tokens[update_mask]
    
    # Final cleanup: unmask any remaining masks
    is_mask = (x == MASK_TOKEN)
    if is_mask.any():
        logits = model(x, torch.tensor([0.0], device=device))
        logits[:, :, MASK_TOKEN] = float('-inf')
        probs = F.softmax(logits / temperature, dim=-1)
        final_preds = probs.argmax(dim=-1)
        x[is_mask] = final_preds[is_mask]
    
    return x[0].tolist()

# Generate some samples!
print("Generated samples:")
print("=" * 60)

tokens = sample(model, seq_len=512, num_steps=100, temperature=1.0)
text = decode(tokens)
print(''.join(text))

Model has 35.97M parameters.
Loading model weights from ./ckpt/latest.pt...
Model loaded successfully (from Epoch 100)
Generated samples:
was a tin face — reliving in piece of report 
that he had been following cousin every day — the 
subject that Dudley was grudging laughter and certed 
him firmly. He was sure that Dudley’s face and sense. They were lowered his 
angry teeth, further and foggy fell slowly against the ground 
again as he pointed at Harry as Dudley already gave it 
again. 
“See, knowing.” 
“ARY!” Marge exploded Uncle Vernon, biding 
of fe again. Dudley his heart became still gasping 
more bravely. “You’ll see Aunt Petunia till any 
palciage.” 
“And they’re quite haven’t kept to buy a stop, 
coweringling in his stest. 
“You might be all looking forward into fat clothes. Marge 
think she needs Dudtilk, stabbling them up.” 
“We’ll escape they are going to reward Muggles!” ” ( 
cle Vernon drew Only looking each other over his 
shoulder), was almost pleasant. Uncle Vernon; Ve