# Week 4 â€” Standalone GT Causal vs Augmented Regimes (PTB / WikiText-2)

This notebook is **standalone** and runs directly in Colab.
No Hugging Face login/token is required.\n
It compares six variants under a shared training budget:
- `transformer_causal`
- `transformer_future_hint`
- `gt_causal`
- `gt_noncausal`
- `gt_pred_next_detach`
- `gt_pred_prev_causal_detach`

Recommended workflow:
1. Start with `DATASET='ptb'`.
2. Then try `DATASET='wikitext2'`.
3. Compare strict-causal vs augmented-context vs transition behavior.


In [None]:
#@title Install dependencies (Colab)
!python -m pip install -q --upgrade pip
!python -m pip install -q matplotlib pandas numpy


In [None]:
#@title Imports and device setup
import math
import os
import random
from collections import Counter
from urllib import request

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)


In [None]:
#@title Config
DATASET = 'ptb'  #@param ['ptb', 'wikitext2']
MAX_VOCAB = 10000 #@param {type:'integer'}
CONTEXT = 128 #@param {type:'integer'}
D_MODEL = 64 #@param {type:'integer'}
N_LAYERS = 2 #@param {type:'integer'}
N_HEADS = 4 #@param {type:'integer'}
BATCH_SIZE = 24 #@param {type:'integer'}
NUM_ITERS = 5000 #@param {type:'integer'}
EVAL_EVERY = 500 #@param {type:'integer'}
LR = 3e-4 #@param {type:'number'}
WEIGHT_DECAY = 1e-5 #@param {type:'number'}
HINT_SCALE = 0.2 #@param {type:'number'}
PRED_TEMP = 1.0 #@param {type:'number'}

print(dict(DATASET=DATASET, CONTEXT=CONTEXT, D_MODEL=D_MODEL, N_LAYERS=N_LAYERS, NUM_ITERS=NUM_ITERS))

# Optional preset for stronger transition effects on PTB
PHASE_TRANSITION_PRESET = True  #@param {type:'boolean'}
if PHASE_TRANSITION_PRESET and DATASET == 'ptb':
    CONTEXT = max(CONTEXT, 128)
    NUM_ITERS = max(NUM_ITERS, 5000)
    EVAL_EVERY = max(EVAL_EVERY, 500)



**Note on phase-transition behavior**\nFor `gt_pred_next_detach`, transition-like behavior usually appears later than 1.2k steps.\nRecommended PTB setting: `CONTEXT>=128`, `NUM_ITERS>=5000`, and optionally 2-3 seeds for stability.\n

In [None]:
#@title Load PTB/WikiText-2 from direct URLs (no HF token needed)
print('No Hugging Face datasets dependency: using direct URL text files.')
def download_ptb(save_dir='/tmp/ptb'):
    os.makedirs(save_dir, exist_ok=True)
    base = 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/'
    files = ['ptb.train.txt', 'ptb.valid.txt', 'ptb.test.txt']
    out = []
    for fn in files:
        path = os.path.join(save_dir, fn)
        if not os.path.exists(path):
            request.urlretrieve(base + fn, path)
        out.append(path)
    return out

def download_wikitext2(save_dir='/tmp/wikitext2'):
    os.makedirs(save_dir, exist_ok=True)
    base = 'https://raw.githubusercontent.com/pytorch/examples/master/word_language_model/data/wikitext-2/'
    files = ['train.txt', 'valid.txt', 'test.txt']
    out = []
    for fn in files:
        path = os.path.join(save_dir, fn)
        if not os.path.exists(path):
            request.urlretrieve(base + fn, path)
        out.append(path)
    return out

def load_tokens(path):
    with open(path, 'r', encoding='utf-8') as f:
        lines = [line.strip().split() for line in f]
    return [tok for line in lines for tok in (line + ['<eos>'])]

