# G-Retrieval Comparison: LPG (GAT) vs RDF (TransE / DistMult)

End-to-end experiment notebook with two tracks:

## Track A: Representation Quality (Cells 1–5, 14–20)
Train graph encoders on FinDER dual-graph PyG dataset. Evaluate via **graph→answer cosine retrieval**.

| Model | Type | Input | Training |
|-------|------|-------|----------|
| **GAT** | GNN (2-layer, 4-head) | LPG subgraphs (384d SentenceTransformer node features) | Link prediction (BCE) |
| **TransE** | KGE (translation) | RDF triples (h + r ≈ t) | Margin ranking loss |
| **DistMult** | KGE (bilinear) | RDF triples (h · diag(r) · t) | Negative sampling loss |

## Track B: G-Retrieval E2E Profiling (Cells 6–13)
Full **Graph Encoder → Projection → LLM** pipeline at D=256.

- Stage-wise profiling with `torch.profiler` + `record_function` labels
- CPU-store baseline: embedding tables forced to CPU, explicit lookup → H2D → GPU compute
- Attention map analysis: how graph soft tokens influence LLM self-attention
- LLM: Llama 3.1 8B Instruct (bfloat16, A100 80GB)

## Data
- **FinDER KG**: 2,542 samples (train 2,030 / val 251 / test 261)
- **LPG**: 13,920 nodes, 18,892 edges | **RDF**: 17,534 entities, 4,340 relations

In [None]:
# Cell 1: Setup & Install
# Uncomment for Colab:
# !pip install torch torch-geometric sentence-transformers rouge-score
# !pip install torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html

import sys, os, json, time, warnings
from pathlib import Path
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 11, 'figure.dpi': 120})
warnings.filterwarnings('ignore', category=FutureWarning)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.nn.kge import TransE, DistMult
from torch_geometric.utils import negative_sampling, scatter

# Project imports — adjust path for Colab
PROJECT_ROOT = Path('..').resolve()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.data import FinDERGraphQADataset, DualGraphBatch, dual_graph_collate_fn, VocabularyBuilder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if device.type == 'cuda':
    print(f'  GPU: {torch.cuda.get_device_name()}')
    print(f'  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
# Cell 2: Data Loading
DATA_ROOT = PROJECT_ROOT / 'data' / 'processed' / 'finder_pyg'

train_ds = FinDERGraphQADataset(root=str(DATA_ROOT), split='train')
val_ds   = FinDERGraphQADataset(root=str(DATA_ROOT), split='val')
test_ds  = FinDERGraphQADataset(root=str(DATA_ROOT), split='test')

BATCH_SIZE = 32
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=dual_graph_collate_fn, num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=dual_graph_collate_fn, num_workers=0)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=dual_graph_collate_fn, num_workers=0)

# Load vocabularies
vocabs = FinDERGraphQADataset.get_vocab(root=str(DATA_ROOT))
metadata = json.loads((DATA_ROOT / 'processed' / 'metadata.json').read_text())

NUM_RDF_ENTITIES  = metadata['vocab_sizes']['rdf_entities']   # 17,534
NUM_RDF_RELATIONS = metadata['vocab_sizes']['rdf_relations']  # 4,340
LPG_FEATURE_DIM   = metadata['lpg_feature_dim']               # 384

# Dataset statistics
print(f"\n{'='*50}")
print(f"Dataset Statistics")
print(f"{'='*50}")
print(f"  Train: {len(train_ds):,} | Val: {len(val_ds):,} | Test: {len(test_ds):,}")
print(f"  LPG feature dim: {LPG_FEATURE_DIM}")
print(f"  RDF entities: {NUM_RDF_ENTITIES:,} | relations: {NUM_RDF_RELATIONS:,}")

# Category distribution
categories = [train_ds[i].category for i in range(len(train_ds))]
cat_counts = pd.Series(categories).value_counts()
print(f"\nCategory distribution (train):")
for cat, count in cat_counts.items():
    print(f"  {cat}: {count}")

# Average graph sizes
lpg_nodes = [train_ds[i].lpg_num_nodes.item() for i in range(min(200, len(train_ds)))]
rdf_edges = [train_ds[i].rdf_edge_index.shape[1] for i in range(min(200, len(train_ds)))]
print(f"\nAvg LPG nodes/sample: {np.mean(lpg_nodes):.1f} (±{np.std(lpg_nodes):.1f})")
print(f"Avg RDF triples/sample: {np.mean(rdf_edges):.1f} (±{np.std(rdf_edges):.1f})")

In [None]:
# Cell 3: Model Definitions

# --- GAT for LPG (batched) ---

class BatchedGAT(nn.Module):
    """GAT encoder for LPG subgraphs with batched graph-level pooling.

    Based on MessagePassingGNN from src/_legacy/models.py but uses
    global_mean_pool for proper mini-batch support.
    """

    def __init__(
        self,
        input_dim: int = 384,
        hidden_dim: int = 256,
        output_dim: int = 384,
        num_layers: int = 2,
        heads: int = 4,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        for i in range(num_layers):
            in_ch = hidden_dim if i == 0 else hidden_dim * heads
            self.convs.append(GATConv(in_ch, hidden_dim, heads=heads, dropout=dropout))
            self.norms.append(nn.LayerNorm(hidden_dim * heads))
        self.output_proj = nn.Linear(hidden_dim * heads, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index, batch):
        """Forward pass with batched graph-level pooling.

        Args:
            x: [sum(N_i), input_dim] node features
            edge_index: [2, sum(E_i)] COO edges
            batch: [sum(N_i)] graph membership index

        Returns:
            [B, output_dim] graph-level embeddings
        """
        x = torch.relu(self.input_proj(x))
        for conv, norm in zip(self.convs, self.norms):
            x = conv(x, edge_index)
            x = norm(x)
            x = torch.relu(x)
            x = self.dropout(x)
        node_emb = self.output_proj(x)  # [sum(N_i), output_dim]
        return global_mean_pool(node_emb, batch)  # [B, output_dim]

    def get_node_embeddings(self, x, edge_index):
        """Get per-node embeddings (no pooling). For link prediction decoding."""
        x = torch.relu(self.input_proj(x))
        for conv, norm in zip(self.convs, self.norms):
            x = conv(x, edge_index)
            x = norm(x)
            x = torch.relu(x)
            x = self.dropout(x)
        return self.output_proj(x)  # [sum(N_i), output_dim]


# --- KGE for RDF (TransE / DistMult) with per-graph aggregation ---

class BatchedKGE(nn.Module):
    """KGE encoder for RDF triples with per-graph embedding aggregation.

    Learns global entity and relation embeddings, then for each question's
    RDF subgraph, aggregates triple representations via scatter mean.
    """

    def __init__(
        self,
        model_type: str,  # 'transe' or 'distmult'
        num_entities: int,
        num_relations: int,
        hidden_dim: int = 256,
        output_dim: int = 384,
    ):
        super().__init__()
        self.model_type = model_type
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        if model_type == 'transe':
            self.kge = TransE(
                num_nodes=num_entities,
                num_relations=num_relations,
                hidden_channels=hidden_dim,
                margin=1.0,
                p_norm=1.0,
            )
        elif model_type == 'distmult':
            self.kge = DistMult(
                num_nodes=num_entities,
                num_relations=num_relations,
                hidden_channels=hidden_dim,
            )
        else:
            raise ValueError(f"Unknown model_type: {model_type}")

        self.output_proj = nn.Linear(hidden_dim, output_dim)

    def loss(self, head_index, rel_type, tail_index):
        """KGE training loss with built-in negative sampling."""
        return self.kge.loss(head_index, rel_type, tail_index)

    def forward(self, batch: 'DualGraphBatch'):
        """Compute per-graph embeddings from RDF triples.

        For each triple (h, r, t) in the batch, computes h_emb + r_emb,
        then aggregates per-graph via scatter mean on the head node's
        graph membership.

        Returns:
            [B, output_dim] graph-level embeddings
        """
        edge_index = batch.rdf_edge_index  # [2, sum(T_i)]
        edge_type = batch.rdf_edge_type    # [sum(T_i)]
        rdf_batch = batch.rdf_batch        # [sum(N_rdf_i)]
        global_idx = batch.rdf_global_node_idx  # [sum(N_rdf_i)]

        if edge_index.shape[1] == 0:
            return torch.zeros(batch.batch_size, self.output_dim, device=edge_index.device)

        head_local = edge_index[0]  # local node indices
        tail_local = edge_index[1]

        # Map local indices to global for embedding lookup
        head_global = global_idx[head_local]
        tail_global = global_idx[tail_local]

        head_emb = self.kge.node_emb(head_global)  # [sum(T_i), hidden_dim]
        rel_emb = self.kge.rel_emb(edge_type)       # [sum(T_i), hidden_dim]

        # Triple representation: h + r (translation-style, works for both)
        triple_emb = head_emb + rel_emb  # [sum(T_i), hidden_dim]

        # Determine graph membership for each triple (from head node)
        triple_graph = rdf_batch[head_local]  # [sum(T_i)]

        # Aggregate triples per graph (using PyG's scatter with reduce='mean')
        graph_emb = scatter(triple_emb, triple_graph, dim=0,
                            dim_size=batch.batch_size, reduce='mean')  # [B, hidden_dim]

        return self.output_proj(graph_emb)  # [B, output_dim]

    def get_entity_embeddings(self):
        """Export projected entity embeddings [num_entities, output_dim]."""
        with torch.no_grad():
            return self.output_proj(self.kge.node_emb.weight)


# Quick sanity check
print('BatchedGAT params:', sum(p.numel() for p in BatchedGAT().parameters()) / 1e3, 'K')
print('BatchedKGE (TransE) params:',
      sum(p.numel() for p in BatchedKGE('transe', NUM_RDF_ENTITIES, NUM_RDF_RELATIONS).parameters()) / 1e6, 'M')
print('BatchedKGE (DistMult) params:',
      sum(p.numel() for p in BatchedKGE('distmult', NUM_RDF_ENTITIES, NUM_RDF_RELATIONS).parameters()) / 1e6, 'M')

In [None]:
# Cell 4: Training — GAT (LPG) via Link Prediction

def train_gat_epoch(model, loader, optimizer):
    """Train GAT via link prediction with negative sampling."""
    model.train()
    total_loss = 0
    num_batches = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # Get node embeddings (no pooling)
        z = model.get_node_embeddings(batch.lpg_x, batch.lpg_edge_index)

        # Positive edges
        pos_edge = batch.lpg_edge_index
        num_nodes = batch.lpg_x.shape[0]

        if pos_edge.shape[1] == 0:
            continue

        # Negative sampling
        neg_edge = negative_sampling(
            pos_edge, num_nodes=num_nodes,
            num_neg_samples=pos_edge.shape[1],
        )

        # Score positive and negative edges via dot product
        pos_score = (z[pos_edge[0]] * z[pos_edge[1]]).sum(dim=-1)
        neg_score = (z[neg_edge[0]] * z[neg_edge[1]]).sum(dim=-1)

        # BCE loss
        pos_loss = F.binary_cross_entropy_with_logits(pos_score, torch.ones_like(pos_score))
        neg_loss = F.binary_cross_entropy_with_logits(neg_score, torch.zeros_like(neg_score))
        loss = pos_loss + neg_loss

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    return total_loss / max(num_batches, 1)


@torch.no_grad()
def eval_gat_link_prediction(model, loader):
    """Evaluate GAT link prediction: MRR and Hits@10."""
    model.eval()
    mrr_sum, hits10_sum, count = 0.0, 0.0, 0

    for batch in loader:
        batch = batch.to(device)
        z = model.get_node_embeddings(batch.lpg_x, batch.lpg_edge_index)

        pos_edge = batch.lpg_edge_index
        if pos_edge.shape[1] == 0:
            continue

        num_nodes = z.shape[0]
        # Score a sample of edges (full ranking is too slow)
        sample_size = min(500, pos_edge.shape[1])
        idx = torch.randperm(pos_edge.shape[1])[:sample_size]
        src, dst = pos_edge[0, idx], pos_edge[1, idx]

        for s, d in zip(src, dst):
            # Score true tail vs all nodes
            scores = (z[s].unsqueeze(0) * z).sum(dim=-1)  # [num_nodes]
            rank = (scores >= scores[d]).sum().item()
            mrr_sum += 1.0 / rank
            hits10_sum += 1.0 if rank <= 10 else 0.0
            count += 1

    mrr = mrr_sum / max(count, 1)
    hits10 = hits10_sum / max(count, 1)
    return {'mrr': mrr, 'hits@10': hits10}


# Train GAT
GAT_EPOCHS = 50
GAT_LR = 1e-3

gat_model = BatchedGAT(input_dim=LPG_FEATURE_DIM).to(device)
gat_optimizer = torch.optim.Adam(gat_model.parameters(), lr=GAT_LR, weight_decay=1e-5)

gat_history = {'train_loss': [], 'val_mrr': [], 'val_hits10': []}
best_val_mrr = 0.0
best_gat_state = None

print(f'Training GAT for {GAT_EPOCHS} epochs...')
for epoch in range(1, GAT_EPOCHS + 1):
    loss = train_gat_epoch(gat_model, train_loader, gat_optimizer)
    gat_history['train_loss'].append(loss)

    if epoch % 5 == 0 or epoch == 1:
        val_metrics = eval_gat_link_prediction(gat_model, val_loader)
        gat_history['val_mrr'].append(val_metrics['mrr'])
        gat_history['val_hits10'].append(val_metrics['hits@10'])
        print(f'  Epoch {epoch:3d} | Loss: {loss:.4f} | Val MRR: {val_metrics["mrr"]:.4f} | Val Hits@10: {val_metrics["hits@10"]:.3f}')

        if val_metrics['mrr'] > best_val_mrr:
            best_val_mrr = val_metrics['mrr']
            best_gat_state = {k: v.cpu().clone() for k, v in gat_model.state_dict().items()}

# Restore best model
if best_gat_state:
    gat_model.load_state_dict(best_gat_state)
    gat_model.to(device)
print(f'\nBest GAT Val MRR: {best_val_mrr:.4f}')

In [None]:
# Cell 5: Training — TransE & DistMult (RDF)

def collect_all_rdf_triples(loader):
    """Collect all (head_global, rel, tail_global) triples from the dataset."""
    heads, rels, tails = [], [], []
    for batch in loader:
        ei = batch.rdf_edge_index
        et = batch.rdf_edge_type
        gi = batch.rdf_global_node_idx
        if ei.shape[1] == 0:
            continue
        heads.append(gi[ei[0]])
        tails.append(gi[ei[1]])
        rels.append(et)
    return torch.cat(heads), torch.cat(rels), torch.cat(tails)


def train_kge_epoch(model, head, rel, tail, optimizer, batch_size=512):
    """Train KGE model for one epoch over all triples."""
    model.train()
    perm = torch.randperm(head.shape[0], device=head.device)
    total_loss = 0
    num_batches = 0

    for i in range(0, head.shape[0], batch_size):
        idx = perm[i:i+batch_size]
        optimizer.zero_grad()
        loss = model.loss(head[idx], rel[idx], tail[idx])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        num_batches += 1

    return total_loss / max(num_batches, 1)


@torch.no_grad()
def eval_kge(model, head, rel, tail, sample_size=1000, k=10):
    """Evaluate KGE with sampled ranking: MRR and Hits@K."""
    model.eval()
    n = min(sample_size, head.shape[0])
    idx = torch.randperm(head.shape[0])[:n]
    h, r, t = head[idx], rel[idx], tail[idx]

    node_emb = model.kge.node_emb.weight  # [num_entities, hidden_dim]
    rel_emb = model.kge.rel_emb(r)        # [n, hidden_dim]
    h_emb = model.kge.node_emb(h)          # [n, hidden_dim]

    mrr_sum, hits_sum = 0.0, 0.0

    for i in range(n):
        if model.model_type == 'transe':
            # score = -||h + r - t||  (higher = better)
            pred = h_emb[i] + rel_emb[i]  # [hidden_dim]
            scores = -torch.norm(node_emb - pred.unsqueeze(0), p=1, dim=-1)  # [num_entities]
        else:  # distmult
            # score = sum(h * r * t)
            pred = h_emb[i] * rel_emb[i]  # [hidden_dim]
            scores = (node_emb * pred.unsqueeze(0)).sum(dim=-1)  # [num_entities]

        rank = (scores >= scores[t[i]]).sum().item()
        mrr_sum += 1.0 / max(rank, 1)
        hits_sum += 1.0 if rank <= k else 0.0

    return {'mrr': mrr_sum / n, f'hits@{k}': hits_sum / n}


# Collect triples
print('Collecting RDF triples...')
train_h, train_r, train_t = collect_all_rdf_triples(train_loader)
val_h, val_r, val_t = collect_all_rdf_triples(val_loader)
train_h, train_r, train_t = train_h.to(device), train_r.to(device), train_t.to(device)
val_h, val_r, val_t = val_h.to(device), val_r.to(device), val_t.to(device)
print(f'  Train triples: {train_h.shape[0]:,} | Val triples: {val_h.shape[0]:,}')

# Train both models
KGE_EPOCHS = 100
KGE_LR = 1e-2
KGE_BATCH = 512

kge_models = {}
kge_histories = {}

for model_type in ['transe', 'distmult']:
    print(f'\n{"="*50}')
    print(f'Training {model_type.upper()} for {KGE_EPOCHS} epochs...')
    print(f'{"="*50}')

    model = BatchedKGE(model_type, NUM_RDF_ENTITIES, NUM_RDF_RELATIONS).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=KGE_LR)
    history = {'train_loss': [], 'val_mrr': [], 'val_hits10': []}
    best_mrr = 0.0
    best_state = None

    for epoch in range(1, KGE_EPOCHS + 1):
        loss = train_kge_epoch(model, train_h, train_r, train_t, optimizer, KGE_BATCH)
        history['train_loss'].append(loss)

        if epoch % 10 == 0 or epoch == 1:
            val_metrics = eval_kge(model, val_h, val_r, val_t)
            history['val_mrr'].append(val_metrics['mrr'])
            history['val_hits10'].append(val_metrics['hits@10'])
            print(f'  Epoch {epoch:3d} | Loss: {loss:.4f} | Val MRR: {val_metrics["mrr"]:.4f} | Val Hits@10: {val_metrics["hits@10"]:.3f}')

            if val_metrics['mrr'] > best_mrr:
                best_mrr = val_metrics['mrr']
                best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

    if best_state:
        model.load_state_dict(best_state)
        model.to(device)
    print(f'Best {model_type.upper()} Val MRR: {best_mrr:.4f}')

    kge_models[model_type] = model
    kge_histories[model_type] = history

