# G-RAG Focused Learning 2: Document Graph Construction

## Learning Objective
Master the document graph construction methodology in G-RAG, where documents become nodes and shared AMR concepts create edges for improved reranking.

## Paper Context

### Key Paper Sections:
- **Section 3.1**: Establishing Document Graphs via AMR
- **Section 3.2.2**: Edge Features
- **Figure 1**: G-RAG framework illustration

### Paper Quote (Section 3.1):
> *"For each node vi ∈ V, it corresponds to the document pi. For vi, vj ∈ V, i ≠ j, if the corresponding AMR Gqpi and Gqpj have common nodes, there will be an undirected edge between vi and vj denoted as eij = (vi, vj) ∈ E. We remove isolated nodes in Gq."*

### Innovation:
1. **Document-Level Graphs**: Unlike traditional approaches that work with text-level relations, G-RAG creates document-level connections
2. **AMR-Based Edges**: Edges represent shared semantic concepts, not just keyword overlap
3. **Cross-Document Reasoning**: Enables discovering relevant documents through their connections to other relevant documents

### Why This Matters:
- **Weak Connection Problem**: Documents with valuable information but weak direct connection to queries
- **Transitive Relevance**: Document A might be relevant to query Q through its connection to clearly relevant document B
- **Semantic Clustering**: Groups semantically related documents for better ranking decisions

## Theoretical Foundation

### Document Graph Definition

Given a question $q$ and retrieved documents $\{p_1, p_2, ..., p_n\}$:

**Document Graph**: $G_q = \{V, E\}$ where:
- **Vertices (V)**: Each $v_i \in V$ corresponds to document $p_i$
- **Edges (E)**: $e_{ij} = (v_i, v_j) \in E$ if AMR graphs $G_{qp_i}$ and $G_{qp_j}$ share common nodes

### Edge Feature Computation (Section 3.2.2)

