# Week 6 — GT‑Full on PTB (Tiny Language Model)
We compare a small **Transformer LM** against a **GT‑Full LM** that uses relation‑aware message passing on the sequence graph.
This is intentionally tiny for classroom speed.


In [None]:
import math
import os
from urllib import request
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


## 1) Load PTB (word-level)


In [None]:
def download_ptb(save_dir='/tmp/ptb'):
    os.makedirs(save_dir, exist_ok=True)
    base = 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/'
    files = ['ptb.train.txt', 'ptb.valid.txt', 'ptb.test.txt']
    for fn in files:
        path = os.path.join(save_dir, fn)
        if not os.path.exists(path):
            request.urlretrieve(base + fn, path)
    return [os.path.join(save_dir, f) for f in files]

def load_tokens(path):
    with open(path, 'r') as f:
        lines = [line.strip().split() for line in f]
    return [tok for line in lines for tok in (line + ['<eos>'])]

train_path, valid_path, test_path = download_ptb()
train_tokens = load_tokens(train_path)
valid_tokens = load_tokens(valid_path)

vocab = {w:i for i,w in enumerate(sorted(set(train_tokens)))}
def to_ids(tokens):
    return torch.tensor([vocab[w] for w in tokens], dtype=torch.long)

train_ids = to_ids(train_tokens)
valid_ids = to_ids(valid_tokens)
vocab_size = len(vocab)
print('Vocab size:', vocab_size)


## 2) Batching utilities


In [None]:
def batchify(data, batch_size):
    n_batch = data.size(0) // batch_size
    data = data[:n_batch * batch_size]
    data = data.view(batch_size, -1).t().contiguous()  # (T, B)
    return data

def get_batch(source, i, bptt=20):
    seq_len = min(bptt, source.size(0) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len]
    return data, target

batch_size = 20
bptt = 20
train_data = batchify(train_ids, batch_size)
valid_data = batchify(valid_ids, batch_size)


## 3) Models


In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe)
    def forward(self, x):
        return x + self.pe[:x.size(0)].unsqueeze(1)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, dim_ff, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.ln1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(nn.Linear(d_model, dim_ff), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim_ff, d_model))
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, attn_mask=None):
        attn_out, _ = self.attn(x, x, x, attn_mask=attn_mask)
        x = self.ln1(x + self.dropout(attn_out))
        ff_out = self.ff(x)
        x = self.ln2(x + self.dropout(ff_out))
        return x

class SimplicialMessagePassing(nn.Module):
    def __init__(self, dim, num_rel, hidden_dim=None):
        super().__init__()
        hidden_dim = hidden_dim or dim
        self.edge_mlp = nn.Sequential(nn.Linear(2*dim + num_rel, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, dim))
    def forward(self, V, edge_index, rel_ids):
        src = edge_index[0].long()
        dst = edge_index[1].long()
        src_h = V[src]
        dst_h = V[dst]
        num_rel = int(rel_ids.max().item()) + 1 if rel_ids.numel() > 0 else 0
        rel_onehot = F.one_hot(rel_ids.long(), num_classes=num_rel).float()
        edge_feat = torch.cat([src_h, dst_h, rel_onehot], dim=-1)
        msg = self.edge_mlp(edge_feat)
        out = torch.zeros_like(V)
        out.index_add_(0, dst, msg)
        return V + out

class GTFullLM(nn.Module):
    def __init__(self, vocab_size, d_model=96, n_layers=2, num_rel=3):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([SimplicialMessagePassing(d_model, num_rel) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, vocab_size)
    def forward(self, x, edge_index, rel_ids):
        emb = self.tok(x)
        h = self.pos(emb)
        T, B, D = h.shape
        h = h.reshape(T*B, D)
        for layer in self.layers:
            h = layer(h, edge_index, rel_ids)
        h = self.ln(h).reshape(T, B, D)
        return self.out(h)

class TransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=64, n_heads=2, n_layers=1, dim_ff=128):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([TransformerBlock(d_model, n_heads, dim_ff) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, vocab_size)
    def _causal_mask(self, T, device):
        mask = torch.triu(torch.full((T, T), float('-inf'), device=device), diagonal=1)
        return mask
    def forward(self, x):
        emb = self.tok(x)
        h = self.pos(emb)
        mask = self._causal_mask(x.size(0), x.device)
        for layer in self.layers:
            h = layer(h, attn_mask=mask)
        h = self.ln(h)
        return self.out(h)


## 4) Training / evaluation


In [None]:
def build_edges(T, B):
    # edges between adjacent positions (rel 0/1) and skip-2 (rel 2)
    src = []
    dst = []
    rel = []
    for b in range(B):
        offset = b * T
        for t in range(T-1):
            i = offset + t
            j = offset + t + 1
            src.append(i); dst.append(j); rel.append(0)
            src.append(j); dst.append(i); rel.append(1)
        for t in range(T-2):
            i = offset + t
            j = offset + t + 2
            src.append(i); dst.append(j); rel.append(2)
    edge_index = torch.tensor([src, dst], dtype=torch.long, device=device)
    rel_ids = torch.tensor(rel, dtype=torch.long, device=device)
    return edge_index, rel_ids

def evaluate(model, data_source, is_gt=False):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i, bptt)
            data = data.to(device)
            targets = targets.to(device)
            if is_gt:
                edge_index, rel_ids = build_edges(data.size(0), data.size(1))
                logits = model(data, edge_index, rel_ids)
            else:
                logits = model(data)
            loss = F.cross_entropy(logits.view(-1, vocab_size), targets.reshape(-1))
            total_loss += loss.item() * targets.numel()
            total_tokens += targets.numel()
    avg_loss = total_loss / total_tokens
    ppl = math.exp(avg_loss)
    return avg_loss, ppl

def train_model(model, name, train_data, valid_data, is_gt=False, max_steps=300):
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
    val_ppl = []
    step = 0
    model.train()
    for epoch in range(1000):
        for i in range(0, train_data.size(0) - 1, bptt):
            data, targets = get_batch(train_data, i, bptt)
            data = data.to(device)
            targets = targets.to(device)
            if is_gt:
                edge_index, rel_ids = build_edges(data.size(0), data.size(1))
                logits = model(data, edge_index, rel_ids)
            else:
                logits = model(data)
            loss = F.cross_entropy(logits.view(-1, vocab_size), targets.reshape(-1))
            opt.zero_grad()
            loss.backward()
            opt.step()
            step += 1
            if step % 50 == 0:
                v_loss, v_ppl = evaluate(model, valid_data, is_gt=is_gt)
                val_ppl.append(v_ppl)
                print(f'[{name}] step {step:4d} | val_ppl={v_ppl:.2f}')
            if step >= max_steps:
                return val_ppl

# Train
tf_model = TransformerLM(vocab_size)
gt_model = GTFullLM(vocab_size)

tf_ppl = train_model(tf_model, 'Transformer', train_data, valid_data, is_gt=False, max_steps=300)
gt_ppl = train_model(gt_model, 'GT‑Full', train_data, valid_data, is_gt=True, max_steps=300)

import numpy as np
plt.figure(figsize=(6,4))
plt.plot(np.minimum.accumulate(tf_ppl), label='Transformer (best-so-far)')
plt.plot(np.minimum.accumulate(gt_ppl), label='GT‑Full (best-so-far)')
plt.xlabel('Eval checkpoints')
plt.ylabel('Validation perplexity')
plt.title('GT‑Full vs Transformer (PTB, tiny)')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()
