In [None]:
from datasets import load_dataset
from typing import List, Dict, Any
import os

# Set your Hugging Face token here
# Option 1: Set it as an environment variable
os.environ['HF_TOKEN'] = ''

# Option 2: Use huggingface_hub login (recommended)
# from huggingface_hub import login
# login(token='your_token_here')

def load_musique_from_hf() -> List[Dict[str, Any]]:
    ds = load_dataset("dgslibisey/MuSiQue")
    
    episodes = []
    for item in ds['train']:  # Adjust split as needed ('train', 'validation', 'test')
        chunks = []
        for p in item['paragraphs']:
            chunks.append({
                'text': f"{p['title']}\n{p['paragraph_text']}",
                'metadata': {
                    'title': p['title'],
                    'idx': p['idx'],
                    'is_supporting': p.get('is_supporting', False)
                }
            })
        episodes.append({
            'id': item['id'],
            'question': item['question'],
            'chunks': chunks
        })
    return episodes

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
datasets = {
        "musique": {
            "path": "Projects/QSGNN/reproduce/dataset/musique_all.json",
            "loader": load_musique_from_hf
        }
    }

In [3]:
processed_episodes =[]
for name, config in datasets.items():
        #data_path = config['path']
        #if not os.path.exists(data_path):
        #    print(f"Skipping {name}, file not found: {data_path}")
        #    continue
            
        #print(f"Processing {name} from {data_path}...")
        episodes = config['loader']()
        
        # Limit processing for large datasets if needed
        # processed_episodes = episodes[:100] 
        processed_episodes = episodes[:100]

Generating train split: 100%|██████████| 19938/19938 [00:00<00:00, 92898.72 examples/s]
Generating validation split: 100%|██████████| 2417/2417 [00:00<00:00, 106791.74 examples/s]


In [5]:
processed_episodes[:1]