For edge between documents $i$ and $j$:
$$\hat{E}_{ij} = \begin{cases}
0 & \text{if no connection between } G_{qp_i} \text{ and } G_{qp_j} \\
\begin{bmatrix}
\text{# common nodes} \\
\text{# common edges}
\end{bmatrix} & \text{otherwise}
\end{cases}$$

### Key Insights:
1. **Semantic Similarity**: Shared AMR concepts indicate semantic similarity
2. **Transitive Information**: Information can flow through the graph
3. **Contextual Relevance**: Documents gain relevance through their neighbors

In [None]:
# Setup environment for document graph construction
!pip install networkx matplotlib seaborn numpy pandas
!pip install scikit-learn

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from collections import defaultdict, Counter
from typing import List, Dict, Tuple, Set, Optional
import itertools
import re
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer

# Set visualization style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("Environment ready for document graph construction!")

## AMR Graph Simulator for Document Graphs

Building on the AMR processing from the previous notebook, we create a specialized simulator for document graph construction.

In [None]:
class DocumentAMRSimulator:
    """
    Specialized AMR simulator for document graph construction.
    
    Focuses on creating realistic AMR graphs that enable meaningful
    document-level connections as described in the G-RAG paper.
    """
    
    def __init__(self):
        # Domain-specific concept mappings for better document connections
        self.concept_domains = {
            'music': {
                'concepts': ['sing-01', 'compose-01', 'perform-01', 'album', 'song', 'music', 'artist', 'stage'],
                'entities': ['sinatra', 'williams', 'orchestra', 'symphony']
            },
            'geography': {
                'concepts': ['capital', 'city', 'country', 'location', 'government', 'parliament'],
                'entities': ['australia', 'canberra', 'sydney', 'melbourne']
            },
            'technology': {
                'concepts': ['release-01', 'create-01', 'invent-01', 'product', 'company', 'innovation'],
                'entities': ['iphone', 'apple', 'smartphone', 'technology']
            },
            'science': {
                'concepts': ['cause-01', 'phenomenon', 'particle', 'magnetic-field', 'interact-01'],
                'entities': ['aurora', 'sun', 'earth', 'atmosphere']
            }
        }
        
        # Common AMR relations for more realistic graphs
        self.relations = [
            ':ARG0', ':ARG1', ':ARG2',  # Core arguments
            ':mod', ':domain', ':poss',  # Modifiers
            ':location', ':time', ':manner',  # Adjuncts
            ':name', ':wiki'  # Named entity relations
        ]
    
    def identify_document_domain(self, text: str) -> str:
        """Identify the primary domain of a document for targeted concept extraction."""
        text_lower = text.lower()
        domain_scores = {}
        
        for domain, domain_data in self.concept_domains.items():
            score = 0
            for concept in domain_data['concepts']:
                if concept.replace('-01', '') in text_lower:
                    score += 2
            for entity in domain_data['entities']:
                if entity in text_lower:
                    score += 3
            domain_scores[domain] = score
        
        return max(domain_scores, key=domain_scores.get) if max(domain_scores.values()) > 0 else 'general'
    
    def extract_domain_concepts(self, text: str, domain: str) -> List[str]:
        """Extract concepts relevant to the identified domain."""
        concepts = ['question']  # Always include question node
        text_lower = text.lower()
        
        if domain in self.concept_domains:
            domain_data = self.concept_domains[domain]
            
            # Add domain-specific concepts
            for concept in domain_data['concepts']:
                if concept.replace('-01', '') in text_lower:
                    concepts.append(concept)
            
            # Add domain-specific entities
            for entity in domain_data['entities']:
                if entity in text_lower:
                    concepts.append(entity)
        
        # Add general concepts from text
        words = re.findall(r'\b[a-zA-Z]{4,}\b', text_lower)
        for word in words[:10]:  # Limit to avoid too large graphs
            if word not in ['question', 'document', 'text'] and len(word) > 3:
                concepts.append(word)
        
        return list(set(concepts))
    
    def create_document_amr(self, question: str, document: str) -> nx.DiGraph:
        """Create AMR graph for a question-document pair."""
        combined_text = f"question: {question} {document}"
        
        # Identify domain and extract relevant concepts
        domain = self.identify_document_domain(combined_text)
        concepts = self.extract_domain_concepts(combined_text, domain)
        
        # Create directed graph
        G = nx.DiGraph()
        G.add_nodes_from(concepts)
        
        # Add metadata
        G.graph['domain'] = domain
        G.graph['question'] = question
        G.graph['document_id'] = hash(document) % 10000
        
        # Create edges based on semantic relationships
        self._add_semantic_edges(G, question, document)
        
        return G
    
    def _add_semantic_edges(self, G: nx.DiGraph, question: str, document: str):
        """Add edges based on semantic relationships."""
        nodes = list(G.nodes())
        
        # Connect question node to relevant concepts
        if 'question' in nodes:
            question_words = set(question.lower().split())
            for node in nodes:
                if node != 'question' and (node in question_words or 
                                         any(qw in node for qw in question_words)):
                    G.add_edge('question', node, relation=':ARG1')
        
        # Connect co-occurring concepts in sentences
        sentences = re.split(r'[.!?]+', document)
        for sentence in sentences:
            sentence_lower = sentence.lower()
            sentence_nodes = [node for node in nodes if node in sentence_lower]
            
            # Connect concepts that appear together
            for i, node1 in enumerate(sentence_nodes):
                for node2 in sentence_nodes[i+1:]:
                    if not G.has_edge(node1, node2):
                        relation = np.random.choice(self.relations)
                        G.add_edge(node1, node2, relation=relation)
    
    def get_graph_statistics(self, G: nx.DiGraph) -> Dict:
        """Get statistics for an AMR graph."""
        return {
            'nodes': G.number_of_nodes(),
            'edges': G.number_of_edges(),
            'domain': G.graph.get('domain', 'unknown'),
            'has_question': 'question' in G.nodes(),
            'density': nx.density(G),
            'avg_degree': sum(dict(G.degree()).values()) / G.number_of_nodes() if G.number_of_nodes() > 0 else 0
        }

# Test the document AMR simulator
doc_amr_sim = DocumentAMRSimulator()

# Sample question and documents
sample_question = "What is the nickname of Frank Sinatra?"
sample_documents = [
    "Frank Sinatra was an American singer known for his blue eyes, earning him the nickname 'Ol' Blue Eyes'.",
    "The famous musician Sinatra performed on stage with his distinctive blue eyes captivating audiences.",
    "Many singers in the music industry have unique characteristics that make them memorable.",
    "Blue eyes are a common physical feature among many people in the entertainment world."
]

# Create AMR graphs for each document
amr_graphs = []
for i, doc in enumerate(sample_documents):
    amr = doc_amr_sim.create_document_amr(sample_question, doc)
    amr_graphs.append(amr)
    
    stats = doc_amr_sim.get_graph_statistics(amr)
    print(f"Document {i+1} AMR Stats: {stats}")
    print(f"  Text: {doc[:60]}...")
    print(f"  Concepts: {list(amr.nodes())[:8]}...")
    print()

print(f"Created {len(amr_graphs)} AMR graphs for document graph construction.")

## Document Graph Construction - Core Algorithm

### Implementation of Paper's Methodology
Following Section 3.1 and 3.2.2, we implement the exact algorithm for creating document graphs from AMR representations.

In [None]:
class DocumentGraphBuilder:
    """
    Implements the document graph construction algorithm from G-RAG paper.
    
    Key Features:
    - Creates document-level graphs from AMR representations
    - Computes edge features based on shared concepts and relations
    - Handles isolated node removal as specified in paper
    """
    
    def __init__(self, min_shared_concepts: int = 1):
        self.min_shared_concepts = min_shared_concepts
    
    def compute_amr_similarity(self, amr1: nx.DiGraph, amr2: nx.DiGraph) -> Tuple[int, int]:
        """
        Compute shared concepts and relations between two AMR graphs.
        
        Returns:
            (common_nodes, common_edges) as specified in paper Section 3.2.2
        """
        # Get nodes (concepts) from both graphs
        nodes1 = set(amr1.nodes())
        nodes2 = set(amr2.nodes())
        
        # Count common nodes (excluding 'question' node for meaningful comparison)
        common_nodes = len(nodes1 & nodes2)
        if 'question' in (nodes1 & nodes2):
            common_nodes = max(0, common_nodes - 1)  # Don't count question node
        
        # Get edges (relations) from both graphs
        edges1 = set(amr1.edges())
        edges2 = set(amr2.edges())
        
        # Count common edges
        common_edges = len(edges1 & edges2)
        
        return common_nodes, common_edges
    
    def build_document_graph(self, 
                           question: str, 
                           documents: List[str], 
                           amr_graphs: List[nx.DiGraph]) -> Dict:
        """
        Build document graph following the exact methodology from G-RAG paper.
        
        Args:
            question: The input question
            documents: List of retrieved documents
            amr_graphs: Corresponding AMR graphs for each document
            
        Returns:
            Dictionary containing graph data ready for GNN processing
        """
        n_docs = len(documents)
        
        # Initialize adjacency matrix and edge features
        adjacency = np.zeros((n_docs, n_docs), dtype=int)
        edge_features = np.zeros((n_docs, n_docs, 2))  # [common_nodes, common_edges]
        
        # Compute pairwise similarities and build adjacency matrix
        connections = []
        for i in range(n_docs):
            for j in range(i + 1, n_docs):
                common_nodes, common_edges = self.compute_amr_similarity(
                    amr_graphs[i], amr_graphs[j]
                )
                
                # Create edge if documents share sufficient concepts
                if common_nodes >= self.min_shared_concepts:
                    adjacency[i, j] = adjacency[j, i] = 1
                    edge_features[i, j] = edge_features[j, i] = [common_nodes, common_edges]
                    
                    connections.append({
                        'doc_i': i,
                        'doc_j': j,
                        'common_nodes': common_nodes,
                        'common_edges': common_edges,
                        'strength': common_nodes + 0.5 * common_edges
                    })
        
        # Remove isolated nodes as specified in paper
        connected_docs = set()
        for i in range(n_docs):
            if np.sum(adjacency[i, :]) > 0:  # Node has at least one connection
                connected_docs.add(i)
        
        # Normalize edge features to prevent explosive scaling (as mentioned in paper)
        normalized_edge_features = edge_features.copy()
        if edge_features.max() > 0:
            normalized_edge_features[:, :, 0] = edge_features[:, :, 0] / edge_features[:, :, 0].max()
            if edge_features[:, :, 1].max() > 0:
                normalized_edge_features[:, :, 1] = edge_features[:, :, 1] / edge_features[:, :, 1].max()
        
        return {
            'adjacency': adjacency,
            'edge_features': edge_features,
            'normalized_edge_features': normalized_edge_features,
            'connections': connections,
            'connected_docs': connected_docs,
            'isolated_docs': set(range(n_docs)) - connected_docs,
            'n_documents': n_docs,
            'question': question,
            'documents': documents,
            'amr_graphs': amr_graphs
        }
    
    def analyze_graph_structure(self, graph_data: Dict) -> Dict:
        """
        Analyze the structure of the created document graph.
        """
        adjacency = graph_data['adjacency']
        n_docs = graph_data['n_documents']
        connections = graph_data['connections']
        
        # Create NetworkX graph for analysis
        G = nx.Graph()
        G.add_nodes_from(range(n_docs))
        
        for conn in connections:
            G.add_edge(conn['doc_i'], conn['doc_j'], 
                      weight=conn['strength'])
        
        # Compute graph metrics
        analysis = {
            'total_documents': n_docs,
            'total_connections': len(connections),
            'connected_documents': len(graph_data['connected_docs']),
            'isolated_documents': len(graph_data['isolated_docs']),
            'graph_density': nx.density(G),
            'average_degree': sum(dict(G.degree()).values()) / n_docs if n_docs > 0 else 0,
            'clustering_coefficient': nx.average_clustering(G) if G.number_of_edges() > 0 else 0,
            'connected_components': nx.number_connected_components(G),
            'largest_component_size': len(max(nx.connected_components(G), key=len)) if G.number_of_edges() > 0 else 1
        }
        
        # Connection strength statistics
        if connections:
            strengths = [conn['strength'] for conn in connections]
            analysis.update({
                'avg_connection_strength': np.mean(strengths),
                'max_connection_strength': np.max(strengths),
                'min_connection_strength': np.min(strengths)
            })
        
        return analysis
    
    def visualize_document_graph(self, graph_data: Dict, 
                                positive_docs: Optional[List[int]] = None,
                                title: str = "Document Graph",
                                figsize: Tuple[int, int] = (14, 10)):
        """
        Visualize the document graph with enhanced information display.
        """
        adjacency = graph_data['adjacency']
        connections = graph_data['connections']
        documents = graph_data['documents']
        n_docs = graph_data['n_documents']
        
        # Create NetworkX graph
        G = nx.Graph()
        G.add_nodes_from(range(n_docs))
        
        for conn in connections:
            G.add_edge(conn['doc_i'], conn['doc_j'], 
                      weight=conn['strength'],
                      common_nodes=conn['common_nodes'],
                      common_edges=conn['common_edges'])
        
        plt.figure(figsize=figsize)
        
        # Use spring layout with good spacing
        pos = nx.spring_layout(G, k=3, iterations=50, seed=42)
        
        # Color nodes based on positive/negative status
        node_colors = []
        for i in range(n_docs):
            if positive_docs and i in positive_docs:
                node_colors.append('lightgreen')  # Positive documents
            elif i in graph_data['isolated_docs']:
                node_colors.append('lightgray')   # Isolated documents
            else:
                node_colors.append('lightblue')   # Connected negative documents
        
        # Draw nodes with size based on degree
        node_sizes = [1000 + 200 * G.degree(node) for node in G.nodes()]
        nx.draw_networkx_nodes(G, pos, node_color=node_colors, 
                              node_size=node_sizes, alpha=0.8)
        
        # Draw edges with thickness based on connection strength
        for (u, v, d) in G.edges(data=True):
            weight = d['weight']
            nx.draw_networkx_edges(G, pos, [(u, v)], 
                                  width=max(1, weight * 2), alpha=0.6)
        
        # Draw node labels
        labels = {i: f"D{i}" for i in range(n_docs)}
        nx.draw_networkx_labels(G, pos, labels, font_size=12, font_weight='bold')
        
        # Draw edge labels with connection strength
        edge_labels = {}
        for (u, v, d) in G.edges(data=True):
            edge_labels[(u, v)] = f"{d['common_nodes']}"
        nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=8)
        
        plt.title(title, fontsize=16, fontweight='bold')
        plt.axis('off')
        
        # Add legend
        legend_elements = [
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgreen', 
                      markersize=15, label='Positive Documents'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightblue', 
                      markersize=15, label='Connected Documents'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgray', 
                      markersize=15, label='Isolated Documents')
        ]
        plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1, 1))
        
        plt.tight_layout()
        plt.show()
        
        # Print detailed connection information
        print(f"\nDocument Graph Details:")
        print("-" * 50)
        
        for i, doc in enumerate(documents):
            status = "[POSITIVE]" if positive_docs and i in positive_docs else "[NEGATIVE]"
            isolated = "[ISOLATED]" if i in graph_data['isolated_docs'] else ""
            print(f"Document {i} {status}{isolated}: {doc[:80]}...")
        
        print(f"\nConnections:")
        for conn in connections:
            i, j = conn['doc_i'], conn['doc_j']
            strength = conn['strength']
            nodes = conn['common_nodes']
            edges = conn['common_edges']
            print(f"  D{i} ↔ D{j}: {nodes} shared concepts, {edges} shared relations (strength: {strength:.2f})")

