# Week 4 — GT‑Full Toy Node Labeling
We compare a simple MLP baseline against a **GT‑Full** message‑passing model on a toy relational graph.
GT‑Full uses **relation‑aware simplicial message passing** (edge types) and should learn faster.


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(0)
np.random.seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


## 1) Build a toy relational graph
We create two relation types: `friend` and `colleague`.
Labels are a simple function of relation‑specific neighborhood counts.


In [None]:
# Toy graph
N = 60
rel_names = ['friend', 'colleague']
num_rel = len(rel_names)

# Random edges with relation types
edges_src = []
edges_dst = []
rel_ids = []

rng = np.random.default_rng(0)
for i in range(N):
    for j in range(N):
        if i == j:
            continue
        if rng.random() < 0.06:
            r = 0
        elif rng.random() < 0.04:
            r = 1
        else:
            continue
        edges_src.append(i)
        edges_dst.append(j)
        rel_ids.append(r)

edge_index = torch.tensor([edges_src, edges_dst], dtype=torch.long)
rel_ids = torch.tensor(rel_ids, dtype=torch.long)

# Node features
X = torch.randn(N, 8)

# Labels: count friends vs colleagues in neighborhood (simple rule)
deg_friend = torch.zeros(N)
deg_coll   = torch.zeros(N)
for s, d, r in zip(edges_src, edges_dst, rel_ids.tolist()):
    if r == 0:
        deg_friend[d] += 1
    else:
        deg_coll[d] += 1

# Label = 1 if friends dominate colleagues, else 0
y = (deg_friend >= deg_coll).long()

# Train/val split
idx = torch.randperm(N)
train_idx = idx[:40]
val_idx = idx[40:]


## 2) Baseline vs GT‑Full model
The baseline ignores relations and only uses node features.
GT‑Full aggregates relation‑aware messages along edges.


In [None]:
# GT‑Full message passing from the repo (minimal copy)
class SimplicialMessagePassing(nn.Module):
    def __init__(self, dim: int, num_rel: int, hidden_dim: int | None = None):
        super().__init__()
        hidden_dim = hidden_dim or dim
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * dim + num_rel, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, V, edge_index, rel_ids):
        if edge_index.shape[0] == 2:
            src = edge_index[0].long()
            dst = edge_index[1].long()
        else:
            src = edge_index[:,0].long()
            dst = edge_index[:,1].long()
        src_h = V[src]
        dst_h = V[dst]
        num_rel = int(rel_ids.max().item()) + 1 if rel_ids.numel() > 0 else 0
        rel_onehot = F.one_hot(rel_ids.long(), num_classes=num_rel).float()
        edge_feat = torch.cat([src_h, dst_h, rel_onehot], dim=-1)
        msg = self.edge_mlp(edge_feat)
        out = torch.zeros_like(V)
        out.index_add_(0, dst, msg)
        return V + out

class GeometricTransformerV2(nn.Module):
    def __init__(self, dim: int, depth: int, num_rel: int):
        super().__init__()
        self.layers = nn.ModuleList([SimplicialMessagePassing(dim, num_rel) for _ in range(depth)])
        self.norm = nn.LayerNorm(dim)

    def forward(self, V, edge_index, rel_ids):
        h = V
        for layer in self.layers:
            h = layer(h, edge_index, rel_ids)
        return self.norm(h)

class BaselineMLP(nn.Module):
    def __init__(self, in_dim, hidden=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 2),
        )
    def forward(self, x):
        return self.net(x)

class GTFullClassifier(nn.Module):
    def __init__(self, in_dim, hidden=32, depth=2, num_rel=2):
        super().__init__()
        self.in_proj = nn.Linear(in_dim, hidden)
        self.gt = GeometricTransformerV2(hidden, depth, num_rel)
        self.out = nn.Linear(hidden, 2)
    def forward(self, x, edge_index, rel_ids):
        h = self.in_proj(x)
        h = self.gt(h, edge_index, rel_ids)
        return self.out(h)


## 3) Train and compare


In [None]:
def train_model(model, is_gt=False, epochs=200, lr=1e-2):
    model = model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    logs = []
    for epoch in range(1, epochs+1):
        model.train()
        if is_gt:
            logits = model(X.to(device), edge_index.to(device), rel_ids.to(device))
        else:
            logits = model(X.to(device))
        loss = F.cross_entropy(logits[train_idx], y[train_idx].to(device))
        opt.zero_grad()
        loss.backward()
        opt.step()
        with torch.no_grad():
            preds = logits.argmax(dim=-1).cpu()
            acc = (preds[val_idx] == y[val_idx]).float().mean().item()
        logs.append((loss.item(), acc))
    return logs

baseline = BaselineMLP(in_dim=8)
gtfull = GTFullClassifier(in_dim=8, depth=2, num_rel=num_rel)

b_logs = train_model(baseline, is_gt=False, epochs=200)
g_logs = train_model(gtfull, is_gt=True, epochs=200)

plt.figure(figsize=(6,4))
plt.plot([x[1] for x in b_logs], label='Baseline MLP')
plt.plot([x[1] for x in g_logs], label='GT‑Full')
plt.xlabel('Epoch')
plt.ylabel('Val accuracy')
plt.title('GT‑Full vs baseline on relational labels')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()