[{'id': '2hop__482757_12019',
  'question': 'When was the institute that owned The Collegian founded?',
  'chunks': [{'text': 'Pakistan Super League\nPakistan Super League (Urdu: پاکستان سپر لیگ \u202c \u200e; PSL) is a Twenty20 cricket league, founded in Lahore on 9 September 2015 with five teams and now comprises six teams. Instead of operating as an association of independently owned teams, the league is a single entity in which each franchise is owned and controlled by investors.',
    'metadata': {'title': 'Pakistan Super League',
     'idx': 0,
     'is_supporting': False}},
   {'text': 'Serena Wilson\nSerena Wilson (August 8, 1933 – June 17, 2007), often known just as "Serena", was a well-known dancer, choreographer, and teacher who helped popularize belly dance in the United States. Serena\'s work also helped legitimize the dance form and helped it to be perceived as more than burlesque or stripping. Serena danced in clubs in her younger years, opened her own studio, hosted her

In [46]:
import torch
import numpy as np
from torch_geometric.data import HeteroData
import hashlib

# --- SIMULATED DATA (From MuSiQue Example) ---
# Question: "When was the institute that owned The Collegian founded?"
passage_content = (
    "The Collegian (Houston Baptist University)\n"
    "The Collegian is the bi-weekly official student publication of Houston Baptist University in Houston, Texas. "
    "It was founded in 1963 as a newsletter."
)

# 1. Sentences extracted from passage
sentences = [
    "The Collegian is the bi-weekly official student publication of Houston Baptist University in Houston, Texas.",
    "It was founded in 1963 as a newsletter."
]

# 2. Entities extracted from sentences
sentence_entities = [
    ["The Collegian", "Houston Baptist University", "Houston", "Texas"], # From Sentence 0
    ["1963", "newsletter"]                                             # From Sentence 1
]

# 3. Triples (OpenIE results) representing Entity-Entity relations
# Format: (Subject, Relation, Object)
triples = [
    ("The Collegian", "is publication of", "Houston Baptist University"),
    ("The Collegian", "founded in", "1963")
]

# --- CONSTRUCTION LOGIC ---

def compute_id(text, prefix=""):
    """Helper to create consistent IDs for nodes."""
    return prefix + hashlib.mdsafe(text.encode()).hexdigest()[:10]

def test_construction():
    data = HeteroData()
    
    # In the real code, these are real embeddings (e.g., from NV-Embed-v2)
    # Here we use dummy vectors of size 128
    EMB_DIM = 128
    
    # Unique lists for nodes
    all_entities = sorted(list(set([t[0] for t in triples] + [t[2] for t in triples] + [ent for sublist in sentence_entities for ent in sublist])))
    entity_to_idx = {ent: i for i, ent in enumerate(all_entities)}
    
    # 1. Define Nodes (Level 1: Entity, Level 2: Sentence, Level 3: Passage)
    data['entity'].x = torch.randn(len(all_entities), EMB_DIM)
    data['sentence'].x = torch.randn(len(sentences), EMB_DIM)
    data['passage'].x = torch.randn(1, EMB_DIM)
    
    print(f"Nodes Created: {len(all_entities)} Entities, {len(sentences)} Sentences, 1 Passage")

    # 2. Define Edges
    
    # A. Entity -> Entity (from Triples)
    e2e_edges = []
    entity_to_entity_attrs =[]
    for sub, rel, obj in triples: # no bidirectional edges where edges have attributes
        e2e_edges.append([entity_to_idx[sub], entity_to_idx[obj]])
        entity_to_entity_attrs.append(obj)
        #e2e_edges.append([entity_to_idx[obj], entity_to_idx[sub]]) 
    data['entity', 're', 'entity'].edge_index = torch.tensor(e2e_edges).t().contiguous()

    # B. Sentence -> Passage (Hierarchical)
    s2p_edges = [[i, 0] for i in range(len(sentences))]
    data['sentence', 'in', 'passage'].edge_index = torch.tensor(s2p_edges).t().contiguous()
    data['passage', 'hv', 'sentence'].edge_index = torch.tensor([[0, i] for i in range(len(sentences))]).t().contiguous()

    # C. Entity -> Sentence (Mapping entities to where they appear)
    e2s_edges = []
    for s_idx, ents in enumerate(sentence_entities):
        for ent in ents:
            e2s_edges.append([entity_to_idx[ent], s_idx])
    data['entity', 'in', 'passage'].edge_index = torch.tensor(e2s_edges).t().contiguous()
    data['passage', 'hv', 'entity'].edge_index = torch.tensor([[v, k] for k, v in e2s_edges]).t().contiguous()

    # D. Sentence -> Sentence (Intra-level context)
    # Fully connected within the passage
    s2s_edges = []
    for i in range(len(sentences)):
        for j in range(len(sentences)):
            if i != j: s2s_edges.append([i, j])
    data['sentence', 're', 'sentence'].edge_index = torch.tensor(s2s_edges).t().contiguous()

    print("\n--- Graph Edge Summary ---")
    for edge_type in data.edge_types:
        print(f"Edge {edge_type}: {data[edge_type].edge_index.shape[1]} connections")

    return data

if __name__ == "__main__":
    graph = test_construction()
    print("\nFinal HeteroData Object:")
    print(graph)

Nodes Created: 6 Entities, 2 Sentences, 1 Passage

--- Graph Edge Summary ---
Edge ('entity', 're', 'entity'): 2 connections
Edge ('sentence', 'in', 'passage'): 2 connections
Edge ('passage', 'hv', 'sentence'): 2 connections
Edge ('entity', 'in', 'passage'): 6 connections
Edge ('passage', 'hv', 'entity'): 6 connections
Edge ('sentence', 're', 'sentence'): 2 connections

Final HeteroData Object:
HeteroData(
  entity={ x=[6, 128] },
  sentence={ x=[2, 128] },
  passage={ x=[1, 128] },
  (entity, re, entity)={ edge_index=[2, 2] },
  (sentence, in, passage)={ edge_index=[2, 2] },
  (passage, hv, sentence)={ edge_index=[2, 2] },
  (entity, in, passage)={ edge_index=[2, 6] },
  (passage, hv, entity)={ edge_index=[2, 6] },
  (sentence, re, sentence)={ edge_index=[2, 2] }
)


In [9]:
all_entities = sorted(list(set([t[0] for t in triples] + [t[2] for t in triples] + [ent for sublist in sentence_entities for ent in sublist])))
entity_to_idx = {ent: i for i, ent in enumerate(all_entities)}
    

In [10]:
entity_to_idx

{'1963': 0,
 'Houston': 1,
 'Houston Baptist University': 2,
 'Texas': 3,
 'The Collegian': 4,
 'newsletter': 5}

In [11]:
e2e_edges_1 = []
for sub, rel, obj in triples:
    e2e_edges_1.append([entity_to_idx[sub], entity_to_idx[obj]])
    e2e_edges_1.append([entity_to_idx[obj], entity_to_idx[sub]]) # Bidirectional

#torch.tensor(e2e_edges_1).t().contiguous()

In [12]:
e2e_edges_1

[[4, 2], [2, 4], [4, 0], [0, 4]]

In [14]:
torch.tensor(e2e_edges_1).t().contiguous()

tensor([[4, 2, 4, 0],
        [2, 4, 0, 4]])

In [18]:
torch.tensor(e2e_edges_1).shape

torch.Size([4, 2])

In [None]:
torch.tensor(e2e_edges_1).t().

torch.Size([2, 4])

In [22]:
class MockEmbeddingModel:
    def __init__(self, dim=128):
        self.dim = dim

    #check what happens here
    def batch_encode(self, texts: List[str]) -> List[np.ndarray]:
        print(f"  [AI Model] Encoding {len(texts)} new strings...")
        # Simulating vector generation
        return [np.random.rand(self.dim).astype(np.float32) for _ in texts]

In [23]:
# --- 2. STRIPPED DOWN EMBEDDING STORE ---
# Simplified version of src/qsgnn_rag/embedding_store.py
class SimpleEmbeddingStore:
    def __init__(self, model, namespace: str):
        self.model = model
        self.namespace = namespace
        self.cache = {} # hash_id -> {'content': str, 'embedding': np.array}
        self.hash_id_to_idx = {} # hash_id -> integer index (for GNN)

    def _compute_hash(self, text: str) -> str:
        """Creates a unique ID for every string."""
        return self.namespace + "-" + hashlib.md5(text.encode()).hexdigest()[:12]

    def insert(self, texts: List[str]):
        """Inserts strings, only encoding those that aren't already cached."""
        unique_texts = list(set(texts))
        
        # Step 1: Identify which strings are actually new
        to_encode = []
        for text in unique_texts:
            h_id = self._compute_hash(text)
            if h_id not in self.cache:
                to_encode.append(text)
        
        # Step 2: Only call the expensive AI model for NEW strings
        if to_encode:
            new_vectors = self.model.batch_encode(to_encode)
            for text, vec in zip(to_encode, new_vectors):
                h_id = self._compute_hash(text)
                self.cache[h_id] = {'content': text, 'embedding': vec}

        # Step 3: Refresh the index mapping (needed for Graph construction)
        all_ids = sorted(list(self.cache.keys()))
        self.hash_id_to_idx = {h_id: i for i, h_id in enumerate(all_ids)}

    def get_all_vectors(self) -> np.ndarray:
        """Returns a matrix of all embeddings in indexed order."""
        sorted_ids = sorted(list(self.cache.keys()))
        return np.array([self.cache[h_id]['embedding'] for h_id in sorted_ids])


In [24]:
# --- 3. DEMONSTRATION ---
if __name__ == "__main__":
    model = MockEmbeddingModel(dim=8) # Small dim for visibility
    store = SimpleEmbeddingStore(model, namespace="entity")

    # Day 1: Indexing some entities
    print("Step 1: First indexing...")
    entities_v1 = ["Houston", "Texas", "University"]
    store.insert(entities_v1)

    # Day 2: Indexing more entities (including some duplicates)
    print("\nStep 2: Second indexing (with duplicates)...")
    entities_v2 = ["Houston", "USA", "NASA"] # Houston is a duplicate
    store.insert(entities_v2)

    # Final Result
    vectors = store.get_all_vectors()
    print(f"\nFinal Store Size: {len(store.cache)} unique entities")
    print(f"Matrix Shape for GNN: {vectors.shape}")
    
    # How it links to the Graph:
    sample_entity = "Houston"
    h_id = store._compute_hash(sample_entity)
    idx = store.hash_id_to_idx[h_id]
    print(f"The entity '{sample_entity}' is mapped to GNN Node Index: {idx}")

Step 1: First indexing...
  [AI Model] Encoding 3 new strings...

Step 2: Second indexing (with duplicates)...
  [AI Model] Encoding 2 new strings...

Final Store Size: 5 unique entities
Matrix Shape for GNN: (5, 8)
The entity 'Houston' is mapped to GNN Node Index: 3


In [48]:

import torch
import numpy as np
from torch_geometric.data import HeteroData
import hashlib

# --- SIMULATED DATA ---
triples = [
    ("The Collegian", "is publication of", "Houston Baptist University"),
    ("The Collegian", "founded in", "1963")
]
passages=["The Collegian is the bi-weekly official student publication of Houston Baptist University in Houston, Texas.",
    "It was founded in 1963 as a newsletter."]

sentences = [
    "The Collegian is the bi-weekly official student publication of Houston Baptist University in Houston, Texas.",
    "It was founded in 1963 as a newsletter."
]

sentence_entities = [
    ["The Collegian", "Houston Baptist University", "Houston", "Texas"],
    ["1963", "newsletter"]
]
structural_rel_emb = {
    "in": torch.randn(1, 128),      # "Contains" relationship
    "hv": torch.randn(1, 128),      # "Held by" relationship
    "seq": torch.randn(1, 128)      # "Sequential" relationship
}
def test_construction():
    data = HeteroData()
    EMB_DIM = 128
    
    # Unique entities
    all_entities = sorted(list(set([t[0] for t in triples] + [t[2] for t in triples] + [ent for sublist in sentence_entities for ent in sublist])))
    entity_to_idx = {ent: i for i, ent in enumerate(all_entities)}
    
    # 1. Nodes
    data['entity'].x = torch.randn(len(all_entities), EMB_DIM)
    data['sentence'].x = torch.randn(len(sentences), EMB_DIM)
    data['passage'].x = torch.randn(len(passages), EMB_DIM)

    # 2. Edges with Attributes
    # A. Entity -> Entity (from Triples)
    e2e_edges = []
    e2e_relation_texts = []
    for sub, rel, obj in triples:
        s_idx, o_idx = entity_to_idx[sub], entity_to_idx[obj]
        
        # A. FORWARD RELATION
        e2e_edges.append([s_idx, o_idx])
        e2e_relation_texts.append(rel)
        
        # B. INVERSE RELATION (The "Reasoning" Path)
        e2e_edges.append([o_idx, s_idx])
        e2e_relation_texts.append(f"inverse {rel}") # or use a mapping like 'is_part_of' -> 'contains'
        

        
    data['entity', 're', 'entity'].edge_index = torch.tensor(e2e_edges).t().contiguous()
    # In real code, your embedding model would encode these strings
    # Forward and Inverse edges now have distinct embeddings!
    data['entity', 're', 'entity'].edge_attr = torch.randn(len(e2e_relation_texts), EMB_DIM)

    # B. Sentence -> Passage (Hierarchical)
    s2p_edges = [[i, i] for i in range(len(sentences))]
    data['sentence', 'in', 'passage'].edge_index = torch.tensor([[i, 0] for i in range(len(sentences))]).t().contiguous()
    data['passage', 'hv', 'sentence'].edge_index = torch.tensor([[0, i] for i in range(len(sentences))]).t().contiguous()

    data['sentence', 'in', 'passage'].edge_attr = structural_rel_emb["in"].expand(len([[i, 0] for i in range(len(sentences))]), -1)
    data['passage', 'hv', 'sentence'].edge_attr = structural_rel_emb["hv"].expand(len([[0, i] for i in range(len(sentences))]), -1)

    # C. Entity -> Sentence
    e2s_edges = []
    for s_idx, ents in enumerate(sentence_entities):
        for ent in ents:
            e2s_edges.append([entity_to_idx[ent], s_idx])
    
    e2p_edges = []
    for ents in sentence_entities:
        for ent in ents:
            # We map the entity index to the passage index (which is 0 in this mock)
            # Use a check to avoid duplicate edges if an entity appears in multiple sentences
            edge = [entity_to_idx[ent], 0]
            if edge not in e2p_edges:
                e2p_edges.append(edge)

    data['entity', 'in', 'sentence'].edge_index = torch.tensor(e2s_edges).t().contiguous()
    data['sentence', 'hv', 'entity'].edge_index = torch.tensor([[v, k] for k, v in e2s_edges]).t().contiguous()

    # FIX: Use len(e2s_edges) instead of the sentences list
    data['entity', 'in', 'passage'].edge_attr = structural_rel_emb["in"].expand(len(e2p_edges), -1)
    data['passage', 'hv', 'entity'].edge_attr = structural_rel_emb["hv"].expand(len(e2p_edges), -1)

    # D. Sentence -> Sentence (Intra-passage flow)
    s2s_edges = []
    for i in range(len(sentences)):
        for j in range(len(sentences)):
            if i != j: s2s_edges.append([i, j])
            
    data['sentence', 're', 'sentence'].edge_index = torch.tensor(s2s_edges).t().contiguous()
    
    # FIX: Use len(s2s_edges) and assign after edge_index is created
    data['sentence', 're', 'sentence'].edge_attr = structural_rel_emb["seq"].expand(len(s2s_edges), -1)

    return data, e2e_relation_texts

if __name__ == "__main__":
    graph, relations = test_construction()
    print("--- Graph with Edge Attributes ---")
    print(graph)
    print(f"Total Entity Edges: {graph['entity', 're', 'entity'].edge_index.shape[1]}")
    for i, rel in enumerate(relations):
        src = graph['entity', 're', 'entity'].edge_index[0, i]
        tgt = graph['entity', 're', 'entity'].edge_index[1, i]
        print(f"  Edge {i}: {rel} ({src.item()} -> {tgt.item()})")


KeyError: '1'

In [41]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_scatter import scatter_add
from torch_geometric.nn import HeteroConv

# --- EDGE-AWARE MESSAGE PASSING LAYER (Strict Algorithm 2) ---
class EdgeAwareMessagePassing(nn.Module):
    def __init__(self, h_dim):
        super().__init__()
        self.h_dim = h_dim
        # MLP to compute attention score from [src, tgt, edge] only
        # Changed from 4 * h_dim to 3 * h_dim
        self.attn_mlp = nn.Linear(3 * h_dim, 1) 

    def forward(self, x, edge_index, edge_attr=None):
        if isinstance(x, tuple):
            src_h, tgt_h = x
        else:
            src_h = tgt_h = x # for enoty-- entity relation
        src, tgt = edge_index

        h_i = src_h[src]      # Source node features (h_i)
        h_j = tgt_h[tgt]      # Target node features (h_j)
        e_ij = edge_attr      # Edge attributes (e_ij)
        
        # Line 7 of Algorithm 2: Score = Attention(hi, hj, eij)
        # Query is no longer part of this local edge calculation
        combined = torch.cat([h_i, h_j, e_ij], dim=-1)
        score = self.attn_mlp(combined)
        
        # Line 9: Normalize attention weights across neighbors
        alpha = softmax(score, tgt)
        
        # Line 10: Aggregate neighbor features using the learned alpha
        out = scatter_add(alpha * h_i, tgt, dim=0, dim_size=tgt_h.size(0))
        return out

# Demonstration
if __name__ == "__main__":
    # ... (assuming graph construction is same as before) ...
    h_dim = 128
    #layer = EdgeAwareMessagePassing(h_dim)
    
    # Process entity-to-entity relations without needing a query_vec here
    # updated_entities = layer(
    #     src_h=graph['entity'].x,
    #     tgt_h=graph['entity'].x,
    #     edge_attr=graph['entity', 're', 'entity'].edge_attr,
    #     edge_index=graph['entity', 're', 'entity'].edge_index
    # )
    # updated_sentences = layer(
    #     src_h=graph['sentence'].x,
    #     tgt_h=graph['sentence'].x,
    #     edge_attr=graph['sentence', 're', 'sentence'].edge_attr,
    #     edge_index=graph['sentence', 're', 'sentence'].edge_index
    # )
    conv = HeteroConv({
        ('entity', 're', 'entity'): EdgeAwareMessagePassing(h_dim),
        ('sentence', 're', 'sentence'): EdgeAwareMessagePassing(h_dim),
        ('entity', 'in', 'sentence'): EdgeAwareMessagePassing(h_dim),
        ('sentence', 'in', 'passage'): EdgeAwareMessagePassing(h_dim),
        ('passage', 'hv', 'sentence'): EdgeAwareMessagePassing(h_dim),
    }, aggr='sum')
    out_dict = conv(
        x_dict=graph.x_dict, 
        edge_index_dict=graph.edge_index_dict,
        edge_attr_dict=graph.edge_attr_dict # Passes the right attr to the right layer!
    )
    print("\n--- Edge-Aware Message Passing Result (Algorithm 2) ---")
    print(f"Updated Entities Shape: {len(out_dict)}")


--- Edge-Aware Message Passing Result (Algorithm 2) ---
Updated Entities Shape: 3


In [None]:
class AuditableHybridGNN(torch.nn.Module):
    def __init__(self, node_types, edge_types, hidden_dim):
        # Local Track: Heterogeneous Graph Transformer (HGT)
        # Captures: "Sentence IN Passage", "Entity RE Entity" (The Audit Path)
        self.local_hgt = HGTConv(hidden_dim, hidden_dim, metadata)
        
        # Global Track: SGFormer-style All-to-All Attention
        # Captures: Semantic connections across the whole KG (Pluralistic Oversight)
        self.global_attn = MultiheadAttention(hidden_dim, heads=4)
        
        # Query-Gating Layer: Forces model to align with User Intent
        self.query_gate = Linear(hidden_dim * 2, hidden_dim)

    def forward(self, hetero_data, query_emb):
        # 1. Structural Reasoning (Local Track)
        # We extract attention weights here for the "Audit Trail" visualization
        h_dict, local_attn_weights = self.local_hgt(hetero_data.x_dict, hetero_data.edge_index_dict)
        
        # 2. Semantic Cross-Referencing (Global Track)
        # Convert to homogeneous to let every node talk to every other node
        h_all = to_homogeneous(h_dict)
        h_global = self.global_attn(h_all, h_all, h_all) 
        
        # 3. Query Alignment (The QSGNN Secret Sauce)
        # Scale every node's importance based on the query signal
        aligned_h_dict = {}
        for node_type, h in h_dict.items():
            # Concatenate node feature with query vector
            gate_input = torch.cat([h, query_emb.repeat(h.size(0), 1)], dim=-1)
            importance_score = self.query_gate(gate_input).sigmoid()
            # The node is updated by both local path and global verification
            aligned_h_dict[node_type] = (h + h_global[node_type]) * importance_score
            
        # 4. Final Scoring (Retrieval Head)
        # Output relevance scores for the Documents (D) or Chunks (C)
        return self.scoring_head(aligned_h_dict['passage'], query_emb)

In [None]:
import torch
import torch.nn as nn
from torch_geometric.nn import HGTConv, Linear
from torch_geometric.utils import to_homogeneous, from_homogeneous

class AuditableHybridGNN(torch.nn.Module):
    def __init__(self, metadata, hidden_dim, out_channels, num_heads=4):
        super().__init__()
        # 1. Local Track (Heterogeneous Graph Transformer)
        # metadata contains (node_types, edge_types)
        self.local_hgt = HGTConv(hidden_dim, hidden_dim, metadata, num_heads)
        
        # 2. Global Track (SGFormer style)
        self.global_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)
        
        # 3. Query Alignment Gate
        # We take node features + query features
        self.query_gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        # 4. Final Scoring Head (The Learned Scoring Head from Algorithm 4)
        self.scoring_head = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, x_dict, edge_index_dict, query_emb):
        # --- PHASE 1: Structural Reasoning (Local) ---
        # HGT processes nodes based on their specific relationships
        h_dict = self.local_hgt(x_dict, edge_index_dict)
        
        # --- PHASE 2: Semantic Oversight (Global) ---
        # 1. Collapse to homogeneous to allow all-to-all attention
        # 'edge_index' is ignored here because we want a full attention matrix
        h_homo, _ = self.collapse_to_homogeneous(h_dict) 
        
        # 2. Global Self-Attention (every node looks at every other node)
        # h_homo: [num_total_nodes, hidden_dim] -> add batch dim for MHA
        h_global, _ = self.global_attn(query_emb, h_homo.unsqueeze(0), h_homo.unsqueeze(0))
        h_global = h_global.squeeze(0)
        
        # --- PHASE 3: Query Gating & Fusion ---
        out_dict = {}
        for node_type, h_local in h_dict.items():
            # Fuse local and global knowledge
            # (In reality, we map the global embeddings back to their node types)
            h_fused = h_local + self.extract_global_for_type(h_global, node_type)
            
            # Query Alignment: Scale node importance based on query
            # Repeat query_emb to match number of nodes of this type
            q_expanded = query_emb.expand(h_fused.size(0), -1)
            gate_input = torch.cat([h_fused, q_expanded], dim=-1)
            importance = self.query_gate(gate_input)
            
            out_dict[node_type] = h_fused * importance

        # --- PHASE 4: Final Retrieval Scoring ---
        # Scoring based on Passage nodes as in RAG
        passage_embeddings = out_dict['passage']
        q_for_scoring = query_emb.expand(passage_embeddings.size(0), -1)
        scores = self.scoring_head(torch.cat([passage_embeddings, q_for_scoring], dim=-1))
        
        return scores

In [47]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import HGTConv
#from torch_geometric.utils import to_homogeneous

# --- 1. THE HYBRID MODEL ---
import torch
import torch.nn as nn
from torch_geometric.nn import HGTConv
from torch_geometric.utils import scatter


class AuditableHybridGNN(nn.Module):
    def __init__(self, metadata, hidden_dim):
        super().__init__()
        self.local_hgt = HGTConv(hidden_dim, hidden_dim, metadata, heads=4)
        
        # Cross-Attention: Query attends to Entities ONLY
        self.entity_global_attn = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)
        # --- NEW: STABILITY LAYERS ---
        # LayerNormy ensures that adding two vectors doesn't make the values "explode"
        self.entity_norm = nn.LayerNorm(hidden_dim)
        self.passage_norm = nn.LayerNorm(hidden_dim)
        self.alpha = 0.1
        # Scoring Head
        # SCORING HEAD
        self.scoring_head = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x_dict, edge_index_dict, query_emb):
        if query_emb.dim() == 1:
            query_emb = query_emb.unsqueeze(0)
        # 1. Local Track (HGT)
        h_dict = self.local_hgt(x_dict, edge_index_dict)
        
        # 2. Global Entity Reasoning
        # query_emb: [1, 1, dim]
        # entities: [Num_Entities, dim]
        entities = h_dict['entity'].unsqueeze(0) # [1, Num_Entities, dim]
        query = query_emb.unsqueeze(0)           # [1, 1, dim]
        
        # Query acts as the 'Key' and 'Value', Entities are the 'Query'
        # Or vice-versa: Let Entities attend to each other globally
        h_local = h_dict['entity'] 
        h_ent_global, _ = self.entity_global_attn(entities, entities, entities)
        #overrighting - means catestraophic forgetting
        #h_dict['entity'] = h_ent_global.squeeze(0)
        # 3. ADD them instead of OVERWRITING (Residual Connection)
        # This ensures the model keeps BOTH structural and semantic info.
        #h_dict['entity'] = h_local + h_ent_global
        h_dict['entity'] = self.entity_norm((1 - self.alpha) * h_local + self.alpha * h_ent_global) # weighted fushion and normalization
        
        # 3. The "Broadcast" Step
        # Now, update Passages using the NEW Entity info
        # We can use the 'entity_in_passage' edges to pull info
        # or a simple mean-pooling of entities per passage
        # 3. The "Broadcast" Step (Query-Guided)
        e2p_index = edge_index_dict[('entity', 'in', 'passage')]
        ent_idx, psg_idx = e2p_index

        # Step A: Calculate which entities match the query (Relevance Score)
        # Shape: [Num_Entities]
        q_expanded = query_emb.expand(h_dict['entity'].size(0), -1)
        relevance = torch.sum(h_dict['entity'] * q_expanded, dim=-1).sigmoid()

        # Step B: Weight the entity features before sending them to the passage
        # This effectively "mutes" irrelevant entities in the document
        weighted_ent_features = h_dict['entity'][ent_idx] * relevance[ent_idx].unsqueeze(-1)

        # Step C: Aggregate (Sum) to get the "Query-Relevant Document Context"
        psg_context = scatter(src=weighted_ent_features, 
                            index=psg_idx, 
                            dim=0, 
                            dim_size=h_dict['passage'].size(0), 
                            reduce='sum')

        # Step D: Update
        h_dict['passage'] = self.passage_norm(h_dict['passage'] + psg_context)


        # 4. Final Scoring
        passages = h_dict['passage']
        q_scoring = query_emb.expand(passages.size(0), -1)
        
        # Combine Passage info + Query info
        return self.scoring_head(torch.cat([passages, q_scoring], dim=-1)).squeeze()