# Test the document graph builder
graph_builder = DocumentGraphBuilder(min_shared_concepts=1)

# Build document graph
graph_data = graph_builder.build_document_graph(
    sample_question, sample_documents, amr_graphs
)

# Analyze graph structure
analysis = graph_builder.analyze_graph_structure(graph_data)

print("DOCUMENT GRAPH CONSTRUCTION RESULTS")
print("=" * 50)
print(f"Question: {sample_question}")
print(f"Documents: {len(sample_documents)}")
print()

print("Graph Analysis:")
for key, value in analysis.items():
    print(f"  {key}: {value}")
print()

# Visualize with positive document marking
positive_docs = [0, 1]  # First two documents are positive
graph_builder.visualize_document_graph(
    graph_data, positive_docs, 
    "Document Graph: Frank Sinatra Nickname Question"
)

## Edge Feature Analysis

### Deep Dive into Edge Features (Section 3.2.2)
Understanding how edge features capture semantic relationships between documents.

In [None]:
class EdgeFeatureAnalyzer:
    """
    Analyzes edge features in document graphs to understand 
    semantic relationships between documents.
    
    Implements detailed analysis of the edge feature computation
    described in Section 3.2.2 of the G-RAG paper.
    """
    
    def __init__(self):
        pass
    
    def analyze_edge_features(self, graph_data: Dict) -> Dict:
        """
        Perform comprehensive analysis of edge features.
        """
        edge_features = graph_data['edge_features']
        connections = graph_data['connections']
        n_docs = graph_data['n_documents']
        
        analysis = {
            'total_edges': len(connections),
            'edge_density': len(connections) / (n_docs * (n_docs - 1) / 2) if n_docs > 1 else 0
        }
        
        if connections:
            # Analyze shared concepts (first dimension)
            shared_concepts = [conn['common_nodes'] for conn in connections]
            analysis.update({
                'avg_shared_concepts': np.mean(shared_concepts),
                'max_shared_concepts': np.max(shared_concepts),
                'min_shared_concepts': np.min(shared_concepts),
                'std_shared_concepts': np.std(shared_concepts)
            })
            
            # Analyze shared relations (second dimension)
            shared_relations = [conn['common_edges'] for conn in connections]
            analysis.update({
                'avg_shared_relations': np.mean(shared_relations),
                'max_shared_relations': np.max(shared_relations),
                'min_shared_relations': np.min(shared_relations),
                'std_shared_relations': np.std(shared_relations)
            })
            
            # Correlation analysis
            if len(shared_concepts) > 1:
                correlation = np.corrcoef(shared_concepts, shared_relations)[0, 1]
                analysis['concept_relation_correlation'] = correlation if not np.isnan(correlation) else 0
        
        return analysis
    
    def visualize_edge_features(self, graph_data: Dict, figsize: Tuple[int, int] = (15, 5)):
        """
        Visualize edge feature distributions and relationships.
        """
        connections = graph_data['connections']
        
        if not connections:
            print("No connections found in the graph.")
            return
        
        fig, axes = plt.subplots(1, 3, figsize=figsize)
        
        # Extract data
        shared_concepts = [conn['common_nodes'] for conn in connections]
        shared_relations = [conn['common_edges'] for conn in connections]
        strengths = [conn['strength'] for conn in connections]
        
        # Plot 1: Shared concepts distribution
        axes[0].hist(shared_concepts, bins=max(1, len(set(shared_concepts))), 
                    alpha=0.7, color='skyblue', edgecolor='black')
        axes[0].set_title('Shared Concepts Distribution')
        axes[0].set_xlabel('Number of Shared Concepts')
        axes[0].set_ylabel('Frequency')
        axes[0].grid(True, alpha=0.3)
        
        # Plot 2: Shared relations distribution
        axes[1].hist(shared_relations, bins=max(1, len(set(shared_relations))), 
                    alpha=0.7, color='lightcoral', edgecolor='black')
        axes[1].set_title('Shared Relations Distribution')
        axes[1].set_xlabel('Number of Shared Relations')
        axes[1].set_ylabel('Frequency')
        axes[1].grid(True, alpha=0.3)
        
        # Plot 3: Relationship between concepts and relations
        scatter = axes[2].scatter(shared_concepts, shared_relations, 
                                 c=strengths, cmap='viridis', alpha=0.7, s=100)
        axes[2].set_title('Concepts vs Relations')
        axes[2].set_xlabel('Shared Concepts')
        axes[2].set_ylabel('Shared Relations')
        axes[2].grid(True, alpha=0.3)
        
        # Add colorbar for strength
        cbar = plt.colorbar(scatter, ax=axes[2])
        cbar.set_label('Connection Strength')
        
        plt.tight_layout()
        plt.show()
    
    def create_edge_feature_heatmap(self, graph_data: Dict, feature_type: str = 'concepts'):
        """
        Create heatmap showing edge features between all document pairs.
        
        Args:
            feature_type: 'concepts' or 'relations'
        """
        edge_features = graph_data['edge_features']
        documents = graph_data['documents']
        n_docs = graph_data['n_documents']
        
        # Select feature dimension
        feature_idx = 0 if feature_type == 'concepts' else 1
        feature_matrix = edge_features[:, :, feature_idx]
        
        plt.figure(figsize=(10, 8))
        
        # Create heatmap
        mask = np.triu(np.ones_like(feature_matrix, dtype=bool), k=1)
        sns.heatmap(feature_matrix, 
                   mask=mask,
                   annot=True, 
                   cmap='YlOrRd', 
                   square=True,
                   fmt='.0f',
                   cbar_kws={'label': f'Shared {feature_type.title()}'})
        
        plt.title(f'Document Similarity: Shared {feature_type.title()}', 
                 fontsize=14, fontweight='bold')
        plt.xlabel('Document Index')
        plt.ylabel('Document Index')
        
        # Add document labels on the side
        plt.figtext(0.02, 0.5, 
                   '\n'.join([f"D{i}: {doc[:40]}..." for i, doc in enumerate(documents)]), 
                   fontsize=8, verticalalignment='center')
        
        plt.tight_layout()
        plt.show()
    
    def compare_positive_negative_connections(self, graph_data: Dict, positive_docs: List[int]):
        """
        Compare edge features between positive and negative documents.
        """
        connections = graph_data['connections']
        
        # Categorize connections
        pos_pos_connections = []  # Positive to positive
        pos_neg_connections = []  # Positive to negative
        neg_neg_connections = []  # Negative to negative
        
        for conn in connections:
            i, j = conn['doc_i'], conn['doc_j']
            
            if i in positive_docs and j in positive_docs:
                pos_pos_connections.append(conn)
            elif (i in positive_docs and j not in positive_docs) or (i not in positive_docs and j in positive_docs):
                pos_neg_connections.append(conn)
            else:
                neg_neg_connections.append(conn)
        
        # Analyze each category
        categories = {
            'Positive-Positive': pos_pos_connections,
            'Positive-Negative': pos_neg_connections,
            'Negative-Negative': neg_neg_connections
        }
        
        print("CONNECTION TYPE ANALYSIS")
        print("=" * 40)
        
        for category, conns in categories.items():
            if conns:
                avg_concepts = np.mean([c['common_nodes'] for c in conns])
                avg_relations = np.mean([c['common_edges'] for c in conns])
                avg_strength = np.mean([c['strength'] for c in conns])
                
                print(f"{category}: {len(conns)} connections")
                print(f"  Avg shared concepts: {avg_concepts:.2f}")
                print(f"  Avg shared relations: {avg_relations:.2f}")
                print(f"  Avg strength: {avg_strength:.2f}")
                print()
            else:
                print(f"{category}: No connections")
                print()
        
        # Visualize comparison
        if any(categories.values()):
            self._plot_connection_comparison(categories)
    
    def _plot_connection_comparison(self, categories: Dict[str, List[Dict]]):
        """Plot comparison of connection types."""
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        metrics = ['common_nodes', 'common_edges', 'strength']
        titles = ['Shared Concepts', 'Shared Relations', 'Connection Strength']
        
        for idx, (metric, title) in enumerate(zip(metrics, titles)):
            data = []
            labels = []
            
            for category, conns in categories.items():
                if conns:
                    values = [conn[metric] for conn in conns]
                    data.append(values)
                    labels.append(f"{category}\n(n={len(conns)})")
            
            if data:
                axes[idx].boxplot(data, labels=labels)
                axes[idx].set_title(title)
                axes[idx].set_ylabel('Value')
                axes[idx].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# Analyze edge features
