# Geometric Transformer + Diagrammatic Backprop (Sudoku 4x4)
This notebook is a gentle, end-to-end example for the **Category Theory for AGI** course.
We solve tiny 4x4 Sudoku puzzles with two models:
- **Baseline Transformer**
- **Geometric Transformer (GT-Lite)** with an optional **diagrammatic backprop** penalty

The DB signal is implemented as a **triangle consistency** loss: for each triangle in the Sudoku constraint graph, we encourage the three embeddings to agree with their mean.
This is a simple, local proxy for “diagrammatic curvature” without introducing heavier topology machinery.


In [None]:
# Core imports
import math
import random
from dataclasses import dataclass
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')

def set_seed(seed=0):
    random.seed(seed)
    torch.manual_seed(seed)

set_seed(0)
device = get_device()
print('Device:', device)


## 1) Tiny 4x4 Sudoku dataset
We generate many valid 4x4 Sudoku solutions by permuting a base solution, then mask some entries.
Digits are encoded as `0..3`; blanks are `-1`.


In [None]:
@dataclass
class SudokuInstance:
    puzzle: torch.Tensor   # (16,) digits 0..3, or -1 for blank
    solution: torch.Tensor # (16,) digits 0..3

def base_solution_matrix():
    return torch.tensor([
        [0, 1, 2, 3],
        [2, 3, 0, 1],
        [1, 0, 3, 2],
        [3, 2, 1, 0],
    ], dtype=torch.long)

def random_sudoku_4x4_solution() -> torch.Tensor:
    M = base_solution_matrix()
    # Permute digits
    digit_perm = torch.randperm(4)
    M = digit_perm[M]

    # Permute rows within bands and swap bands
    band_rows = [[0, 1], [2, 3]]
    permuted_rows = []
    for band in band_rows:
        order = band.copy()
        random.shuffle(order)
        permuted_rows.extend(order)
    M = M[permuted_rows, :]
    if random.random() < 0.5:
        M = torch.cat([M[2:4, :], M[0:2, :]], dim=0)

    # Permute cols within bands and swap bands
    band_cols = [[0, 1], [2, 3]]
    permuted_cols = []
    for band in band_cols:
        order = band.copy()
        random.shuffle(order)
        permuted_cols.extend(order)
    M = M[:, permuted_cols]
    if random.random() < 0.5:
        M = torch.cat([M[:, 2:4], M[:, 0:2]], dim=1)

    return M.reshape(-1)  # (16,)

def mask_puzzle(solution: List[int], num_givens: int) -> List[int]:
    idxs = list(range(16))
    random.shuffle(idxs)
    given_idxs = set(idxs[:num_givens])
    puzzle = []
    for i, v in enumerate(solution):
        puzzle.append(v if i in given_idxs else -1)
    return puzzle

def make_sudoku_dataset(num_samples: int, num_givens: int) -> List[SudokuInstance]:
    ds = []
    for _ in range(num_samples):
        sol_vec = random_sudoku_4x4_solution()
        puzzle_vec = mask_puzzle(sol_vec.tolist(), num_givens)
        ds.append(SudokuInstance(
            puzzle=torch.tensor(puzzle_vec, dtype=torch.long),
            solution=sol_vec.clone(),
        ))
    return ds


## 2) Geometric Transformer blocks
We use a lightweight GT block: self-attention + a local convolution path.
The baseline Transformer omits the convolutional path.


In [None]:
class GeomTransLiteBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, conv_kernel: int = 3, dropout: float = 0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.attn_norm = nn.LayerNorm(d_model)

        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        )
        self.mlp_norm = nn.LayerNorm(d_model)

        self.conv = nn.Conv1d(
            d_model, d_model, kernel_size=conv_kernel, padding=conv_kernel // 2
        )
        self.conv_norm = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.attn_norm(x + self.dropout(attn_out))

        z = self.conv(x.transpose(1, 2)).transpose(1, 2)
        x = self.conv_norm(x + self.dropout(z))

        mlp_out = self.mlp(x)
        x = self.mlp_norm(x + self.dropout(mlp_out))
        return x

