# Week 1 — Language Modeling: GT vs Transformer (PTB)
This notebook compares a **baseline Transformer** to a **Geometric Transformer (GT-Lite)** on a small language modeling task.
We use a tiny **Penn Treebank (PTB)** word-level dataset and train for a short run to keep things fast.

**Goal:** show how GT-style local smoothing can change learning dynamics relative to a standard Transformer.


In [None]:
# Core imports
import math
import os
import random
from urllib import request

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

def set_seed(seed=0):
    random.seed(seed)
    torch.manual_seed(seed)

set_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


## 1) Load PTB (word-level)
We download a tiny PTB word-level dataset and build a vocabulary from the training split.


In [None]:
def download_ptb(save_dir='/tmp/ptb'):
    os.makedirs(save_dir, exist_ok=True)
    base = 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/'
    files = ['ptb.train.txt', 'ptb.valid.txt', 'ptb.test.txt']
    for fn in files:
        path = os.path.join(save_dir, fn)
        if not os.path.exists(path):
            request.urlretrieve(base + fn, path)
    return [os.path.join(save_dir, f) for f in files]

def load_tokens(path):
    with open(path, 'r') as f:
        lines = [line.strip().split() for line in f]
    # add <eos> at end of each line
    return [tok for line in lines for tok in (line + ['<eos>'])]

train_path, valid_path, test_path = download_ptb()
train_tokens = load_tokens(train_path)
valid_tokens = load_tokens(valid_path)
test_tokens  = load_tokens(test_path)

vocab = {w:i for i,w in enumerate(sorted(set(train_tokens)))}

def to_ids(tokens):
    return torch.tensor([vocab[w] for w in tokens], dtype=torch.long)

train_ids = to_ids(train_tokens)
valid_ids = to_ids(valid_tokens)
test_ids  = to_ids(test_tokens)

vocab_size = len(vocab)
print('Vocab size:', vocab_size)
print('Train tokens:', len(train_ids))


## 2) Batching utilities
We use the classic PTB batching scheme (contiguous segments).


In [None]:
def batchify(data, batch_size):
    n_batch = data.size(0) // batch_size
    data = data[:n_batch * batch_size]
    data = data.view(batch_size, -1).t().contiguous()  # (T, B)
    return data

def get_batch(source, i, bptt=35):
    seq_len = min(bptt, source.size(0) - 1 - i)
    data = source[i:i+seq_len]       # (seq_len, B)
    target = source[i+1:i+1+seq_len] # (seq_len, B)
    return data, target

batch_size = 20
bptt = 35
train_data = batchify(train_ids, batch_size)
valid_data = batchify(valid_ids, batch_size)
test_data  = batchify(test_ids,  batch_size)

print('Train batchified shape:', train_data.shape)  # (T, B)


## 3) Models: Transformer vs GT-Lite
GT-Lite adds a small 1D convolution in each block to encourage local smoothing.


In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0)].unsqueeze(1)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, dim_ff, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.ln1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_ff, d_model),
        )
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        # x: (T, B, D)
        attn_out, _ = self.attn(x, x, x, attn_mask=attn_mask)
        x = self.ln1(x + self.dropout(attn_out))
        ff_out = self.ff(x)
        x = self.ln2(x + self.dropout(ff_out))
        return x

class GeomTransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, dim_ff, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.ln_attn = nn.LayerNorm(d_model)
        self.conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.ln_conv = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_ff, d_model),
        )
        self.ln_ff = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        attn_out, _ = self.attn(x, x, x, attn_mask=attn_mask)
        x = self.ln_attn(x + self.dropout(attn_out))

        x_conv = self.conv(x.permute(1, 2, 0)).permute(2, 0, 1)  # (T,B,D)
        x = self.ln_conv(x + 0.2 * x_conv)

        ff_out = self.ff(x)
        x = self.ln_ff(x + self.dropout(ff_out))
        return x

