# G-Retrieval Style Comparison: LPG (GAT) vs RDF (TransE / DistMult)

This notebook trains graph encoders on the FinDER dual-graph PyG dataset and evaluates
how well learned graph embeddings capture answer-relevant information.

**Models:**
- **GAT** (Graph Attention Network) on LPG subgraphs — 384d node features + message passing
- **TransE** on RDF triples — translation-based: h + r ≈ t (asymmetric)
- **DistMult** on RDF triples — bilinear: h · diag(r) · t (symmetric)

**Evaluation:**
- Link prediction (self-supervised training objective)
- Graph→Answer retrieval (cosine similarity with sentence embeddings)
- Category-wise and graph-size breakdown

In [None]:
# Cell 1: Setup & Install
# Uncomment for Colab:
# !pip install torch torch-geometric sentence-transformers rouge-score
# !pip install torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html

import sys, os, json, time, warnings
from pathlib import Path
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 11, 'figure.dpi': 120})
warnings.filterwarnings('ignore', category=FutureWarning)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.nn.kge import TransE, DistMult
from torch_geometric.utils import negative_sampling, scatter

# Project imports — adjust path for Colab
PROJECT_ROOT = Path('..').resolve()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.data import FinDERGraphQADataset, DualGraphBatch, dual_graph_collate_fn, VocabularyBuilder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if device.type == 'cuda':
    print(f'  GPU: {torch.cuda.get_device_name()}')
    print(f'  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
# Cell 2: Data Loading
DATA_ROOT = PROJECT_ROOT / 'data' / 'processed' / 'finder_pyg'

train_ds = FinDERGraphQADataset(root=str(DATA_ROOT), split='train')
val_ds   = FinDERGraphQADataset(root=str(DATA_ROOT), split='val')
test_ds  = FinDERGraphQADataset(root=str(DATA_ROOT), split='test')

BATCH_SIZE = 32
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=dual_graph_collate_fn, num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=dual_graph_collate_fn, num_workers=0)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=dual_graph_collate_fn, num_workers=0)

# Load vocabularies
vocabs = FinDERGraphQADataset.get_vocab(root=str(DATA_ROOT))
metadata = json.loads((DATA_ROOT / 'processed' / 'metadata.json').read_text())

NUM_RDF_ENTITIES  = metadata['vocab_sizes']['rdf_entities']   # 17,534
NUM_RDF_RELATIONS = metadata['vocab_sizes']['rdf_relations']  # 4,340
LPG_FEATURE_DIM   = metadata['lpg_feature_dim']               # 384

# Dataset statistics
print(f"\n{'='*50}")
print(f"Dataset Statistics")
print(f"{'='*50}")
print(f"  Train: {len(train_ds):,} | Val: {len(val_ds):,} | Test: {len(test_ds):,}")
print(f"  LPG feature dim: {LPG_FEATURE_DIM}")
print(f"  RDF entities: {NUM_RDF_ENTITIES:,} | relations: {NUM_RDF_RELATIONS:,}")

# Category distribution
categories = [train_ds[i].category for i in range(len(train_ds))]
cat_counts = pd.Series(categories).value_counts()
print(f"\nCategory distribution (train):")
for cat, count in cat_counts.items():
    print(f"  {cat}: {count}")

# Average graph sizes
lpg_nodes = [train_ds[i].lpg_num_nodes.item() for i in range(min(200, len(train_ds)))]
rdf_edges = [train_ds[i].rdf_edge_index.shape[1] for i in range(min(200, len(train_ds)))]
print(f"\nAvg LPG nodes/sample: {np.mean(lpg_nodes):.1f} (±{np.std(lpg_nodes):.1f})")
print(f"Avg RDF triples/sample: {np.mean(rdf_edges):.1f} (±{np.std(rdf_edges):.1f})")

In [None]:
# Cell 3: Model Definitions

# --- GAT for LPG (batched) ---

