# Week 5 — Mini‑Democritus: Causal Triples → Manifold
We embed a small set of causal relational triples into a 2D manifold using **GT‑Full** message passing.
This is a lightweight, classroom‑friendly proxy for the Democritus pipeline.


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(0)
np.random.seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


## 1) Define a small set of causal triples
Each triple is (subject, relation, object). We build a graph from these triples.


In [None]:
triples = [
    ('smoking', 'causes', 'cancer'),
    ('pollution', 'causes', 'asthma'),
    ('exercise', 'reduces', 'risk'),
    ('sleep', 'improves', 'memory'),
    ('stress', 'causes', 'inflammation'),
    ('diet', 'influences', 'health'),
    ('exercise', 'improves', 'health'),
    ('sleep', 'reduces', 'stress'),
    ('smoking', 'increases', 'risk'),
    ('pollution', 'increases', 'risk'),
]

nodes = sorted({x for t in triples for x in (t[0], t[2])})
rels = sorted({t[1] for t in triples})
node2id = {n:i for i,n in enumerate(nodes)}
rel2id = {r:i for i,r in enumerate(rels)}

edges_src = []
edges_dst = []
rel_ids = []
for s, r, o in triples:
    edges_src.append(node2id[s])
    edges_dst.append(node2id[o])
    rel_ids.append(rel2id[r])

edge_index = torch.tensor([edges_src, edges_dst], dtype=torch.long)
rel_ids = torch.tensor(rel_ids, dtype=torch.long)

print('Nodes:', nodes)
print('Relations:', rels)


## 2) GT‑Full message passing to get embeddings


In [None]:
class SimplicialMessagePassing(nn.Module):
    def __init__(self, dim: int, num_rel: int, hidden_dim: int | None = None):
        super().__init__()
        hidden_dim = hidden_dim or dim
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * dim + num_rel, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, V, edge_index, rel_ids):
        src = edge_index[0].long()
        dst = edge_index[1].long()
        src_h = V[src]
        dst_h = V[dst]
        num_rel = int(rel_ids.max().item()) + 1 if rel_ids.numel() > 0 else 0
        rel_onehot = F.one_hot(rel_ids.long(), num_classes=num_rel).float()
        edge_feat = torch.cat([src_h, dst_h, rel_onehot], dim=-1)
        msg = self.edge_mlp(edge_feat)
        out = torch.zeros_like(V)
        out.index_add_(0, dst, msg)
        return V + out

class GeometricTransformerV2(nn.Module):
    def __init__(self, dim: int, depth: int, num_rel: int):
        super().__init__()
        self.layers = nn.ModuleList([SimplicialMessagePassing(dim, num_rel) for _ in range(depth)])
        self.norm = nn.LayerNorm(dim)
    def forward(self, V, edge_index, rel_ids):
        h = V
        for layer in self.layers:
            h = layer(h, edge_index, rel_ids)
        return self.norm(h)

N = len(nodes)
dim = 16
V = torch.randn(N, dim, device=device)
gt = GeometricTransformerV2(dim, depth=2, num_rel=len(rels)).to(device)

with torch.no_grad():
    H = gt(V, edge_index.to(device), rel_ids.to(device)).cpu().numpy()


## 3) 2D visualization (PCA)


In [None]:
# Simple PCA to 2D
Hc = H - H.mean(axis=0, keepdims=True)
U, S, Vt = np.linalg.svd(Hc, full_matrices=False)
Z = Hc @ Vt[:2].T

plt.figure(figsize=(6,4))
plt.scatter(Z[:,0], Z[:,1], s=80)
for i, name in enumerate(nodes):
    plt.text(Z[i,0]+0.02, Z[i,1]+0.02, name)
plt.title('Mini‑Democritus manifold (GT‑Full embeddings)')
plt.axis('equal')
plt.tight_layout()
plt.show()