# Loss curve comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
for mt, h in kge_histories.items():
    axes[0].plot(h['train_loss'], label=mt.upper())
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('KGE Training Loss')
axes[0].legend()

# Also plot GAT loss
axes[1].plot(gat_history['train_loss'], label='GAT', color='tab:green')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('GAT Training Loss')
axes[1].legend()
plt.tight_layout()
plt.show()

---

# Track B: G-Retrieval E2E Profiling (D=256)

**Architecture:** Graph Encoder → Projection MLP → LLM (Llama 3.1 8B, 4-bit)

**D=256**: LPG features PCA-projected from 384→256. KGE embeddings natively 256d.

**CPU-store baseline**: Embedding tables forced to CPU. Explicit stage separation:

| # | Label | LPG (GAT) | RDF (DistMult) |
|---|-------|-----------|----------------|
| 1 | `SUBGRAPH_LOAD` | Extract index tensors from batch | Extract edge_index, edge_type, global_idx |
| 2 | `LPG_CPU_GATHER_X` / `RDF_KGE_CPU_LOOKUP_ENT_REL` | CPU feature store indexing | CPU embedding table lookup |
| 3 | `LPG_H2D_COPY` / `RDF_KGE_H2D_COPY` | Sync transfer to GPU | Sync transfer to GPU |
| 4 | `LPG_GAT_FWD` / `RDF_KGE_FWD_DISTMULT` | GAT layers + global_mean_pool | h*r (DistMult) + scatter + proj |
| 5 | `PROJECTION` | Linear(256 → LLM hidden) | Linear(256 → LLM hidden) |
| 6 | `LLM_PREFILL` | LLM forward with graph soft tokens | Same |

**Two profiling passes:**
1. Custom CUDA Event timers → quantitative summary tables
2. `torch.profiler` + `record_function` → Chrome trace (`chrome://tracing`)

### Cell 6: Profiling Utilities + CPU Store Simulators

**Objective:** Set up timing infrastructure and CPU-resident feature/embedding stores for G-Retrieval profiling.

**Key components:**

| Component | Purpose |
|-----------|---------|
| `GRAPH_DIM = 256` | Unified embedding dimension (PCA-projected from 384) |
| `StageTiming` | Dataclass holding CPU ms, CUDA ms, bytes transferred per stage |
| `CUDATimer` | Context manager using `torch.cuda.Event` pairs for GPU timing |
| `CPUTimer` | Context manager using `time.perf_counter` for CPU-only stages |
| `LPGCPUStore` | CPU-resident node features [13920, 256] — forces explicit `gather()` + H2D |
| `RDFCPUStore` | CPU-resident KGE embeddings (entity [17534, 256] + relation [4340, 256]) — forces explicit `lookup()` + H2D |

**PCA Projection:** `torch.pca_lowrank(384→256)` applied offline to global LPG node features. Reports explained variance ratio.

**Design decision:** `non_blocking=False` for all H2D copies to ensure stage isolation in profiling. `pin_memory=False` on DataLoader for baseline measurement.

In [None]:
# Cell 6: Profiling Utilities + CPU Store Simulators (D=256)

import time
from torch.profiler import profile as torch_profile, ProfilerActivity, record_function

# ── Constants ──
WARMUP_BATCHES = 3
PROFILE_BATCHES = 20
PROFILE_BATCH_SIZE = 128  # A100 80GB can handle large batches
GRAPH_DIM = 256  # Unified graph embedding dim (PCA-projected from 384)

HAS_CUDA = device.type == 'cuda'

# ── Timing Infrastructure ──

@dataclass
class StageTiming:
    stage: str
    cpu_ms: float
    cuda_ms: float = float('nan')
    bytes_transferred: int = 0


class CUDATimer:
    def __init__(self, stage: str):
        self.stage = stage
        if HAS_CUDA:
            self.start_evt = torch.cuda.Event(enable_timing=True)
            self.end_evt = torch.cuda.Event(enable_timing=True)

    def __enter__(self):
        if HAS_CUDA:
            torch.cuda.synchronize()
        self.cpu_start = time.perf_counter()
        if HAS_CUDA:
            self.start_evt.record()
        return self

    def __exit__(self, *args):
        if HAS_CUDA:
            self.end_evt.record()
            torch.cuda.synchronize()
        self.cpu_end = time.perf_counter()
        cpu_ms = (self.cpu_end - self.cpu_start) * 1000
        cuda_ms = self.start_evt.elapsed_time(self.end_evt) if HAS_CUDA else float('nan')
        self._result = StageTiming(stage=self.stage, cpu_ms=cpu_ms, cuda_ms=cuda_ms)

    def result(self):
        return self._result


class CPUTimer:
    def __init__(self, stage: str):
        self.stage = stage

    def __enter__(self):
        self.cpu_start = time.perf_counter()
        return self

    def __exit__(self, *args):
        self.cpu_end = time.perf_counter()
        cpu_ms = (self.cpu_end - self.cpu_start) * 1000
        self._result = StageTiming(stage=self.stage, cpu_ms=cpu_ms)

    def result(self):
        return self._result


# ── Result Aggregation ──

def aggregate_timings(all_timings, batch_sizes, node_counts, edge_counts):
    stages = list(all_timings[0].keys())
    rows = []
    for stage in stages:
        cpu_vals = [t[stage].cpu_ms for t in all_timings]
        cuda_vals = [t[stage].cuda_ms for t in all_timings]
        bytes_vals = [t[stage].bytes_transferred for t in all_timings]
        row = {
            'stage': stage,
            'cpu_mean': np.mean(cpu_vals),
            'cpu_std': np.std(cpu_vals),
            'cpu_p95': np.percentile(cpu_vals, 95),
            'cuda_mean': np.nanmean(cuda_vals) if not all(np.isnan(cuda_vals)) else float('nan'),
            'cuda_std': np.nanstd(cuda_vals) if not all(np.isnan(cuda_vals)) else float('nan'),
            'cuda_p95': np.nanpercentile(cuda_vals, 95) if not all(np.isnan(cuda_vals)) else float('nan'),
            'bytes_mean': np.mean(bytes_vals),
        }
        rows.append(row)
    return pd.DataFrame(rows)