def get_token_splits(name):
    if name == 'ptb':
        tr, va, te = download_ptb()
    elif name == 'wikitext2':
        tr, va, te = download_wikitext2()
    else:
        raise ValueError(name)
    return load_tokens(tr), load_tokens(va), load_tokens(te)

def build_vocab(train_tokens, max_vocab=10000):
    cnt = Counter(train_tokens)
    vocab = ['<pad>', '<unk>'] + [w for w,_ in cnt.most_common(max_vocab-2)]
    stoi = {w:i for i,w in enumerate(vocab)}
    return vocab, stoi

def to_ids(tokens, stoi):
    unk = stoi['<unk>']
    return np.array([stoi.get(w, unk) for w in tokens], dtype=np.int64)

train_tokens, valid_tokens, test_tokens = get_token_splits(DATASET)
vocab, stoi = build_vocab(train_tokens, MAX_VOCAB)
train_np = to_ids(train_tokens, stoi)
valid_np = to_ids(valid_tokens, stoi)
test_np = to_ids(test_tokens, stoi)
vocab_size = len(vocab)
print('vocab_size:', vocab_size, 'train_tokens:', len(train_np), 'valid_tokens:', len(valid_np), 'test_tokens:', len(test_np))


In [None]:
#@title Models (Transformer + GT-Lite variants)
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_heads, max_len=2048, future_hint=False, hint_scale=0.2):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Embedding(max_len, d_model)
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=4*d_model, batch_first=True)
            for _ in range(n_layers)
        ])
        self.out = nn.Linear(d_model, vocab_size)
        self.future_hint = future_hint
        self.hint_scale = float(hint_scale)

    def forward(self, x):
        B, L = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0).expand(B, L)
        h = self.emb(x) + self.pos(pos)
        mask = torch.triu(torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1)

        for blk in self.blocks:
            h = blk(h, src_mask=mask)

        if self.future_hint and L > 1:
            # Explicit non-causal local hint control
            h[:, :-1, :] = h[:, :-1, :] + self.hint_scale * h[:, 1:, :]

        return self.out(h)

class GeometricTransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, geo_causal=False, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ffn = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.GELU(), nn.Linear(4*d_model, d_model))
        self.geo_causal = bool(geo_causal)
        self.geo_conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=0 if geo_causal else 1, groups=d_model)
        self.n1 = nn.LayerNorm(d_model)
        self.n2 = nn.LayerNorm(d_model)
        self.ng = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, h, mask, geo_in=None):
        a,_ = self.attn(h,h,h,attn_mask=mask,need_weights=False)
        h = self.n1(h + self.drop(a))
        h = self.n2(h + self.drop(self.ffn(h)))

        gbase = h if geo_in is None else geo_in
        g = gbase.transpose(1,2)
        if self.geo_causal:
            g = F.pad(g, (2,0))
        g = self.geo_conv(g).transpose(1,2)
        h = self.ng(h + self.drop(g))
        return h

class GTLiteLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_heads, max_len=2048, geo_causal=False, geo_mix_source='hidden', pred_temp=1.0):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([GeometricTransformerBlock(d_model, n_heads, geo_causal=geo_causal) for _ in range(n_layers)])
        self.out = nn.Linear(d_model, vocab_size)
        self.geo_mix_source = geo_mix_source
        self.pred_temp = float(pred_temp)

    def _pred_emb(self, h, detach=True):
        logits = self.out(h) / max(1e-6, self.pred_temp)
        probs = torch.softmax(logits, dim=-1)
        pe = probs @ self.emb.weight
        return pe.detach() if detach else pe

    def forward(self, x):
        B, L = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0).expand(B, L)
        h = self.emb(x) + self.pos(pos)
        mask = torch.triu(torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1)

        for layer in self.layers:
            geo_in = None
            if self.geo_mix_source == 'pred_next_detach':
                geo_in = self._pred_emb(h, detach=True)
            elif self.geo_mix_source == 'pred_prev_causal_detach':
                pe = self._pred_emb(h, detach=True)
                geo_in = torch.zeros_like(pe)
                geo_in[:, 1:, :] = pe[:, :-1, :]
            h = layer(h, mask, geo_in=geo_in)

        return self.out(h)