edge_analyzer = EdgeFeatureAnalyzer()

# Perform edge feature analysis
edge_analysis = edge_analyzer.analyze_edge_features(graph_data)

print("EDGE FEATURE ANALYSIS")
print("=" * 30)
for key, value in edge_analysis.items():
    print(f"{key}: {value}")
print()

# Visualize edge features
edge_analyzer.visualize_edge_features(graph_data)

# Create heatmaps
edge_analyzer.create_edge_feature_heatmap(graph_data, 'concepts')
edge_analyzer.create_edge_feature_heatmap(graph_data, 'relations')

# Compare positive vs negative document connections
positive_docs = [0, 1]  # First two documents are positive
edge_analyzer.compare_positive_negative_connections(graph_data, positive_docs)

## Comparative Analysis: Different Graph Construction Strategies

### Comparing AMR-based vs Alternative Approaches
Understanding why AMR-based document graphs outperform simpler alternatives.

In [None]:
class GraphConstructionComparator:
    """
    Compares different document graph construction strategies:
    1. AMR-based (G-RAG approach)
    2. TF-IDF similarity
    3. Keyword overlap
    4. Embedding similarity
    
    This helps understand why the AMR-based approach is superior.
    """
    
    def __init__(self):
        self.tfidf_vectorizer = TfidfVectorizer(max_features=1000, stop_words='english')
    
    def build_tfidf_graph(self, documents: List[str], threshold: float = 0.1) -> np.ndarray:
        """
        Build document graph using TF-IDF similarity.
        """
        # Compute TF-IDF vectors
        tfidf_matrix = self.tfidf_vectorizer.fit_transform(documents)
        
        # Compute cosine similarity
        similarity_matrix = cosine_similarity(tfidf_matrix)
        
        # Create adjacency matrix with threshold
        adjacency = (similarity_matrix > threshold).astype(int)
        
        # Remove self-connections
        np.fill_diagonal(adjacency, 0)
        
        return adjacency
    
    def build_keyword_overlap_graph(self, documents: List[str], min_overlap: int = 2) -> np.ndarray:
        """
        Build document graph using keyword overlap.
        """
        n_docs = len(documents)
        adjacency = np.zeros((n_docs, n_docs), dtype=int)
        
        # Extract keywords from each document
        doc_keywords = []
        for doc in documents:
            # Simple keyword extraction (words > 3 chars, excluding common words)
            words = re.findall(r'\b[a-zA-Z]{4,}\b', doc.lower())
            stop_words = {'that', 'with', 'have', 'this', 'will', 'your', 'from', 'they', 'know', 
                         'want', 'been', 'good', 'much', 'some', 'time', 'very', 'when', 'come',
                         'here', 'just', 'like', 'long', 'make', 'many', 'over', 'such', 'take',
                         'than', 'them', 'well', 'were'}
            keywords = set([w for w in words if w not in stop_words])
            doc_keywords.append(keywords)
        
        # Compute overlap
        for i in range(n_docs):
            for j in range(i + 1, n_docs):
                overlap = len(doc_keywords[i] & doc_keywords[j])
                if overlap >= min_overlap:
                    adjacency[i, j] = adjacency[j, i] = 1
        
        return adjacency
    
    def build_embedding_similarity_graph(self, documents: List[str], threshold: float = 0.5) -> np.ndarray:
        """
        Build document graph using simple embedding similarity.
        (Simplified version for demonstration)
        """
        n_docs = len(documents)
        
        # Simple bag-of-words embeddings
        vectorizer = TfidfVectorizer(max_features=100)
        embeddings = vectorizer.fit_transform(documents).toarray()
        
        # Compute similarity
        similarity_matrix = cosine_similarity(embeddings)
        
        # Create adjacency matrix
        adjacency = (similarity_matrix > threshold).astype(int)
        np.fill_diagonal(adjacency, 0)
        
        return adjacency
    
    def compare_graph_methods(self, question: str, documents: List[str], 
                            positive_docs: List[int], amr_graphs: List[nx.DiGraph]):
        """
        Compare different graph construction methods.
        """
        print("GRAPH CONSTRUCTION COMPARISON")
        print("=" * 50)
        print(f"Question: {question}")
        print(f"Documents: {len(documents)}")
        print(f"Positive documents: {positive_docs}")
        print()
        
        # 1. AMR-based graph (G-RAG approach)
        graph_builder = DocumentGraphBuilder(min_shared_concepts=1)
        amr_graph_data = graph_builder.build_document_graph(question, documents, amr_graphs)
        amr_adjacency = amr_graph_data['adjacency']
        
        # 2. TF-IDF based graph
        tfidf_adjacency = self.build_tfidf_graph(documents, threshold=0.1)
        
        # 3. Keyword overlap graph
        keyword_adjacency = self.build_keyword_overlap_graph(documents, min_overlap=2)
        
        # 4. Embedding similarity graph
        embedding_adjacency = self.build_embedding_similarity_graph(documents, threshold=0.3)
        
        # Analyze each method
        methods = {
            'AMR-based (G-RAG)': amr_adjacency,
            'TF-IDF Similarity': tfidf_adjacency,
            'Keyword Overlap': keyword_adjacency,
            'Embedding Similarity': embedding_adjacency
        }
        
        results = {}
        for method_name, adjacency in methods.items():
            analysis = self._analyze_adjacency_matrix(adjacency, positive_docs)
            results[method_name] = analysis
            
            print(f"{method_name}:")
            for key, value in analysis.items():
                print(f"  {key}: {value}")
            print()
        
        # Visualize comparison
        self._visualize_method_comparison(methods, positive_docs, documents)
        
        return results
    
    def _analyze_adjacency_matrix(self, adjacency: np.ndarray, positive_docs: List[int]) -> Dict:
        """
        Analyze properties of an adjacency matrix.
        """
        n_docs = adjacency.shape[0]
        
        # Basic graph properties
        total_edges = np.sum(adjacency) // 2  # Undirected graph
        density = total_edges / (n_docs * (n_docs - 1) / 2) if n_docs > 1 else 0
        
        # Connected components
        G = nx.from_numpy_array(adjacency)
        components = list(nx.connected_components(G))
        largest_component = max(components, key=len) if components else set()
        
        # Positive document connectivity
        positive_connections = 0
        for i in positive_docs:
            for j in positive_docs:
                if i != j and adjacency[i, j] == 1:
                    positive_connections += 1
        positive_connections //= 2  # Undirected
        
        # Mixed connections (positive to negative)
        mixed_connections = 0
        for i in positive_docs:
            for j in range(n_docs):
                if j not in positive_docs and adjacency[i, j] == 1:
                    mixed_connections += 1
        
        return {
            'total_edges': total_edges,
            'density': round(density, 3),
            'connected_components': len(components),
            'largest_component_size': len(largest_component),
            'positive_connections': positive_connections,
            'mixed_connections': mixed_connections,
            'avg_degree': round(2 * total_edges / n_docs, 2) if n_docs > 0 else 0
        }
    
    def _visualize_method_comparison(self, methods: Dict[str, np.ndarray], 
                                   positive_docs: List[int], documents: List[str]):
        """
        Visualize adjacency matrices for different methods.
        """
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        axes = axes.flatten()
        
        for idx, (method_name, adjacency) in enumerate(methods.items()):
            ax = axes[idx]
            
            # Create heatmap
            im = ax.imshow(adjacency, cmap='RdYlBu_r', vmin=0, vmax=1)
            
            # Highlight positive documents
            for pos_doc in positive_docs:
                # Add border around positive documents
                rect = plt.Rectangle((pos_doc-0.5, pos_doc-0.5), 1, 1, 
                                   fill=False, edgecolor='red', linewidth=3)
                ax.add_patch(rect)
            
            ax.set_title(f'{method_name}\n({np.sum(adjacency)//2} connections)')
            ax.set_xlabel('Document Index')
            ax.set_ylabel('Document Index')
            
            # Add grid
            ax.set_xticks(range(len(documents)))
            ax.set_yticks(range(len(documents)))
            ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Print document reference
        print("Document Reference:")
        for i, doc in enumerate(documents):
            status = "[POS]" if i in positive_docs else "[NEG]"
            print(f"  D{i} {status}: {doc[:60]}...")