class BatchedGAT(nn.Module):
    """GAT encoder for LPG subgraphs with batched graph-level pooling.

    Based on MessagePassingGNN from src/_legacy/models.py but uses
    global_mean_pool for proper mini-batch support.
    """

    def __init__(
        self,
        input_dim: int = 384,
        hidden_dim: int = 256,
        output_dim: int = 384,
        num_layers: int = 2,
        heads: int = 4,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        for i in range(num_layers):
            in_ch = hidden_dim if i == 0 else hidden_dim * heads
            self.convs.append(GATConv(in_ch, hidden_dim, heads=heads, dropout=dropout))
            self.norms.append(nn.LayerNorm(hidden_dim * heads))
        self.output_proj = nn.Linear(hidden_dim * heads, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index, batch):
        """Forward pass with batched graph-level pooling.

        Args:
            x: [sum(N_i), input_dim] node features
            edge_index: [2, sum(E_i)] COO edges
            batch: [sum(N_i)] graph membership index

        Returns:
            [B, output_dim] graph-level embeddings
        """
        x = torch.relu(self.input_proj(x))
        for conv, norm in zip(self.convs, self.norms):
            x = conv(x, edge_index)
            x = norm(x)
            x = torch.relu(x)
            x = self.dropout(x)
        node_emb = self.output_proj(x)  # [sum(N_i), output_dim]
        return global_mean_pool(node_emb, batch)  # [B, output_dim]

    def get_node_embeddings(self, x, edge_index):
        """Get per-node embeddings (no pooling). For link prediction decoding."""
        x = torch.relu(self.input_proj(x))
        for conv, norm in zip(self.convs, self.norms):
            x = conv(x, edge_index)
            x = norm(x)
            x = torch.relu(x)
            x = self.dropout(x)
        return self.output_proj(x)  # [sum(N_i), output_dim]


# --- KGE for RDF (TransE / DistMult) with per-graph aggregation ---

class BatchedKGE(nn.Module):
    """KGE encoder for RDF triples with per-graph embedding aggregation.

    Learns global entity and relation embeddings, then for each question's
    RDF subgraph, aggregates triple representations via scatter mean.
    """

    def __init__(
        self,
        model_type: str,  # 'transe' or 'distmult'
        num_entities: int,
        num_relations: int,
        hidden_dim: int = 256,
        output_dim: int = 384,
    ):
        super().__init__()
        self.model_type = model_type
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        if model_type == 'transe':
            self.kge = TransE(
                num_nodes=num_entities,
                num_relations=num_relations,
                hidden_channels=hidden_dim,
                margin=1.0,
                p_norm=1.0,
            )
        elif model_type == 'distmult':
            self.kge = DistMult(
                num_nodes=num_entities,
                num_relations=num_relations,
                hidden_channels=hidden_dim,
            )
        else:
            raise ValueError(f"Unknown model_type: {model_type}")

        self.output_proj = nn.Linear(hidden_dim, output_dim)

    def loss(self, head_index, rel_type, tail_index):
        """KGE training loss with built-in negative sampling."""
        return self.kge.loss(head_index, rel_type, tail_index)

    def forward(self, batch: 'DualGraphBatch'):
        """Compute per-graph embeddings from RDF triples.

        For each triple (h, r, t) in the batch, computes h_emb + r_emb,
        then aggregates per-graph via scatter mean on the head node's
        graph membership.

        Returns:
            [B, output_dim] graph-level embeddings
        """
        edge_index = batch.rdf_edge_index  # [2, sum(T_i)]
        edge_type = batch.rdf_edge_type    # [sum(T_i)]
        rdf_batch = batch.rdf_batch        # [sum(N_rdf_i)]
        global_idx = batch.rdf_global_node_idx  # [sum(N_rdf_i)]

        if edge_index.shape[1] == 0:
            return torch.zeros(batch.batch_size, self.output_dim, device=edge_index.device)

        head_local = edge_index[0]  # local node indices
        tail_local = edge_index[1]

        # Map local indices to global for embedding lookup
        head_global = global_idx[head_local]
        tail_global = global_idx[tail_local]

        head_emb = self.kge.node_emb(head_global)  # [sum(T_i), hidden_dim]
        rel_emb = self.kge.rel_emb(edge_type)       # [sum(T_i), hidden_dim]

        # Triple representation: h + r (translation-style, works for both)
        triple_emb = head_emb + rel_emb  # [sum(T_i), hidden_dim]

        # Determine graph membership for each triple (from head node)
        triple_graph = rdf_batch[head_local]  # [sum(T_i)]

        # Aggregate triples per graph (using PyG's scatter with reduce='mean')
        graph_emb = scatter(triple_emb, triple_graph, dim=0,
                            dim_size=batch.batch_size, reduce='mean')  # [B, hidden_dim]

        return self.output_proj(graph_emb)  # [B, output_dim]

    def get_entity_embeddings(self):
        """Export projected entity embeddings [num_entities, output_dim]."""
        with torch.no_grad():
            return self.output_proj(self.kge.node_emb.weight)


# Quick sanity check
print('BatchedGAT params:', sum(p.numel() for p in BatchedGAT().parameters()) / 1e3, 'K')
print('BatchedKGE (TransE) params:',
      sum(p.numel() for p in BatchedKGE('transe', NUM_RDF_ENTITIES, NUM_RDF_RELATIONS).parameters()) / 1e6, 'M')
print('BatchedKGE (DistMult) params:',
      sum(p.numel() for p in BatchedKGE('distmult', NUM_RDF_ENTITIES, NUM_RDF_RELATIONS).parameters()) / 1e6, 'M')

In [None]:
# Cell 4: Training — GAT (LPG) via Link Prediction

def train_gat_epoch(model, loader, optimizer):
    """Train GAT via link prediction with negative sampling."""
    model.train()
    total_loss = 0
    num_batches = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # Get node embeddings (no pooling)
        z = model.get_node_embeddings(batch.lpg_x, batch.lpg_edge_index)

        # Positive edges
        pos_edge = batch.lpg_edge_index
        num_nodes = batch.lpg_x.shape[0]

        if pos_edge.shape[1] == 0:
            continue

        # Negative sampling
        neg_edge = negative_sampling(
            pos_edge, num_nodes=num_nodes,
            num_neg_samples=pos_edge.shape[1],
        )

        # Score positive and negative edges via dot product
        pos_score = (z[pos_edge[0]] * z[pos_edge[1]]).sum(dim=-1)
        neg_score = (z[neg_edge[0]] * z[neg_edge[1]]).sum(dim=-1)

        # BCE loss
        pos_loss = F.binary_cross_entropy_with_logits(pos_score, torch.ones_like(pos_score))
        neg_loss = F.binary_cross_entropy_with_logits(neg_score, torch.zeros_like(neg_score))
        loss = pos_loss + neg_loss

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    return total_loss / max(num_batches, 1)


@torch.no_grad()
def eval_gat_link_prediction(model, loader):
    """Evaluate GAT link prediction: MRR and Hits@10."""
    model.eval()
    mrr_sum, hits10_sum, count = 0.0, 0.0, 0

    for batch in loader:
        batch = batch.to(device)
        z = model.get_node_embeddings(batch.lpg_x, batch.lpg_edge_index)

        pos_edge = batch.lpg_edge_index
        if pos_edge.shape[1] == 0:
            continue

        num_nodes = z.shape[0]
        # Score a sample of edges (full ranking is too slow)
        sample_size = min(500, pos_edge.shape[1])
        idx = torch.randperm(pos_edge.shape[1])[:sample_size]
        src, dst = pos_edge[0, idx], pos_edge[1, idx]

        for s, d in zip(src, dst):
            # Score true tail vs all nodes
            scores = (z[s].unsqueeze(0) * z).sum(dim=-1)  # [num_nodes]
            rank = (scores >= scores[d]).sum().item()
            mrr_sum += 1.0 / rank
            hits10_sum += 1.0 if rank <= 10 else 0.0
            count += 1

    mrr = mrr_sum / max(count, 1)
    hits10 = hits10_sum / max(count, 1)
    return {'mrr': mrr, 'hits@10': hits10}


# Train GAT
GAT_EPOCHS = 50
GAT_LR = 1e-3

gat_model = BatchedGAT(input_dim=LPG_FEATURE_DIM).to(device)
gat_optimizer = torch.optim.Adam(gat_model.parameters(), lr=GAT_LR, weight_decay=1e-5)

gat_history = {'train_loss': [], 'val_mrr': [], 'val_hits10': []}
best_val_mrr = 0.0
best_gat_state = None

print(f'Training GAT for {GAT_EPOCHS} epochs...')
for epoch in range(1, GAT_EPOCHS + 1):
    loss = train_gat_epoch(gat_model, train_loader, gat_optimizer)
    gat_history['train_loss'].append(loss)

    if epoch % 5 == 0 or epoch == 1:
        val_metrics = eval_gat_link_prediction(gat_model, val_loader)
        gat_history['val_mrr'].append(val_metrics['mrr'])
        gat_history['val_hits10'].append(val_metrics['hits@10'])
        print(f'  Epoch {epoch:3d} | Loss: {loss:.4f} | Val MRR: {val_metrics["mrr"]:.4f} | Val Hits@10: {val_metrics["hits@10"]:.3f}')

        if val_metrics['mrr'] > best_val_mrr:
            best_val_mrr = val_metrics['mrr']
            best_gat_state = {k: v.cpu().clone() for k, v in gat_model.state_dict().items()}

# Restore best model
if best_gat_state:
    gat_model.load_state_dict(best_gat_state)
    gat_model.to(device)
print(f'\nBest GAT Val MRR: {best_val_mrr:.4f}')

In [None]:
# Cell 5: Training — TransE & DistMult (RDF)

def collect_all_rdf_triples(loader):
    """Collect all (head_global, rel, tail_global) triples from the dataset."""
    heads, rels, tails = [], [], []
    for batch in loader:
        ei = batch.rdf_edge_index
        et = batch.rdf_edge_type
        gi = batch.rdf_global_node_idx
        if ei.shape[1] == 0:
            continue
        heads.append(gi[ei[0]])
        tails.append(gi[ei[1]])
        rels.append(et)
    return torch.cat(heads), torch.cat(rels), torch.cat(tails)


def train_kge_epoch(model, head, rel, tail, optimizer, batch_size=512):
    """Train KGE model for one epoch over all triples."""
    model.train()
    perm = torch.randperm(head.shape[0], device=head.device)
    total_loss = 0
    num_batches = 0

    for i in range(0, head.shape[0], batch_size):
        idx = perm[i:i+batch_size]
        optimizer.zero_grad()
        loss = model.loss(head[idx], rel[idx], tail[idx])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        num_batches += 1

    return total_loss / max(num_batches, 1)


@torch.no_grad()
def eval_kge(model, head, rel, tail, sample_size=1000, k=10):
    """Evaluate KGE with sampled ranking: MRR and Hits@K."""
    model.eval()
    n = min(sample_size, head.shape[0])
    idx = torch.randperm(head.shape[0])[:n]
    h, r, t = head[idx], rel[idx], tail[idx]

    node_emb = model.kge.node_emb.weight  # [num_entities, hidden_dim]
    rel_emb = model.kge.rel_emb(r)        # [n, hidden_dim]
    h_emb = model.kge.node_emb(h)          # [n, hidden_dim]

    mrr_sum, hits_sum = 0.0, 0.0

    for i in range(n):
        if model.model_type == 'transe':
            # score = -||h + r - t||  (higher = better)
            pred = h_emb[i] + rel_emb[i]  # [hidden_dim]
            scores = -torch.norm(node_emb - pred.unsqueeze(0), p=1, dim=-1)  # [num_entities]
        else:  # distmult
            # score = sum(h * r * t)
            pred = h_emb[i] * rel_emb[i]  # [hidden_dim]
            scores = (node_emb * pred.unsqueeze(0)).sum(dim=-1)  # [num_entities]

        rank = (scores >= scores[t[i]]).sum().item()
        mrr_sum += 1.0 / max(rank, 1)
        hits_sum += 1.0 if rank <= k else 0.0

    return {'mrr': mrr_sum / n, f'hits@{k}': hits_sum / n}


# Collect triples
print('Collecting RDF triples...')
train_h, train_r, train_t = collect_all_rdf_triples(train_loader)
val_h, val_r, val_t = collect_all_rdf_triples(val_loader)
train_h, train_r, train_t = train_h.to(device), train_r.to(device), train_t.to(device)
val_h, val_r, val_t = val_h.to(device), val_r.to(device), val_t.to(device)
print(f'  Train triples: {train_h.shape[0]:,} | Val triples: {val_h.shape[0]:,}')

# Train both models
KGE_EPOCHS = 100
KGE_LR = 1e-2
KGE_BATCH = 512

kge_models = {}
kge_histories = {}

for model_type in ['transe', 'distmult']:
    print(f'\n{"="*50}')
    print(f'Training {model_type.upper()} for {KGE_EPOCHS} epochs...')
    print(f'{"="*50}')

    model = BatchedKGE(model_type, NUM_RDF_ENTITIES, NUM_RDF_RELATIONS).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=KGE_LR)
    history = {'train_loss': [], 'val_mrr': [], 'val_hits10': []}
    best_mrr = 0.0
    best_state = None

    for epoch in range(1, KGE_EPOCHS + 1):
        loss = train_kge_epoch(model, train_h, train_r, train_t, optimizer, KGE_BATCH)
        history['train_loss'].append(loss)

        if epoch % 10 == 0 or epoch == 1:
            val_metrics = eval_kge(model, val_h, val_r, val_t)
            history['val_mrr'].append(val_metrics['mrr'])
            history['val_hits10'].append(val_metrics['hits@10'])
            print(f'  Epoch {epoch:3d} | Loss: {loss:.4f} | Val MRR: {val_metrics["mrr"]:.4f} | Val Hits@10: {val_metrics["hits@10"]:.3f}')

            if val_metrics['mrr'] > best_mrr:
                best_mrr = val_metrics['mrr']
                best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

    if best_state:
        model.load_state_dict(best_state)
        model.to(device)
    print(f'Best {model_type.upper()} Val MRR: {best_mrr:.4f}')

    kge_models[model_type] = model
    kge_histories[model_type] = history