# --- 2. THE TRAINING & EVALUATION LOGIC ---
def train_step(model, data, query_emb, target_passage_idx, optimizer, criterion):
    model.train() # Set to training mode (enables dropout, etc.)
    optimizer.zero_grad() # Clear previous gradients
    
    # 1. Forward Pass
    # We pass the graph data and the query we are looking for
    predicted_scores = model(data.x_dict, data.edge_index_dict, query_emb)
    
    # 2. Calculate Loss
    # We want the 'target_passage_idx' to have the highest score
    # CrossEntropy expects (all_scores, index_of_the_right_one)
    loss = criterion(predicted_scores.unsqueeze(0), torch.tensor([target_passage_idx]).to(predicted_scores.device))
    
    # 3. Backward Pass (The "Update" phase)
    loss.backward() # Calculate gradients
    optimizer.step() # Update weights based on gradients
    
    return loss.item()

def test_step(model, data, query_emb, target_passage_idx):
    model.eval() # Set to evaluation mode (disables dropout)
    with torch.no_grad(): # Don't calculate gradients (saves memory/time)
        scores = model(data.x_dict, data.edge_index_dict, query_emb)
        # Find which passage got the highest score
        pred_idx = torch.argmax(scores).item()
        is_correct = (pred_idx == target_passage_idx)
    return is_correct