# Run comparison
comparator = GraphConstructionComparator()

# Compare different graph construction methods
comparison_results = comparator.compare_graph_methods(
    sample_question, sample_documents, positive_docs, amr_graphs
)

print("\nKEY INSIGHTS:")
print("=" * 30)
print("• AMR-based graphs capture semantic relationships better than surface-level similarity")
print("• TF-IDF may miss semantic connections that AMR captures")
print("• Keyword overlap is too simplistic for complex semantic relationships")
print("• AMR provides structured representation that enables better document connections")
print("• The graph structure directly impacts the GNN's ability to propagate information")

## Scalability Analysis

### Understanding Performance with Different Document Set Sizes
Analyzing how document graph construction scales and performs with varying numbers of documents.

In [None]:
class ScalabilityAnalyzer:
    """
    Analyzes scalability of document graph construction.
    
    Tests performance with different document set sizes
    and provides insights for large-scale deployment.
    """
    
    def __init__(self):
        self.doc_amr_sim = DocumentAMRSimulator()
        self.graph_builder = DocumentGraphBuilder()
    
    def generate_test_documents(self, base_question: str, n_docs: int) -> Tuple[List[str], List[int]]:
        """
        Generate test documents for scalability analysis.
        
        Returns:
            (documents, positive_indices)
        """
        # Base documents with known positive examples
        base_docs = [
            "Frank Sinatra was known for his bright blue eyes, earning him the nickname 'Ol' Blue Eyes'.",
            "The famous singer Sinatra performed with distinctive blue eyes that captivated audiences.",
            "His blue eyes made Frank Sinatra instantly recognizable to fans everywhere."
        ]
        
        # Generate additional documents with varying relevance
        templates = [
            "The musician {name} was known for {characteristic} in the entertainment industry.",
            "Many singers like {name} have performed on stage with {characteristic}.",
            "The artist {name} gained fame through {characteristic} and musical talent.",
            "In the music world, {name} stood out due to {characteristic}.",
            "Entertainment history remembers {name} for {characteristic} and performances."
        ]
        
        names = ['Smith', 'Johnson', 'Williams', 'Brown', 'Davis', 'Miller', 'Wilson', 'Moore']
        characteristics = ['unique voice', 'stage presence', 'musical style', 'performance skills', 
                          'artistic vision', 'vocal range', 'charisma', 'talent']
        
        documents = base_docs.copy()
        positive_indices = [0, 1, 2]  # Base positive documents
        
        # Add generated documents
        for i in range(n_docs - len(base_docs)):
            template = np.random.choice(templates)
            name = np.random.choice(names)
            char = np.random.choice(characteristics)
            
            # Occasionally add relevant documents
            if np.random.random() < 0.2:  # 20% chance of relevance
                char = 'blue eyes and distinctive appearance'
                if np.random.random() < 0.5:
                    name = 'Sinatra'
                positive_indices.append(len(documents))
            
            doc = template.format(name=name, characteristic=char)
            documents.append(doc)
        
        return documents, positive_indices
    
    def measure_scalability(self, document_sizes: List[int]) -> Dict:
        """
        Measure performance across different document set sizes.
        """
        question = "What is the nickname of Frank Sinatra?"
        results = []
        
        print("SCALABILITY ANALYSIS")
        print("=" * 30)
        
        for n_docs in document_sizes:
            print(f"\nTesting with {n_docs} documents...")
            
            # Generate test documents
            documents, positive_indices = self.generate_test_documents(question, n_docs)
            
            # Create AMR graphs
            import time
            start_time = time.time()
            
            amr_graphs = []
            for doc in documents:
                amr = self.doc_amr_sim.create_document_amr(question, doc)
                amr_graphs.append(amr)
            
            amr_time = time.time() - start_time
            
            # Build document graph
            start_time = time.time()
            graph_data = self.graph_builder.build_document_graph(question, documents, amr_graphs)
            graph_time = time.time() - start_time
            
            # Analyze results
            analysis = self.graph_builder.analyze_graph_structure(graph_data)
            
            result = {
                'n_documents': n_docs,
                'positive_docs': len(positive_indices),
                'amr_processing_time': amr_time,
                'graph_construction_time': graph_time,
                'total_time': amr_time + graph_time,
                'total_connections': analysis['total_connections'],
                'graph_density': analysis['graph_density'],
                'connected_documents': analysis['connected_documents'],
                'largest_component_size': analysis['largest_component_size']
            }
            
            results.append(result)
            
            print(f"  AMR processing: {amr_time:.3f}s")
            print(f"  Graph construction: {graph_time:.3f}s")
            print(f"  Total connections: {analysis['total_connections']}")
            print(f"  Graph density: {analysis['graph_density']:.3f}")
        
        return results
    
    def visualize_scalability_results(self, results: List[Dict]):
        """
        Visualize scalability analysis results.
        """
        df = pd.DataFrame(results)
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()
        
        # Plot 1: Processing time vs document count
        axes[0].plot(df['n_documents'], df['amr_processing_time'], 'o-', label='AMR Processing')
        axes[0].plot(df['n_documents'], df['graph_construction_time'], 's-', label='Graph Construction')
        axes[0].plot(df['n_documents'], df['total_time'], '^-', label='Total Time')
        axes[0].set_xlabel('Number of Documents')
        axes[0].set_ylabel('Time (seconds)')
        axes[0].set_title('Processing Time vs Document Count')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Plot 2: Time complexity analysis
        axes[1].plot(df['n_documents'], df['total_time'] / (df['n_documents'] ** 2), 'o-')
        axes[1].set_xlabel('Number of Documents')
        axes[1].set_ylabel('Time / n²')
        axes[1].set_title('Time Complexity Analysis')
        axes[1].grid(True, alpha=0.3)
        
        # Plot 3: Number of connections vs document count
        axes[2].plot(df['n_documents'], df['total_connections'], 'o-', color='green')
        axes[2].set_xlabel('Number of Documents')
        axes[2].set_ylabel('Total Connections')
        axes[2].set_title('Graph Connections vs Document Count')
        axes[2].grid(True, alpha=0.3)
        
        # Plot 4: Graph density vs document count
        axes[3].plot(df['n_documents'], df['graph_density'], 's-', color='red')
        axes[3].set_xlabel('Number of Documents')
        axes[3].set_ylabel('Graph Density')
        axes[3].set_title('Graph Density vs Document Count')
        axes[3].grid(True, alpha=0.3)
        
        # Plot 5: Connected vs total documents
        axes[4].plot(df['n_documents'], df['connected_documents'], 'o-', label='Connected')
        axes[4].plot(df['n_documents'], df['n_documents'], '--', alpha=0.5, label='Total')
        axes[4].set_xlabel('Total Documents')
        axes[4].set_ylabel('Documents')
        axes[4].set_title('Connected vs Total Documents')
        axes[4].legend()
        axes[4].grid(True, alpha=0.3)
        
        # Plot 6: Largest component size
        axes[5].plot(df['n_documents'], df['largest_component_size'], '^-', color='purple')
        axes[5].set_xlabel('Number of Documents')
        axes[5].set_ylabel('Largest Component Size')
        axes[5].set_title('Largest Connected Component')
        axes[5].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Print scalability insights
        print("\nSCALABILITY INSIGHTS:")
        print("=" * 30)
        
        if len(results) > 1:
            time_growth_rate = (results[-1]['total_time'] / results[0]['total_time']) / (results[-1]['n_documents'] / results[0]['n_documents'])
            print(f"• Time growth rate: {time_growth_rate:.2f}x relative to document count")
            
            avg_density = np.mean([r['graph_density'] for r in results])
            print(f"• Average graph density: {avg_density:.3f}")
            
            connection_efficiency = results[-1]['total_connections'] / (results[-1]['n_documents'] * (results[-1]['n_documents'] - 1) / 2)
            print(f"• Connection efficiency at largest scale: {connection_efficiency:.3f}")
        
        print("• AMR processing scales roughly linearly with document count")
        print("• Graph construction scales quadratically (expected for pairwise comparisons)")
        print("• Graph density typically decreases as document count increases")
        print("• Connected components provide natural clustering for large document sets")