# Loss curve comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
for mt, h in kge_histories.items():
    axes[0].plot(h['train_loss'], label=mt.upper())
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('KGE Training Loss')
axes[0].legend()

# Also plot GAT loss
axes[1].plot(gat_history['train_loss'], label='GAT', color='tab:green')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('GAT Training Loss')
axes[1].legend()
plt.tight_layout()
plt.show()

---

# Track B: G-Retrieval E2E Pipeline + Bottleneck Profiling

**G-Retrieval Architecture:** Graph Encoder → Projection → LLM (Llama 3 via HuggingFace)

Profile the **full inference pipeline** stage-by-stage:

| Stage | LPG (GAT) | RDF (TransE / DistMult) |
|-------|-----------|------------------------|
| **SUBGRAPH_LOAD** | Extract tensors from batch | Extract edge_index, edge_type, global_idx |
| **CPU_GATHER / LOOKUP** | Feature store indexing on CPU | Embedding table lookup on CPU |
| **H2D_COPY** | Synchronous transfer to GPU | Synchronous transfer to GPU |
| **ENCODER_FWD** | GAT layers + global_mean_pool | h+r aggregation + scatter + proj |
| **PROJECTION** | Linear(384 → LLM hidden) | Linear(384 → LLM hidden) |
| **LLM_PREFILL** | LLM forward with graph soft tokens prepended | Same |

Simulates a **CPU-resident feature/embedding store** (G-Retrieval production pattern).

In [None]:
# Cell 6: Profiling Utilities + CPU Store Simulators

import time
from torch.profiler import profile as torch_profile, ProfilerActivity, record_function

# ── Constants ──
WARMUP_BATCHES = 3
PROFILE_BATCHES = 20
PROFILE_BATCH_SIZE = 32

# ── Timing Infrastructure ──

@dataclass
class StageTiming:
    """Timing result for a single profiling stage."""
    stage: str
    cpu_ms: float
    cuda_ms: float = float('nan')  # NaN for CPU-only stages
    bytes_transferred: int = 0

HAS_CUDA = device.type == 'cuda'


class CUDATimer:
    """Context manager: measures both CPU wall-clock and CUDA kernel time."""
    def __init__(self, stage: str):
        self.stage = stage
        if HAS_CUDA:
            self.start_evt = torch.cuda.Event(enable_timing=True)
            self.end_evt = torch.cuda.Event(enable_timing=True)

    def __enter__(self):
        if HAS_CUDA:
            torch.cuda.synchronize()
        self.cpu_start = time.perf_counter()
        if HAS_CUDA:
            self.start_evt.record()
        return self

    def __exit__(self, *args):
        if HAS_CUDA:
            self.end_evt.record()
            torch.cuda.synchronize()
        self.cpu_end = time.perf_counter()
        cpu_ms = (self.cpu_end - self.cpu_start) * 1000
        cuda_ms = self.start_evt.elapsed_time(self.end_evt) if HAS_CUDA else float('nan')
        self._result = StageTiming(stage=self.stage, cpu_ms=cpu_ms, cuda_ms=cuda_ms)

    def result(self) -> StageTiming:
        return self._result


class CPUTimer:
    """Context manager: measures CPU wall-clock only (no GPU sync)."""
    def __init__(self, stage: str):
        self.stage = stage

    def __enter__(self):
        self.cpu_start = time.perf_counter()
        return self

    def __exit__(self, *args):
        self.cpu_end = time.perf_counter()
        cpu_ms = (self.cpu_end - self.cpu_start) * 1000
        self._result = StageTiming(stage=self.stage, cpu_ms=cpu_ms)

    def result(self) -> StageTiming:
        return self._result


# ── CPU Store Simulators (G-Retrieval production pattern) ──

class LPGCPUStore:
    """CPU-resident node feature store for LPG.

    Simulates the G-Retrieval pattern where graph node features
    come from a CPU feature database, not pre-loaded on GPU.
    """
    def __init__(self, global_node_features: torch.Tensor):
        self.features = global_node_features.cpu().clone()  # [13920, 384]
        print(f'  LPGCPUStore: {self.features.shape}, '
              f'{self.features.nelement() * self.features.element_size() / 1e6:.1f} MB on CPU')

    def gather(self, global_node_idx: torch.LongTensor) -> torch.Tensor:
        return self.features[global_node_idx.cpu()]


class RDFCPUStore:
    """CPU-resident KGE embedding tables for RDF.

    Moves trained entity/relation embeddings to CPU to simulate
    the production pattern where embedding tables are too large for GPU.
    """
    def __init__(self, kge_model: 'BatchedKGE'):
        self.node_emb_weight = kge_model.kge.node_emb.weight.detach().cpu().clone()
        self.rel_emb_weight = kge_model.kge.rel_emb.weight.detach().cpu().clone()
        node_mb = self.node_emb_weight.nelement() * self.node_emb_weight.element_size() / 1e6
        rel_mb = self.rel_emb_weight.nelement() * self.rel_emb_weight.element_size() / 1e6
        print(f'  RDFCPUStore: node_emb {self.node_emb_weight.shape} ({node_mb:.1f} MB) + '
              f'rel_emb {self.rel_emb_weight.shape} ({rel_mb:.1f} MB) on CPU')

    def lookup(self, head_global: torch.LongTensor,
               edge_type: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]:
        h_emb = self.node_emb_weight[head_global.cpu()]
        r_emb = self.rel_emb_weight[edge_type.cpu()]
        return h_emb, r_emb


# ── Result Aggregation ──

def aggregate_timings(all_timings: List[Dict[str, StageTiming]],
                      batch_sizes: List[int],
                      node_counts: List[int],
                      edge_counts: List[int]) -> pd.DataFrame:
    """Aggregate per-batch timings into summary statistics."""
    stages = list(all_timings[0].keys())
    rows = []
    for stage in stages:
        cpu_vals = [t[stage].cpu_ms for t in all_timings]
        cuda_vals = [t[stage].cuda_ms for t in all_timings]
        bytes_vals = [t[stage].bytes_transferred for t in all_timings]
        row = {
            'stage': stage,
            'cpu_mean': np.mean(cpu_vals),
            'cpu_std': np.std(cpu_vals),
            'cpu_p95': np.percentile(cpu_vals, 95),
            'cuda_mean': np.nanmean(cuda_vals) if not all(np.isnan(cuda_vals)) else float('nan'),
            'cuda_std': np.nanstd(cuda_vals) if not all(np.isnan(cuda_vals)) else float('nan'),
            'cuda_p95': np.nanpercentile(cuda_vals, 95) if not all(np.isnan(cuda_vals)) else float('nan'),
            'bytes_mean': np.mean(bytes_vals),
        }
        rows.append(row)
    return pd.DataFrame(rows)


def print_profiling_table(df: pd.DataFrame, model_name: str, n_batches: int):
    """Pretty-print a profiling summary table."""
    total_cpu = df['cpu_mean'].sum()
    print(f'\n{"="*65}')
    print(f'{model_name} Profiling — {n_batches} batches, BS={PROFILE_BATCH_SIZE}')
    print(f'{"="*65}')
    print(f'{"Stage":<18} {"CPU ms":>10} {"CUDA ms":>10} {"% total":>8} {"H2D MB/s":>10}')
    print(f'{"-"*65}')
    for _, r in df.iterrows():
        pct = r['cpu_mean'] / total_cpu * 100 if total_cpu > 0 else 0
        cuda_str = f'{r["cuda_mean"]:.2f}' if not np.isnan(r['cuda_mean']) else '—'
        bw_str = ''
        if r['bytes_mean'] > 0 and not np.isnan(r['cuda_mean']) and r['cuda_mean'] > 0:
            bw_gbps = (r['bytes_mean'] / 1e9) / (r['cuda_mean'] / 1e3)
            bw_str = f'{bw_gbps:.1f} GB/s'
        print(f'{r["stage"]:<18} {r["cpu_mean"]:>8.2f}±{r["cpu_std"]:<4.1f} '
              f'{cuda_str:>8} {pct:>7.1f}%  {bw_str}')
    print(f'{"-"*65}')
    print(f'{"TOTAL":<18} {total_cpu:>10.2f}')


# ── Load Global LPG Features for CPU Store ──
print('Loading global LPG features for CPU store simulation...')
_lpg_cache_path = PROJECT_ROOT / 'data' / 'processed' / 'lpg_full_graph.pt'

if _lpg_cache_path.exists():
    _lpg_cache = torch.load(str(_lpg_cache_path), weights_only=False, map_location='cpu')
    global_node_features = _lpg_cache['data'].x  # [13920, 384]
    del _lpg_cache