# --- 3. EXECUTION ---
# Setup
hidden_dim = 128
# Assuming 'graph' is your HeteroData object from previous steps
model = AuditableHybridGNN(graph.metadata(), hidden_dim)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Mock training data (1 query, looking for passage index 5)
mock_query = torch.randn(1, hidden_dim)
gold_passage = 5

# Training Loop
for epoch in range(10):
    loss = train_step(model, graph, mock_query, gold_passage, optimizer, criterion)
    correct = test_step(model, graph, mock_query, gold_passage)
    print(f"Epoch {epoch} | Loss: {loss:.4f} | Correct: {correct}")

IndexError: Found indices in 'edge_index' that are larger than 8 (got 9). Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, 9) in your node feature matrix and try again.

## lets construct the knowledge graph

In [1]:
import json
dataset_name = "musique"

corpus_path = f"dataset/{dataset_name}_corpus.json"
with open(corpus_path, "r") as f:
    corpus = json.load(f)

docs = [f"{doc['title']}\n{doc['text']}" for doc in corpus]

In [2]:
docs[0]

'Catalan language\nThe Germanic superstrate has had different outcomes in Spanish and Catalan. For example, Catalan fang "mud" and rostir "to roast", of Germanic origin, contrast with Spanish lodo and asar, of Latin origin; whereas Catalan filosa "spinning wheel" and pols "temple", of Latin origin, contrast with Spanish rueca and sien, of Germanic origin.'