# Run scalability analysis
scalability_analyzer = ScalabilityAnalyzer()

# Test with different document sizes
test_sizes = [5, 10, 15, 20, 25]
scalability_results = scalability_analyzer.measure_scalability(test_sizes)

# Visualize results
scalability_analyzer.visualize_scalability_results(scalability_results)

print("\nScalability analysis complete!")
print("This provides insights for deploying G-RAG at different scales.")

## Integration with GNN Architecture

### Preparing Document Graph Data for Graph Neural Networks
Bridge between document graph construction and the GNN-based reranking module.

In [None]:
class DocumentGraphToGNNBridge:
    """
    Prepares document graph data for consumption by Graph Neural Networks.
    
    Implements the data preparation pipeline that connects document graph
    construction to the GNN-based reranking described in the paper.
    """
    
    def __init__(self):
        pass
    
    def prepare_gnn_data(self, graph_data: Dict, question: str) -> Dict:
        """
        Prepare document graph data for GNN processing.
        
        Returns data in the format expected by PyTorch Geometric.
        """
        adjacency = graph_data['adjacency']
        normalized_edge_features = graph_data['normalized_edge_features']
        documents = graph_data['documents']
        amr_graphs = graph_data['amr_graphs']
        n_docs = graph_data['n_documents']
        
        # Create edge index (COO format for PyTorch Geometric)
        edge_indices = np.where(adjacency)
        edge_index = np.vstack([edge_indices[0], edge_indices[1]])
        
        # Extract edge attributes
        edge_attr = []
        for i, j in zip(edge_indices[0], edge_indices[1]):
            edge_attr.append(normalized_edge_features[i, j])
        edge_attr = np.array(edge_attr) if edge_attr else np.empty((0, 2))
        
        # Prepare node features (to be combined with document embeddings)
        node_features = []
        for i in range(n_docs):
            amr_stats = self._extract_amr_statistics(amr_graphs[i])
            node_features.append({
                'document': documents[i],
                'amr_sequence': self._extract_amr_sequence(amr_graphs[i]),
                'amr_stats': amr_stats,
                'graph_position': i
            })
        
        gnn_data = {
            'edge_index': edge_index,
            'edge_attr': edge_attr,
            'node_features': node_features,
            'num_nodes': n_docs,
            'question': question,
            'adjacency': adjacency,
            'raw_graph_data': graph_data
        }
        
        return gnn_data
    
    def _extract_amr_statistics(self, amr_graph: nx.DiGraph) -> Dict:
        """
        Extract statistical features from AMR graph for node features.
        """
        return {
            'num_nodes': amr_graph.number_of_nodes(),
            'num_edges': amr_graph.number_of_edges(),
            'has_question_node': 'question' in amr_graph.nodes(),
            'avg_degree': np.mean(list(dict(amr_graph.degree()).values())) if amr_graph.number_of_nodes() > 0 else 0,
            'density': nx.density(amr_graph),
            'domain': amr_graph.graph.get('domain', 'unknown')
        }
    
    def _extract_amr_sequence(self, amr_graph: nx.DiGraph) -> str:
        """
        Extract AMR sequence using shortest paths from question node.
        
        This replicates the methodology from the previous notebook.
        """
        if 'question' not in amr_graph.nodes():
            return ''
        
        try:
            # Find shortest paths from question node
            paths = nx.single_source_shortest_path(amr_graph, 'question')
            
            # Extract concepts from paths (excluding question node)
            concepts = []
            for path in paths.values():
                concepts.extend(path[1:])  # Skip question node
            
            # Remove duplicates while preserving order
            unique_concepts = []
            seen = set()
            for concept in concepts:
                if concept not in seen:
                    unique_concepts.append(concept)
                    seen.add(concept)
            
            return ' '.join(unique_concepts)
        except:
            return ''
    
    def create_mock_embeddings(self, gnn_data: Dict, embedding_dim: int = 128) -> np.ndarray:
        """
        Create mock document embeddings for demonstration.
        
        In practice, these would come from a pre-trained language model
        like BERT, RoBERTa, or sentence transformers.
        """
        node_features = gnn_data['node_features']
        num_nodes = gnn_data['num_nodes']
        
        # Create embeddings based on document and AMR content
        embeddings = []
        
        for nf in node_features:
            # Combine document text and AMR sequence
            combined_text = f"{nf['document']} {nf['amr_sequence']}"
            
            # Simple hash-based embedding (for demonstration)
            np.random.seed(hash(combined_text) % (2**32))
            embedding = np.random.randn(embedding_dim)
            
            # Add AMR statistics as additional features
            stats = nf['amr_stats']
            stat_features = np.array([
                stats['num_nodes'] / 20.0,  # Normalize
                stats['num_edges'] / 20.0,
                float(stats['has_question_node']),
                stats['avg_degree'] / 5.0,
                stats['density']
            ])
            
            # Combine text embedding with AMR features
            full_embedding = np.concatenate([embedding, stat_features])
            embeddings.append(full_embedding)
        
        return np.array(embeddings)
    
    def visualize_gnn_data_structure(self, gnn_data: Dict, embeddings: np.ndarray):
        """
        Visualize the structure of data prepared for GNN processing.
        """
        print("GNN DATA STRUCTURE")
        print("=" * 30)
        
        print(f"Number of nodes (documents): {gnn_data['num_nodes']}")
        print(f"Number of edges: {gnn_data['edge_index'].shape[1]}")
        print(f"Edge feature dimension: {gnn_data['edge_attr'].shape[1] if gnn_data['edge_attr'].size > 0 else 0}")
        print(f"Node embedding dimension: {embeddings.shape[1]}")
        print(f"Question: {gnn_data['question']}")
        print()
        
        # Show sample node features
        print("Sample Node Features:")
        for i, nf in enumerate(gnn_data['node_features'][:3]):
            print(f"\nNode {i}:")
            print(f"  Document: {nf['document'][:60]}...")
            print(f"  AMR sequence: {nf['amr_sequence'][:60]}...")
            print(f"  AMR stats: {nf['amr_stats']}")
            print(f"  Embedding shape: {embeddings[i].shape}")
        
        # Show edge information
        if gnn_data['edge_index'].shape[1] > 0:
            print(f"\nSample Edges:")
            for i in range(min(5, gnn_data['edge_index'].shape[1])):
                src = gnn_data['edge_index'][0, i]
                dst = gnn_data['edge_index'][1, i]
                attr = gnn_data['edge_attr'][i] if gnn_data['edge_attr'].size > 0 else "No attributes"
                print(f"  Edge {i}: Node {src} ↔ Node {dst}, Attributes: {attr}")
        else:
            print("\nNo edges in the graph.")
        
        # Visualize embedding space
        self._plot_embedding_space(embeddings, gnn_data)
    
    def _plot_embedding_space(self, embeddings: np.ndarray, gnn_data: Dict):
        """
        Plot 2D projection of embedding space.
        """
        from sklearn.decomposition import PCA
        
        # Project to 2D using PCA
        pca = PCA(n_components=2)
        embeddings_2d = pca.fit_transform(embeddings)
        
        plt.figure(figsize=(12, 8))
        
        # Plot nodes
        for i, (x, y) in enumerate(embeddings_2d):
            plt.scatter(x, y, s=200, alpha=0.7, label=f'Doc {i}')
            plt.annotate(f'D{i}', (x, y), xytext=(5, 5), textcoords='offset points')
        
        # Draw edges
        edge_index = gnn_data['edge_index']
        for i in range(edge_index.shape[1]):
            src, dst = edge_index[0, i], edge_index[1, i]
            x1, y1 = embeddings_2d[src]
            x2, y2 = embeddings_2d[dst]
            plt.plot([x1, x2], [y1, y2], 'k-', alpha=0.3, linewidth=1)
        
        plt.title('Document Embedding Space (2D Projection)\nConnected documents in graph')
        plt.xlabel(f'PC1 (explains {pca.explained_variance_ratio_[0]:.1%} variance)')
        plt.ylabel(f'PC2 (explains {pca.explained_variance_ratio_[1]:.1%} variance)')
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

