Graph Attention Networks (GAT) for Document Retrieval
===================================================

This notebook demonstrates the implementation of Graph Attention Networks (GAT) 
for learning node embeddings in citation networks, which can be used for 
document retrieval and relationship analysis.


In [2]:
import os
import json
import glob
from datetime import datetime
import torch
import torch.nn.functional as F
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.nn import GATv2Conv
from sklearn.manifold import TSNE, MDS
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
import matplotlib.pyplot as plt
import numpy as np
from torch_geometric.utils import negative_sampling, to_networkx
import seaborn as sns
import networkx as nx
from typing import Dict, Any, Tuple
from torch_geometric.datasets import CoraFull
from torch_geometric.transforms import NormalizeFeatures
from langchain_community.llms import Ollama
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

Load and Explore Dataset
=======================

We use the Cora citation network dataset where:
- Nodes represent academic papers
- Edges represent citations between papers
- Node features are bag-of-words representations of papers
- Labels represent paper categories


In [3]:
# Load the Cora dataset
dataset = CoraFull(root='data/CoraFull', transform=NormalizeFeatures())
data = dataset[0]

# Basic info about the dataset
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Number of features: {data.num_features}')
print(f'Number of classes: {dataset.num_classes}')

Number of nodes: 19793
Number of edges: 126842
Number of features: 8710
Number of classes: 70


Graph Attention Network (GAT) Architecture
========================================

GAT learns node embeddings by attending over node neighborhoods:

1. Attention Mechanism:
   α_ij = softmax_j(LeakyReLU(a^T [Wh_i || Wh_j]))

2. Node Feature Update:
   h_i' = σ(∑_j α_ij Wh_j)

where:
- h_i: features of node i
- W: learnable weight matrix
- a: learnable attention vector
- ||: concatenation
- σ: activation function (ELU)


In [4]:
class GATEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
        super().__init__()
        self.conv1 = GATv2Conv(in_channels, hidden_channels, heads=heads)
        self.conv2 = GATv2Conv(hidden_channels * heads, out_channels, heads=1)

    def forward(self, x, edge_index):
        # First GAT layer
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        
        # Second GAT layer to generate final node embeddings
        x = self.conv2(x, edge_index)
        
        # Normalize embeddings
        return F.normalize(x, p=2, dim=1)

Early Stopping Implementation
===========================

Prevents overfitting by monitoring validation loss:
- Stops training if loss doesn't improve for 'patience' epochs
- Improvement must be greater than 'min_delta'

In [5]:
class EarlyStoppingCallback:
    def __init__(self, patience: int = 10, min_delta: float = 1e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss: float) -> bool:
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        else:
            self.best_loss = val_loss
            self.counter = 0
        return False

Embedding Evaluation
==================

Comprehensive evaluation of learned embeddings using:
1. Link Prediction (AUC, AP)
2. Node Classification
3. Embedding Visualization
4. Graph Structure Analysis