In [None]:
self.chunk_embedding_store.insert_strings(docs)
chunk_to_rows = self.chunk_embedding_store.get_all_id_to_rows()

In [3]:
from hashlib import md5

def compute_mdhash_id(content: str, prefix: str = "") -> str:
    return prefix + md5(content.encode()).hexdigest()


In [4]:
embeddings = []
texts = []
hash_ids = []
def _upsert(hash_ids, texts, embeddings):
    embeddings.extend(embeddings)
    hash_ids.extend(hash_ids)
    texts.extend(texts)

    #logger.info(f"Saving new records.")
    #self._save_data()

In [4]:
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoModel
from typing import List, Optional

class NVEmbedV2EmbeddingModel:
    def __init__(self, model_name: str = "nvidia/NV-Embed-v2"):
        """
        Initializes the NV-Embed-v2 model.
        """
        self.model_name = model_name
        
        # Load the model with trust_remote_code=True for NV-Embed
        self.model = AutoModel.from_pretrained(
            model_name,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else None
        )
        
    def batch_encode(
        self, 
        texts: List[str], 
        batch_size: int = 16, 
        max_length: int = 32768, 
        instruction: str = "", 
        norm: bool = True
    ) -> np.ndarray:
        """
        Encodes a list of strings into embeddings.
        """
        if isinstance(texts, str):
            texts = [texts]

        # Prepare instructions according to NV-Embed format
        formatted_instruction = f"Instruct: {instruction}\nQuery: " if instruction else ""

        results = []
        if len(texts) <= batch_size:
            # Single batch
            results = self.model.encode(
                prompts=texts, 
                instruction=formatted_instruction, 
                max_length=max_length
            )
        else:
            # Multiple batches with progress bar
            for i in tqdm(range(0, len(texts), batch_size), desc="Batch Encoding"):
                batch_texts = texts[i : i + batch_size]
                batch_results = self.model.encode(
                    prompts=batch_texts, 
                    instruction=formatted_instruction, 
                    max_length=max_length
                )
                results.append(batch_results)
            results = torch.cat(results, dim=0)

        # Convert to numpy
        if isinstance(results, torch.Tensor):
            results = results.cpu().float().numpy()
            
        # L2 Normalization
        if norm:
            results = (results.T / np.linalg.norm(results, axis=1)).T

        return results