def print_profiling_table(df, model_name, n_batches):
    total_cpu = df['cpu_mean'].sum()
    print(f'\n{"="*70}')
    print(f'{model_name} Profiling -- {n_batches} batches, BS={PROFILE_BATCH_SIZE}, D={GRAPH_DIM}')
    print(f'{"="*70}')
    print(f'{"Stage":<30} {"CPU ms":>10} {"CUDA ms":>10} {"% total":>8} {"BW":>10}')
    print(f'{"-"*70}')
    for _, r in df.iterrows():
        pct = r['cpu_mean'] / total_cpu * 100 if total_cpu > 0 else 0
        cuda_str = f'{r["cuda_mean"]:.2f}' if not np.isnan(r['cuda_mean']) else '--'
        bw_str = ''
        if r['bytes_mean'] > 0 and not np.isnan(r['cuda_mean']) and r['cuda_mean'] > 0:
            bw_gbps = (r['bytes_mean'] / 1e9) / (r['cuda_mean'] / 1e3)
            bw_str = f'{bw_gbps:.1f} GB/s'
        print(f'{r["stage"]:<30} {r["cpu_mean"]:>8.2f}+/-{r["cpu_std"]:<4.1f} '
              f'{cuda_str:>8} {pct:>7.1f}%  {bw_str}')
    print(f'{"-"*70}')
    print(f'{"TOTAL":<30} {total_cpu:>10.2f}')


# ══════════════════════════════════════════════════
# Offline PCA Projection: 384 -> 256
# ══════════════════════════════════════════════════
print(f'Loading & projecting LPG features (384d -> {GRAPH_DIM}d via PCA)...')
_lpg_cache_path = PROJECT_ROOT / 'data' / 'processed' / 'lpg_full_graph.pt'

if _lpg_cache_path.exists():
    _lpg_cache = torch.load(str(_lpg_cache_path), weights_only=False, map_location='cpu')
    _feat_384 = _lpg_cache['data'].x  # [13920, 384]
    del _lpg_cache
else:
    print('  lpg_full_graph.pt not found -- reconstructing from dataset...')
    _max_idx = 0
    for ds in [train_ds, val_ds, test_ds]:
        for i in range(len(ds)):
            gi = ds[i].lpg_global_node_idx
            if gi.numel() > 0:
                _max_idx = max(_max_idx, gi.max().item())
    _feat_384 = torch.zeros(_max_idx + 1, LPG_FEATURE_DIM)
    for ds in [train_ds, val_ds, test_ds]:
        for i in range(len(ds)):
            sample = ds[i]
            gi = sample.lpg_global_node_idx
            if gi.numel() > 0:
                _feat_384[gi] = sample.lpg_x
    print(f'  Reconstructed from {len(train_ds)+len(val_ds)+len(test_ds)} samples')

_centered = _feat_384 - _feat_384.mean(dim=0)
U, S, V = torch.pca_lowrank(_centered, q=GRAPH_DIM)
global_node_features = (_centered @ V[:, :GRAPH_DIM]).float()  # [N, 256]
_var_ratio = (S[:GRAPH_DIM]**2).sum() / (S**2).sum()
print(f'  PCA: {_feat_384.shape} -> {global_node_features.shape}')
print(f'  Explained variance ratio: {_var_ratio:.3f}')
del _feat_384, _centered, U, S, V


# ══════════════════════════════════════════════════
# CPU Store Simulators (forced CPU residence)
# ══════════════════════════════════════════════════
print('\nInitializing CPU stores...')


class LPGCPUStore:
    """CPU-resident node feature store for LPG (D=256).
    Forces gather on CPU before explicit H2D copy."""
    def __init__(self, features):
        self.features = features.cpu().clone()
        mb = self.features.nelement() * self.features.element_size() / 1e6
        print(f'  LPGCPUStore: {self.features.shape}, {mb:.1f} MB on CPU')

    def gather(self, global_node_idx):
        return self.features[global_node_idx.cpu()]


class RDFCPUStore:
    """CPU-resident KGE embedding tables (D=256).
    Entity + relation embeddings forced to CPU.
    lookup() returns CPU tensors for explicit H2D."""
    def __init__(self, kge_model):
        self.node_emb_weight = kge_model.kge.node_emb.weight.detach().cpu().clone()
        self.rel_emb_weight = kge_model.kge.rel_emb.weight.detach().cpu().clone()
        node_mb = self.node_emb_weight.nelement() * self.node_emb_weight.element_size() / 1e6
        rel_mb = self.rel_emb_weight.nelement() * self.rel_emb_weight.element_size() / 1e6
        print(f'  RDFCPUStore: node_emb {self.node_emb_weight.shape} ({node_mb:.1f} MB)')
        print(f'               rel_emb  {self.rel_emb_weight.shape} ({rel_mb:.1f} MB)')

    def lookup(self, head_global, edge_type):
        h_emb = self.node_emb_weight[head_global.cpu()]
        r_emb = self.rel_emb_weight[edge_type.cpu()]
        return h_emb, r_emb


lpg_cpu_store = LPGCPUStore(global_node_features)
rdf_cpu_store = RDFCPUStore(kge_models['distmult'])

# Profiling DataLoader (no pin_memory for baseline)
profile_loader = DataLoader(train_ds, batch_size=PROFILE_BATCH_SIZE, shuffle=False,
                            collate_fn=dual_graph_collate_fn, num_workers=0,
                            pin_memory=False)
print(f'\nProfile loader: {len(train_ds)} samples, ~{len(train_ds)//PROFILE_BATCH_SIZE} batches')
print(f'Config: {WARMUP_BATCHES} warmup + {PROFILE_BATCHES} measure, D={GRAPH_DIM}, sync H2D')

### Cell 7: Profiling Models (D=256) + Llama 3 Loading

**Objective:** Create profiling-specific 256d encoders and load the LLM.