In [6]:
class EmbeddingEvaluator:
    @staticmethod
    def compute_metrics(embeddings: torch.Tensor, 
                       pos_edge_index: torch.Tensor, 
                       neg_edge_index: torch.Tensor,
                       node_labels: torch.Tensor = None) -> Dict[str, float]:
        """Comprehensive evaluation metrics"""
        # Link prediction metrics
        pos_score = torch.cosine_similarity(
            embeddings[pos_edge_index[0]], embeddings[pos_edge_index[1]])
        neg_score = torch.cosine_similarity(
            embeddings[neg_edge_index[0]], embeddings[neg_edge_index[1]])
        
        scores = torch.cat([pos_score, neg_score])
        labels = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
        
        metrics = {
            'auc': roc_auc_score(labels.cpu(), scores.cpu()),
            'ap': average_precision_score(labels.cpu(), scores.cpu()),
            'avg_pos_sim': pos_score.mean().item(),
            'avg_neg_sim': neg_score.mean().item()
        }

        # Node classification metrics (if labels provided)
        if node_labels is not None:
            X = embeddings.detach().cpu().numpy()
            y = node_labels.cpu().numpy()
            
            clf = LogisticRegression(max_iter=1000)
            cv_scores = cross_val_score(clf, X, y, cv=5)
            metrics['node_clf_acc'] = cv_scores.mean()
        
        return metrics

    @staticmethod
    def visualize_embeddings(embeddings: torch.Tensor, 
                           labels: torch.Tensor = None,
                           perplexity: int = 30) -> None:
        """Visualize embeddings using t-SNE"""
        embeddings_np = embeddings.detach().cpu().numpy()
        
        # Create t-SNE visualization
        tsne = TSNE(n_components=2, 
                    perplexity=perplexity, 
                    random_state=42)
        embeddings_2d = tsne.fit_transform(embeddings_np)
        
        plt.figure(figsize=(10, 8))
        scatter = plt.scatter(
            embeddings_2d[:, 0], 
            embeddings_2d[:, 1],
            c=labels.cpu().numpy() if labels is not None else None,
            cmap='tab10',
            alpha=0.6
        )
        
        if labels is not None:
            plt.colorbar(scatter, label='Class')
        
        plt.title('t-SNE Visualization of Node Embeddings')
        plt.xlabel('t-SNE 1')
        plt.ylabel('t-SNE 2')
        
        # Add density contours
        sns.kdeplot(
            x=embeddings_2d[:, 0],
            y=embeddings_2d[:, 1],
            levels=5,
            color='k',
            alpha=0.3,
            linewidths=1
        )
        
        plt.tight_layout()
        plt.savefig('embeddings_tsne.png')
        plt.close()

    @staticmethod
    def visualize_similarity_matrix(embeddings: torch.Tensor,
                                  labels: torch.Tensor = None) -> None:
        """Visualize pairwise similarities between embeddings"""
        # Compute similarity matrix
        sim_matrix = torch.mm(embeddings, embeddings.t()).cpu().numpy()
        
        plt.figure(figsize=(10, 8))
        sns.heatmap(sim_matrix, 
                   cmap='coolwarm', 
                   center=0,
                   square=True)
        plt.title('Pairwise Similarity Matrix')
        plt.savefig('similarity_matrix.png')
        plt.close()

    @staticmethod
    def visualize_graph_structure(edge_index: torch.Tensor, 
                                embeddings: torch.Tensor,
                                labels: torch.Tensor = None) -> None:
        """Visualize graph structure with node colors based on embeddings"""
        G = to_networkx(torch.zeros(embeddings.shape[0]), edge_index)
        
        # Use t-SNE for color mapping if no labels provided
        if labels is None:
            colors = TSNE(n_components=1).fit_transform(embeddings.detach().cpu().numpy())
        else:
            colors = labels.cpu().numpy()
        
        plt.figure(figsize=(10, 10))
        pos = nx.spring_layout(G)
        nx.draw(G, pos, node_color=colors, 
                node_size=50, cmap='tab10',
                with_labels=False, alpha=0.8)
        plt.savefig('graph_structure.png')
        plt.close()

In [7]:
def plot_training_metrics(metrics_history: Dict) -> None:
    """
    Plot all training metrics over time:
    - Loss
    - AUC-ROC
    - Average Precision
    - Node Classification Accuracy (if available)
    """
    n_metrics = len([k for k in metrics_history.keys() if metrics_history[k] is not None])
    plt.figure(figsize=(5*n_metrics, 4))
    
    idx = 1
    for metric_name, values in metrics_history.items():
        if values is not None:
            plt.subplot(1, n_metrics, idx)
            plt.plot(values)
            plt.title(f'{metric_name.replace("_", " ").title()}')
            plt.xlabel('Epoch')
            idx += 1
    
    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.close()

Model Checkpointing
==================

Save model states and metrics during training:
- Regular checkpoints every 10 epochs
- Best model based on validation metrics
- Complete training history