else:
    # Fallback: reconstruct global node features from dataset samples
    print('  lpg_full_graph.pt not found — reconstructing from dataset...')
    _max_idx = 0
    for ds in [train_ds, val_ds, test_ds]:
        for i in range(len(ds)):
            gi = ds[i].lpg_global_node_idx
            if gi.numel() > 0:
                _max_idx = max(_max_idx, gi.max().item())
    global_node_features = torch.zeros(_max_idx + 1, LPG_FEATURE_DIM)
    for ds in [train_ds, val_ds, test_ds]:
        for i in range(len(ds)):
            sample = ds[i]
            gi = sample.lpg_global_node_idx
            if gi.numel() > 0:
                global_node_features[gi] = sample.lpg_x
    print(f'  Reconstructed from {len(train_ds)+len(val_ds)+len(test_ds)} samples')
print(f'  Global LPG features: {global_node_features.shape}')

# Create CPU stores
print('\nInitializing CPU stores...')
lpg_cpu_store = LPGCPUStore(global_node_features)
rdf_cpu_stores = {}
for mt, model in kge_models.items():
    print(f'  {mt.upper()}:')
    rdf_cpu_stores[mt] = RDFCPUStore(model)

# Profiling DataLoader (no pin_memory, no shuffle for reproducibility)
profile_loader = DataLoader(train_ds, batch_size=PROFILE_BATCH_SIZE, shuffle=False,
                            collate_fn=dual_graph_collate_fn, num_workers=0,
                            pin_memory=False)
print(f'\nProfile loader: {len(train_ds)} samples, ~{len(train_ds)//PROFILE_BATCH_SIZE} batches')
print(f'Will use {WARMUP_BATCHES} warmup + {PROFILE_BATCHES} measurement batches')

In [None]:
# Cell 7: G-Retrieval Model Definition + Llama 3 Loading
#
# Architecture: Graph Encoder → Projection MLP → LLM (soft token prepend)
# The graph encoder output is projected into the LLM's embedding space
# and prepended as "graph tokens" before the text question tokens.

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# ── G-Retrieval Wrapper ──

class GraphLLM(nn.Module):
    """G-Retrieval style model: Graph Encoder → Projection → LLM.

    Encodes a graph subgraph into soft tokens and prepends them
    to the LLM's text input embeddings for graph-conditioned generation.
    """
    def __init__(self, graph_encoder, graph_output_dim: int, llm,
                 num_graph_tokens: int = 1):
        super().__init__()
        self.graph_encoder = graph_encoder
        self.llm = llm
        self.num_graph_tokens = num_graph_tokens
        self.llm_hidden = llm.config.hidden_size  # e.g. 4096 for Llama 3

        # 2-layer projection MLP: graph_dim → llm_hidden × num_tokens
        self.projector = nn.Sequential(
            nn.Linear(graph_output_dim, self.llm_hidden),
            nn.GELU(),
            nn.Linear(self.llm_hidden, self.llm_hidden * num_graph_tokens),
        )

    def encode_graph_lpg(self, batch, dev):
        """Encode LPG subgraph → graph embedding [B, graph_output_dim]."""
        return self.graph_encoder(
            batch.lpg_x.to(dev), batch.lpg_edge_index.to(dev), batch.lpg_batch.to(dev)
        )

    def encode_graph_rdf(self, batch, dev):
        """Encode RDF subgraph → graph embedding [B, graph_output_dim]."""
        batch_gpu = batch.to(dev)
        return self.graph_encoder(batch_gpu)

    def project(self, graph_emb):
        """Project graph embedding to LLM token space [B, T, llm_hidden]."""
        proj = self.projector(graph_emb)  # [B, llm_hidden * T]
        B = proj.shape[0]
        # Cast to LLM dtype (bfloat16) to match text embeddings
        proj = proj.to(self.llm.dtype)
        return proj.view(B, self.num_graph_tokens, self.llm_hidden)

    @torch.no_grad()
    def forward_prefill(self, graph_tokens, input_ids, attention_mask):
        """Single LLM forward pass with graph soft tokens prepended.

        Returns logits for next-token prediction (profiling target).
        """
        # Text embeddings from LLM's own embedding layer
        text_emb = self.llm.get_input_embeddings()(input_ids)  # [B, seq, H]

        # Prepend graph tokens
        combined = torch.cat([graph_tokens, text_emb], dim=1)  # [B, T+seq, H]

        # Extend attention mask for graph tokens
        graph_mask = torch.ones(
            graph_tokens.shape[0], graph_tokens.shape[1],
            device=attention_mask.device, dtype=attention_mask.dtype
        )
        combined_mask = torch.cat([graph_mask, attention_mask], dim=1)

        # LLM forward (no gradient)
        outputs = self.llm(inputs_embeds=combined, attention_mask=combined_mask)
        return outputs.logits

    @torch.no_grad()
    def generate(self, graph_tokens, input_ids, attention_mask, tokenizer,
                 max_new_tokens=128):
        """Generate answer conditioned on graph tokens."""
        text_emb = self.llm.get_input_embeddings()(input_ids)
        combined = torch.cat([graph_tokens, text_emb], dim=1)
        graph_mask = torch.ones(
            graph_tokens.shape[0], graph_tokens.shape[1],
            device=attention_mask.device, dtype=attention_mask.dtype
        )
        combined_mask = torch.cat([graph_mask, attention_mask], dim=1)

        outputs = self.llm.generate(
            inputs_embeds=combined,
            attention_mask=combined_mask,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )
        # Decode only the generated tokens (skip input length)
        gen_ids = outputs[:, input_ids.shape[1] + graph_tokens.shape[1]:]
        return tokenizer.batch_decode(gen_ids, skip_special_tokens=True)


# ── Load Llama 3 ──

LLM_MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
NUM_GRAPH_TOKENS = 1
GRAPH_OUTPUT_DIM = 384  # output dim of both BatchedGAT and BatchedKGE

print(f'Loading LLM: {LLM_MODEL_ID}')
print(f'  Graph tokens: {NUM_GRAPH_TOKENS}')

# 4-bit quantization for memory efficiency (T4: 16GB, A100: 40/80GB)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