# Demonstrate the bridge
bridge = DocumentGraphToGNNBridge()

# Prepare GNN data from our document graph
gnn_data = bridge.prepare_gnn_data(graph_data, sample_question)

# Create mock embeddings
embeddings = bridge.create_mock_embeddings(gnn_data, embedding_dim=128)

# Visualize the prepared data
bridge.visualize_gnn_data_structure(gnn_data, embeddings)

print("\nData preparation complete!")
print("This data is now ready for the GNN-based reranking module.")
print("The next notebook will cover the GNN architecture and training process.")

## Key Insights and Learning Summary

### 🎯 What We've Mastered:

#### 1. **Document Graph Construction Methodology**
- **Node Definition**: Each document becomes a node in the graph
- **Edge Creation**: Edges connect documents with shared AMR concepts
- **Feature Computation**: Edge features capture both shared nodes and relations
- **Normalization**: Essential to prevent explosive scaling in GNN operations

#### 2. **AMR-Based Semantic Connections**
- **Beyond Keywords**: AMR captures semantic relationships that keyword matching misses
- **Structured Representation**: AMR provides more reliable connections than TF-IDF
- **Domain Awareness**: Documents cluster naturally by semantic domains
- **Transitive Relevance**: Documents gain relevance through their connections

#### 3. **Edge Feature Analysis**
- **Two-Dimensional Features**: [common_nodes, common_edges] as described in paper
- **Connection Strength**: Combination of shared concepts and relations
- **Positive Document Patterns**: Positive documents often have stronger connections
- **Graph Topology**: Structure directly impacts information flow in GNN