In [9]:
class ModelCheckpointer:
    def __init__(self, save_dir: str = 'checkpoints'):
        self.save_dir = save_dir
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.checkpoint_dir = os.path.join(save_dir, f'run_{self.timestamp}')
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        
        # Save metrics history
        self.metrics_log = []
    
    def save_checkpoint(self, 
                       model: torch.nn.Module, 
                       epoch: int, 
                       metrics: dict,
                       is_best: bool = False) -> None:
        """Save model checkpoint and metrics"""
        # Create checkpoint dictionary
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'metrics': metrics
        }
        
        # Save regular checkpoint
        if epoch % 10 == 0:
            checkpoint_path = os.path.join(
                self.checkpoint_dir, 
                f'checkpoint_epoch_{epoch}.pt'
            )
            torch.save(checkpoint, checkpoint_path)
            
            # Log metrics
            self.metrics_log.append({
                'epoch': epoch,
                **metrics
            })
            
            # Save metrics log
            metrics_path = os.path.join(self.checkpoint_dir, 'metrics_history.pt')
            torch.save(self.metrics_log, metrics_path)
        
        # Save best model separately
        if is_best:
            best_path = os.path.join(self.checkpoint_dir, 'best_model.pt')
            torch.save(checkpoint, best_path)
    
    def load_checkpoint(self, epoch: int = None) -> dict:
        """Load a specific checkpoint or the best model"""
        if epoch is not None:
            checkpoint_path = os.path.join(
                self.checkpoint_dir,
                f'checkpoint_epoch_{epoch}.pt'
            )
        else:
            checkpoint_path = os.path.join(self.checkpoint_dir, 'best_model.pt')
        
        if os.path.exists(checkpoint_path):
            return torch.load(checkpoint_path)
        else:
            raise FileNotFoundError(f"No checkpoint found at {checkpoint_path}")
        
    @staticmethod
    def load_best_model(model: torch.nn.Module,
                       run_dir: str = None) -> Tuple[torch.nn.Module, dict]:
        """Load the best model from a specific run or the latest run"""
        if run_dir is None:
            # Find the latest run
            all_runs = sorted(glob.glob('checkpoints/run_*'))
            if not all_runs:
                raise FileNotFoundError("No model runs found")
            run_dir = all_runs[-1]
        
        best_model_path = os.path.join(run_dir, 'best_model.pt')
        if not os.path.exists(best_model_path):
            raise FileNotFoundError(f"No best model found in {run_dir}")
        
        # Load checkpoint
        checkpoint = torch.load(best_model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        return model, checkpoint['metrics']


In [10]:
def evaluate_model(model: torch.nn.Module, 
                  data: Any,
                  evaluator: EmbeddingEvaluator) -> dict:
    """Comprehensive model evaluation"""
    model.eval()
    with torch.no_grad():
        # Generate embeddings
        embeddings = model(data.x, data.edge_index)
        
        # Generate negative edges for evaluation
        neg_edge_index = negative_sampling(
            edge_index=data.edge_index,
            num_nodes=data.num_nodes,
            num_neg_samples=data.edge_index.size(1)
        )
        
        # Compute metrics
        metrics = evaluator.compute_metrics(
            embeddings,
            data.edge_index,
            neg_edge_index,
            data.y if hasattr(data, 'y') else None
        )
        
        # Visualize results
        evaluator.visualize_embeddings(
            embeddings,
            data.y if hasattr(data, 'y') else None
        )
        evaluator.visualize_similarity_matrix(embeddings)
        
        return metrics

Training Process
===============

Main training loop with:
1. Edge split for training/validation
2. Negative sampling
3. Contrastive learning objective
4. Regular evaluation and checkpointing


In [10]:
def train_gat_encoder(model: torch.nn.Module, 
                     data: Any, 
                     epochs: int = 100,
                     patience: int = 10,
                     eval_every: int = 5) -> Tuple[torch.nn.Module, Dict]:
    
    # Initialize checkpointer
    checkpointer = ModelCheckpointer()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
    early_stopping = EarlyStoppingCallback(patience=patience)
    evaluator = EmbeddingEvaluator()
    
    # Split edges for training/validation
    num_edges = data.edge_index.size(1)
    perm = torch.randperm(num_edges)
    train_edges = data.edge_index[:, perm[:int(0.8 * num_edges)]]
    val_edges = data.edge_index[:, perm[int(0.8 * num_edges):]]
    
    metrics_history = {
        'loss': [], 'auc': [], 'ap': [],
        'pos_sim': [], 'neg_sim': []
    }
    
    best_val_metric = float('-inf')
    
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        
        # Forward pass
        embeddings = model(data.x, train_edges)
        
        # Generate negative samples
        neg_edge_index = negative_sampling(
            edge_index=train_edges,
            num_nodes=data.num_nodes,
            num_neg_samples=train_edges.size(1)
        )
        
        # Compute loss
        pos_sim = F.cosine_similarity(
            embeddings[train_edges[0]], 
            embeddings[train_edges[1]]
        )
        neg_sim = F.cosine_similarity(
            embeddings[neg_edge_index[0]], 
            embeddings[neg_edge_index[1]]
        )
        
        loss = F.margin_ranking_loss(
            pos_sim,
            neg_sim,
            torch.ones_like(pos_sim),
            margin=0.5
        )
        
        loss.backward()
        optimizer.step()
        
        metrics_history['loss'].append(loss.item())
        
        # Evaluation
        if epoch % eval_every == 0:
            model.eval()
            with torch.no_grad():
                embeddings = model(data.x, data.edge_index)
                neg_edges_eval = negative_sampling(
                    edge_index=val_edges,
                    num_nodes=data.num_nodes,
                    num_neg_samples=val_edges.size(1)
                )
                
                # Compute metrics
                metrics = evaluator.compute_metrics(
                    embeddings, 
                    val_edges, 
                    neg_edges_eval,
                    data.y if hasattr(data, 'y') else None
                )
                
                # Update metrics history
                for k, v in metrics.items():
                    if k in metrics_history:
                        metrics_history[k].append(v)
                
                # Check if this is the best model
                val_metric = metrics['auc']
                is_best = val_metric > best_val_metric
                if is_best:
                    best_val_metric = val_metric
                
                # Save checkpoint
                current_metrics = {
                    'loss': loss.item(),
                    **metrics
                }
                checkpointer.save_checkpoint(
                    model, 
                    epoch, 
                    current_metrics,
                    is_best=is_best
                )
                
                print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, '
                      f'AUC: {metrics["auc"]:.4f}, AP: {metrics["ap"]:.4f}')
                
                # Visualizations every 50 epochs
                if epoch % 50 == 0:
                    evaluator.visualize_embeddings(
                        embeddings,
                        data.y if hasattr(data, 'y') else None
                    )
                    evaluator.visualize_similarity_matrix(embeddings)
                
                if early_stopping(loss.item()):
                    print("Early stopping triggered")
                    break
    
    # Load best model for final evaluation
    best_checkpoint = checkpointer.load_checkpoint()
    model.load_state_dict(best_checkpoint['model_state_dict'])
    
    return model, metrics_history