# --- Usage ---
# 1. Initialize the model (only needs to be done once)
# model_instance = NVEmbedV2EmbeddingModel()

# 2. Encode your texts
# missing_embeddings = model_instance.batch_encode(texts_to_encode)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

nodes_dict = {}
hashes=[]
content=[]
for text in docs:
    hash_id = compute_mdhash_id(text, prefix='docs' + "-")
    nodes_dict[hash_id] = {'content': text}
    hashes.append(hash_id)
    content.append(text)

hash_id_to_row = {
                h: {"hash_id": h, "content": t}
                for h, t in zip(hashes,content)
            }

# Get all hash_ids from the input dictionary.
all_hash_ids = list(nodes_dict.keys())
#if not all_hash_ids:
#    return  # Nothing to insert.

existing = hash_id_to_row.keys()

# Filter out the missing hash_ids.
missing_ids = [hash_id for hash_id in all_hash_ids if hash_id not in existing]

print(f"Inserting {len(missing_ids)} new records, {len(all_hash_ids) - len(missing_ids)} records already exist.")

#if not missing_ids:
#    return  {}# All records already exist.

# Prepare the texts to encode from the "content" field.
texts_to_encode = [nodes_dict[hash_id]["content"] for hash_id in missing_ids]


# --- Usage ---
# 1. Initialize the model (only needs to be done once)
model_instance = NVEmbedV2EmbeddingModel()

# 2. Encode your texts
missing_embeddings = model_instance.batch_encode(texts_to_encode)

Inserting 0 new records, 15803 records already exist.


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards:  25%|██▌       | 1/4 [00:08<00:26,  8.94s/it]