# Week 1 — Causal Leakage Check: Transformer vs GT

This notebook tests a specific question: does GT get unfair next-token information from a non-causal geometric mixing layer?

We run a synthetic stress test where true next tokens are independent random symbols.
- A strictly causal model should stay near chance.
- A model with future-token leakage can artificially achieve high accuracy.


In [None]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

def set_seed(seed=0):
    random.seed(seed)
    torch.manual_seed(seed)

set_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)


In [None]:
class TinyTransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=4, num_layers=1, max_len=128):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Embedding(max_len, d_model)
        enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=4*d_model, batch_first=True)
        self.tr = nn.TransformerEncoder(enc, num_layers=num_layers)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        B, L = x.shape
        p = torch.arange(L, device=x.device).unsqueeze(0).expand(B, L)
        h = self.emb(x) + self.pos(p)
        mask = torch.triu(torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1)
        h = self.tr(h, mask)
        return self.out(h)

class TinyGTBlock(nn.Module):
    def __init__(self, d_model=64, nhead=4, geo_causal=False):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.ffn = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.GELU(), nn.Linear(4*d_model, d_model))
        self.geo_causal = bool(geo_causal)
        self.conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=0 if geo_causal else 1, groups=d_model)
        self.n1 = nn.LayerNorm(d_model)
        self.n2 = nn.LayerNorm(d_model)
        self.ng = nn.LayerNorm(d_model)

    def forward(self, x, attn_mask=None):
        y, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
        x = self.n1(x + y)
        x = self.n2(x + self.ffn(x))
        z = x.transpose(1, 2)
        if self.geo_causal:
            z = F.pad(z, (2, 0))
        z = self.conv(z).transpose(1, 2)
        return self.ng(x + z)

class TinyGTLM(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=4, num_layers=1, max_len=128, geo_causal=False):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([TinyGTBlock(d_model=d_model, nhead=nhead, geo_causal=geo_causal) for _ in range(num_layers)])
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        B, L = x.shape
        p = torch.arange(L, device=x.device).unsqueeze(0).expand(B, L)
        h = self.emb(x) + self.pos(p)
        mask = torch.triu(torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1)
        for layer in self.layers:
            h = layer(h, attn_mask=mask)
        return self.out(h)


In [None]:
@torch.no_grad()
def future_sensitivity(model, vocab_size=50, seq_len=24, batch=64, probe_t=8, trials=20):
    model.eval()
    diffs = []
    changed = 0
    for _ in range(trials):
        x = torch.randint(0, vocab_size, (batch, seq_len), device=device)
        x2 = x.clone()
        x2[:, probe_t + 1] = torch.randint(0, vocab_size, (batch,), device=device)

        y1 = model(x)[:, probe_t, :]
        y2 = model(x2)[:, probe_t, :]
        d = (y1 - y2).abs().mean(dim=1)
        diffs.append(d.mean().item())
        changed += (d > 1e-8).float().sum().item()

    return {
        'mean_abs_logit_delta': float(sum(diffs) / len(diffs)),
        'fraction_changed': float(changed / (trials * batch)),
    }

VOCAB = 64
cfg = dict(vocab_size=VOCAB, d_model=64, nhead=4, num_layers=1, max_len=64)
m_tf = TinyTransformerLM(**cfg).to(device)
m_gt_nc = TinyGTLM(**cfg, geo_causal=False).to(device)
m_gt_c = TinyGTLM(**cfg, geo_causal=True).to(device)

sens = {
    'transformer_causal': future_sensitivity(m_tf, vocab_size=VOCAB),
    'gt_noncausal_conv': future_sensitivity(m_gt_nc, vocab_size=VOCAB),
    'gt_causal_conv': future_sensitivity(m_gt_c, vocab_size=VOCAB),
}
sens


In [None]:
def sample_batch(batch_size=64, seq_len=25, vocab_size=64):
    s = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
    x = s[:, :-1]  # observed tokens
    y = s[:, 1:]   # next tokens (independent random)
    return x, y

def run_train(model, steps=250, lr=3e-4, batch_size=64, seq_len=25, vocab_size=64):
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    ce = nn.CrossEntropyLoss()
    hist_loss, hist_acc = [], []
    model.train()
    for _ in range(steps):
        x, y = sample_batch(batch_size=batch_size, seq_len=seq_len, vocab_size=vocab_size)
        logits = model(x)
        loss = ce(logits.reshape(-1, vocab_size), y.reshape(-1))
        opt.zero_grad()
        loss.backward()
        opt.step()

        with torch.no_grad():
            pred = logits.argmax(dim=-1)
            acc = (pred == y).float().mean().item()
        hist_loss.append(loss.item())
        hist_acc.append(acc)
    return hist_loss, hist_acc

def eval_model(model, batches=40, batch_size=64, seq_len=25, vocab_size=64):
    ce = nn.CrossEntropyLoss()
    model.eval()
    losses, accs = [], []
    with torch.no_grad():
        for _ in range(batches):
            x, y = sample_batch(batch_size=batch_size, seq_len=seq_len, vocab_size=vocab_size)
            logits = model(x)
            loss = ce(logits.reshape(-1, vocab_size), y.reshape(-1))
            pred = logits.argmax(dim=-1)
            acc = (pred == y).float().mean().item()
            losses.append(loss.item())
            accs.append(acc)
    return float(sum(losses)/len(losses)), float(sum(accs)/len(accs))


In [None]:
set_seed(0)
cfg = dict(vocab_size=64, d_model=64, nhead=4, num_layers=1, max_len=64)
models = {
    'transformer_causal': TinyTransformerLM(**cfg).to(device),
    'gt_noncausal_conv': TinyGTLM(**cfg, geo_causal=False).to(device),
    'gt_causal_conv': TinyGTLM(**cfg, geo_causal=True).to(device),
}

hist = {}
for name, m in models.items():
    l, a = run_train(m, steps=250, lr=3e-4, batch_size=64, seq_len=25, vocab_size=64)
    te_loss, te_acc = eval_model(m, batches=50, batch_size=64, seq_len=25, vocab_size=64)
    hist[name] = {'loss': l, 'acc': a, 'test_loss': te_loss, 'test_acc': te_acc}

chance_acc = 1.0 / 64.0
chance_ce = math.log(64.0)
print('chance acc =', round(chance_acc, 4), 'chance CE =', round(chance_ce, 4))
for name in hist:
    print(name, 'test_loss=', round(hist[name]['test_loss'], 4), 'test_acc=', round(hist[name]['test_acc'], 4))


In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
for name, d in hist.items():
    ax[0].plot(d['loss'], label=name)
    ax[1].plot(d['acc'], label=name)
ax[0].axhline(math.log(64.0), color='k', ls='--', alpha=0.5, label='chance CE')
ax[1].axhline(1.0/64.0, color='k', ls='--', alpha=0.5, label='chance acc')
ax[0].set_title('Train CE loss')
ax[1].set_title('Train token accuracy')
for a in ax:
    a.set_xlabel('step')
    a.grid(alpha=0.3)
ax[0].legend()
ax[1].legend()
plt.tight_layout()
plt.show()


## Interpretation

If `gt_noncausal_conv` is much better than chance on this random-token task while causal models remain near chance, the gain is from future-token leakage, not better sequence modeling.

For fair LM comparisons, use strictly causal geometric mixing (`geo_causal=True`) so the GT layer respects the same information boundary as baseline Transformers.