#### 4. **Scalability Considerations**
- **Time Complexity**: O(n²) for pairwise comparisons, but manageable for typical retrieval sets
- **Memory Efficiency**: Graph sparsity helps with large document collections
- **Connected Components**: Natural clustering for very large sets
- **Optimization Opportunities**: Caching and parallel processing

### 🔬 Research Implications:

1. **Graph-Based Information Retrieval**: Document graphs enable new approaches to relevance ranking
2. **Semantic Clustering**: AMR-based connections reveal semantic document clusters
3. **Weak Signal Amplification**: Marginally relevant documents benefit from strong neighbors
4. **Cross-Document Reasoning**: Information flows through the graph structure

### 🚀 Connection to GNN Processing:

The document graphs we've constructed provide:
- **Node Features**: Document embeddings + AMR statistics
- **Edge Features**: Shared concept/relation counts
- **Graph Structure**: Adjacency matrix for message passing
- **Semantic Context**: Rich representation for neural processing

### 💡 Practical Applications:

- **Search Result Reranking**: Improve relevance through document connections
- **Recommendation Systems**: Leverage item relationships for better recommendations
- **Content Discovery**: Find related content through semantic graphs
- **Knowledge Organization**: Automatic semantic clustering of documents

### 🔍 Key Differences from Traditional Approaches:

| Aspect | Traditional | G-RAG Document Graphs |
|--------|-------------|----------------------|
| **Connection Basis** | Keyword overlap | AMR semantic concepts |
| **Relationship Type** | Surface similarity | Deep semantic relations |
| **Information Flow** | Independent scoring | Graph-based propagation |
| **Weak Connections** | Often missed | Amplified through neighbors |
| **Computational** | Linear in documents | Quadratic but manageable |

This foundation in document graph construction sets the stage for understanding how Graph Neural Networks can leverage these semantic connections for superior document reranking in the next focused learning notebook!