llm = AutoModelForCausalLM.from_pretrained(
    LLM_MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
llm.eval()
for p in llm.parameters():
    p.requires_grad = False

llm_mem_gb = sum(p.nelement() * p.element_size() for p in llm.parameters()) / 1e9
print(f'  LLM loaded: {llm_mem_gb:.1f} GB, hidden_size={llm.config.hidden_size}')

# ── Create GraphLLM instances ──

graph_llms = {}

# GAT (LPG)
graph_llms['gat'] = GraphLLM(gat_model, GRAPH_OUTPUT_DIM, llm, NUM_GRAPH_TOKENS)
graph_llms['gat'].projector.to(device)

# TransE / DistMult (RDF)
for mt in ['transe', 'distmult']:
    graph_llms[mt] = GraphLLM(kge_models[mt], GRAPH_OUTPUT_DIM, llm, NUM_GRAPH_TOKENS)
    graph_llms[mt].projector.to(device)

proj_params = sum(p.numel() for p in graph_llms['gat'].projector.parameters())
print(f'  Projector params: {proj_params/1e3:.1f}K  ({GRAPH_OUTPUT_DIM} → {llm.config.hidden_size} × {NUM_GRAPH_TOKENS})')
print(f'  GraphLLM instances: {list(graph_llms.keys())}')

In [None]:
# Cell 8: E2E Stage-wise Profiling
#
# Profile the full G-Retrieval inference pipeline for all 3 encoder types.
# 6 stages: SUBGRAPH_LOAD → CPU_GATHER/LOOKUP → H2D_COPY → ENCODER_FWD → PROJECTION → LLM_PREFILL

PROMPT_TEMPLATE = (
    "Based on the graph context, answer the following question concisely.\n"
    "Question: {question}\nAnswer:"
)

def tokenize_questions(questions: List[str], max_length: int = 128):
    """Tokenize questions with prompt template → (input_ids, attention_mask) on device."""
    prompts = [PROMPT_TEMPLATE.format(question=q) for q in questions]
    enc = tokenizer(prompts, return_tensors='pt', padding=True,
                    truncation=True, max_length=max_length)
    return enc['input_ids'].to(device), enc['attention_mask'].to(device)


# ── LPG (GAT) E2E Profiling ──

def profile_lpg_e2e(gllm, lpg_store, loader, n_warmup, n_measure):
    """Profile full GAT → Projection → LLM pipeline."""
    gllm.graph_encoder.eval()
    timings_list, batch_sizes, node_counts, edge_counts = [], [], [], []

    for i, batch in enumerate(loader):
        if i < n_warmup:
            with torch.no_grad():
                x = lpg_store.gather(batch.lpg_global_node_idx).to(device)
                ei = batch.lpg_edge_index.to(device)
                b = batch.lpg_batch.to(device)
                g_emb = gllm.graph_encoder(x, ei, b)
                g_tok = gllm.project(g_emb)
                ids, mask = tokenize_questions(batch.questions)
                _ = gllm.forward_prefill(g_tok, ids, mask)
            continue
        if i >= n_warmup + n_measure:
            break

        timings = {}

        with CPUTimer('SUBGRAPH_LOAD') as t:
            global_idx = batch.lpg_global_node_idx
            edge_index = batch.lpg_edge_index
            batch_vec = batch.lpg_batch
        timings['SUBGRAPH_LOAD'] = t.result()

        with CPUTimer('CPU_GATHER') as t:
            x_cpu = lpg_store.gather(global_idx)
        timings['CPU_GATHER'] = t.result()

        with CUDATimer('H2D_COPY') as t:
            x_gpu = x_cpu.to(device, non_blocking=False)
            ei_gpu = edge_index.to(device, non_blocking=False)
            b_gpu = batch_vec.to(device, non_blocking=False)
        timings['H2D_COPY'] = t.result()
        timings['H2D_COPY'].bytes_transferred = (
            x_cpu.nelement() * x_cpu.element_size() +
            edge_index.nelement() * edge_index.element_size() +
            batch_vec.nelement() * batch_vec.element_size()
        )

        with torch.no_grad():
            with CUDATimer('ENCODER_FWD') as t:
                graph_emb = gllm.graph_encoder(x_gpu, ei_gpu, b_gpu)
        timings['ENCODER_FWD'] = t.result()

        with torch.no_grad():
            with CUDATimer('PROJECTION') as t:
                graph_tokens = gllm.project(graph_emb)
        timings['PROJECTION'] = t.result()

        with torch.no_grad():
            with CUDATimer('LLM_PREFILL') as t:
                input_ids, attn_mask = tokenize_questions(batch.questions)
                _ = gllm.forward_prefill(graph_tokens, input_ids, attn_mask)
        timings['LLM_PREFILL'] = t.result()

        timings_list.append(timings)
        batch_sizes.append(batch.batch_size)
        node_counts.append(x_cpu.shape[0])
        edge_counts.append(edge_index.shape[1])

    return timings_list, batch_sizes, node_counts, edge_counts


# ── RDF (KGE) E2E Profiling ──

def profile_rdf_e2e(gllm, rdf_store, output_proj_gpu, loader, n_warmup, n_measure):
    """Profile KGE → Projection → LLM pipeline with CPU-resident embeddings."""
    timings_list, batch_sizes, node_counts, edge_counts = [], [], [], []

    for i, batch in enumerate(loader):
        ei = batch.rdf_edge_index
        if ei.shape[1] == 0:
            continue

        if i < n_warmup:
            with torch.no_grad():
                hg = batch.rdf_global_node_idx[ei[0]]
                h, r = rdf_store.lookup(hg, batch.rdf_edge_type)
                tg = batch.rdf_batch[ei[0]].to(device)
                g = scatter(h.to(device) + r.to(device), tg, dim=0,
                            dim_size=batch.batch_size, reduce='mean')
                g_emb = output_proj_gpu(g)
                g_tok = gllm.project(g_emb)
                ids, mask = tokenize_questions(batch.questions)
                _ = gllm.forward_prefill(g_tok, ids, mask)
            continue
        if i >= n_warmup + n_measure:
            break

        timings = {}

        with CPUTimer('SUBGRAPH_LOAD') as t:
            head_local = ei[0]
            global_idx = batch.rdf_global_node_idx
            head_global = global_idx[head_local]
            edge_type = batch.rdf_edge_type
            rdf_batch = batch.rdf_batch
        timings['SUBGRAPH_LOAD'] = t.result()

        with CPUTimer('CPU_LOOKUP') as t:
            h_emb_cpu, r_emb_cpu = rdf_store.lookup(head_global, edge_type)
        timings['CPU_LOOKUP'] = t.result()

        with CUDATimer('H2D_COPY') as t:
            h_emb_gpu = h_emb_cpu.to(device, non_blocking=False)
            r_emb_gpu = r_emb_cpu.to(device, non_blocking=False)
            triple_graph = rdf_batch[head_local].to(device, non_blocking=False)
        timings['H2D_COPY'] = t.result()
        timings['H2D_COPY'].bytes_transferred = (
            h_emb_cpu.nelement() * h_emb_cpu.element_size() +
            r_emb_cpu.nelement() * r_emb_cpu.element_size() +
            head_local.nelement() * 8
        )

        with torch.no_grad():
            with CUDATimer('ENCODER_FWD') as t:
                triple_emb = h_emb_gpu + r_emb_gpu
                graph_emb_raw = scatter(triple_emb, triple_graph, dim=0,
                                        dim_size=batch.batch_size, reduce='mean')
                graph_emb = output_proj_gpu(graph_emb_raw)
        timings['ENCODER_FWD'] = t.result()

        with torch.no_grad():
            with CUDATimer('PROJECTION') as t:
                graph_tokens = gllm.project(graph_emb)
        timings['PROJECTION'] = t.result()

        with torch.no_grad():
            with CUDATimer('LLM_PREFILL') as t:
                input_ids, attn_mask = tokenize_questions(batch.questions)
                _ = gllm.forward_prefill(graph_tokens, input_ids, attn_mask)
        timings['LLM_PREFILL'] = t.result()

        timings_list.append(timings)
        batch_sizes.append(batch.batch_size)
        node_counts.append(global_idx.shape[0])
        edge_counts.append(ei.shape[1])

    return timings_list, batch_sizes, node_counts, edge_counts


# ── Run All Profiling ──

all_profile_results = {}

print('Profiling GAT (LPG) E2E pipeline...')
tl, bs, nc, ec = profile_lpg_e2e(
    graph_llms['gat'], lpg_cpu_store, profile_loader, WARMUP_BATCHES, PROFILE_BATCHES)
gat_df = aggregate_timings(tl, bs, nc, ec)
all_profile_results['gat'] = {'df': gat_df, 'raw': tl, 'bs': bs, 'nc': nc, 'ec': ec}
print_profiling_table(gat_df, 'GAT (LPG)', len(tl))
torch.cuda.empty_cache()

for mt in ['transe', 'distmult']:
    print(f'\nProfiling {mt.upper()} (RDF) E2E pipeline...')
    tl, bs, nc, ec = profile_rdf_e2e(
        graph_llms[mt], rdf_cpu_stores[mt], kge_models[mt].output_proj,
        profile_loader, WARMUP_BATCHES, PROFILE_BATCHES)
    df = aggregate_timings(tl, bs, nc, ec)
    all_profile_results[mt] = {'df': df, 'raw': tl, 'bs': bs, 'nc': nc, 'ec': ec}
    print_profiling_table(df, f'{mt.upper()} (RDF)', len(tl))
    torch.cuda.empty_cache()

In [None]:
# Cell 9: Combined Comparison Table + Bottleneck Analysis
#
# 3-model side-by-side profiling summary with bottleneck identification
# and H2D bandwidth analysis.

MODEL_LABELS = {'gat': 'GAT (LPG)', 'transe': 'TransE (RDF)', 'distmult': 'DistMult (RDF)'}
STAGE_ORDER = ['SUBGRAPH_LOAD', 'CPU_GATHER', 'CPU_LOOKUP', 'H2D_COPY',
               'ENCODER_FWD', 'PROJECTION', 'LLM_PREFILL']

# Normalize stage names: GAT uses CPU_GATHER, RDF uses CPU_LOOKUP → unify to CPU_GATHER/LKP
def get_stage_ms(df, stage):
    """Get cpu_mean for a stage, returning 0 if not present."""
    row = df[df['stage'] == stage]
    return row['cpu_mean'].values[0] if len(row) > 0 else 0.0

def get_stage_p95(df, stage):
    row = df[df['stage'] == stage]
    return row['cpu_p95'].values[0] if len(row) > 0 else 0.0


# ── Combined Table ──

# Map each model to unified stage list
UNIFIED_STAGES = ['SUBGRAPH_LOAD', 'CPU_GATHER/LKP', 'H2D_COPY',
                  'ENCODER_FWD', 'PROJECTION', 'LLM_PREFILL']

def get_unified_stage(df, unified_name, model_key):
    """Get timing for a unified stage name."""
    if unified_name == 'CPU_GATHER/LKP':
        # GAT uses CPU_GATHER, RDF models use CPU_LOOKUP
        if model_key == 'gat':
            return get_stage_ms(df, 'CPU_GATHER'), get_stage_p95(df, 'CPU_GATHER')
        else:
            return get_stage_ms(df, 'CPU_LOOKUP'), get_stage_p95(df, 'CPU_LOOKUP')
    return get_stage_ms(df, unified_name), get_stage_p95(df, unified_name)


# Build combined table
header = f"\n{'='*85}"
header += f"\nG-Retrieval E2E Bottleneck Analysis — {PROFILE_BATCHES} batches, BS={PROFILE_BATCH_SIZE}"
header += f"\n{'='*85}"
header += f"\n{'Stage':<18}"
for mk in ['gat', 'transe', 'distmult']:
    header += f"  {'mean':>6}  {'p95':>6}  {'%':>5}  "
print(header)

# Column headers
col_header = f"{'':18}"
for mk in ['gat', 'transe', 'distmult']:
    label = MODEL_LABELS[mk]
    col_header += f"  {label:^23}"
print(col_header)
print('-' * 85)

# Per-model totals for % calculation
model_totals = {}
for mk in ['gat', 'transe', 'distmult']:
    df = all_profile_results[mk]['df']
    total = sum(get_unified_stage(df, s, mk)[0] for s in UNIFIED_STAGES)
    model_totals[mk] = total

# Print each stage row
bottleneck = {}
for stage in UNIFIED_STAGES:
    row_str = f"{stage:<18}"
    for mk in ['gat', 'transe', 'distmult']:
        df = all_profile_results[mk]['df']
        mean_ms, p95_ms = get_unified_stage(df, stage, mk)
        pct = mean_ms / model_totals[mk] * 100 if model_totals[mk] > 0 else 0
        row_str += f"  {mean_ms:6.2f}  {p95_ms:6.2f}  {pct:4.1f}%  "
        # Track bottleneck (highest %)
        if mk not in bottleneck or pct > bottleneck[mk][1]:
            bottleneck[mk] = (stage, pct)
    print(row_str)

print('-' * 85)
total_str = f"{'TOTAL':<18}"
for mk in ['gat', 'transe', 'distmult']:
    total_str += f"  {model_totals[mk]:6.2f}  {'':>6}  {'':>5}  "
print(total_str)

# ── Bottleneck Identification ──
print(f"\n{'='*85}")
print("Bottleneck Identification:")
for mk in ['gat', 'transe', 'distmult']:
    stage, pct = bottleneck[mk]
    print(f"  {MODEL_LABELS[mk]:20s}: {stage} ({pct:.1f}%)")

# ── Encoder vs LLM Time Ratio ──
print(f"\nEncoder vs LLM Time Ratio:")
for mk in ['gat', 'transe', 'distmult']:
    df = all_profile_results[mk]['df']
    enc_ms = get_unified_stage(df, 'ENCODER_FWD', mk)[0]
    proj_ms = get_unified_stage(df, 'PROJECTION', mk)[0]
    llm_ms = get_unified_stage(df, 'LLM_PREFILL', mk)[0]
    encoder_total = enc_ms + proj_ms
    ratio = encoder_total / llm_ms if llm_ms > 0 else float('inf')
    print(f"  {MODEL_LABELS[mk]:20s}: Encoder+Proj={encoder_total:.2f}ms, "
          f"LLM={llm_ms:.2f}ms, Ratio={ratio:.3f}x")

# ── H2D Bandwidth Analysis ──
print(f"\nH2D Transfer Bandwidth:")
# Theoretical PCIe bandwidth (T4 ≈ 12 GB/s, A100 ≈ 32 GB/s)
if HAS_CUDA:
    gpu_name = torch.cuda.get_device_name(0).lower()
    if 'a100' in gpu_name:
        theoretical_bw = 32.0
    elif 'v100' in gpu_name:
        theoretical_bw = 16.0
    elif 'l4' in gpu_name:
        theoretical_bw = 32.0
    else:
        theoretical_bw = 12.0  # T4 default
    print(f"  Theoretical PCIe BW: ~{theoretical_bw:.0f} GB/s ({gpu_name})")
else:
    theoretical_bw = None
    print("  (No CUDA device — bandwidth measurement not applicable)")

for mk in ['gat', 'transe', 'distmult']:
    df = all_profile_results[mk]['df']
    h2d_row = df[df['stage'] == 'H2D_COPY']
    if len(h2d_row) > 0:
        bytes_mean = h2d_row['bytes_mean'].values[0]
        cuda_mean = h2d_row['cuda_mean'].values[0]
        if not np.isnan(cuda_mean) and cuda_mean > 0:
            bw_gbps = (bytes_mean / 1e9) / (cuda_mean / 1e3)
            efficiency = bw_gbps / theoretical_bw * 100 if theoretical_bw else 0
            print(f"  {MODEL_LABELS[mk]:20s}: {bytes_mean/1e6:.2f} MB, "
                  f"{cuda_mean:.2f} ms → {bw_gbps:.1f} GB/s "
                  f"({efficiency:.0f}% of theoretical)" if theoretical_bw
                  else f"  {MODEL_LABELS[mk]:20s}: {bytes_mean/1e6:.2f} MB, {cuda_mean:.2f} ms → {bw_gbps:.1f} GB/s")

# ── Graph Size → Latency Correlation ──
print(f"\nGraph Size → Total Latency Correlation (Spearman):")
from scipy.stats import spearmanr as _spearmanr
for mk in ['gat', 'transe', 'distmult']:
    res = all_profile_results[mk]
    sizes = np.array(res['nc'] if mk == 'gat' else res['ec'])
    total_ms = np.array([
        sum(t[s].cpu_ms for s in t) for t in res['raw']
    ])
    if len(sizes) > 3:
        corr, pval = _spearmanr(sizes, total_ms)
        size_label = "nodes" if mk == 'gat' else "triples"
        print(f"  {MODEL_LABELS[mk]:20s}: rho={corr:.3f}, p={pval:.3e} ({size_label})")


In [None]:
# Cell 10: Profiling Visualization
#
# 4 charts: (1) Stacked bar — stage-wise breakdown
#           (2) Scatter — graph size vs total latency
#           (3) Pie charts — time distribution per model
#           (4) Encoder vs LLM ratio bar chart

fig = plt.figure(figsize=(18, 14))

# ── Chart 1: Stacked Bar — Stage-wise Time Breakdown ──
ax1 = fig.add_subplot(2, 2, 1)

model_keys = ['gat', 'transe', 'distmult']
x_pos = np.arange(len(model_keys))
bar_width = 0.5

# Colors for each stage
stage_colors = {
    'SUBGRAPH_LOAD': '#2196F3',
    'CPU_GATHER/LKP': '#FF9800',
    'H2D_COPY': '#F44336',
    'ENCODER_FWD': '#4CAF50',
    'PROJECTION': '#9C27B0',
    'LLM_PREFILL': '#795548',
}

bottom = np.zeros(len(model_keys))
for stage in UNIFIED_STAGES:
    vals = []
    for mk in model_keys:
        df = all_profile_results[mk]['df']
        mean_ms, _ = get_unified_stage(df, stage, mk)
        vals.append(mean_ms)
    vals = np.array(vals)
    ax1.bar(x_pos, vals, bar_width, bottom=bottom,
            label=stage, color=stage_colors[stage], edgecolor='white', linewidth=0.5)
    bottom += vals

ax1.set_xticks(x_pos)
ax1.set_xticklabels([MODEL_LABELS[mk] for mk in model_keys], fontsize=9)
ax1.set_ylabel('Time (ms)')
ax1.set_title('E2E Latency Breakdown by Stage')
ax1.legend(fontsize=7, loc='upper left')

# Add total labels on top
for i, mk in enumerate(model_keys):
    ax1.text(i, bottom[i] + 0.5, f'{bottom[i]:.1f}ms', ha='center', fontsize=8, fontweight='bold')


# ── Chart 2: Scatter — Graph Size vs Total Latency ──
ax2 = fig.add_subplot(2, 2, 2)

scatter_colors = {'gat': '#4CAF50', 'transe': '#2196F3', 'distmult': '#FF9800'}
for mk in model_keys:
    res = all_profile_results[mk]
    sizes = np.array(res['nc'] if mk == 'gat' else res['ec'])
    total_ms = np.array([sum(t[s].cpu_ms for s in t) for t in res['raw']])

    ax2.scatter(sizes, total_ms, alpha=0.5, s=20, color=scatter_colors[mk], label=MODEL_LABELS[mk])

    # Trend line
    if len(sizes) > 3:
        z = np.polyfit(sizes, total_ms, 1)
        p = np.poly1d(z)
        x_line = np.linspace(sizes.min(), sizes.max(), 50)
        ax2.plot(x_line, p(x_line), '--', color=scatter_colors[mk], alpha=0.7, linewidth=1.5)

ax2.set_xlabel('Graph Size (nodes for GAT, triples for KGE)')
ax2.set_ylabel('Total Latency (ms)')
ax2.set_title('Graph Size vs E2E Latency')
ax2.legend(fontsize=8)


# ── Chart 3: Pie Charts — Time Distribution per Model ──
for idx, mk in enumerate(model_keys):
    ax = fig.add_subplot(2, 6, 7 + idx * 2, aspect='equal')
    df = all_profile_results[mk]['df']

    sizes_pie = []
    labels_pie = []
    colors_pie = []
    for stage in UNIFIED_STAGES:
        mean_ms, _ = get_unified_stage(df, stage, mk)
        if mean_ms > 0:
            sizes_pie.append(mean_ms)
            labels_pie.append(stage.replace('CPU_GATHER/LKP', 'GATHER/LKP'))
            colors_pie.append(stage_colors[stage])

    wedges, texts, autotexts = ax.pie(
        sizes_pie, labels=None, colors=colors_pie,
        autopct=lambda p: f'{p:.0f}%' if p > 5 else '',
        pctdistance=0.75, startangle=90, textprops={'fontsize': 6}
    )
    ax.set_title(MODEL_LABELS[mk], fontsize=9, fontweight='bold')


# ── Chart 4: Encoder vs LLM Ratio ──
ax4 = fig.add_subplot(2, 6, 12)

enc_times = []
llm_times = []
for mk in model_keys:
    df = all_profile_results[mk]['df']
    enc_ms = get_unified_stage(df, 'ENCODER_FWD', mk)[0]
    proj_ms = get_unified_stage(df, 'PROJECTION', mk)[0]
    llm_ms = get_unified_stage(df, 'LLM_PREFILL', mk)[0]
    enc_times.append(enc_ms + proj_ms)
    llm_times.append(llm_ms)

x_bar = np.arange(len(model_keys))
bar_w = 0.35
ax4.barh(x_bar - bar_w/2, enc_times, bar_w, label='Encoder+Proj', color='#4CAF50')
ax4.barh(x_bar + bar_w/2, llm_times, bar_w, label='LLM Prefill', color='#795548')
ax4.set_yticks(x_bar)
ax4.set_yticklabels([mk.upper() for mk in model_keys], fontsize=8)
ax4.set_xlabel('Time (ms)')
ax4.set_title('Encoder vs LLM', fontsize=9)
ax4.legend(fontsize=7)

plt.suptitle('G-Retrieval E2E Pipeline Profiling', fontsize=14, fontweight='bold', y=1.01)
plt.tight_layout()
plt.show()

print('\nTrack B profiling complete. Proceeding to Track A (retrieval evaluation)...')


In [None]:
# Cell 6: Graph Embedding Extraction

@torch.no_grad()
def extract_embeddings(gat_model, kge_models, loader):
    """Extract per-question graph embeddings from all models.

    Returns dict with:
        'gat': [N, 384] LPG-GAT embeddings
        'transe': [N, 384] RDF-TransE embeddings
        'distmult': [N, 384] RDF-DistMult embeddings
        'questions': list of question strings
        'answers': list of answer strings
        'question_ids': list of question IDs
        'categories': list of category strings
    """
    gat_model.eval()
    for m in kge_models.values():
        m.eval()

    all_gat, all_transe, all_distmult = [], [], []
    all_questions, all_answers, all_qids, all_cats = [], [], [], []

    for batch in loader:
        batch_gpu = batch.to(device)

        # GAT embedding
        gat_emb = gat_model(batch_gpu.lpg_x, batch_gpu.lpg_edge_index, batch_gpu.lpg_batch)
        all_gat.append(gat_emb.cpu())

        # KGE embeddings
        for name, model in kge_models.items():
            emb = model(batch_gpu)
            if name == 'transe':
                all_transe.append(emb.cpu())
            else:
                all_distmult.append(emb.cpu())

        all_questions.extend(batch.questions)
        all_answers.extend(batch.answers)
        all_qids.extend(batch.question_ids)
        all_cats.extend(batch.categories)

    return {
        'gat': F.normalize(torch.cat(all_gat, dim=0), dim=-1),
        'transe': F.normalize(torch.cat(all_transe, dim=0), dim=-1),
        'distmult': F.normalize(torch.cat(all_distmult, dim=0), dim=-1),
        'questions': all_questions,
        'answers': all_answers,
        'question_ids': all_qids,
        'categories': all_cats,
    }


print('Extracting embeddings...')
train_emb = extract_embeddings(gat_model, kge_models, train_loader)
val_emb   = extract_embeddings(gat_model, kge_models, val_loader)
test_emb  = extract_embeddings(gat_model, kge_models, test_loader)

print(f'  Train: {train_emb["gat"].shape[0]} samples')
print(f'  Val:   {val_emb["gat"].shape[0]} samples')
print(f'  Test:  {test_emb["gat"].shape[0]} samples')
print(f'  Embedding dim: {train_emb["gat"].shape[1]}')

In [None]:
# Cell 7: Question-Answer Embedding Baseline

from sentence_transformers import SentenceTransformer

st_model = SentenceTransformer('all-MiniLM-L6-v2', device=str(device))


def encode_texts(texts, batch_size=64):
    """Encode texts with sentence-transformers → [N, 384] normalized."""
    embs = st_model.encode(texts, batch_size=batch_size, show_progress_bar=False,
                           convert_to_tensor=True, normalize_embeddings=True)
    return embs.cpu()


print('Encoding answers with sentence-transformers...')
test_answer_emb = encode_texts(test_emb['answers'])
test_question_emb = encode_texts(test_emb['questions'])

# Quick check: cosine similarity between graph embeddings and answer embeddings
print('\nMean cosine similarity (graph_emb · answer_emb):')
for name in ['gat', 'transe', 'distmult']:
    cos_sim = (test_emb[name] * test_answer_emb).sum(dim=-1).mean().item()
    print(f'  {name:10s}: {cos_sim:.4f}')

# Baseline: question-only embedding
q_cos = (test_question_emb * test_answer_emb).sum(dim=-1).mean().item()
print(f'  {"question":10s}: {q_cos:.4f} (text-only baseline)')

In [None]:
# Cell 8: Retrieval Evaluation

def retrieval_eval(query_emb, corpus_emb, ks=(1, 5, 10)):
    """Compute retrieval metrics: Recall@K and MRR.

    Each query[i] should retrieve corpus[i] (diagonal = ground truth).

    Args:
        query_emb: [N, D] query embeddings (graph or question)
        corpus_emb: [N, D] corpus embeddings (answers)
        ks: tuple of K values for Recall@K

    Returns:
        dict with 'mrr' and 'recall@k' for each k
    """
    # Similarity matrix [N, N]
    sim = query_emb @ corpus_emb.T
    N = sim.shape[0]

    # Rank of the correct answer (diagonal)
    diag = sim.diag().unsqueeze(1)  # [N, 1]
    ranks = (sim >= diag).sum(dim=1).float()  # [N] — 1-based rank

    results = {'mrr': (1.0 / ranks).mean().item()}
    for k in ks:
        results[f'recall@{k}'] = (ranks <= k).float().mean().item()
    return results


# Evaluate all models
print(f'{"Model":<12} {"MRR":>8} {"R@1":>8} {"R@5":>8} {"R@10":>8}')
print('-' * 48)

retrieval_results = {}
for name in ['gat', 'transe', 'distmult']:
    r = retrieval_eval(test_emb[name], test_answer_emb)
    retrieval_results[name] = r
    print(f'{name:<12} {r["mrr"]:8.4f} {r["recall@1"]:8.4f} {r["recall@5"]:8.4f} {r["recall@10"]:8.4f}')

# Baseline: question text → answer text retrieval
r_baseline = retrieval_eval(test_question_emb, test_answer_emb)
retrieval_results['question'] = r_baseline
print(f'{"question":<12} {r_baseline["mrr"]:8.4f} {r_baseline["recall@1"]:8.4f} {r_baseline["recall@5"]:8.4f} {r_baseline["recall@10"]:8.4f}')
print('  (question = text-only baseline, no graph)')

# Bar chart
fig, ax = plt.subplots(figsize=(10, 5))
models = list(retrieval_results.keys())
metrics = ['mrr', 'recall@1', 'recall@5', 'recall@10']
x = np.arange(len(models))
width = 0.2

for i, m in enumerate(metrics):
    vals = [retrieval_results[model][m] for model in models]
    ax.bar(x + i * width, vals, width, label=m.upper())

ax.set_xticks(x + width * 1.5)
ax.set_xticklabels([m.upper() for m in models])
ax.set_ylabel('Score')
ax.set_title('Graph → Answer Retrieval Performance')
ax.legend()
ax.set_ylim(0, 1)
plt.tight_layout()
plt.show()

In [None]:
# Cell 9: Category-wise Analysis

def category_retrieval_eval(emb_dict, answer_emb, categories):
    """Evaluate retrieval per category.

    Returns DataFrame: rows=categories, columns=model×metric.
    """
    cats = sorted(set(categories))
    cat_array = np.array(categories)
    rows = []

    for cat in cats:
        mask = cat_array == cat
        n = mask.sum()
        if n < 2:
            continue
        indices = np.where(mask)[0]

        row = {'category': cat, 'n': int(n)}
        for model_name in ['gat', 'transe', 'distmult']:
            q_emb = emb_dict[model_name][indices]
            a_emb = answer_emb[indices]
            r = retrieval_eval(q_emb, a_emb)
            for metric, val in r.items():
                row[f'{model_name}_{metric}'] = val
        rows.append(row)

    return pd.DataFrame(rows)


cat_df = category_retrieval_eval(test_emb, test_answer_emb, test_emb['categories'])
print('Category-wise MRR:')
print(cat_df[['category', 'n', 'gat_mrr', 'transe_mrr', 'distmult_mrr']].to_string(index=False))

# Bar chart: category × model MRR
fig, ax = plt.subplots(figsize=(14, 5))
cats = cat_df['category'].values
x = np.arange(len(cats))
width = 0.25

for i, model in enumerate(['gat', 'transe', 'distmult']):
    vals = cat_df[f'{model}_mrr'].values
    ax.bar(x + i * width, vals, width, label=model.upper())

ax.set_xticks(x + width)
ax.set_xticklabels(cats, rotation=30, ha='right')
ax.set_ylabel('MRR')
ax.set_title('Category-wise Graph → Answer Retrieval MRR')
ax.legend()
plt.tight_layout()
plt.show()

# Heatmap: model advantage per category
print('\nBest model per category (MRR):')
for _, row in cat_df.iterrows():
    mrrs = {m: row[f'{m}_mrr'] for m in ['gat', 'transe', 'distmult']}
    best = max(mrrs, key=mrrs.get)
    print(f'  {row["category"]:25s} → {best.upper()} ({mrrs[best]:.4f})')

In [None]:
# Cell 10: Graph Size Effect

def compute_per_sample_rank(query_emb, corpus_emb):
    """Compute rank of correct answer for each sample."""
    sim = query_emb @ corpus_emb.T
    diag = sim.diag().unsqueeze(1)
    ranks = (sim >= diag).sum(dim=1).float()
    return ranks.numpy()


# Collect graph sizes from test set
test_lpg_nodes = []
test_lpg_edges = []
test_rdf_triples = []

for i in range(len(test_ds)):
    d = test_ds[i]
    test_lpg_nodes.append(d.lpg_num_nodes.item())
    test_lpg_edges.append(d.lpg_edge_index.shape[1])
    test_rdf_triples.append(d.rdf_edge_index.shape[1])

test_lpg_nodes = np.array(test_lpg_nodes)
test_lpg_edges = np.array(test_lpg_edges)
test_rdf_triples = np.array(test_rdf_triples)

# Compute per-sample reciprocal rank
rr = {}
for name in ['gat', 'transe', 'distmult']:
    ranks = compute_per_sample_rank(test_emb[name], test_answer_emb)
    rr[name] = 1.0 / ranks

# Scatter plots
fig, axes = plt.subplots(1, 3, figsize=(16, 4.5))

for ax, (name, size_arr, size_label) in zip(axes, [
    ('gat', test_lpg_nodes, 'LPG Nodes'),
    ('transe', test_rdf_triples, 'RDF Triples'),
    ('distmult', test_rdf_triples, 'RDF Triples'),
]):
    ax.scatter(size_arr, rr[name], alpha=0.3, s=10)
    ax.set_xlabel(size_label)
    ax.set_ylabel('Reciprocal Rank')
    ax.set_title(f'{name.upper()}')

    # Trend line via binning
    bins = np.percentile(size_arr, np.linspace(0, 100, 11))
    bins = np.unique(bins)
    if len(bins) >= 2:
        bin_idx = np.digitize(size_arr, bins) - 1
        bin_idx = np.clip(bin_idx, 0, len(bins) - 2)
        bin_centers = []
        bin_means = []
        for b in range(len(bins) - 1):
            mask = bin_idx == b
            if mask.sum() > 0:
                bin_centers.append((bins[b] + bins[b+1]) / 2)
                bin_means.append(rr[name][mask].mean())
        ax.plot(bin_centers, bin_means, 'r-o', linewidth=2, markersize=4, label='Binned mean')
        ax.legend()

plt.suptitle('Graph Size vs Retrieval Quality', y=1.02)
plt.tight_layout()
plt.show()

# Correlation
print('Spearman correlation (graph size vs reciprocal rank):')
from scipy.stats import spearmanr
for name, sizes in [('gat', test_lpg_nodes), ('transe', test_rdf_triples), ('distmult', test_rdf_triples)]:
    corr, pval = spearmanr(sizes, rr[name])
    print(f'  {name:10s}: rho={corr:.3f}, p={pval:.3e}')

In [None]:
# Cell 11: TransE vs DistMult Deep Dive

# 1. Loss curve comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(kge_histories['transe']['train_loss'], label='TransE')
axes[0].plot(kge_histories['distmult']['train_loss'], label='DistMult')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss Comparison')
axes[0].legend()

# Val MRR over time
eval_epochs_kge = [e for e in range(1, KGE_EPOCHS+1) if e % 10 == 0 or e == 1]
for mt in ['transe', 'distmult']:
    n = len(kge_histories[mt]['val_mrr'])
    axes[1].plot(eval_epochs_kge[:n], kge_histories[mt]['val_mrr'][:n], '-o', label=mt.upper())
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Val MRR')
axes[1].set_title('Validation MRR Over Training')
axes[1].legend()
plt.tight_layout()
plt.show()


# 2. Per-sample comparison: where does one beat the other?
transe_rr = rr['transe']
distmult_rr = rr['distmult']
diff = transe_rr - distmult_rr  # positive = TransE better

transe_wins = (diff > 0).sum()
distmult_wins = (diff < 0).sum()
ties = (diff == 0).sum()
print(f'Per-sample comparison (test set, N={len(diff)}):')
print(f'  TransE wins:  {transe_wins} ({100*transe_wins/len(diff):.1f}%)')
print(f'  DistMult wins: {distmult_wins} ({100*distmult_wins/len(diff):.1f}%)')
print(f'  Ties:         {ties} ({100*ties/len(diff):.1f}%)')


# 3. Category-level TransE vs DistMult advantage
print('\nCategory-level advantage (MRR difference = TransE - DistMult):')
for _, row in cat_df.iterrows():
    delta = row['transe_mrr'] - row['distmult_mrr']
    arrow = '→ TransE' if delta > 0 else '→ DistMult'
    print(f'  {row["category"]:25s}: {delta:+.4f} {arrow}')


# 4. t-SNE visualization of embedding spaces
from sklearn.manifold import TSNE

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Use test set embeddings (subsample if needed)
n_vis = min(300, test_emb['gat'].shape[0])
vis_idx = np.random.choice(test_emb['gat'].shape[0], n_vis, replace=False)
vis_cats = np.array(test_emb['categories'])[vis_idx]
unique_cats = sorted(set(vis_cats))
cat_colors = {c: plt.cm.tab10(i) for i, c in enumerate(unique_cats)}
colors = [cat_colors[c] for c in vis_cats]

for ax, name in zip(axes, ['gat', 'transe', 'distmult']):
    emb = test_emb[name][vis_idx].numpy()
    tsne = TSNE(n_components=2, perplexity=30, random_state=42)
    proj = tsne.fit_transform(emb)

    for cat in unique_cats:
        mask = vis_cats == cat
        ax.scatter(proj[mask, 0], proj[mask, 1], c=[cat_colors[cat]],
                   s=15, alpha=0.6, label=cat[:15])
    ax.set_title(f'{name.upper()} Embedding Space')
    ax.set_xticks([])
    ax.set_yticks([])

axes[2].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
plt.suptitle('t-SNE of Graph Embeddings (colored by category)', y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# Cell 12: Summary & Conclusions

print('=' * 60)
print('SUMMARY: G-Retrieval Style Comparison')
print('=' * 60)

# Overall retrieval results table
results_df = pd.DataFrame([
    {'model': name, **metrics}
    for name, metrics in retrieval_results.items()
    if name != 'question'
])
print('\n--- Graph → Answer Retrieval (Test Set) ---')
print(results_df.to_string(index=False, float_format='%.4f'))

# Best model
best_model = results_df.loc[results_df['mrr'].idxmax(), 'model']
print(f'\nBest overall model: {best_model.upper()} (MRR={results_df["mrr"].max():.4f})')

# Category breakdown summary
print('\n--- Best Model per Category (MRR) ---')
category_results = []
for _, row in cat_df.iterrows():
    mrrs = {m: row[f'{m}_mrr'] for m in ['gat', 'transe', 'distmult']}
    best = max(mrrs, key=mrrs.get)
    category_results.append({'category': row['category'], 'best_model': best,
                             'mrr': mrrs[best], 'n': row['n']})
    print(f'  {row["category"]:25s} → {best.upper():10s} (MRR={mrrs[best]:.4f}, n={int(row["n"])})')

category_results = pd.DataFrame(category_results)

# Model strengths
print('\n--- Model Strengths & Weaknesses ---')
print('GAT (LPG):      Pre-computed 384d node features + message passing.')
print('                 Best for: categories with rich LPG structure.')
print('TransE (RDF):    Translation h+r≈t. Asymmetric, handles directed relations.')
print('                 Best for: categories with directional relationships (OWNS, REPORTED).')
print('DistMult (RDF):  Bilinear h·r·t. Symmetric, simpler training dynamics.')
print('                 Best for: symmetric or co-occurrence patterns.')

# Verification assertions
assert len(results_df) == 3, f'Expected 3 models, got {len(results_df)}'
assert all(col in results_df.columns for col in ['model', 'mrr', 'recall@1', 'recall@5'])
assert len(category_results) >= 1, 'No category results'
print(f'\nVerification passed: {len(results_df)} models, {len(category_results)} categories.')