In [11]:
class GraphRAG:
    def __init__(self, model: torch.nn.Module, data: Any, k: int = 5):
        """
        Initialize GraphRAG with a trained model and data
        """
        self.model = model
        self.data = data
        self.k = k
        
        # Generate and store embeddings
        self.model.eval()
        with torch.no_grad():
            self.embeddings = self.model(data.x, data.edge_index)
        
        # Initialize Ollama
        self.llm = Ollama(
            model="llama3.2:latest",
            temperature=0.3,
        )

    def find_similar_nodes(self, query_embedding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Find k most similar nodes to the query embedding"""
        with torch.no_grad():
            # Compute similarities with all nodes
            similarities = torch.mm(query_embedding, self.embeddings.t())
            
            # Get top k similar nodes
            values, indices = similarities.topk(self.k)
            
            return indices[0], values[0]

    def get_node_context(self, node_idx: int) -> str:
        """Get context information for a specific node"""
        # Get node features
        node_features = self.data.x[node_idx]
        
        # Find similar nodes
        with torch.no_grad():
            node_embedding = self.embeddings[node_idx].unsqueeze(0)
            similarities = torch.mm(node_embedding, self.embeddings.t())
            values, indices = similarities.topk(self.k)
        
        # Create context string
        context_parts = [f"Node {node_idx} features: {node_features.cpu().numpy().tolist()}"]
        
        # Add neighbor information
        edge_index = self.data.edge_index.cpu()
        neighbors = edge_index[1][edge_index[0] == node_idx].tolist()
        if neighbors:
            context_parts.append(f"Connected to nodes: {neighbors}")
        
        # Add similar nodes
        for idx, sim in zip(indices[0], values[0]):
            if idx != node_idx:
                context_parts.append(
                    f"Similar node {idx.item()} "
                    f"(similarity: {sim.item():.3f}): "
                    f"{self.data.x[idx].cpu().numpy().tolist()}"
                )
        
        return "\n".join(context_parts)

    def query(self, question: str) -> str:
        """Query the graph-based RAG system"""
        try:
            # Use average embedding as query embedding (simple approach)
            with torch.no_grad():
                query_embedding = self.embeddings.mean(dim=0, keepdim=True)
            
            # Find relevant nodes
            node_indices, similarities = self.find_similar_nodes(query_embedding)
            
            # Get context from relevant nodes
            contexts = []
            for node_idx, sim in zip(node_indices, similarities):
                node_context = self.get_node_context(node_idx.item())
                contexts.append(f"Relevance score: {sim.item():.3f}\n{node_context}")
            
            # Combine contexts
            combined_context = "\n\n".join(contexts)
            
            prompt = f"""Based on the following graph context:

{combined_context}

Question: {question}

Please analyze the relationships and patterns in the graph data to answer the question.
Focus on the most relevant nodes and their connections.
"""

            return self.llm.invoke(prompt)
            
        except Exception as e:
            return f"Error processing query: {str(e)}"


In [19]:
# Create model
model = GATEncoder(
    in_channels=dataset.num_features,
    hidden_channels=32,
    out_channels=64
)

# Train with all enhancements
model, metrics = train_gat_encoder(
    model=model,
    data=data,
    epochs=100,
    patience=10,
    eval_every=5
)

Epoch 000, Loss: 0.4989, AUC: 0.9756, AP: 0.9812
Epoch 005, Loss: 0.2309, AUC: 0.9607, AP: 0.9610
Epoch 010, Loss: 0.1157, AUC: 0.9683, AP: 0.9668
Epoch 015, Loss: 0.0772, AUC: 0.9790, AP: 0.9776
Epoch 020, Loss: 0.0605, AUC: 0.9833, AP: 0.9822
Epoch 025, Loss: 0.0495, AUC: 0.9860, AP: 0.9853
Epoch 030, Loss: 0.0454, AUC: 0.9875, AP: 0.9864
Epoch 035, Loss: 0.0610, AUC: 0.9875, AP: 0.9875
Epoch 040, Loss: 0.0530, AUC: 0.9905, AP: 0.9903
Epoch 045, Loss: 0.0354, AUC: 0.9910, AP: 0.9906
Epoch 050, Loss: 0.0385, AUC: 0.9920, AP: 0.9917
Epoch 055, Loss: 0.1367, AUC: 0.9911, AP: 0.9908
Epoch 060, Loss: 0.0486, AUC: 0.9906, AP: 0.9900
Epoch 065, Loss: 0.0385, AUC: 0.9909, AP: 0.9905
Epoch 070, Loss: 0.0363, AUC: 0.9909, AP: 0.9907
Epoch 075, Loss: 0.0345, AUC: 0.9919, AP: 0.9914
Epoch 080, Loss: 0.0314, AUC: 0.9931, AP: 0.9926
Epoch 085, Loss: 0.0286, AUC: 0.9940, AP: 0.9936
Epoch 090, Loss: 0.0455, AUC: 0.9907, AP: 0.9910
Epoch 095, Loss: 0.0399, AUC: 0.9917, AP: 0.9913


  return torch.load(checkpoint_path)


In [11]:
# Option 1: Load from latest run
model = GATEncoder(
    in_channels=dataset.num_features,
    hidden_channels=32,
    out_channels=64
)
model, best_metrics = ModelCheckpointer.load_best_model(model)

  checkpoint = torch.load(best_model_path)


In [20]:
# Evaluate loaded model
evaluator = EmbeddingEvaluator()
evaluation_metrics = evaluate_model(model, data, evaluator)

print("\nBest Model Metrics:")
print(json.dumps(best_metrics, indent=2))

print("\nCurrent Evaluation Metrics:")
print(json.dumps(evaluation_metrics, indent=2))


Best Model Metrics:
{
  "loss": 0.028570471331477165,
  "auc": 0.9940383188258145,
  "ap": 0.9935596727020666,
  "avg_pos_sim": 0.924940288066864,
  "avg_neg_sim": 0.03801451623439789,
  "node_clf_acc": 0.5499915059118471
}

Current Evaluation Metrics:
{
  "auc": 0.9943906504193957,
  "ap": 0.9941193450907881,
  "avg_pos_sim": 0.9266040921211243,
  "avg_neg_sim": 0.0395059660077095,
  "node_clf_acc": 0.5499915059118471
}


In [33]:
# Initialize GraphRAG
rag_system = GraphRAG(model, data)

# Example query
question = "Find nodes related to neural networks and explain their relationships"
print(rag_system.query(question))

After analyzing the graph data, I found that the nodes related to neural networks are:

1. **"Neural Network"**: This node is connected to several other nodes, including "Deep Learning", "Artificial Intelligence", and "Machine Learning". These connections suggest that neural networks are a key component of these fields.
2. **"Deep Learning"**: This node is also connected to "Neural Network", "Artificial Intelligence", and "Machine Learning". The connection to "Neural Network" implies that deep learning is a subset or application of neural networks.
3. **"Artificial Intelligence"**: This node is connected to "Neural Network", "Deep Learning", and "Machine Learning". The connections suggest that artificial intelligence is a broader field that encompasses neural networks, deep learning, and machine learning.
4. **"Machine Learning"**: This node is connected to "Neural Network", "Deep Learning", and "Artificial Intelligence". Like the other nodes, it suggests that machine learning is relat