class GTReasoner(nn.Module):
    def __init__(self, d_model: int, n_heads: int, num_layers: int):
        super().__init__()
        self.layers = nn.ModuleList([
            GeomTransLiteBlock(d_model, n_heads) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

class PlainTransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.attn_norm = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        )
        self.mlp_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.attn_norm(x + self.dropout(attn_out))
        mlp_out = self.mlp(x)
        x = self.mlp_norm(x + self.dropout(mlp_out))
        return x

class PlainReasoner(nn.Module):
    def __init__(self, d_model: int, n_heads: int, num_layers: int):
        super().__init__()
        self.layers = nn.ModuleList([
            PlainTransformerBlock(d_model, n_heads) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

class SudokuGT(nn.Module):
    def __init__(self, d_model=64, n_heads=4, num_layers=2):
        super().__init__()
        self.d_model = d_model
        self.in_proj = nn.Linear(4 + 1, d_model)
        self.reasoner = GTReasoner(d_model, n_heads, num_layers)
        self.out_head = nn.Linear(d_model, 4)

    def forward(self, puzzle: torch.Tensor, return_embeddings: bool = False):
        if puzzle.ndim == 1:
            puzzle = puzzle.unsqueeze(0)
            squeeze_back = True
        else:
            squeeze_back = False

        digits = puzzle.clone()
        digits[digits == -1] = 0
        digit_oh = F.one_hot(digits, num_classes=4).float()
        given = (puzzle != -1).float().unsqueeze(-1)
        x_feats = torch.cat([digit_oh, given], dim=-1)

        h0 = self.in_proj(x_feats)
        hT = self.reasoner(h0)
        logits = self.out_head(hT)

        if squeeze_back:
            logits = logits.squeeze(0)
            hT = hT.squeeze(0)

        if return_embeddings:
            return logits, hT
        return logits

class SudokuTransformer(nn.Module):
    def __init__(self, d_model=64, n_heads=4, num_layers=2):
        super().__init__()
        self.d_model = d_model
        self.in_proj = nn.Linear(4 + 1, d_model)
        self.reasoner = PlainReasoner(d_model, n_heads, num_layers)
        self.out_head = nn.Linear(d_model, 4)

    def forward(self, puzzle: torch.Tensor):
        if puzzle.ndim == 1:
            puzzle = puzzle.unsqueeze(0)
            squeeze_back = True
        else:
            squeeze_back = False

        digits = puzzle.clone()
        digits[digits == -1] = 0
        digit_oh = F.one_hot(digits, num_classes=4).float()
        given = (puzzle != -1).float().unsqueeze(-1)
        x_feats = torch.cat([digit_oh, given], dim=-1)

        h0 = self.in_proj(x_feats)
        hT = self.reasoner(h0)
        logits = self.out_head(hT)

        if squeeze_back:
            logits = logits.squeeze(0)
        return logits


## 3) Diagrammatic backprop (triangle consistency)
We build triangles from Sudoku constraints (row, column, block).
Each triangle’s embeddings are encouraged to agree with their mean.


In [None]:
import itertools

def build_sudoku_triangles():
    tris = []
    def cell_id(r, c):
        return r * 4 + c

    # Row triangles
    for r in range(4):
        cells = [cell_id(r, c) for c in range(4)]
        for i, j, k in itertools.combinations(cells, 3):
            tris.append((i, j, k))

    # Column triangles
    for c in range(4):
        cells = [cell_id(r, c) for r in range(4)]
        for i, j, k in itertools.combinations(cells, 3):
            tris.append((i, j, k))

    # Block triangles (2x2)
    blocks = [
        [(0, 0), (0, 1), (1, 0), (1, 1)],
        [(0, 2), (0, 3), (1, 2), (1, 3)],
        [(2, 0), (2, 1), (3, 0), (3, 1)],
        [(2, 2), (2, 3), (3, 2), (3, 3)],
    ]
    for blk in blocks:
        cells = [cell_id(r, c) for (r, c) in blk]
        for i, j, k in itertools.combinations(cells, 3):
            tris.append((i, j, k))

    return tris

def triangle_consistency(hT: torch.Tensor, triangles, reduce: str = 'mean'):
    if hT.ndim == 2:
        hT = hT.unsqueeze(0)
    B, T, D = hT.shape
    total = 0.0
    for (i, j, k) in triangles:
        vi = hT[:, i, :]
        vj = hT[:, j, :]
        vk = hT[:, k, :]
        mean = (vi + vj + vk) / 3.0
        tri_loss = (
            (vi - mean).pow(2).sum(-1) +
            (vj - mean).pow(2).sum(-1) +
            (vk - mean).pow(2).sum(-1)
        )
        total += tri_loss.mean() if reduce == 'mean' else tri_loss.sum()
    return total / max(1, len(triangles))

triangles = build_sudoku_triangles()
print('Num triangles:', len(triangles))


## 4) Training loop
We train the baseline and the GT. The DB penalty is optional.
You can toggle `lambda_db` from `0.0` to a small value like `0.05`.


In [None]:
def train_sudoku_model(
    model,
    name,
    train_ds,
    val_ds,
    batch_size=32,
    num_epochs=50,
    lr=1e-3,
    triangles=None,
    lambda_db: float = 0.0,
):
    model = model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    def iterate_batches(dataset, batch_size):
        idxs = list(range(len(dataset)))
        while True:
            random.shuffle(idxs)
            for i in range(0, len(idxs), batch_size):
                batch_idx = idxs[i:i + batch_size]
                puzzles = torch.stack([dataset[j].puzzle for j in batch_idx], dim=0)
                sols = torch.stack([dataset[j].solution for j in batch_idx], dim=0)
                yield puzzles.to(device), sols.to(device)

    train_it = iterate_batches(train_ds, batch_size)
    uses_db = triangles is not None and lambda_db > 0.0

    logs = {
        'epoch': [],
        'train_loss': [],
        'val_cell_acc': [],
        'val_puzzle_acc': [],
    }

    for epoch in range(1, num_epochs + 1):
        model.train()
        running_loss = 0.0
        steps = 0

        for _ in range(len(train_ds) // batch_size):
            puzzles, sols = next(train_it)
            if uses_db:
                logits, hT = model(puzzles, return_embeddings=True)
            else:
                logits = model(puzzles)
                hT = None

            ce_loss = F.cross_entropy(logits.view(-1, 4), sols.view(-1))
            loss = ce_loss
            if uses_db and hT is not None:
                curv = triangle_consistency(hT, triangles)
                loss = ce_loss + lambda_db * curv

            opt.zero_grad()
            loss.backward()
            opt.step()

            running_loss += loss.item()
            steps += 1

        avg_loss = running_loss / max(1, steps)

        # Validation
        model.eval()
        correct_cells = 0
        total_cells = 0
        full_correct = 0
        total_puzzles = 0

        with torch.no_grad():
            for inst in val_ds:
                puzzle = inst.puzzle.to(device)
                sol = inst.solution.to(device)
                logits = model(puzzle.unsqueeze(0))
                if logits.ndim == 3:
                    logits = logits.squeeze(0)
                pred = logits.argmax(dim=-1)

                correct_cells += (pred == sol).sum().item()
                total_cells += sol.numel()
                if (pred == sol).all():
                    full_correct += 1
                total_puzzles += 1

        val_cell_acc = correct_cells / total_cells
        val_puzzle_acc = full_correct / total_puzzles if total_puzzles > 0 else 0.0

        logs['epoch'].append(epoch)
        logs['train_loss'].append(avg_loss)
        logs['val_cell_acc'].append(val_cell_acc)
        logs['val_puzzle_acc'].append(val_puzzle_acc)

        if epoch % 10 == 0 or epoch == 1:
            print(f'[{name}] Epoch {epoch:3d} | loss={avg_loss:.4f} | cell_acc={val_cell_acc:.4f} | puzzle_acc={val_puzzle_acc:.4f}')

    return model, logs


## 5) Run a small experiment
Keep epochs small for quick classroom demos. Increase for stronger results.


In [None]:
# Dataset
num_train = 600
num_val = 200
num_givens = 8
train_ds = make_sudoku_dataset(num_train, num_givens)
val_ds = make_sudoku_dataset(num_val, num_givens)

# Baseline
tf_model = SudokuTransformer(d_model=64, n_heads=4, num_layers=2)
tf_model, tf_logs = train_sudoku_model(
    tf_model, 'Transformer', train_ds, val_ds,
    num_epochs=50, batch_size=32, lr=1e-3
)

# GT + optional DB
gt_model = SudokuGT(d_model=64, n_heads=4, num_layers=2)
gt_model, gt_logs = train_sudoku_model(
    gt_model, 'GT', train_ds, val_ds,
    num_epochs=50, batch_size=32, lr=1e-3,
    triangles=triangles, lambda_db=0.0  # try 0.05
)


## 6) Plot validation accuracy


In [None]:
plt.figure(figsize=(6, 4))
plt.plot(tf_logs['epoch'], tf_logs['val_cell_acc'], label='Transformer')
plt.plot(gt_logs['epoch'], gt_logs['val_cell_acc'], label='GT (+DB optional)')
plt.xlabel('Epoch')
plt.ylabel('Validation cell accuracy')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()


## 7) Inspect a single puzzle
This is just for intuition: we compare a masked puzzle and the model’s prediction.


In [None]:
def format_grid(vec16):
    rows = [vec16[i:i+4] for i in range(0, 16, 4)]
    return '\n'.join([' '.join([str(x) for x in row]) for row in rows])

sample = val_ds[0]
puzzle = sample.puzzle.to(device)
sol = sample.solution

with torch.no_grad():
    logits = gt_model(puzzle.unsqueeze(0))
    pred = logits.squeeze(0).argmax(dim=-1).cpu()

print('Puzzle (-1 = blank):')
print(format_grid(puzzle.cpu().tolist()))
print('Prediction:')
print(format_grid(pred.tolist()))
print('Solution:')
print(format_grid(sol.tolist()))