**Models created (separate from Track A's 384d trained models):**
- `prof_gat`: `BatchedGAT(input_dim=256, output_dim=256)` — random init (latency is weight-independent)
- `prof_kge_proj`: `Linear(256→256)` — replaces Track A's `256→384` output projection
- `GraphLLM`: Projection MLP `256 → LLM_hidden × T` + LLM forward with `inputs_embeds`

**LLM:** `meta-llama/Llama-3.1-8B-Instruct` loaded in **bfloat16** (no quantization) to eliminate quant/dequant overhead from profiling. ~16GB VRAM on A100 80GB.

**Graph tokens:** `NUM_GRAPH_TOKENS = 4` soft tokens prepended to text input embeddings.

In [None]:
# Cell 7: Profiling Models (D=256) + Llama 3 Loading
#
# Profiling-specific 256d encoders (separate from Track A's 384d trained models).
# Untrained weights are fine for latency profiling (compute cost is identical).

from transformers import AutoModelForCausalLM, AutoTokenizer

# ── Profiling Encoders at D=256 ──
print(f'Creating profiling encoders at D={GRAPH_DIM}...')

# GAT encoder (256d input from PCA-projected features)
prof_gat = BatchedGAT(
    input_dim=GRAPH_DIM, hidden_dim=256, output_dim=GRAPH_DIM,
    num_layers=2, heads=4, dropout=0.1,
).to(device)
prof_gat.eval()
print(f'  prof_gat: {sum(p.numel() for p in prof_gat.parameters())/1e3:.1f}K params')

# KGE output projection (256 -> 256, replaces Track A's 256->384)
prof_kge_proj = nn.Linear(256, GRAPH_DIM).to(device)
prof_kge_proj.eval()
print(f'  prof_kge_proj: Linear(256 -> {GRAPH_DIM})')


# ── GraphLLM: Graph Encoder -> Projection -> LLM ──

class GraphLLM(nn.Module):
    def __init__(self, graph_output_dim, llm, num_graph_tokens=1):
        super().__init__()
        self.llm = llm
        self.num_graph_tokens = num_graph_tokens
        self.llm_hidden = llm.config.hidden_size

        self.projector = nn.Sequential(
            nn.Linear(graph_output_dim, self.llm_hidden),
            nn.GELU(),
            nn.Linear(self.llm_hidden, self.llm_hidden * num_graph_tokens),
        )

    def project(self, graph_emb):
        proj = self.projector(graph_emb)
        B = proj.shape[0]
        proj = proj.to(self.llm.dtype)  # float32 -> bfloat16
        return proj.view(B, self.num_graph_tokens, self.llm_hidden)

    @torch.no_grad()
    def forward_prefill(self, graph_tokens, input_ids, attention_mask):
        text_emb = self.llm.get_input_embeddings()(input_ids)
        combined = torch.cat([graph_tokens, text_emb], dim=1)
        graph_mask = torch.ones(
            graph_tokens.shape[0], graph_tokens.shape[1],
            device=attention_mask.device, dtype=attention_mask.dtype,
        )
        combined_mask = torch.cat([graph_mask, attention_mask], dim=1)
        return self.llm(inputs_embeds=combined, attention_mask=combined_mask).logits


# ── Load Llama 3.1 8B (4-bit NF4) ──
LLM_MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
NUM_GRAPH_TOKENS = 4  # More graph context tokens (A100 80GB has headroom)

print(f'\nLoading LLM: {LLM_MODEL_ID} (bfloat16, no quantization)')
print(f'  A100 80GB: loading full bfloat16 for clean profiling (no quant overhead)')

tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

llm = AutoModelForCausalLM.from_pretrained(
    LLM_MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
llm.eval()
for p in llm.parameters():
    p.requires_grad = False

llm_mem_gb = sum(p.nelement() * p.element_size() for p in llm.parameters()) / 1e9
print(f'  LLM: {llm_mem_gb:.1f} GB, hidden_size={llm.config.hidden_size}')

# ── GraphLLM instance (shared LLM, D=256 projector) ──
graph_llm = GraphLLM(GRAPH_DIM, llm, NUM_GRAPH_TOKENS)
graph_llm.projector.to(device)
proj_params = sum(p.numel() for p in graph_llm.projector.parameters())
print(f'  Projector: {GRAPH_DIM} -> {llm.config.hidden_size} x {NUM_GRAPH_TOKENS}, {proj_params/1e3:.1f}K params')

### Cell 8: E2E Stage-wise Profiling

**Objective:** Profile the full G-Retrieval inference pipeline with two complementary methods.

**Pass 1 — Custom CUDA Event Timers** (20 batches):
Quantitative per-stage timing with mean/std/p95 and H2D bandwidth.

**Pass 2 — `torch.profiler` + `record_function`** (20 batches):
Chrome-trace-compatible profiling. Export to `gat_profile_trace.json` and `rdf_distmult_profile_trace.json` for visualization at `chrome://tracing` or `perfetto.dev`.

**Stage labels (6 stages per pipeline):**

| # | GAT (LPG) | DistMult (RDF) |
|---|-----------|----------------|
| 1 | `SUBGRAPH_LOAD` | `SUBGRAPH_LOAD` |
| 2 | `LPG_CPU_GATHER_X` | `RDF_KGE_CPU_LOOKUP_ENT_REL` |
| 3 | `LPG_H2D_COPY` | `RDF_KGE_H2D_COPY` |
| 4 | `LPG_GAT_FWD` | `RDF_KGE_FWD_DISTMULT` |
| 5 | `PROJECTION` | `PROJECTION` |
| 6 | `LLM_PREFILL` | `LLM_PREFILL` |

In [None]:
# Cell 8: E2E Stage-wise Profiling
#
# Pass 1: Custom CUDA Event timers -> quantitative summary tables
# Pass 2: torch.profiler + record_function -> Chrome trace
#
# Labels:
#   LPG:  SUBGRAPH_LOAD, LPG_CPU_GATHER_X, LPG_H2D_COPY, LPG_GAT_FWD, PROJECTION, LLM_PREFILL
#   RDF:  SUBGRAPH_LOAD, RDF_KGE_CPU_LOOKUP_ENT_REL, RDF_KGE_H2D_COPY, RDF_KGE_FWD_DISTMULT, PROJECTION, LLM_PREFILL

from itertools import islice, cycle

PROMPT_TEMPLATE = (
    "Based on the graph context, answer the following question concisely.\n"
    "Question: {question}\nAnswer:"
)

def tokenize_questions(questions, max_length=128):
    prompts = [PROMPT_TEMPLATE.format(question=q) for q in questions]
    enc = tokenizer(prompts, return_tensors='pt', padding=True,
                    truncation=True, max_length=max_length)
    return enc['input_ids'].to(device), enc['attention_mask'].to(device)


# ═══════════════════════════════════════════════════════════
# Pass 1: Custom Timers -- GAT (LPG)
# ═══════════════════════════════════════════════════════════

def profile_lpg(gat_model, lpg_store, gllm, loader, n_warmup, n_measure):
    gat_model.eval()
    timings_list, batch_sizes, node_counts, edge_counts = [], [], [], []

    for i, batch in enumerate(islice(cycle(loader), n_warmup + n_measure)):

        # Warmup (same ops, no timing)
        if i < n_warmup:
            with torch.no_grad():
                x = lpg_store.gather(batch.lpg_global_node_idx).to(device)
                g = gat_model(x, batch.lpg_edge_index.to(device), batch.lpg_batch.to(device))
                gt = gllm.project(g)
                ids, mask = tokenize_questions(batch.questions)
                _ = gllm.forward_prefill(gt, ids, mask)
            continue

        timings = {}

        # Stage 1: SUBGRAPH_LOAD
        with CPUTimer('SUBGRAPH_LOAD') as t:
            global_idx = batch.lpg_global_node_idx
            edge_index = batch.lpg_edge_index
            batch_vec = batch.lpg_batch
        timings['SUBGRAPH_LOAD'] = t.result()

        # Stage 2: LPG_CPU_GATHER_X -- CPU feature store indexing
        with CPUTimer('LPG_CPU_GATHER_X') as t:
            x_cpu = lpg_store.gather(global_idx)  # CPU tensor [sum(N_i), 256]
        timings['LPG_CPU_GATHER_X'] = t.result()

        # Stage 3: LPG_H2D_COPY -- sync transfer to GPU
        with CUDATimer('LPG_H2D_COPY') as t:
            x_gpu = x_cpu.to(device, non_blocking=False)
            ei_gpu = edge_index.to(device, non_blocking=False)
            b_gpu = batch_vec.to(device, non_blocking=False)
        timings['LPG_H2D_COPY'] = t.result()
        timings['LPG_H2D_COPY'].bytes_transferred = (
            x_cpu.nelement() * x_cpu.element_size() +
            edge_index.nelement() * edge_index.element_size() +
            batch_vec.nelement() * batch_vec.element_size()
        )

        # Stage 4: LPG_GAT_FWD -- GAT layers + global_mean_pool
        with torch.no_grad():
            with CUDATimer('LPG_GAT_FWD') as t:
                graph_emb = gat_model(x_gpu, ei_gpu, b_gpu)  # [B, 256]
        timings['LPG_GAT_FWD'] = t.result()

        # Stage 5: PROJECTION -- Linear(256 -> LLM_hidden)
        with torch.no_grad():
            with CUDATimer('PROJECTION') as t:
                graph_tokens = gllm.project(graph_emb)
        timings['PROJECTION'] = t.result()

        # Stage 6: LLM_PREFILL -- LLM forward with graph soft tokens
        with torch.no_grad():
            with CUDATimer('LLM_PREFILL') as t:
                input_ids, attn_mask = tokenize_questions(batch.questions)
                _ = gllm.forward_prefill(graph_tokens, input_ids, attn_mask)
        timings['LLM_PREFILL'] = t.result()

        timings_list.append(timings)
        batch_sizes.append(batch.batch_size)
        node_counts.append(x_cpu.shape[0])
        edge_counts.append(edge_index.shape[1])

    return timings_list, batch_sizes, node_counts, edge_counts


# ═══════════════════════════════════════════════════════════
# Pass 1: Custom Timers -- DistMult (RDF)
# ═══════════════════════════════════════════════════════════

def profile_rdf_distmult(rdf_store, kge_proj, gllm, loader, n_warmup, n_measure):
    kge_proj.eval()
    timings_list, batch_sizes, node_counts, edge_counts = [], [], [], []

    for i, batch in enumerate(islice(cycle(loader), n_warmup + n_measure)):
        ei = batch.rdf_edge_index
        if ei.shape[1] == 0:
            continue

        if i < n_warmup:
            with torch.no_grad():
                hg = batch.rdf_global_node_idx[ei[0]]
                h, r = rdf_store.lookup(hg, batch.rdf_edge_type)
                tg = batch.rdf_batch[ei[0]].to(device)
                te = h.to(device) * r.to(device)
                ge = scatter(te, tg, dim=0, dim_size=batch.batch_size, reduce='mean')
                ge = kge_proj(ge)
                gt = gllm.project(ge)
                ids, mask = tokenize_questions(batch.questions)
                _ = gllm.forward_prefill(gt, ids, mask)
            continue

        timings = {}

        # Stage 1: SUBGRAPH_LOAD
        with CPUTimer('SUBGRAPH_LOAD') as t:
            head_local = ei[0]
            global_idx = batch.rdf_global_node_idx
            head_global = global_idx[head_local]
            edge_type = batch.rdf_edge_type
            rdf_batch_vec = batch.rdf_batch
        timings['SUBGRAPH_LOAD'] = t.result()

        # Stage 2: RDF_KGE_CPU_LOOKUP_ENT_REL -- CPU embedding table lookup
        with CPUTimer('RDF_KGE_CPU_LOOKUP_ENT_REL') as t:
            h_emb_cpu, r_emb_cpu = rdf_store.lookup(head_global, edge_type)
        timings['RDF_KGE_CPU_LOOKUP_ENT_REL'] = t.result()

        # Stage 3: RDF_KGE_H2D_COPY -- sync transfer to GPU
        with CUDATimer('RDF_KGE_H2D_COPY') as t:
            h_emb_gpu = h_emb_cpu.to(device, non_blocking=False)
            r_emb_gpu = r_emb_cpu.to(device, non_blocking=False)
            triple_graph = rdf_batch_vec[head_local].to(device, non_blocking=False)
        timings['RDF_KGE_H2D_COPY'] = t.result()
        timings['RDF_KGE_H2D_COPY'].bytes_transferred = (
            h_emb_cpu.nelement() * h_emb_cpu.element_size() +
            r_emb_cpu.nelement() * r_emb_cpu.element_size() +
            head_local.nelement() * 8
        )

        # Stage 4: RDF_KGE_FWD_DISTMULT -- h*r (DistMult) + scatter + proj
        with torch.no_grad():
            with CUDATimer('RDF_KGE_FWD_DISTMULT') as t:
                triple_emb = h_emb_gpu * r_emb_gpu  # DistMult: element-wise product
                graph_emb_raw = scatter(triple_emb, triple_graph, dim=0,
                                        dim_size=batch.batch_size, reduce='mean')
                graph_emb = kge_proj(graph_emb_raw)  # [B, 256]
        timings['RDF_KGE_FWD_DISTMULT'] = t.result()

        # Stage 5: PROJECTION
        with torch.no_grad():
            with CUDATimer('PROJECTION') as t:
                graph_tokens = gllm.project(graph_emb)
        timings['PROJECTION'] = t.result()

        # Stage 6: LLM_PREFILL
        with torch.no_grad():
            with CUDATimer('LLM_PREFILL') as t:
                input_ids, attn_mask = tokenize_questions(batch.questions)
                _ = gllm.forward_prefill(graph_tokens, input_ids, attn_mask)
        timings['LLM_PREFILL'] = t.result()

        timings_list.append(timings)
        batch_sizes.append(batch.batch_size)
        node_counts.append(global_idx.shape[0])
        edge_counts.append(ei.shape[1])

    return timings_list, batch_sizes, node_counts, edge_counts


# ═══════════════════════════════════════════════════════════
# Run Pass 1: Custom Timer Profiling
# ═══════════════════════════════════════════════════════════

all_profile_results = {}

print('=' * 60)
print('Pass 1: Custom CUDA Event Timer Profiling')
print('=' * 60)

print('\n[GAT (LPG)]')
tl, bs, nc, ec = profile_lpg(
    prof_gat, lpg_cpu_store, graph_llm, profile_loader, WARMUP_BATCHES, PROFILE_BATCHES)
gat_df = aggregate_timings(tl, bs, nc, ec)
all_profile_results['gat'] = {'df': gat_df, 'raw': tl, 'bs': bs, 'nc': nc, 'ec': ec}
print_profiling_table(gat_df, 'GAT (LPG)', len(tl))
torch.cuda.empty_cache()

print('\n[DistMult (RDF)]')
tl, bs, nc, ec = profile_rdf_distmult(
    rdf_cpu_store, prof_kge_proj, graph_llm, profile_loader, WARMUP_BATCHES, PROFILE_BATCHES)
dm_df = aggregate_timings(tl, bs, nc, ec)
all_profile_results['distmult'] = {'df': dm_df, 'raw': tl, 'bs': bs, 'nc': nc, 'ec': ec}
print_profiling_table(dm_df, 'DistMult (RDF)', len(tl))
torch.cuda.empty_cache()


# ═══════════════════════════════════════════════════════════
# Pass 2: torch.profiler + record_function -> Chrome Trace
# ═══════════════════════════════════════════════════════════

TRACE_BATCHES = 20  # More batches for meaningful Chrome trace
print(f'\n\n{"=" * 60}')
print(f'Pass 2: torch.profiler trace ({TRACE_BATCHES} batches)')
print(f'{"=" * 60}')


def profiler_lpg_pass(gat_model, lpg_store, gllm, loader, n_batches):
    gat_model.eval()
    batch_iter = islice(cycle(loader), n_batches)
    activities = [ProfilerActivity.CPU] + ([ProfilerActivity.CUDA] if HAS_CUDA else [])

    with torch_profile(activities=activities, record_shapes=True, profile_memory=True) as prof:
        for batch in batch_iter:
            with torch.no_grad():
                with record_function("SUBGRAPH_LOAD"):
                    global_idx = batch.lpg_global_node_idx
                    edge_index = batch.lpg_edge_index
                    batch_vec = batch.lpg_batch

                with record_function("LPG_CPU_GATHER_X"):
                    x_cpu = lpg_store.gather(global_idx)

                with record_function("LPG_H2D_COPY"):
                    x_gpu = x_cpu.to(device, non_blocking=False)
                    ei_gpu = edge_index.to(device, non_blocking=False)
                    b_gpu = batch_vec.to(device, non_blocking=False)

                with record_function("LPG_GAT_FWD"):
                    graph_emb = gat_model(x_gpu, ei_gpu, b_gpu)

                with record_function("PROJECTION"):
                    graph_tokens = gllm.project(graph_emb)

                with record_function("LLM_PREFILL"):
                    ids, mask = tokenize_questions(batch.questions)
                    _ = gllm.forward_prefill(graph_tokens, ids, mask)
    return prof


def profiler_rdf_pass(rdf_store, kge_proj, gllm, loader, n_batches):
    kge_proj.eval()
    batch_iter = islice(cycle(loader), n_batches)
    activities = [ProfilerActivity.CPU] + ([ProfilerActivity.CUDA] if HAS_CUDA else [])

    with torch_profile(activities=activities, record_shapes=True, profile_memory=True) as prof:
        for batch in batch_iter:
            ei = batch.rdf_edge_index
            if ei.shape[1] == 0:
                continue
            with torch.no_grad():
                with record_function("SUBGRAPH_LOAD"):
                    head_local = ei[0]
                    head_global = batch.rdf_global_node_idx[head_local]
                    edge_type = batch.rdf_edge_type
                    rdf_batch_vec = batch.rdf_batch

                with record_function("RDF_KGE_CPU_LOOKUP_ENT_REL"):
                    h_emb_cpu, r_emb_cpu = rdf_store.lookup(head_global, edge_type)

                with record_function("RDF_KGE_H2D_COPY"):
                    h_emb_gpu = h_emb_cpu.to(device, non_blocking=False)
                    r_emb_gpu = r_emb_cpu.to(device, non_blocking=False)
                    triple_graph = rdf_batch_vec[head_local].to(device, non_blocking=False)

                with record_function("RDF_KGE_FWD_DISTMULT"):
                    triple_emb = h_emb_gpu * r_emb_gpu
                    graph_emb_raw = scatter(triple_emb, triple_graph, dim=0,
                                            dim_size=batch.batch_size, reduce='mean')
                    graph_emb = kge_proj(graph_emb_raw)

                with record_function("PROJECTION"):
                    graph_tokens = gllm.project(graph_emb)

                with record_function("LLM_PREFILL"):
                    ids, mask = tokenize_questions(batch.questions)
                    _ = gllm.forward_prefill(graph_tokens, ids, mask)
    return prof


# Run torch.profiler passes
n_loader_batches = len(profile_loader)
print(f'\nDataLoader has {n_loader_batches} batches (BS={PROFILE_BATCH_SIZE}), '
      f'TRACE_BATCHES={TRACE_BATCHES} (cycling if needed)')

print('\n[GAT (LPG)] torch.profiler trace...')
gat_prof = profiler_lpg_pass(prof_gat, lpg_cpu_store, graph_llm, profile_loader, TRACE_BATCHES)

print('[DistMult (RDF)] torch.profiler trace...')
rdf_prof = profiler_rdf_pass(rdf_cpu_store, prof_kge_proj, graph_llm, profile_loader, TRACE_BATCHES)

# Print key_averages summary
print('\n--- GAT (LPG) torch.profiler key_averages ---')
print(gat_prof.key_averages().table(sort_by="cpu_time_total", row_limit=15))

print('\n--- DistMult (RDF) torch.profiler key_averages ---')
print(rdf_prof.key_averages().table(sort_by="cpu_time_total", row_limit=15))

# Export Chrome traces
gat_prof.export_chrome_trace('gat_profile_trace.json')
rdf_prof.export_chrome_trace('rdf_distmult_profile_trace.json')
print('\nChrome traces saved:')
print('  gat_profile_trace.json')
print('  rdf_distmult_profile_trace.json')
print('  View at: chrome://tracing or https://ui.perfetto.dev/')
torch.cuda.empty_cache()


### Cell 9: Combined Comparison Table + Bottleneck Analysis

**Objective:** Side-by-side comparison of GAT vs DistMult profiling results.

**Outputs:**
1. **Combined table** — Unified stage names × 2 models (mean ms, p95 ms, % of total)
2. **Bottleneck identification** — Which stage dominates each model's latency?
3. **Encoder vs LLM ratio** — How much time is graph encoding vs LLM prefill?
4. **H2D bandwidth** — Measured vs theoretical PCIe bandwidth (T4: 12 GB/s, A100: 32 GB/s)
5. **Graph size → latency correlation** — Spearman rho with p-value

In [None]:
# Cell 9: Combined Comparison Table + Bottleneck Analysis

MODEL_LABELS = {'gat': 'GAT (LPG)', 'distmult': 'DistMult (RDF)'}
MODEL_KEYS = ['gat', 'distmult']

# Unified stage names for cross-model comparison
UNIFIED_STAGES = [
    'SUBGRAPH_LOAD',
    'CPU_GATHER/LOOKUP',  # LPG_CPU_GATHER_X or RDF_KGE_CPU_LOOKUP_ENT_REL
    'H2D_COPY',           # LPG_H2D_COPY or RDF_KGE_H2D_COPY
    'ENCODER_FWD',        # LPG_GAT_FWD or RDF_KGE_FWD_DISTMULT
    'PROJECTION',
    'LLM_PREFILL',
]

# Map unified names to model-specific stage names
STAGE_MAP = {
    'gat': {
        'SUBGRAPH_LOAD': 'SUBGRAPH_LOAD',
        'CPU_GATHER/LOOKUP': 'LPG_CPU_GATHER_X',
        'H2D_COPY': 'LPG_H2D_COPY',
        'ENCODER_FWD': 'LPG_GAT_FWD',
        'PROJECTION': 'PROJECTION',
        'LLM_PREFILL': 'LLM_PREFILL',
    },
    'distmult': {
        'SUBGRAPH_LOAD': 'SUBGRAPH_LOAD',
        'CPU_GATHER/LOOKUP': 'RDF_KGE_CPU_LOOKUP_ENT_REL',
        'H2D_COPY': 'RDF_KGE_H2D_COPY',
        'ENCODER_FWD': 'RDF_KGE_FWD_DISTMULT',
        'PROJECTION': 'PROJECTION',
        'LLM_PREFILL': 'LLM_PREFILL',
    },
}


def get_stage_val(df, stage_name, col='cpu_mean'):
    row = df[df['stage'] == stage_name]
    return row[col].values[0] if len(row) > 0 else 0.0


# ── Combined Table ──
print(f'\n{"="*75}')
print(f'G-Retrieval E2E Bottleneck Analysis -- {PROFILE_BATCHES} batches, BS={PROFILE_BATCH_SIZE}, D={GRAPH_DIM}')
print(f'{"="*75}')

# Header
print(f'{"Stage":<25}', end='')
for mk in MODEL_KEYS:
    print(f'  {"mean":>6}  {"p95":>6}  {"%":>5}', end='')
print()
print(f'{"":<25}', end='')
for mk in MODEL_KEYS:
    print(f'  {MODEL_LABELS[mk]:^20}', end='')
print()
print('-' * 75)

# Per-model totals
model_totals = {}
for mk in MODEL_KEYS:
    df = all_profile_results[mk]['df']
    total = sum(get_stage_val(df, STAGE_MAP[mk][s]) for s in UNIFIED_STAGES)
    model_totals[mk] = total

# Print rows
bottleneck = {}
for stage in UNIFIED_STAGES:
    print(f'{stage:<25}', end='')
    for mk in MODEL_KEYS:
        df = all_profile_results[mk]['df']
        actual_stage = STAGE_MAP[mk][stage]
        mean_ms = get_stage_val(df, actual_stage, 'cpu_mean')
        p95_ms = get_stage_val(df, actual_stage, 'cpu_p95')
        pct = mean_ms / model_totals[mk] * 100 if model_totals[mk] > 0 else 0
        print(f'  {mean_ms:6.2f}  {p95_ms:6.2f}  {pct:4.1f}%', end='')
        if mk not in bottleneck or pct > bottleneck[mk][1]:
            bottleneck[mk] = (stage, pct)
    print()

print('-' * 75)
print(f'{"TOTAL":<25}', end='')
for mk in MODEL_KEYS:
    print(f'  {model_totals[mk]:6.2f}  {"":>6}  {"":>5}', end='')
print()

# ── Bottleneck ID ──
print(f'\n{"="*75}')
print('Bottleneck Identification:')
for mk in MODEL_KEYS:
    stage, pct = bottleneck[mk]
    print(f'  {MODEL_LABELS[mk]:20s}: {stage} ({pct:.1f}%)')

# ── Encoder vs LLM ──
print(f'\nEncoder vs LLM Time Ratio:')
for mk in MODEL_KEYS:
    df = all_profile_results[mk]['df']
    enc_ms = get_stage_val(df, STAGE_MAP[mk]['ENCODER_FWD'])
    proj_ms = get_stage_val(df, STAGE_MAP[mk]['PROJECTION'])
    llm_ms = get_stage_val(df, STAGE_MAP[mk]['LLM_PREFILL'])
    enc_total = enc_ms + proj_ms
    ratio = enc_total / llm_ms if llm_ms > 0 else float('inf')
    print(f'  {MODEL_LABELS[mk]:20s}: Encoder+Proj={enc_total:.2f}ms, LLM={llm_ms:.2f}ms, ratio={ratio:.3f}x')

# ── H2D Bandwidth ──
print(f'\nH2D Transfer Bandwidth:')
if HAS_CUDA:
    gpu_name = torch.cuda.get_device_name(0).lower()
    theoretical_bw = 32.0 if any(x in gpu_name for x in ['a100', 'l4']) else 16.0 if 'v100' in gpu_name else 12.0
    print(f'  Theoretical PCIe BW: ~{theoretical_bw:.0f} GB/s ({gpu_name})')
else:
    theoretical_bw = None

for mk in MODEL_KEYS:
    df = all_profile_results[mk]['df']
    h2d_stage = STAGE_MAP[mk]['H2D_COPY']
    h2d_row = df[df['stage'] == h2d_stage]
    if len(h2d_row) > 0:
        bytes_mean = h2d_row['bytes_mean'].values[0]
        cuda_mean = h2d_row['cuda_mean'].values[0]
        if not np.isnan(cuda_mean) and cuda_mean > 0:
            bw_gbps = (bytes_mean / 1e9) / (cuda_mean / 1e3)
            eff = f' ({bw_gbps/theoretical_bw*100:.0f}% eff)' if theoretical_bw else ''
            print(f'  {MODEL_LABELS[mk]:20s}: {bytes_mean/1e6:.2f} MB, {cuda_mean:.2f} ms -> {bw_gbps:.1f} GB/s{eff}')

# ── Graph Size -> Latency Correlation ──
print(f'\nGraph Size -> Total Latency Correlation (Spearman):')
from scipy.stats import spearmanr as _spearmanr
for mk in MODEL_KEYS:
    res = all_profile_results[mk]
    sizes = np.array(res['nc'] if mk == 'gat' else res['ec'])
    total_ms = np.array([sum(t[s].cpu_ms for s in t) for t in res['raw']])
    if len(sizes) > 3:
        corr, pval = _spearmanr(sizes, total_ms)
        label = "nodes" if mk == 'gat' else "triples"
        print(f'  {MODEL_LABELS[mk]:20s}: rho={corr:.3f}, p={pval:.3e} ({label})')

### Cell 10: Profiling Visualization

**Objective:** Visual summary of the profiling results in 4 charts.

**Charts:**
1. **Stacked bar** — Stage-wise latency breakdown (ms) per model, with total labels
2. **Scatter + trend** — Graph size (nodes/triples) vs total E2E latency
3. **Pie charts** — Time distribution per model (% of total)
4. **Encoder vs LLM** — Horizontal bar comparing graph encoding time vs LLM prefill time
5. **Label reference** — `record_function` label names + Chrome trace file paths

In [None]:
# Cell 10: Profiling Visualization
#
# 4 charts: stacked bar, scatter, pie charts, encoder-vs-LLM

fig = plt.figure(figsize=(16, 12))

# Stage colors
stage_colors = {
    'SUBGRAPH_LOAD': '#2196F3',
    'CPU_GATHER/LOOKUP': '#FF9800',
    'H2D_COPY': '#F44336',
    'ENCODER_FWD': '#4CAF50',
    'PROJECTION': '#9C27B0',
    'LLM_PREFILL': '#795548',
}


# ── Chart 1: Stacked Bar ──
ax1 = fig.add_subplot(2, 2, 1)
x_pos = np.arange(len(MODEL_KEYS))
bar_width = 0.4
bottom = np.zeros(len(MODEL_KEYS))

for stage in UNIFIED_STAGES:
    vals = []
    for mk in MODEL_KEYS:
        df = all_profile_results[mk]['df']
        actual = STAGE_MAP[mk][stage]
        vals.append(get_stage_val(df, actual))
    vals = np.array(vals)
    ax1.bar(x_pos, vals, bar_width, bottom=bottom,
            label=stage, color=stage_colors[stage], edgecolor='white', linewidth=0.5)
    bottom += vals

ax1.set_xticks(x_pos)
ax1.set_xticklabels([MODEL_LABELS[mk] for mk in MODEL_KEYS], fontsize=10)
ax1.set_ylabel('Time (ms)')
ax1.set_title(f'E2E Latency Breakdown (D={GRAPH_DIM})')
ax1.legend(fontsize=7, loc='upper left')
for i, mk in enumerate(MODEL_KEYS):
    ax1.text(i, bottom[i] + 0.5, f'{bottom[i]:.1f}ms', ha='center', fontsize=9, fontweight='bold')


# ── Chart 2: Scatter -- Graph Size vs Latency ──
ax2 = fig.add_subplot(2, 2, 2)
scatter_colors = {'gat': '#4CAF50', 'distmult': '#FF9800'}

for mk in MODEL_KEYS:
    res = all_profile_results[mk]
    sizes = np.array(res['nc'] if mk == 'gat' else res['ec'])
    total_ms = np.array([sum(t[s].cpu_ms for s in t) for t in res['raw']])
    ax2.scatter(sizes, total_ms, alpha=0.5, s=25, color=scatter_colors[mk], label=MODEL_LABELS[mk])
    if len(sizes) > 3:
        z = np.polyfit(sizes, total_ms, 1)
        p = np.poly1d(z)
        x_line = np.linspace(sizes.min(), sizes.max(), 50)
        ax2.plot(x_line, p(x_line), '--', color=scatter_colors[mk], alpha=0.7, linewidth=1.5)

ax2.set_xlabel('Graph Size (nodes for GAT, triples for DistMult)')
ax2.set_ylabel('Total Latency (ms)')
ax2.set_title('Graph Size vs E2E Latency')
ax2.legend(fontsize=9)


# ── Chart 3: Pie Charts ──
for idx, mk in enumerate(MODEL_KEYS):
    ax = fig.add_subplot(2, 4, 5 + idx)
    df = all_profile_results[mk]['df']
    sizes_pie, colors_pie, labels_pie = [], [], []
    for stage in UNIFIED_STAGES:
        actual = STAGE_MAP[mk][stage]
        val = get_stage_val(df, actual)
        if val > 0:
            sizes_pie.append(val)
            colors_pie.append(stage_colors[stage])
            labels_pie.append(stage)
    ax.pie(sizes_pie, labels=None, colors=colors_pie,
           autopct=lambda p: f'{p:.0f}%' if p > 5 else '',
           pctdistance=0.75, startangle=90, textprops={'fontsize': 7})
    ax.set_title(MODEL_LABELS[mk], fontsize=10, fontweight='bold')


# ── Chart 4: Encoder vs LLM ──
ax4 = fig.add_subplot(2, 4, 7)
enc_times, llm_times = [], []
for mk in MODEL_KEYS:
    df = all_profile_results[mk]['df']
    enc = get_stage_val(df, STAGE_MAP[mk]['ENCODER_FWD']) + get_stage_val(df, STAGE_MAP[mk]['PROJECTION'])
    llm = get_stage_val(df, STAGE_MAP[mk]['LLM_PREFILL'])
    enc_times.append(enc)
    llm_times.append(llm)

x_bar = np.arange(len(MODEL_KEYS))
bar_w = 0.35
ax4.barh(x_bar - bar_w/2, enc_times, bar_w, label='Encoder+Proj', color='#4CAF50')
ax4.barh(x_bar + bar_w/2, llm_times, bar_w, label='LLM Prefill', color='#795548')
ax4.set_yticks(x_bar)
ax4.set_yticklabels([mk.upper() for mk in MODEL_KEYS], fontsize=9)
ax4.set_xlabel('Time (ms)')
ax4.set_title('Encoder vs LLM', fontsize=10)
ax4.legend(fontsize=8)

# ── Chart 5 (4th position): record_function label reference ──
ax5 = fig.add_subplot(2, 4, 8)
ax5.axis('off')
label_text = 'record_function labels:\n\n'
label_text += 'GAT (LPG):\n'
for s in STAGE_MAP['gat'].values():
    label_text += f'  {s}\n'
label_text += '\nDistMult (RDF):\n'
for s in STAGE_MAP['distmult'].values():
    label_text += f'  {s}\n'
label_text += '\nChrome traces:\n  gat_profile_trace.json\n  rdf_distmult_profile_trace.json'
ax5.text(0.05, 0.95, label_text, transform=ax5.transAxes, fontsize=7,
         verticalalignment='top', fontfamily='monospace',
         bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

plt.suptitle(f'G-Retrieval E2E Profiling (D={GRAPH_DIM})', fontsize=14, fontweight='bold', y=1.01)
plt.tight_layout()
plt.show()

print('Track B profiling complete. Proceeding to Track A (retrieval evaluation)...')

---

## Track B-2: Attention Map Analysis

How do graph soft tokens influence the LLM's self-attention?

- **Graph tokens** are prepended at positions `0..NUM_GRAPH_TOKENS-1`
- **Text tokens** follow at positions `NUM_GRAPH_TOKENS..`
- We extract attention weights from every layer/head and measure how much text tokens "look at" graph tokens
- Compare GAT vs DistMult: does the graph encoder affect attention patterns?

### Cell 11: Attention Map Analysis — Graph Token Influence

**Objective:** Visualize how the LLM's self-attention interacts with prepended graph soft tokens.

**Method:**
1. Run LLM forward with `output_attentions=True`
2. Extract attention weights from all 32 layers × 32 heads
3. Measure: `attention[text_positions → graph_positions]` per layer/head
4. Aggregate across 8 samples for stability

**Visualizations (4 charts):**
1. **Layer × Head heatmap** — GAT and DistMult side-by-side; top-5 (layer, head) pairs marked with ★
2. **Per-layer attention curve** — Mean graph attention across layers (early/mid/late pattern)
3. **Per-position attention** — Which text tokens attend most to graph? (mid-layer average)
4. Uniform baseline reference line: `T / seq_len`

**Key questions:**
- Do early layers (feature extraction) or late layers (reasoning) attend more to graph tokens?
- Does graph encoder choice (GAT vs DistMult) change the attention pattern?
- Are there specialized "graph-reading" attention heads?

In [None]:
# Cell 11: Attention Map Analysis — Graph Token Influence
#
# Extract attention weights from every LLM layer/head and measure
# how much text tokens attend to the prepended graph soft tokens.

N_ATTN_SAMPLES = 8  # Number of samples to analyze (memory-intensive)
NUM_LAYERS = llm.config.num_hidden_layers  # 32 for Llama 3.1 8B
NUM_HEADS = llm.config.num_attention_heads  # 32 for Llama 3.1 8B

print(f'Attention analysis: {NUM_LAYERS} layers x {NUM_HEADS} heads, '
      f'{NUM_GRAPH_TOKENS} graph tokens, {N_ATTN_SAMPLES} samples')


@torch.no_grad()
def extract_attention_to_graph(gllm, graph_emb, questions, n_graph_tokens):
    """Run LLM with output_attentions=True, extract attention to graph tokens.

    Returns:
        attn_to_graph: [num_layers, num_heads] mean attention from text -> graph tokens
        token_attn:    [num_layers, seq_len] per-position attention to graph (head-averaged)
        tokens:        list of token strings for labeling
    """
    graph_tokens = gllm.project(graph_emb[:1])  # [1, T, H]

    enc = tokenizer(
        PROMPT_TEMPLATE.format(question=questions[0]),
        return_tensors='pt', truncation=True, max_length=128,
    )
    input_ids = enc['input_ids'].to(device)
    attention_mask = enc['attention_mask'].to(device)

    text_emb = gllm.llm.get_input_embeddings()(input_ids)
    combined = torch.cat([graph_tokens, text_emb], dim=1)
    graph_mask = torch.ones(1, n_graph_tokens, device=device, dtype=attention_mask.dtype)
    combined_mask = torch.cat([graph_mask, attention_mask], dim=1)

    outputs = gllm.llm(
        inputs_embeds=combined,
        attention_mask=combined_mask,
        output_attentions=True,
    )

    # outputs.attentions: tuple of [1, num_heads, seq_len, seq_len] per layer
    num_layers = len(outputs.attentions)
    seq_len = outputs.attentions[0].shape[2]

    # Mean attention from text tokens -> graph tokens, per layer x head
    attn_to_graph = torch.zeros(num_layers, NUM_HEADS)
    # Per-position attention to graph tokens (averaged across heads)
    token_attn = torch.zeros(num_layers, seq_len)

    for layer_idx, attn in enumerate(outputs.attentions):
        # attn: [1, heads, seq, seq]
        a = attn[0]  # [heads, seq, seq]

        # Attention from text positions to graph positions
        # text positions: n_graph_tokens..seq_len
        # graph positions: 0..n_graph_tokens
        text_to_graph = a[:, n_graph_tokens:, :n_graph_tokens]  # [heads, text_len, T]
        attn_to_graph[layer_idx] = text_to_graph.sum(dim=-1).mean(dim=-1).cpu()  # [heads]

        # Per-position: how much each position attends to graph (head-averaged)
        pos_to_graph = a[:, :, :n_graph_tokens].sum(dim=-1).mean(dim=0).cpu()  # [seq]
        token_attn[layer_idx] = pos_to_graph

    # Token labels for x-axis
    token_ids = input_ids[0].cpu().tolist()
    token_strs = ['[G]'] * n_graph_tokens + [
        tokenizer.decode([t]).strip()[:10] for t in token_ids
    ]

    return attn_to_graph, token_attn, token_strs


# ── Extract attention for both encoders ──
attn_results = {}
sample_batch = next(iter(profile_loader))

for enc_name in ['gat', 'distmult']:
    print(f'\n[{enc_name.upper()}] Extracting attention maps...')

    if enc_name == 'gat':
        x_cpu = lpg_cpu_store.gather(sample_batch.lpg_global_node_idx)
        graph_emb = prof_gat(
            x_cpu.to(device),
            sample_batch.lpg_edge_index.to(device),
            sample_batch.lpg_batch.to(device),
        )
    else:
        ei = sample_batch.rdf_edge_index
        head_local = ei[0]
        head_global = sample_batch.rdf_global_node_idx[head_local]
        h_emb, r_emb = rdf_cpu_store.lookup(head_global, sample_batch.rdf_edge_type)
        triple_graph = sample_batch.rdf_batch[head_local].to(device)
        triple_emb = h_emb.to(device) * r_emb.to(device)
        graph_emb_raw = scatter(triple_emb, triple_graph, dim=0,
                                dim_size=sample_batch.batch_size, reduce='mean')
        graph_emb = prof_kge_proj(graph_emb_raw)

    # Aggregate across multiple samples
    all_attn_to_graph = []
    n = min(N_ATTN_SAMPLES, sample_batch.batch_size)
    for s in range(n):
        a2g, tok_a, tok_s = extract_attention_to_graph(
            graph_llm, graph_emb[s:s+1], [sample_batch.questions[s]], NUM_GRAPH_TOKENS)
        all_attn_to_graph.append(a2g)

    mean_attn = torch.stack(all_attn_to_graph).mean(dim=0)  # [layers, heads]
    attn_results[enc_name] = {
        'attn_to_graph': mean_attn,
        'last_token_attn': tok_a,    # from last sample
        'token_strs': tok_s,
    }
    print(f'  Mean attention to graph tokens: {mean_attn.mean():.4f}')
    print(f'  Max (layer, head): layer={mean_attn.max(dim=1).values.argmax().item()}, '
          f'head={mean_attn.max(dim=0).values.argmax().item()}, val={mean_attn.max():.4f}')

torch.cuda.empty_cache()


# ══════════════════════════════════════════════════════
# Visualization
# ══════════════════════════════════════════════════════

fig = plt.figure(figsize=(20, 16))

# ── Chart 1 & 2: Layer x Head heatmaps (GAT vs DistMult) ──
for idx, enc_name in enumerate(['gat', 'distmult']):
    ax = fig.add_subplot(2, 2, idx + 1)
    a2g = attn_results[enc_name]['attn_to_graph'].numpy()
    im = ax.imshow(a2g, aspect='auto', cmap='YlOrRd', interpolation='nearest')
    ax.set_xlabel('Head')
    ax.set_ylabel('Layer')
    ax.set_title(f'{enc_name.upper()}: Text -> Graph Token Attention')
    plt.colorbar(im, ax=ax, shrink=0.8, label='Attention weight')

    # Mark top-5 (layer, head) pairs
    flat_idx = np.argsort(a2g.ravel())[-5:]
    for fi in flat_idx:
        ly, hd = np.unravel_index(fi, a2g.shape)
        ax.plot(hd, ly, 'k*', markersize=8)

# ── Chart 3: Per-layer mean attention to graph (both models) ──
ax3 = fig.add_subplot(2, 2, 3)
for enc_name, color in [('gat', '#4CAF50'), ('distmult', '#FF9800')]:
    a2g = attn_results[enc_name]['attn_to_graph']
    layer_mean = a2g.mean(dim=1).numpy()  # [layers]
    ax3.plot(range(NUM_LAYERS), layer_mean, '-o', color=color,
             markersize=3, label=f'{enc_name.upper()}', linewidth=1.5)

ax3.set_xlabel('Layer')
ax3.set_ylabel('Mean Attention to Graph Tokens')
ax3.set_title('Graph Token Attention Across Layers')
ax3.legend()
ax3.axhline(y=NUM_GRAPH_TOKENS / 128, color='gray', linestyle='--', alpha=0.5,
            label=f'Uniform baseline ({NUM_GRAPH_TOKENS}/128)')
ax3.legend()

# ── Chart 4: Per-position attention (last sample, both models) ──
ax4 = fig.add_subplot(2, 2, 4)
# Average across middle layers (layers 8-24 where most reasoning happens)
mid_start, mid_end = NUM_LAYERS // 4, 3 * NUM_LAYERS // 4
for enc_name, color in [('gat', '#4CAF50'), ('distmult', '#FF9800')]:
    tok_a = attn_results[enc_name]['last_token_attn']
    mid_avg = tok_a[mid_start:mid_end].mean(dim=0).numpy()  # [seq_len]
    seq_len = len(mid_avg)
    ax4.plot(range(seq_len), mid_avg, alpha=0.7, color=color, label=f'{enc_name.upper()}')

# Highlight graph token region
ax4.axvspan(0, NUM_GRAPH_TOKENS - 0.5, alpha=0.15, color='blue', label='Graph tokens')
ax4.set_xlabel('Position')
ax4.set_ylabel(f'Attention to Graph (layers {mid_start}-{mid_end} avg)')
ax4.set_title('Per-Position Attention to Graph Tokens')
ax4.legend(fontsize=8)

# Add token labels (sparse)
tok_strs = attn_results['gat']['token_strs']
step = max(1, len(tok_strs) // 20)
ax4.set_xticks(range(0, len(tok_strs), step))
ax4.set_xticklabels([tok_strs[i] for i in range(0, len(tok_strs), step)],
                     rotation=45, ha='right', fontsize=6)

plt.suptitle(f'Graph Token Attention Analysis (D={GRAPH_DIM}, T={NUM_GRAPH_TOKENS})',
             fontsize=14, fontweight='bold', y=1.01)
plt.tight_layout()
plt.show()


# ── Summary statistics ──
print('\n' + '=' * 70)
print(f'Attention Analysis Summary (T={NUM_GRAPH_TOKENS} graph tokens)')
print('=' * 70)

for enc_name in ['gat', 'distmult']:
    a2g = attn_results[enc_name]['attn_to_graph']
    layer_mean = a2g.mean(dim=1)

    # Which layers attend most to graph?
    top_layers = layer_mean.argsort(descending=True)[:5].tolist()
    # Early (0-7), Mid (8-23), Late (24-31) layer groups
    early = layer_mean[:8].mean().item()
    mid = layer_mean[8:24].mean().item()
    late = layer_mean[24:].mean().item()

    print(f'\n  {enc_name.upper()}:')
    print(f'    Overall mean attention to graph: {a2g.mean():.4f}')
    print(f'    Early layers (0-7):   {early:.4f}')
    print(f'    Mid layers (8-23):    {mid:.4f}')
    print(f'    Late layers (24-31):  {late:.4f}')
    print(f'    Top-5 layers: {top_layers}')
    print(f'    Max attention: layer={a2g.max(dim=1).values.argmax().item()}, '
          f'head={a2g.max(dim=0).values.argmax().item()}, val={a2g.max():.4f}')

# Compare GAT vs DistMult
gat_mean = attn_results['gat']['attn_to_graph'].mean().item()
dm_mean = attn_results['distmult']['attn_to_graph'].mean().item()
ratio = gat_mean / dm_mean if dm_mean > 0 else float('inf')
print(f'\n  GAT vs DistMult attention ratio: {ratio:.2f}x')
if ratio > 1.1:
    print('  -> GAT graph tokens receive MORE attention (richer graph signal)')
elif ratio < 0.9:
    print('  -> DistMult graph tokens receive MORE attention')
else:
    print('  -> Similar attention levels (graph encoder choice has limited impact on attention)')

print('\nTrack B complete.')


---

## Track A: Representation Quality Evaluation

Evaluate how well learned graph embeddings (384d) capture answer-relevant information via **cosine retrieval** against SentenceTransformer answer embeddings.

**Models compared:** GAT (LPG) vs TransE (RDF) vs DistMult (RDF) vs Question-only baseline

### Cell 6: Graph Embedding Extraction

**Objective:** Extract per-question graph-level embeddings from all 3 trained models (Track A's 384d models).

**Process:**
- GAT: `forward(x, edge_index, batch)` → `global_mean_pool` → [B, 384]
- TransE/DistMult: `forward(batch)` → triple aggregation via `scatter(mean)` → [B, 384]
- All embeddings L2-normalized for cosine similarity

**Outputs:** `train_emb`, `val_emb`, `test_emb` dicts with keys: `gat`, `transe`, `distmult`, `questions`, `answers`, `categories`

In [None]:
# Cell 6: Graph Embedding Extraction

@torch.no_grad()
def extract_embeddings(gat_model, kge_models, loader):
    """Extract per-question graph embeddings from all models.

    Returns dict with:
        'gat': [N, 384] LPG-GAT embeddings
        'transe': [N, 384] RDF-TransE embeddings
        'distmult': [N, 384] RDF-DistMult embeddings
        'questions': list of question strings
        'answers': list of answer strings
        'question_ids': list of question IDs
        'categories': list of category strings
    """
    gat_model.eval()
    for m in kge_models.values():
        m.eval()

    all_gat, all_transe, all_distmult = [], [], []
    all_questions, all_answers, all_qids, all_cats = [], [], [], []

    for batch in loader:
        batch_gpu = batch.to(device)

        # GAT embedding
        gat_emb = gat_model(batch_gpu.lpg_x, batch_gpu.lpg_edge_index, batch_gpu.lpg_batch)
        all_gat.append(gat_emb.cpu())

        # KGE embeddings
        for name, model in kge_models.items():
            emb = model(batch_gpu)
            if name == 'transe':
                all_transe.append(emb.cpu())
            else:
                all_distmult.append(emb.cpu())

        all_questions.extend(batch.questions)
        all_answers.extend(batch.answers)
        all_qids.extend(batch.question_ids)
        all_cats.extend(batch.categories)

    return {
        'gat': F.normalize(torch.cat(all_gat, dim=0), dim=-1),
        'transe': F.normalize(torch.cat(all_transe, dim=0), dim=-1),
        'distmult': F.normalize(torch.cat(all_distmult, dim=0), dim=-1),
        'questions': all_questions,
        'answers': all_answers,
        'question_ids': all_qids,
        'categories': all_cats,
    }


print('Extracting embeddings...')
train_emb = extract_embeddings(gat_model, kge_models, train_loader)
val_emb   = extract_embeddings(gat_model, kge_models, val_loader)
test_emb  = extract_embeddings(gat_model, kge_models, test_loader)

print(f'  Train: {train_emb["gat"].shape[0]} samples')
print(f'  Val:   {val_emb["gat"].shape[0]} samples')
print(f'  Test:  {test_emb["gat"].shape[0]} samples')
print(f'  Embedding dim: {train_emb["gat"].shape[1]}')

### Cell 7: Question-Answer Embedding Baseline

**Objective:** Create text-only baseline using SentenceTransformer (`all-MiniLM-L6-v2`, 384d).

**Process:**
1. Encode all test answers → `test_answer_emb` [N, 384]
2. Encode all test questions → `test_question_emb` [N, 384]
3. Quick sanity check: mean cosine similarity between graph embeddings and answer embeddings

**Note:** This baseline measures how well pure text similarity (question↔answer) performs, without any graph structure. It's the lower bound that graph models must beat to justify their complexity.

In [None]:
# Cell 7: Question-Answer Embedding Baseline

from sentence_transformers import SentenceTransformer

st_model = SentenceTransformer('all-MiniLM-L6-v2', device=str(device))


def encode_texts(texts, batch_size=64):
    """Encode texts with sentence-transformers → [N, 384] normalized."""
    embs = st_model.encode(texts, batch_size=batch_size, show_progress_bar=False,
                           convert_to_tensor=True, normalize_embeddings=True)
    return embs.cpu()


print('Encoding answers with sentence-transformers...')
test_answer_emb = encode_texts(test_emb['answers'])
test_question_emb = encode_texts(test_emb['questions'])

# Quick check: cosine similarity between graph embeddings and answer embeddings
print('\nMean cosine similarity (graph_emb · answer_emb):')
for name in ['gat', 'transe', 'distmult']:
    cos_sim = (test_emb[name] * test_answer_emb).sum(dim=-1).mean().item()
    print(f'  {name:10s}: {cos_sim:.4f}')

# Baseline: question-only embedding
q_cos = (test_question_emb * test_answer_emb).sum(dim=-1).mean().item()
print(f'  {"question":10s}: {q_cos:.4f} (text-only baseline)')

### Cell 8: Retrieval Evaluation (Graph → Answer)

**Objective:** Core evaluation — can graph embeddings retrieve the correct answer?

**Task setup:**
- Query: `graph_emb[i]` (384d, from GAT/TransE/DistMult)
- Corpus: `answer_emb[j]` (384d, from SentenceTransformer)
- Score: cosine similarity
- Ground truth: `i == j` (diagonal of similarity matrix)

**Metrics:**
- **MRR** (Mean Reciprocal Rank) — average of 1/rank of correct answer
- **Recall@1/5/10** — fraction of queries where correct answer is in top-K

**Baseline:** Question text embedding → Answer text embedding (no graph). If this beats graph models, graph information adds no value over text.

In [None]:
# Cell 8: Retrieval Evaluation

def retrieval_eval(query_emb, corpus_emb, ks=(1, 5, 10)):
    """Compute retrieval metrics: Recall@K and MRR.

    Each query[i] should retrieve corpus[i] (diagonal = ground truth).

    Args:
        query_emb: [N, D] query embeddings (graph or question)
        corpus_emb: [N, D] corpus embeddings (answers)
        ks: tuple of K values for Recall@K

    Returns:
        dict with 'mrr' and 'recall@k' for each k
    """
    # Similarity matrix [N, N]
    sim = query_emb @ corpus_emb.T
    N = sim.shape[0]

    # Rank of the correct answer (diagonal)
    diag = sim.diag().unsqueeze(1)  # [N, 1]
    ranks = (sim >= diag).sum(dim=1).float()  # [N] — 1-based rank

    results = {'mrr': (1.0 / ranks).mean().item()}
    for k in ks:
        results[f'recall@{k}'] = (ranks <= k).float().mean().item()
    return results


# Evaluate all models
print(f'{"Model":<12} {"MRR":>8} {"R@1":>8} {"R@5":>8} {"R@10":>8}')
print('-' * 48)

retrieval_results = {}
for name in ['gat', 'transe', 'distmult']:
    r = retrieval_eval(test_emb[name], test_answer_emb)
    retrieval_results[name] = r
    print(f'{name:<12} {r["mrr"]:8.4f} {r["recall@1"]:8.4f} {r["recall@5"]:8.4f} {r["recall@10"]:8.4f}')

# Baseline: question text → answer text retrieval
r_baseline = retrieval_eval(test_question_emb, test_answer_emb)
retrieval_results['question'] = r_baseline
print(f'{"question":<12} {r_baseline["mrr"]:8.4f} {r_baseline["recall@1"]:8.4f} {r_baseline["recall@5"]:8.4f} {r_baseline["recall@10"]:8.4f}')
print('  (question = text-only baseline, no graph)')

# Bar chart
fig, ax = plt.subplots(figsize=(10, 5))
models = list(retrieval_results.keys())
metrics = ['mrr', 'recall@1', 'recall@5', 'recall@10']
x = np.arange(len(models))
width = 0.2

for i, m in enumerate(metrics):
    vals = [retrieval_results[model][m] for model in models]
    ax.bar(x + i * width, vals, width, label=m.upper())

ax.set_xticks(x + width * 1.5)
ax.set_xticklabels([m.upper() for m in models])
ax.set_ylabel('Score')
ax.set_title('Graph → Answer Retrieval Performance')
ax.legend()
ax.set_ylim(0, 1)
plt.tight_layout()
plt.show()

### Cell 9: Category-wise Analysis

**Objective:** Break down retrieval performance by FinDER's 8 question categories.

**Method:**
- For each category, compute MRR and Recall@K using only that category's samples
- Identify best model per category
- Bar chart comparison across categories

**Categories:** Financials, Insurance, Banking, Real Estate, Securities, Economics, Accounting, General

**Key insight:** Different graph structures may capture different types of financial relationships better — e.g., GAT may excel in densely connected LPG categories while KGE models may perform better for categories with clear directional relationships.

In [None]:
# Cell 9: Category-wise Analysis

def category_retrieval_eval(emb_dict, answer_emb, categories):
    """Evaluate retrieval per category.

    Returns DataFrame: rows=categories, columns=model×metric.
    """
    cats = sorted(set(categories))
    cat_array = np.array(categories)
    rows = []

    for cat in cats:
        mask = cat_array == cat
        n = mask.sum()
        if n < 2:
            continue
        indices = np.where(mask)[0]

        row = {'category': cat, 'n': int(n)}
        for model_name in ['gat', 'transe', 'distmult']:
            q_emb = emb_dict[model_name][indices]
            a_emb = answer_emb[indices]
            r = retrieval_eval(q_emb, a_emb)
            for metric, val in r.items():
                row[f'{model_name}_{metric}'] = val
        rows.append(row)

    return pd.DataFrame(rows)


cat_df = category_retrieval_eval(test_emb, test_answer_emb, test_emb['categories'])
print('Category-wise MRR:')
print(cat_df[['category', 'n', 'gat_mrr', 'transe_mrr', 'distmult_mrr']].to_string(index=False))

# Bar chart: category × model MRR
fig, ax = plt.subplots(figsize=(14, 5))
cats = cat_df['category'].values
x = np.arange(len(cats))
width = 0.25

for i, model in enumerate(['gat', 'transe', 'distmult']):
    vals = cat_df[f'{model}_mrr'].values
    ax.bar(x + i * width, vals, width, label=model.upper())

ax.set_xticks(x + width)
ax.set_xticklabels(cats, rotation=30, ha='right')
ax.set_ylabel('MRR')
ax.set_title('Category-wise Graph → Answer Retrieval MRR')
ax.legend()
plt.tight_layout()
plt.show()

# Heatmap: model advantage per category
print('\nBest model per category (MRR):')
for _, row in cat_df.iterrows():
    mrrs = {m: row[f'{m}_mrr'] for m in ['gat', 'transe', 'distmult']}
    best = max(mrrs, key=mrrs.get)
    print(f'  {row["category"]:25s} → {best.upper()} ({mrrs[best]:.4f})')

### Cell 10: Graph Size Effect

**Objective:** Analyze correlation between subgraph size and retrieval quality.

**Method:**
- Collect per-sample graph sizes (LPG nodes, RDF triples) from test set
- Compute per-sample reciprocal rank for each model
- Scatter plot with binned trend lines (decile bins)
- Spearman rank correlation with p-values

**Expected patterns:**
- Positive correlation → more context helps capture answer information
- Negative correlation → noise from mean pooling dilutes signal
- No correlation → graph size is orthogonal to answer relevance

In [None]:
# Cell 10: Graph Size Effect

def compute_per_sample_rank(query_emb, corpus_emb):
    """Compute rank of correct answer for each sample."""
    sim = query_emb @ corpus_emb.T
    diag = sim.diag().unsqueeze(1)
    ranks = (sim >= diag).sum(dim=1).float()
    return ranks.numpy()


# Collect graph sizes from test set
test_lpg_nodes = []
test_lpg_edges = []
test_rdf_triples = []

for i in range(len(test_ds)):
    d = test_ds[i]
    test_lpg_nodes.append(d.lpg_num_nodes.item())
    test_lpg_edges.append(d.lpg_edge_index.shape[1])
    test_rdf_triples.append(d.rdf_edge_index.shape[1])

test_lpg_nodes = np.array(test_lpg_nodes)
test_lpg_edges = np.array(test_lpg_edges)
test_rdf_triples = np.array(test_rdf_triples)

# Compute per-sample reciprocal rank
rr = {}
for name in ['gat', 'transe', 'distmult']:
    ranks = compute_per_sample_rank(test_emb[name], test_answer_emb)
    rr[name] = 1.0 / ranks

# Scatter plots
fig, axes = plt.subplots(1, 3, figsize=(16, 4.5))

for ax, (name, size_arr, size_label) in zip(axes, [
    ('gat', test_lpg_nodes, 'LPG Nodes'),
    ('transe', test_rdf_triples, 'RDF Triples'),
    ('distmult', test_rdf_triples, 'RDF Triples'),
]):
    ax.scatter(size_arr, rr[name], alpha=0.3, s=10)
    ax.set_xlabel(size_label)
    ax.set_ylabel('Reciprocal Rank')
    ax.set_title(f'{name.upper()}')

    # Trend line via binning
    bins = np.percentile(size_arr, np.linspace(0, 100, 11))
    bins = np.unique(bins)
    if len(bins) >= 2:
        bin_idx = np.digitize(size_arr, bins) - 1
        bin_idx = np.clip(bin_idx, 0, len(bins) - 2)
        bin_centers = []
        bin_means = []
        for b in range(len(bins) - 1):
            mask = bin_idx == b
            if mask.sum() > 0:
                bin_centers.append((bins[b] + bins[b+1]) / 2)
                bin_means.append(rr[name][mask].mean())
        ax.plot(bin_centers, bin_means, 'r-o', linewidth=2, markersize=4, label='Binned mean')
        ax.legend()

plt.suptitle('Graph Size vs Retrieval Quality', y=1.02)
plt.tight_layout()
plt.show()

# Correlation
print('Spearman correlation (graph size vs reciprocal rank):')
from scipy.stats import spearmanr
for name, sizes in [('gat', test_lpg_nodes), ('transe', test_rdf_triples), ('distmult', test_rdf_triples)]:
    corr, pval = spearmanr(sizes, rr[name])
    print(f'  {name:10s}: rho={corr:.3f}, p={pval:.3e}')

### Cell 11: TransE vs DistMult Deep Dive

**Objective:** Head-to-head comparison of the two KGE models on RDF triples.

**Analysis:**
1. **Training dynamics** — Loss curves + validation MRR over epochs
2. **Per-sample win/loss** — For each test sample, which model ranks the correct answer higher?
3. **Category-level advantage** — MRR difference (TransE − DistMult) per category; identifies which categories benefit from asymmetric (TransE) vs symmetric (DistMult) modeling
4. **t-SNE embedding visualization** — 2D projection of graph embeddings colored by category; reveals clustering quality

**Key question:** Does TransE's ability to model directed relations (OWNS, REPORTED) outweigh DistMult's training stability in specific categories?

In [None]:
# Cell 11: TransE vs DistMult Deep Dive

# 1. Loss curve comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(kge_histories['transe']['train_loss'], label='TransE')
axes[0].plot(kge_histories['distmult']['train_loss'], label='DistMult')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss Comparison')
axes[0].legend()

# Val MRR over time
eval_epochs_kge = [e for e in range(1, KGE_EPOCHS+1) if e % 10 == 0 or e == 1]
for mt in ['transe', 'distmult']:
    n = len(kge_histories[mt]['val_mrr'])
    axes[1].plot(eval_epochs_kge[:n], kge_histories[mt]['val_mrr'][:n], '-o', label=mt.upper())
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Val MRR')
axes[1].set_title('Validation MRR Over Training')
axes[1].legend()
plt.tight_layout()
plt.show()


# 2. Per-sample comparison: where does one beat the other?
transe_rr = rr['transe']
distmult_rr = rr['distmult']
diff = transe_rr - distmult_rr  # positive = TransE better

transe_wins = (diff > 0).sum()
distmult_wins = (diff < 0).sum()
ties = (diff == 0).sum()
print(f'Per-sample comparison (test set, N={len(diff)}):')
print(f'  TransE wins:  {transe_wins} ({100*transe_wins/len(diff):.1f}%)')
print(f'  DistMult wins: {distmult_wins} ({100*distmult_wins/len(diff):.1f}%)')
print(f'  Ties:         {ties} ({100*ties/len(diff):.1f}%)')


# 3. Category-level TransE vs DistMult advantage
print('\nCategory-level advantage (MRR difference = TransE - DistMult):')
for _, row in cat_df.iterrows():
    delta = row['transe_mrr'] - row['distmult_mrr']
    arrow = '→ TransE' if delta > 0 else '→ DistMult'
    print(f'  {row["category"]:25s}: {delta:+.4f} {arrow}')


# 4. t-SNE visualization of embedding spaces
from sklearn.manifold import TSNE

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Use test set embeddings (subsample if needed)
n_vis = min(300, test_emb['gat'].shape[0])
vis_idx = np.random.choice(test_emb['gat'].shape[0], n_vis, replace=False)
vis_cats = np.array(test_emb['categories'])[vis_idx]
unique_cats = sorted(set(vis_cats))
cat_colors = {c: plt.cm.tab10(i) for i, c in enumerate(unique_cats)}
colors = [cat_colors[c] for c in vis_cats]

for ax, name in zip(axes, ['gat', 'transe', 'distmult']):
    emb = test_emb[name][vis_idx].numpy()
    tsne = TSNE(n_components=2, perplexity=30, random_state=42)
    proj = tsne.fit_transform(emb)

    for cat in unique_cats:
        mask = vis_cats == cat
        ax.scatter(proj[mask, 0], proj[mask, 1], c=[cat_colors[cat]],
                   s=15, alpha=0.6, label=cat[:15])
    ax.set_title(f'{name.upper()} Embedding Space')
    ax.set_xticks([])
    ax.set_yticks([])

axes[2].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
plt.suptitle('t-SNE of Graph Embeddings (colored by category)', y=1.02)
plt.tight_layout()
plt.show()

### Cell 12: Summary & Conclusions

**Objective:** Consolidate all results into a final summary table.

**Outputs:**
- Overall retrieval results table (MRR, Recall@K for all 3 models)
- Best model identification (by MRR)
- Per-category best model recommendation
- Model strengths & weaknesses interpretation
- Verification assertions (sanity checks)

**Interpretation guide:** See `docs/design/g_retrieval_experiment.md` § "예상 결과 해석 가이드" for pattern-to-interpretation mapping.

In [None]:
# Cell 12: Summary & Conclusions

print('=' * 60)
print('SUMMARY: G-Retrieval Style Comparison')
print('=' * 60)

# Overall retrieval results table
results_df = pd.DataFrame([
    {'model': name, **metrics}
    for name, metrics in retrieval_results.items()
    if name != 'question'
])
print('\n--- Graph → Answer Retrieval (Test Set) ---')
print(results_df.to_string(index=False, float_format='%.4f'))

# Best model
best_model = results_df.loc[results_df['mrr'].idxmax(), 'model']
print(f'\nBest overall model: {best_model.upper()} (MRR={results_df["mrr"].max():.4f})')

# Category breakdown summary
print('\n--- Best Model per Category (MRR) ---')
category_results = []
for _, row in cat_df.iterrows():
    mrrs = {m: row[f'{m}_mrr'] for m in ['gat', 'transe', 'distmult']}
    best = max(mrrs, key=mrrs.get)
    category_results.append({'category': row['category'], 'best_model': best,
                             'mrr': mrrs[best], 'n': row['n']})
    print(f'  {row["category"]:25s} → {best.upper():10s} (MRR={mrrs[best]:.4f}, n={int(row["n"])})')

category_results = pd.DataFrame(category_results)

# Model strengths
print('\n--- Model Strengths & Weaknesses ---')
print('GAT (LPG):      Pre-computed 384d node features + message passing.')
print('                 Best for: categories with rich LPG structure.')
print('TransE (RDF):    Translation h+r≈t. Asymmetric, handles directed relations.')
print('                 Best for: categories with directional relationships (OWNS, REPORTED).')
print('DistMult (RDF):  Bilinear h·r·t. Symmetric, simpler training dynamics.')
print('                 Best for: symmetric or co-occurrence patterns.')

# Verification assertions
assert len(results_df) == 3, f'Expected 3 models, got {len(results_df)}'
assert all(col in results_df.columns for col in ['model', 'mrr', 'recall@1', 'recall@5'])
assert len(category_results) >= 1, 'No category results'
print(f'\nVerification passed: {len(results_df)} models, {len(category_results)} categories.')