In [None]:
#@title Training + evaluation utilities
def sample_windows(arr, batch_size, context):
    n = len(arr) - context - 1
    ids = np.random.randint(0, n, size=batch_size)
    x = np.stack([arr[i:i+context] for i in ids])
    y = np.stack([arr[i+1:i+1+context] for i in ids])
    return torch.from_numpy(x).long(), torch.from_numpy(y).long()

@torch.no_grad()
def eval_model(model, arr, batch_size=32, context=64, num_batches=30):
    model.eval()
    ce = nn.CrossEntropyLoss()
    total_loss = 0.0
    total_tok = 0
    for _ in range(num_batches):
        x, y = sample_windows(arr, batch_size, context)
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = ce(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
        total_loss += float(loss.item()) * y.numel()
        total_tok += y.numel()
    avg = total_loss / max(1, total_tok)
    return avg, math.exp(avg)

def run_variant(name, model):
    ce = nn.CrossEntropyLoss()
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    hist = []
    model.train()
    for it in range(1, NUM_ITERS + 1):
        x, y = sample_windows(train_np, BATCH_SIZE, CONTEXT)
        x, y = x.to(device), y.to(device)

        opt.zero_grad()
        logits = model(x)
        loss = ce(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        if it % EVAL_EVERY == 0 or it == NUM_ITERS:
            vloss, vppl = eval_model(model, valid_np, batch_size=BATCH_SIZE, context=CONTEXT, num_batches=25)
            hist.append({'iter': it, 'val_loss': vloss, 'val_ppl': vppl})
            print(f'[{name}] iter={it} val_loss={vloss:.3f} val_ppl={vppl:.3f}')

    tloss, tppl = eval_model(model, test_np, batch_size=BATCH_SIZE, context=CONTEXT, num_batches=40)
    return hist, tloss, tppl


In [None]:
#@title Run six-way causal-ablation comparison
set_seed(0)

variants = {
    'transformer_causal': lambda: TransformerLM(vocab_size, D_MODEL, N_LAYERS, N_HEADS, max_len=CONTEXT, future_hint=False, hint_scale=HINT_SCALE),
    'transformer_future_hint': lambda: TransformerLM(vocab_size, D_MODEL, N_LAYERS, N_HEADS, max_len=CONTEXT, future_hint=True, hint_scale=HINT_SCALE),
    'gt_noncausal': lambda: GTLiteLM(vocab_size, D_MODEL, N_LAYERS, N_HEADS, max_len=CONTEXT, geo_causal=False, geo_mix_source='hidden', pred_temp=PRED_TEMP),
    'gt_causal': lambda: GTLiteLM(vocab_size, D_MODEL, N_LAYERS, N_HEADS, max_len=CONTEXT, geo_causal=True, geo_mix_source='hidden', pred_temp=PRED_TEMP),
    'gt_pred_next_detach': lambda: GTLiteLM(vocab_size, D_MODEL, N_LAYERS, N_HEADS, max_len=CONTEXT, geo_causal=False, geo_mix_source='pred_next_detach', pred_temp=PRED_TEMP),
    'gt_pred_prev_causal_detach': lambda: GTLiteLM(vocab_size, D_MODEL, N_LAYERS, N_HEADS, max_len=CONTEXT, geo_causal=True, geo_mix_source='pred_prev_causal_detach', pred_temp=PRED_TEMP),
}

results = {}
rows = []
for name, ctor in variants.items():
    print(f'\n=== {name} ===')
    model = ctor().to(device)
    hist, test_loss, test_ppl = run_variant(name, model)
    results[name] = {'history': hist, 'test_loss': test_loss, 'test_ppl': test_ppl}
    last = hist[-1] if hist else {'val_loss': float('nan'), 'val_ppl': float('nan')}
    rows.append({
        'model': name,
        'last_val_loss': last['val_loss'],
        'last_val_ppl': last['val_ppl'],
        'test_loss': test_loss,
        'test_ppl': test_ppl,
    })

df = pd.DataFrame(rows)
display(df.sort_values('test_ppl'))


In [None]:
#@title Plot validation loss and perplexity curves
plt.figure(figsize=(9,4))
for name, obj in results.items():
    h = obj['history']
    if not h:
        continue
    xs = [r['iter'] for r in h]
    ys = [r['val_loss'] for r in h]
    plt.plot(xs, ys, 'o--', label=name)
plt.title(f'Val loss ({DATASET})')
plt.xlabel('Iteration')
plt.ylabel('Val loss')
plt.grid(alpha=0.3)
plt.legend(bbox_to_anchor=(1.02,1), loc='upper left')
plt.tight_layout()
plt.show()

plt.figure(figsize=(9,4))
for name, obj in results.items():
    h = obj['history']
    if not h:
        continue
    xs = [r['iter'] for r in h]
    ys = [r['val_ppl'] for r in h]
    plt.plot(xs, ys, 'o--', label=name)
plt.title(f'Val perplexity ({DATASET})')
plt.xlabel('Iteration')
plt.ylabel('Val ppl')
plt.grid(alpha=0.3)
plt.legend(bbox_to_anchor=(1.02,1), loc='upper left')
plt.tight_layout()
plt.show()


In [None]:
#@title Auto-interpretation summary
def regime(m):
    if m in ['transformer_causal', 'gt_causal', 'gt_pred_prev_causal_detach']:
        return 'C'
    if m in ['transformer_future_hint', 'gt_noncausal']:
        return 'A'
    if m == 'gt_pred_next_detach':
        return 'Transition'
    return '?'

df['regime'] = df['model'].map(regime)
display(df[['model','regime','last_val_ppl','test_ppl']].sort_values('test_ppl'))

m = {r['model']: float(r['test_ppl']) for _, r in df.iterrows()}
causal_best = min(m['transformer_causal'], m['gt_causal'], m['gt_pred_prev_causal_detach'])
causal_worst = max(m['transformer_causal'], m['gt_causal'], m['gt_pred_prev_causal_detach'])
aug_best = min(m['transformer_future_hint'], m['gt_noncausal'])
aug_worst = max(m['transformer_future_hint'], m['gt_noncausal'])
pred_next = m['gt_pred_next_detach']
gain = (causal_best - pred_next) / (causal_best - aug_best + 1e-12)

print(f'Dataset: {DATASET}')
print(f'Causal cluster test PPL range: [{causal_best:.3f}, {causal_worst:.3f}]')
print(f'Augmented cluster test PPL range: [{aug_best:.3f}, {aug_worst:.3f}]')
print(f'gt_pred_next_detach test PPL: {pred_next:.3f}')
print(f'transition_gain: {gain:.3f}')

if pred_next <= aug_worst:
    print('Interpretation: transition variant is in/near augmented regime for this run.')
elif pred_next < causal_best:
    print('Interpretation: transition variant is intermediate (better than strict-causal, not yet augmented).')
else:
    print('Interpretation: transition variant remains strict-causal-like for this run.')


## Interpreting the Result (Student Notes)
- In this run, `gt_pred_next_detach` starts near the strict-causal models, then drops sharply after ~2k-2.5k steps.
- This is the **phase-transition pattern**: prediction-based hints become useful enough to move performance toward the augmented regime.
- `gt_noncausal` remains the upper bound here because it has the strongest augmented-context pathway.
- `gt_pred_prev_causal_detach` staying near causal baselines is an important control: strict causal self-conditioning alone does not create the same jump.
- Practical lesson: the transition is sensitive to training horizon and context length (typically clearer with larger `CONTEXT` and longer runs).