class LMModel(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4, n_layers=2, dim_ff=512, gt=False):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model)
        block = GeomTransformerBlock if gt else TransformerBlock
        self.layers = nn.ModuleList([block(d_model, n_heads, dim_ff) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, vocab_size)

    def _causal_mask(self, T, device):
        mask = torch.triu(torch.full((T, T), float('-inf'), device=device), diagonal=1)
        return mask

    def forward(self, x):
        # x: (T, B)
        emb = self.tok(x)  # (T, B, D)
        h = self.pos(emb)
        mask = self._causal_mask(x.size(0), x.device)
        for layer in self.layers:
            h = layer(h, attn_mask=mask)
        h = self.ln(h)
        logits = self.out(h)
        return logits


## 4) Training + evaluation
We keep it short for classroom use. Increase `max_steps` for stronger results.


In [None]:
def evaluate(model, data_source, bptt=35):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i, bptt)
            data = data.to(device)
            targets = targets.to(device)
            logits = model(data)
            # logits: (T,B,V) -> (T*B,V)
            loss = F.cross_entropy(logits.view(-1, vocab_size), targets.reshape(-1))
            total_loss += loss.item() * targets.numel()
            total_tokens += targets.numel()
    avg_loss = total_loss / total_tokens
    ppl = math.exp(avg_loss)
    return avg_loss, ppl

def train_model(model, name, train_data, valid_data, max_steps=300, bptt=35, lr=1e-3):
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    losses = []
    val_ppl = []
    step = 0

    model.train()
    for epoch in range(1000):
        for i in range(0, train_data.size(0) - 1, bptt):
            data, targets = get_batch(train_data, i, bptt)
            data = data.to(device)
            targets = targets.to(device)

            logits = model(data)
            loss = F.cross_entropy(logits.view(-1, vocab_size), targets.reshape(-1))

            opt.zero_grad()
            loss.backward()
            opt.step()

            step += 1
            if step % 50 == 0:
                losses.append(loss.item())

            if step % 100 == 0:
                v_loss, v_ppl = evaluate(model, valid_data, bptt)
                val_ppl.append(v_ppl)
                print(f'[{name}] step {step:4d} | train_loss={loss.item():.3f} | val_ppl={v_ppl:.2f}')

            if step >= max_steps:
                return losses, val_ppl


## 5) Run GT vs Transformer
Try 300 steps for a fast demo; increase to 1000+ for stronger separation.


In [None]:
# Hyperparameters
d_model = 128
n_heads = 4
n_layers = 2
dim_ff = 512
max_steps = 300

# Baseline Transformer
tf_model = LMModel(vocab_size, d_model, n_heads, n_layers, dim_ff, gt=False)
tf_losses, tf_ppl = train_model(tf_model, 'Transformer', train_data, valid_data, max_steps=max_steps, bptt=bptt)

# Geometric Transformer (GT-Lite)
gt_model = LMModel(vocab_size, d_model, n_heads, n_layers, dim_ff, gt=True)
gt_losses, gt_ppl = train_model(gt_model, 'GT', train_data, valid_data, max_steps=max_steps, bptt=bptt)


## 6) Plot validation perplexity


In [None]:
plt.figure(figsize=(6,4))
plt.plot(tf_ppl, label='Transformer')
plt.plot(gt_ppl, label='GT-Lite')
plt.xlabel('Eval checkpoints')
plt.ylabel('Validation perplexity')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()


## 7) Quick text generation (greedy)
We sample a short continuation from each model for intuition.


In [None]:
# Build inverse vocab for decoding
inv_vocab = {i:w for w,i in vocab.items()}

def greedy_generate(model, start_tokens, max_new=25):
    model.eval()
    seq = start_tokens.clone()  # (T, B)
    with torch.no_grad():
        for _ in range(max_new):
            logits = model(seq)  # (T,B,V)
            next_token = logits[-1].argmax(dim=-1, keepdim=True)  # (B,1)
            seq = torch.cat([seq, next_token.t()], dim=0)  # (T+1,B)
    return seq

# Use a short prefix from validation data
prefix_len = 10
start = valid_data[:prefix_len].to(device)  # (T,B)

tf_seq = greedy_generate(tf_model, start, max_new=20).cpu()
gt_seq = greedy_generate(gt_model, start, max_new=20).cpu()

def decode(seq, col=0):
    toks = [inv_vocab[int(t)] for t in seq[:, col]]
    return ' '.join(toks)

print('Prefix:')
print(decode(start.cpu(), col=0))
print('\nTransformer continuation:')
print(decode(tf_seq, col=0))
print('\nGT continuation:')
print(decode(gt_seq, col=0))
