In [1]:
# Install required packages if needed
# import subprocess; subprocess.run(['pip', 'install', 'torch-geometric', 'sentence-transformers', 'tqdm'], check=True)

In [2]:
import torch
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, Linear
from torch_geometric.utils import train_test_split_edges, negative_sampling
from tqdm import tqdm
import json
import numpy as np

GNN_config = {
    'epoch': 10,
    'batch': 128,
    'dropout': 0.2,
    'early_stopping': 2,
    'learning_rate': 1e-3
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


In [3]:
import os
from pathlib import Path
import shutil

base_dir = Path('.')
data_dir = base_dir / 'data'
checkpoints_dir = base_dir / 'checkpoints'

# Fix nested directory structure if files were transferred from Turing
if (base_dir / 'data' / 'data').exists():
    print("Moving files from data/data/ to data/...")
    for item in (base_dir / 'data' / 'data').iterdir():
        if item.is_file():
            shutil.move(str(item), str(data_dir / item.name))
    print("✓ Done")

if (base_dir / 'checkpoints' / 'checkpoints').exists():
    print("Moving files from checkpoints/checkpoints/ to checkpoints/...")
    for item in (base_dir / 'checkpoints' / 'checkpoints').iterdir():
        if item.is_file():
            shutil.move(str(item), str(checkpoints_dir / item.name))
    print("✓ Done")

data_dir.mkdir(exist_ok=True)
checkpoints_dir.mkdir(exist_ok=True)
print(f"Data: {data_dir.absolute()}\nCheckpoints: {checkpoints_dir.absolute()}")

Data: /home/upandit/mag_citation_recommender/data
Checkpoints: /home/upandit/mag_citation_recommender/checkpoints


In [4]:
import urllib.request

urls = {
    'train.txt.zip.001': 'https://github.com/QianWangWPI/Released-Microsoft-dataset/raw/main/train.txt.zip.001',
    'train.txt.zip.002': 'https://github.com/QianWangWPI/Released-Microsoft-dataset/raw/main/train.txt.zip.002',
    'test.txt': 'https://raw.githubusercontent.com/QianWangWPI/Released-Microsoft-dataset/refs/heads/main/test.txt',
    'val.txt': 'https://raw.githubusercontent.com/QianWangWPI/Released-Microsoft-dataset/refs/heads/main/val.txt'
}

print("Downloading dataset files...")
for filename, url in urls.items():
    filepath = data_dir / filename
    if not filepath.exists():
        urllib.request.urlretrieve(url, filepath)
        print(f"✓ {filename}")
    else:
        print(f"⊙ {filename} (exists)")

print("Download complete!")



Downloading dataset files...
⊙ train.txt.zip.001 (exists)
⊙ train.txt.zip.002 (exists)
⊙ test.txt (exists)
⊙ val.txt (exists)
Download complete!


In [5]:
import zipfile
import glob

train_txt_path = data_dir / "train.txt"
train_zip_path = data_dir / "train.txt.zip"

if not train_txt_path.exists():
    zip_parts = sorted(glob.glob(str(data_dir / "train.txt.zip.*")))
    if zip_parts:
        print("Combining zip parts...")
        with open(train_zip_path, 'wb') as outfile:
            for part in zip_parts:
                with open(part, 'rb') as infile:
                    outfile.write(infile.read())
        print("✓ Combined")
    
    if train_zip_path.exists():
        print("Extracting...")
        with zipfile.ZipFile(train_zip_path, 'r') as zip_ref:
            zip_ref.extractall(data_dir)
        print("✓ Extracted")
else:
    print("train.txt already exists")

train.txt already exists


In [6]:
import json

with open(data_dir / "train.txt", encoding="utf-8") as f:
    train_data = json.load(f)
with open(data_dir / "val.txt", encoding="utf-8") as f:
    val_data = json.load(f)
with open(data_dir / "test.txt", encoding="utf-8") as f:
    test_data = json.load(f)

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")


Train: 42000, Val: 9000, Test: 9000


In [None]:
train_data[0].keys()

dict_keys(['publication_ID', 'Citations', 'pubDate', 'language', 'title', 'journal', 'abstract', 'keywords', 'authors', 'venue', 'doi'])

In [8]:
import pandas as pd

def preprocess(data):
    df = pd.DataFrame(data)

    # Split citations into list
    df["Citations"] = df["Citations"].apply(lambda x: x.split(";") if isinstance(x, str) else [])

    df["title"] = df["title"].astype(str)
    df["abstract"] = df["abstract"].astype(str)
    df["keywords"] = df["keywords"].astype(str)
    # Extract doi for reporting
    df["doi"] = df["doi"].apply(lambda x: x.strip() if isinstance(x, str) else None)

    # Combine title + abstract + keywords as paper text
    df["text"] = df["title"] + ". " + df["abstract"] + ". " + df["keywords"] + '. doi:' + df['doi']

    # Extract authors as list of IDs
    df["author_ids"] = df["authors"].apply(lambda authors: [a.get('id', a.get('name')) for a in authors] if isinstance(authors, list) else [])


    # Extract keywords as lowercase list
    df["keyword_list"] = df["keywords"].apply(lambda x: [kw.strip().lower() for kw in x.split(';')] if isinstance(x, str) else [])

    # Extract venues as lowercase list (will be node)
    df["venue"] = df["venue"].apply(lambda x: x.split("name':")[-1].strip()[1:-2] if isinstance(x, str) else None)

    # Extract pub date for attributes
    df["pubDate"] = df["pubDate"].apply(lambda x: x.strip() if isinstance(x, str) else None)

    # Comment out the following if we want to consider the full date.
    # My reasoning for only using the year is that it compresses the graph down significantly if multiple papers share a year node. May cause more noise tho
    df["pubDate"] = df["pubDate"].apply(lambda date: int(date.split(' ')[0]) if date[0].isdigit() else None)

    return df[["publication_ID", "title", "text", "Citations", "author_ids", "keyword_list", "venue", "pubDate"]] #Added title for inference later

train_df = preprocess(train_data)
val_df = preprocess(val_data)
test_df = preprocess(test_data)

train_df.head(2)
val_df.head(5)

Unnamed: 0,publication_ID,title,text,Citations,author_ids,keyword_list,venue,pubDate
0,23641233,Local functional connectivity as a pre surgica...,Local functional connectivity as a pre surgica...,"[23847586, 24073391, 26204264]","[560cdf1545ce1e5960a19851, 560cdf1545ce1e5960a...",[0],"Frontiers in neurology', 'id': '5451a5cae0cf0b...",2013
1,17157189,Heat shock response and acute lung injury,Heat shock response and acute lung injury. All...,"[20465849, 23536968, 24524071, 22140545, 21543...","[53f4442cdabfaee43ec75166, 54096d37dabfae8faa6...","[animals, heat-shock proteins, genetics, metab...","Free Radical Biology and Medicine', 'id': '545...",2007
2,15872007,STa and cGMP stimulate CFTR translocation to t...,STa and cGMP stimulate CFTR translocation to t...,"[21347269, 22069681, 24275951, 21347269, 22069...","[53f467c8dabfaeecd6a126b7, 5608ce0745cedb3396d...","[animals, bacterial toxins, pharmacology, biot...",American journal of physiology. Cell physiolog...,2005
3,20360276,Stigma and depression treatment utilization am...,Stigma and depression treatment utilization am...,"[27473569, 26576680, 24938081, 28774339, 29536...","[53f42d05dabfaedf43511829, 53f4263edabfaeb2acf...","[adolescent, adult, aged, antidepressive agent...","Psychiatric services (Washington, D.C.)",2010
4,15963034,Increased incidence and severity of diabetic k...,Increased incidence and severity of diabetic k...,"[31086620, 24355514, 34188679, 35511179]","[53f43fb3dabfaee4dc7be511, 53f4d409dabfaeedd17...","[adolescent, child, child, preschool, colorado...","Pediatric Diabetes', 'id': '5451a5c4e0cf0b02b5...",2005


In [9]:
import itertools

all_df = pd.concat([train_df, val_df, test_df], ignore_index=True)

# All papers (in data or cited)
all_papers = set(all_df['publication_ID'].astype(str))
for cits in all_df['Citations']:
    all_papers.update(cits)
paper2idx = {pid: i for i, pid in enumerate(sorted(all_papers))}

# All authors
all_authors = set(itertools.chain.from_iterable(all_df['author_ids']))
# Filter out None values before sorting
all_authors = {aid for aid in all_authors if aid is not None}
author2idx = {aid: i for i, aid in enumerate(sorted(all_authors))}

# All keywords
all_keywords = set(itertools.chain.from_iterable(all_df['keyword_list']))
keyword2idx = {kw: i for i, kw in enumerate(sorted(all_keywords))}

# All venues
all_venues = set(all_df['venue'].dropna())
venue2idx = {v: i for i, v in enumerate(sorted(all_venues))}

# All pub dates
all_pubDates = set(all_df['pubDate'].dropna())
pubDate2idx = {pd: i for i, pd in enumerate(sorted(all_pubDates))}

print(f"#papers={len(paper2idx)} #authors={len(author2idx)} #keywords={len(keyword2idx)} #venues={len(venue2idx)} #pubDates={len(pubDate2idx)}")

#papers=424279 #authors=307266 #keywords=17207 #venues=5501 #pubDates=35


# Upgrading the graph where paper node features come from real text embeddings

In [10]:
# Install sentence-transformers if needed (uncomment if required)
# import subprocess; subprocess.run(['pip', 'install', '-q', 'sentence-transformers'], check=True)

from sentence_transformers import SentenceTransformer
import torch
import numpy as np
from tqdm import tqdm


In [11]:
model = SentenceTransformer("all-MiniLM-L6-v2")  # 384-dim embeddings

In [12]:
import math
import numpy as np

paper_texts = []
valid_pids = [pid for pid in paper2idx.keys() if str(pid).isdigit()]  # keep only numeric strings

for pid in sorted(valid_pids, key=lambda x: int(x)):
    pid_int = int(pid)
    text = all_df.loc[all_df["publication_ID"] == pid_int, "text"]

    if len(text) > 0 and isinstance(text.values[0], str):
        paper_texts.append(text.values[0])
    else:
        paper_texts.append(" ")  # placeholder if no text found


In [None]:
print(f"Total paper nodes: {len(paper2idx)}")
print(f"Total valid papers with text: {len(paper_texts)}")


Total paper nodes: 424279
Total valid papers with text: 424278


In [14]:
paper_embs = model.encode(
    paper_texts,
    batch_size=64,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)


Batches:   0%|          | 0/6630 [00:00<?, ?it/s]

In [15]:
# Save paper embeddings to data directory
np.save(data_dir / "paper_embeddings.npy", paper_embs)
print(f"Paper embeddings saved to {data_dir / 'paper_embeddings.npy'}")

Paper embeddings saved to data/paper_embeddings.npy


In [None]:
# Load the embeddings
# import numpy as np
# paper_embs = np.load(data_dir / "paper_embeddings.npy")
# print("Loaded paper embeddings:", paper_embs.shape)


In [17]:
# Jacob Question: Why are only author and keyword learnable embeddings? what about paper?

from torch_geometric.data import HeteroData
import torch

data = HeteroData()

# Paper node features
embedding_dim = 384 #this is because the transformer has 384-dim embeddings

data['paper'].x = torch.tensor(paper_embs, dtype=torch.float) #the paper node is now embeddings from transformer
data['author'].x = torch.randn(len(author2idx), embedding_dim)
data['keyword'].x = torch.randn(len(keyword2idx), embedding_dim)
data['venue'].x = torch.randn(len(venue2idx), embedding_dim)
data['pubDate'].x = torch.randn(len(pubDate2idx), embedding_dim)

In [None]:
print(data['paper'].x.shape)
print(data['author'].x.shape)
print(data['keyword'].x.shape)


torch.Size([424278, 384])
torch.Size([307266, 384])
torch.Size([17207, 384])


In [19]:
# Saw time inefficiency in the above code, the below is a fix
#-----Shik Fixed here -------------
# Fixed edge overwriting (previously only kept last record)

edge_store = {
    ('paper', 'cites', 'paper'): [[], []],
    ('paper', 'written_by', 'author'): [[], []],
    ('author', 'authored', 'paper'): [[], []],
    ('paper', 'mentions', 'keyword'): [[], []],
    ('keyword', 'appears_in', 'paper'): [[], []],
    ('paper', 'published_in', 'venue'): [[], []],
    ('venue', 'published', 'paper'): [[], []],
    ('paper', 'publication_date', 'pubDate'): [[], []],
}

for _, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Building edges"):
    pid = str(row['publication_ID'])
    if pid not in paper2idx:
        continue
    pidx = paper2idx[pid]

    # --- (a) Citations: paper → paper ---
    for cited in row['Citations']:
        cited = str(cited)
        if cited in paper2idx:
            edge_store[('paper', 'cites', 'paper')][0].append(pidx)
            edge_store[('paper', 'cites', 'paper')][1].append(paper2idx[cited])

    # --- (b) Authors: paper ↔ author ---
    for aid in row['author_ids']:
        if aid in author2idx:
            edge_store[('paper', 'written_by', 'author')][0].append(pidx)
            edge_store[('paper', 'written_by', 'author')][1].append(author2idx[aid])
            edge_store[('author', 'authored', 'paper')][0].append(author2idx[aid])
            edge_store[('author', 'authored', 'paper')][1].append(pidx)

    # --- (c) Keywords: paper ↔ keyword ---
    for kw in row['keyword_list']:
        if kw in keyword2idx:
            edge_store[('paper', 'mentions', 'keyword')][0].append(pidx)
            edge_store[('paper', 'mentions', 'keyword')][1].append(keyword2idx[kw])
            edge_store[('keyword', 'appears_in', 'paper')][0].append(keyword2idx[kw])
            edge_store[('keyword', 'appears_in', 'paper')][1].append(pidx)

    # --- (d) Venue: paper ↔ venue ---
    venue = row['venue']
    if venue in venue2idx:
        edge_store[('paper', 'published_in', 'venue')][0].append(pidx)
        edge_store[('paper', 'published_in', 'venue')][1].append(venue2idx[venue])
        edge_store[('venue', 'published', 'paper')][0].append(venue2idx[venue])
        edge_store[('venue', 'published', 'paper')][1].append(pidx)

    # --- (e) Publication Date: paper → pubDate ---
    pubDate = row['pubDate']
    if pubDate in pubDate2idx:
        edge_store[('paper', 'publication_date', 'pubDate')][0].append(pidx)
        edge_store[('paper', 'publication_date', 'pubDate')][1].append(pubDate2idx[pubDate])

for rel, (src, dst) in edge_store.items():
    if len(src) > 0:
        data[rel].edge_index = torch.tensor([src, dst], dtype=torch.long)
    else:
        print(f"No edges found for relation {rel}")


Building edges: 100%|██████████████████| 42000/42000 [00:03<00:00, 13453.79it/s]


In [None]:
def summarize_heterodata(data):
    print("=== Node Counts ===")
    for ntype in data.node_types:
        print(f"{ntype:<12} → {data[ntype].num_nodes:,} nodes | "
              f"Feature dim: {data[ntype].x.shape[1] if 'x' in data[ntype] else 'N/A'}")

    print("\n=== Edge Counts ===")
    for etype in data.edge_types:
        e = data[etype].edge_index
        print(f"{etype} → {e.shape[1]:,} edges (shape={tuple(e.shape)})")

summarize_heterodata(data)


=== Node Counts ===
paper        → 424,278 nodes | Feature dim: 384
author       → 307,266 nodes | Feature dim: 384
keyword      → 17,207 nodes | Feature dim: 384
venue        → 5,501 nodes | Feature dim: 384
pubDate      → 35 nodes | Feature dim: 384

=== Edge Counts ===
('paper', 'cites', 'paper') → 486,632 edges (shape=(2, 486632))
('paper', 'written_by', 'author') → 291,666 edges (shape=(2, 291666))
('author', 'authored', 'paper') → 291,666 edges (shape=(2, 291666))
('paper', 'mentions', 'keyword') → 948,766 edges (shape=(2, 948766))
('keyword', 'appears_in', 'paper') → 948,766 edges (shape=(2, 948766))
('paper', 'published_in', 'venue') → 42,000 edges (shape=(2, 42000))
('venue', 'published', 'paper') → 42,000 edges (shape=(2, 42000))
('paper', 'publication_date', 'pubDate') → 42,000 edges (shape=(2, 42000))


In [21]:
# Checking for self-loops
num_self_loops = (data['paper', 'cites', 'paper'].edge_index[0] ==
                  data['paper', 'cites', 'paper'].edge_index[1]).sum().item()
print(f"Self-loops in cites relation: {num_self_loops}")

# Checking for duplicate edges
ei = data['paper', 'cites', 'paper'].edge_index
num_unique = torch.unique(ei, dim=1).shape[1]
print(f"Duplicate edges: {ei.shape[1] - num_unique}")


Self-loops in cites relation: 0
Duplicate edges: 63994


In [22]:
import torch

# Remove duplicate edges from (paper, cites, paper)
ei = data['paper', 'cites', 'paper'].edge_index

# Sort columns so [src, dst] and [dst, src] duplicates align
ei_unique = torch.unique(ei, dim=1)
num_removed = ei.shape[1] - ei_unique.shape[1]

data['paper', 'cites', 'paper'].edge_index = ei_unique

print(f"Removed {num_removed} duplicate citation edges. New total: {ei_unique.shape[1]}")


Removed 63994 duplicate citation edges. New total: 422638


In [23]:
ei = data['paper', 'cites', 'paper'].edge_index
print(f"New edge count: {ei.shape[1]}")
num_unique = torch.unique(ei, dim=1).shape[1]
print(f"Remaining duplicates: {ei.shape[1] - num_unique}")


New edge count: 422638
Remaining duplicates: 0


In [None]:
from torch_geometric.nn import HeteroConv, GATConv, Linear
import torch.nn.functional as F
import torch.nn as nn

class HeteroGNN(nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()

        self.conv1 = HeteroConv({
            ('paper', 'cites', 'paper'): GATConv(-1, hidden_channels, residual=True, add_self_loops=True),
            ('paper', 'written_by', 'author'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False),
            ('author', 'authored', 'paper'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False),
            ('paper', 'mentions', 'keyword'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False),
            ('keyword', 'appears_in', 'paper'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False),
            ('paper', 'published_in', 'venue'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False),
            ('venue', 'published', 'paper'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False),
            ('paper','publication_date', 'pubDate'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False)
        }, aggr='sum')

        self.lin = nn.Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = {key: self.lin(x) for key, x in x_dict.items()}
        return x_dict

In [25]:
model = HeteroGNN(hidden_channels=64, out_channels=64).to(device)
data = data.to(device) #Needed for all values to be on same device
print(model)


HeteroGNN(
  (conv1): HeteroConv(num_relations=8)
  (lin): Linear(in_features=64, out_features=64, bias=True)
)


In [None]:
from torch_geometric.utils import negative_sampling

def loss_fn(emb_src, emb_dst, pos_edges, neg_edges):
    pos_score = (emb_src[pos_edges[0]] * emb_dst[pos_edges[1]]).sum(dim=1)
    neg_score = (emb_src[neg_edges[0]] * emb_dst[neg_edges[1]]).sum(dim=1)
    labels = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
    scores = torch.cat([pos_score, neg_score])
    return F.binary_cross_entropy_with_logits(scores, labels)


In [27]:
def clean_invalid_edges(data):
    for etype in data.edge_types:
        if 'edge_index' not in data[etype]:
            continue

        ei = data[etype].edge_index
        src_type, _, dst_type = etype
        src_nodes = data[src_type].num_nodes
        dst_nodes = data[dst_type].num_nodes

        mask = (ei[0] < src_nodes) & (ei[1] < dst_nodes)
        valid_count = mask.sum().item()

        if valid_count < ei.shape[1]:
            print(f"Removing {ei.shape[1] - valid_count} invalid edges from {etype}")
            data[etype].edge_index = ei[:, mask]

clean_invalid_edges(data)


Removing 48 invalid edges from ('paper', 'cites', 'paper')


In [28]:
def check_hetero_integrity(data):
    print("=== Sanity Check for HeteroData ===")
    for etype in data.edge_types:
        ei = data[etype].edge_index
        src_type, _, dst_type = etype

        if ei is None or ei.numel() == 0:
            print(f"{etype} has no edges")
            continue

        src_nodes, dst_nodes = data[src_type].num_nodes, data[dst_type].num_nodes
        max_src, max_dst = ei[0].max().item(), ei[1].max().item()
        min_src, min_dst = ei[0].min().item(), ei[1].min().item()

        if max_src >= src_nodes or max_dst >= dst_nodes:
            print(f"{etype} has invalid indices: "
                  f"src max {max_src}/{src_nodes}, dst max {max_dst}/{dst_nodes}")
        elif min_src < 0 or min_dst < 0:
            print(f"{etype} has negative indices!")
        else:
            print(f"{etype} OK ({ei.shape[1]} edges)")

check_hetero_integrity(data)


=== Sanity Check for HeteroData ===
('paper', 'cites', 'paper') OK (422590 edges)
('paper', 'written_by', 'author') OK (291666 edges)
('author', 'authored', 'paper') OK (291666 edges)
('paper', 'mentions', 'keyword') OK (948766 edges)
('keyword', 'appears_in', 'paper') OK (948766 edges)
('paper', 'published_in', 'venue') OK (42000 edges)
('venue', 'published', 'paper') OK (42000 edges)
('paper', 'publication_date', 'pubDate') OK (42000 edges)


In [29]:
# Split Citation Edges
from torch_geometric.transforms import RandomLinkSplit

# Randomly split only the citation edges for link prediction
transform = RandomLinkSplit(
    num_val=0.1,                   # 10% validation
    num_test=0.1,                  # 10% test
    is_undirected=False,           # citations are directional
    add_negative_train_samples=True,
    edge_types=[('paper', 'cites', 'paper')],  # focus only on citation edges
    rev_edge_types=[None]          # no reverse relation
)

train_data, val_data, test_data = transform(data)
train_data = train_data.to(device)
val_data = val_data.to(device)
test_data = test_data.to(device)



In [30]:
# Fixed to focus only on paper, cites, paper edges
optimizer = torch.optim.Adam(model.parameters(), lr=GNN_config["learning_rate"])

for epoch in range(GNN_config['epoch']):
    model.train()
    optimizer.zero_grad()

    out_dict = model(train_data.x_dict, train_data.edge_index_dict)

    # --- Focus only on (paper, cites, paper) edges ---
    etype = ('paper', 'cites', 'paper')
    pos_edges = train_data[etype].edge_label_index  # edges for training

    # Negative sampling for link prediction
    neg_edges = negative_sampling(
        edge_index=pos_edges,
        num_nodes=(train_data['paper'].num_nodes, train_data['paper'].num_nodes),
        num_neg_samples=pos_edges.size(1)
    )


    emb_src = out_dict['paper']
    emb_dst = out_dict['paper']

    total_loss = loss_fn(emb_src, emb_dst, pos_edges, neg_edges)

    total_loss.backward()
    optimizer.step()

    print(f"Epoch {epoch+1}/{GNN_config['epoch']} | Loss: {total_loss.item():.4f}")

Epoch 1/10 | Loss: 0.7750
Epoch 2/10 | Loss: 0.7271
Epoch 3/10 | Loss: 0.7143
Epoch 4/10 | Loss: 0.7051
Epoch 5/10 | Loss: 0.6976
Epoch 6/10 | Loss: 0.6919
Epoch 7/10 | Loss: 0.6874
Epoch 8/10 | Loss: 0.6841
Epoch 9/10 | Loss: 0.6811
Epoch 10/10 | Loss: 0.6776


In [31]:
torch.save(model.state_dict(), "hetero_gnn_checkpoint.pt")

In [32]:
import torch
import os
from pathlib import Path

# Use local checkpoints directory (defined in cell 2)
# checkpoint_dir is already defined as checkpoints_dir from earlier setup
checkpoint_path = checkpoints_dir / 'hetero_gnn_checkpoint.pt'

# Save model
torch.save(model.state_dict(), checkpoint_path)
print(f"Model saved at {checkpoint_path.absolute()}")

# Load:
# model.load_state_dict(torch.load(checkpoint_path))
# model.to(device)


Model saved at /home/upandit/mag_citation_recommender/checkpoints/hetero_gnn_checkpoint.pt


# Inference

In [33]:
# Set model to evaluation mode
import torch
from pathlib import Path

# Load model from local checkpoints directory
checkpoint_path = checkpoints_dir / 'hetero_gnn_checkpoint.pt'

model = HeteroGNN(hidden_channels=64, out_channels=64)
model.load_state_dict(torch.load(checkpoint_path))
model = model.to(device)
model.eval()



HeteroGNN(
  (conv1): HeteroConv(num_relations=8)
  (lin): Linear(in_features=64, out_features=64, bias=True)
)

In [34]:
model.eval()
with torch.no_grad():
    out_dict = model(data.x_dict, data.edge_index_dict)
    paper_emb = out_dict['paper'].cpu()

In [35]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

query_pid = "17396995"  # Example paper ID
query_idx = paper2idx[query_pid]

if isinstance(paper_embs, torch.Tensor):
    paper_embs = paper_embs.cpu().numpy()

# Extract query vector
query_vec = paper_embs[query_idx].reshape(1, -1)

# Compute cosine similarity across all papers
sims = cosine_similarity(query_vec, paper_embs)[0]
topk = np.argsort(sims)[::-1][1:6]  # Skip itself, take top 5

# Show info about the query paper itself
query_row = all_df.loc[all_df["publication_ID"] == int(query_pid)]
if len(query_row) > 0:
    print("Query Paper:")
    print(f"ID: {query_pid}")
    print(f"Title: {query_row['title'].values[0] if 'title' in query_row.columns else 'Unknown'}")
    print(f"Year: {query_row['pubDate'].values[0]}")
    print(f"Venue: {query_row['venue'].values[0]}")
    print(f"Abstract: {query_row['text'].values[0][:400]}...")
else:
    print("Query paper not found in all_df.")

print("\nTop 5 similar papers:")
print("-" * 50)
for i in topk:
    pid = list(paper2idx.keys())[i]
    row = all_df.loc[all_df["publication_ID"] == int(pid)]
    if len(row) > 0:
        title = row["title"].values[0] if "title" in row.columns else "Unknown"
        print(f"{pid} — {title}")
    else:
        print(f"{pid} — Not found")


Query Paper:
ID: 17396995
Title: Herpes simplex virus type 2 infection does not influence viral dynamics during early HIV 1 infection
Year: 2007
Venue: The Journal of infectious diseases
Abstract: Herpes simplex virus type 2 infection does not influence viral dynamics during early HIV 1 infection. We sought to compare baseline and longitudinal plasma HIV-1 loads between herpes simplex virus type 2 (HSV-2)-seropositive and -seronegative individuals who are enrolled in a primary HIV-1 infection cohort in San Diego, California.. Adult;California;epidemiology;Cohort Studies;HIV Infections;blood...

Top 5 similar papers:
--------------------------------------------------
17264332 — Clinicopathologic features of osteosarcoma in patients with Rothmund Thomson syndrome
15372107 — PDX 1 haploinsufficiency limits the compensatory islet hyperplasia that occurs in response to insulin resistance
19759291 — Medial prefrontal cortex secondary hyperalgesia and the default mode network
15983384 — High 

In [None]:
from sklearn.metrics import roc_auc_score, average_precision_score

@torch.no_grad()
def evaluate(model, data, device):
    model.eval()
    out = model(data.x_dict, data.edge_index_dict)
    src_emb = out['paper']
    dst_emb = out['paper']

    # Positive and negative edges from test split
    pos_edges = data['paper', 'cites', 'paper'].edge_label_index[:, data['paper', 'cites', 'paper'].edge_label == 1]
    neg_edges = data['paper', 'cites', 'paper'].edge_label_index[:, data['paper', 'cites', 'paper'].edge_label == 0]

    # Compute scores (dot product similarity)
    pos_scores = (src_emb[pos_edges[0]] * dst_emb[pos_edges[1]]).sum(dim=1).cpu().numpy()
    neg_scores = (src_emb[neg_edges[0]] * dst_emb[neg_edges[1]]).sum(dim=1).cpu().numpy()

    y_true = np.concatenate([np.ones_like(pos_scores), np.zeros_like(neg_scores)])
    y_scores = np.concatenate([pos_scores, neg_scores])

    auc = roc_auc_score(y_true, y_scores)
    ap = average_precision_score(y_true, y_scores)

    return auc, ap


In [37]:
val_auc, val_ap = evaluate(model, val_data, device)
print(f"Validation AUC: {val_auc:.4f}, AP: {val_ap:.4f}")

test_auc, test_ap = evaluate(model, test_data, device)
print(f"Test AUC: {test_auc:.4f}, AP: {test_ap:.4f}")


Validation AUC: 0.8987, AP: 0.8412
Test AUC: 0.8974, AP: 0.8410


# =========================================
# PART 2: GRIL ALGORITHMS IMPLEMENTATION
# =========================================

This section implements the GRIL algorithms for citation recommendation:
- Algorithm 1: Attention-based Graph Retriever
- Complexity Assessment Module (CAM)
- Joint Training Framework
- Placeholders for SAG Pooling (Sam's task) and LLM Integration


In [49]:
# Verify Part 1 outputs are available for Part 2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData

# Check required variables from Part 1
required_vars = {
    'model': 'HeteroGNN model (loaded from checkpoint)',
    'data': 'HeteroData graph',
    'paper2idx': 'Paper ID to index mapping',
    'all_df': 'DataFrame with paper metadata',
    'device': 'torch.device (cuda/cpu)',
    'checkpoints_dir': 'Path to checkpoints directory'
}

missing = []
for var_name, description in required_vars.items():
    if var_name not in globals():
        missing.append(f"{var_name} ({description})")
    else:
        print(f"✓ {var_name}: {type(globals()[var_name]).__name__}")

if missing:
    print(f"\n⚠️  Missing variables: {', '.join(missing)}")
    print("Please run Part 1 cells first (especially model loading cell).")
else:
    print("\n✅ All required variables from Part 1 are available!")
    print(f"   Model: {type(model).__name__}")
    print(f"   Data: {type(data).__name__} with {len(data.node_types)} node types")
    print(f"   Device: {device}")


✓ model: HeteroGNN
✓ data: HeteroData
✓ paper2idx: dict
✓ all_df: DataFrame
✓ device: device
✓ checkpoints_dir: PosixPath

✅ All required variables from Part 1 are available!
   Model: HeteroGNN
   Data: HeteroData with 5 node types
   Device: cuda


## Query Encoder (SentenceTransformer)


In [50]:
# =========================================
# QUERY ENCODER
# =========================================
from sentence_transformers import SentenceTransformer
from typing import List
import torch

class QueryEncoder:
    """
    Encodes natural language queries into dense vectors using SentenceTransformer.
    """
    def __init__(self, model_name: str = 'all-MiniLM-L6-v2', device=None):
        """
        Args:
            model_name: SentenceTransformer model name (default: 'all-MiniLM-L6-v2' - 384 dim)
            device: Device to run the model on
        """
        self.device = device if device is not None else torch.device('cpu')
        self.model = SentenceTransformer(model_name, device=self.device)
        self.embedding_dim = self.model.get_sentence_embedding_dimension()
        print(f"QueryEncoder initialized with {model_name}, embedding_dim={self.embedding_dim}")
    
    def encode(self, query_text: str, convert_to_tensor: bool = True) -> torch.Tensor:
        """Encode a single query text into a dense vector."""
        embedding = self.model.encode(query_text, convert_to_tensor=convert_to_tensor, device=self.device)
        return embedding
    
    def encode_batch(self, query_texts: List[str], batch_size: int = 32) -> torch.Tensor:
        """Encode a batch of query texts."""
        embeddings = self.model.encode(query_texts, batch_size=batch_size, 
                                      convert_to_tensor=True, device=self.device)
        return embeddings

# Initialize query encoder
query_encoder = QueryEncoder(device=device)


QueryEncoder initialized with all-MiniLM-L6-v2, embedding_dim=384


## Attention-based Relevance Scorer


In [51]:
# =========================================
# ATTENTION-BASED RELEVANCE SCORER
# =========================================
import torch.nn as nn
import torch.nn.functional as F

class RelevanceScorer(nn.Module):
    """
    Computes relevance scores between query, source node, and destination node embeddings.
    This is the core attention mechanism for Algorithm 1.
    """
    def __init__(self, query_dim: int, node_dim: int, hidden_dim: int = 128):
        super().__init__()
        
        # Project query to node dimension if needed
        if query_dim != node_dim:
            self.query_proj = nn.Linear(query_dim, node_dim)
        else:
            self.query_proj = nn.Identity()
        
        # MLP for computing relevance: f(query, src, dst) -> score
        self.mlp = nn.Sequential(
            nn.Linear(node_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )
    
    def forward(self, query_emb: torch.Tensor, src_embs: torch.Tensor, 
                dst_embs: torch.Tensor) -> torch.Tensor:
        """Compute relevance scores for edges."""
        query_proj = self.query_proj(query_emb)
        if query_proj.dim() == 1:
            query_proj = query_proj.unsqueeze(0)
        
        num_edges = src_embs.size(0)
        query_expanded = query_proj.expand(num_edges, -1)
        combined = torch.cat([query_expanded, src_embs, dst_embs], dim=1)
        scores = self.mlp(combined).squeeze(-1)
        scores = torch.sigmoid(scores)
        return scores

# Initialize relevance scorer
query_dim = query_encoder.embedding_dim  # 384
node_dim = 64  # From HeteroGNN output
relevance_scorer = RelevanceScorer(query_dim=query_dim, node_dim=node_dim).to(device)
print(f"✅ RelevanceScorer initialized (query_dim={query_dim}, node_dim={node_dim})")


✅ RelevanceScorer initialized (query_dim=384, node_dim=64)


## Algorithm 1: Attention-based Growing and Pruning


In [52]:
# =========================================
# ALGORITHM 1: ATTENTION-BASED GROWING AND PRUNING
# =========================================
from typing import Dict, List, Tuple, Set, Optional
import torch
import numpy as np
from torch_geometric.data import HeteroData

def attention_based_graph_retriever(
    query_text: str,
    query_seed_entities: Dict[str, List[int]],
    gnn_model: nn.Module,
    full_data: HeteroData,
    query_encoder: QueryEncoder,
    relevance_scorer: RelevanceScorer,
    max_hops: int = 2,
    relevance_threshold: float = 0.1,
    max_nodes_per_hop: Optional[int] = None,
    device: torch.device = None
) -> HeteroData:
    """
    Implements Algorithm 1: Attention-based Growing and Pruning for dynamic subgraph retrieval.
    
    Args:
        query_text: Natural language query (e.g., paper title + abstract)
        query_seed_entities: Dict mapping node types to lists of node indices to start from
        gnn_model: Trained HeteroGNN model
        full_data: Full HeteroData graph
        query_encoder: QueryEncoder instance
        relevance_scorer: RelevanceScorer instance
        max_hops: Maximum number of hops to expand (default: 2)
        relevance_threshold: Minimum relevance score to retain an edge (default: 0.1)
        max_nodes_per_hop: Maximum number of nodes to retain per hop (None = no limit)
        device: Device to run computation on
    
    Returns:
        HeteroData: Retrieved subgraph containing only relevant nodes and edges
    """
    if device is None:
        device = next(gnn_model.parameters()).device
    
    # 1. Encode query
    query_emb = query_encoder.encode(query_text).to(device)
    
    # 2. Get GNN embeddings for all nodes
    gnn_model.eval()
    with torch.no_grad():
        out_dict = gnn_model(full_data.x_dict, full_data.edge_index_dict)
    
    # 3. Initialize current nodes (seed entities)
    current_nodes: Dict[str, Set[int]] = {k: set(v) for k, v in query_seed_entities.items()}
    
    # 4. Track all retained edges and nodes across hops
    retained_edges: Dict[Tuple[str, str, str], List[Tuple[int, int, float]]] = {}
    all_retained_nodes: Dict[str, Set[int]] = {k: set(v) for k, v in current_nodes.items()}
    
    # 5. Multi-hop expansion
    for hop in range(max_hops):
        newly_retained_nodes: Dict[str, Set[int]] = {}
        
        # Process each edge type
        for etype in full_data.edge_types:
            src_type, rel_type, dst_type = etype
            edge_index = full_data[etype].edge_index
            
            # 5.1. Grow: Filter edges starting from current nodes
            if src_type not in current_nodes or len(current_nodes[src_type]) == 0:
                continue
            
            src_nodes_tensor = torch.tensor(list(current_nodes[src_type]), device=device)
            src_mask = torch.isin(edge_index[0], src_nodes_tensor)
            potential_edge_indices = torch.where(src_mask)[0]
            
            if potential_edge_indices.numel() == 0:
                continue
            
            # Get source and destination indices for potential edges
            potential_src = edge_index[0, potential_edge_indices]
            potential_dst = edge_index[1, potential_edge_indices]
            
            # 5.2. Score: Calculate relevance for each potential edge
            # Clone tensors to detach from inference mode (allows use in autograd if needed)
            src_node_embs = out_dict[src_type][potential_src].clone()
            dst_node_embs = out_dict[dst_type][potential_dst].clone()
            # For inference, wrap in no_grad to avoid gradient computation
            with torch.no_grad():
                relevance_scores = relevance_scorer(query_emb, src_node_embs, dst_node_embs)
            
            # 5.3. Prune: Filter edges below threshold
            relevant_mask = relevance_scores > relevance_threshold
            
            if relevant_mask.sum() == 0:
                continue
            
            # Get retained edges
            retained_src = potential_src[relevant_mask].cpu().numpy()
            retained_dst = potential_dst[relevant_mask].cpu().numpy()
            retained_scores = relevance_scores[relevant_mask].cpu().numpy()
            
            # Optionally limit nodes per hop
            if max_nodes_per_hop is not None:
                top_k = min(max_nodes_per_hop, len(retained_scores))
                top_indices = np.argsort(retained_scores)[::-1][:top_k]
                retained_src = retained_src[top_indices]
                retained_dst = retained_dst[top_indices]
                retained_scores = retained_scores[top_indices]
            
            # Store retained edges
            edge_triples = list(zip(retained_src.tolist(), retained_dst.tolist(), retained_scores.tolist()))
            if etype not in retained_edges:
                retained_edges[etype] = []
            retained_edges[etype].extend(edge_triples)
            
            # 5.4. Update: Add destination nodes for next hop
            newly_retained_nodes.setdefault(dst_type, set()).update(retained_dst.tolist())
            all_retained_nodes.setdefault(dst_type, set()).update(retained_dst.tolist())
        
        # Merge newly retained nodes for next iteration
        for node_type, new_nodes in newly_retained_nodes.items():
            current_nodes.setdefault(node_type, set()).update(new_nodes)
    
    # 6. Construct subgraph from retained nodes and edges
    subgraph_data = _construct_subgraph(full_data, all_retained_nodes, retained_edges, device)
    return subgraph_data


def _construct_subgraph(
    full_data: HeteroData,
    retained_nodes: Dict[str, Set[int]],
    retained_edges: Dict[Tuple[str, str, str], List[Tuple[int, int, float]]],
    device: torch.device
) -> HeteroData:
    """Construct a HeteroData subgraph from retained nodes and edges."""
    subgraph_data = HeteroData().to(device)
    
    # Add node features for retained nodes
    for node_type in retained_nodes:
        if node_type not in full_data.node_types:
            continue
        
        node_indices = sorted(list(retained_nodes[node_type]))
        if len(node_indices) == 0:
            continue
        
        node_mapping = {orig_idx: new_idx for new_idx, orig_idx in enumerate(node_indices)}
        node_tensor = torch.tensor(node_indices, device=device)
        subgraph_data[node_type].x = full_data[node_type].x[node_tensor]
        subgraph_data[node_type]._node_mapping = node_mapping
    
    # Add edges
    for etype, edge_list in retained_edges.items():
        if len(edge_list) == 0:
            continue
        
        src_type, rel_type, dst_type = etype
        src_mapping = subgraph_data[src_type]._node_mapping
        dst_mapping = subgraph_data[dst_type]._node_mapping
        
        edge_src = []
        edge_dst = []
        
        for src, dst, score in edge_list:
            if src in src_mapping and dst in dst_mapping:
                edge_src.append(src_mapping[src])
                edge_dst.append(dst_mapping[dst])
        
        if len(edge_src) > 0:
            edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long, device=device)
            subgraph_data[etype].edge_index = edge_index
    
    return subgraph_data

print("✅ Algorithm 1 functions defined")


✅ Algorithm 1 functions defined


In [53]:
# =========================================
# COMPLEXITY ASSESSMENT MODULE (CAM)
# =========================================
class ComplexityAssessmentModule(nn.Module):
    """
    MLP classifier that predicts question complexity (number of hops).
    Output determines retrieval budget: number of triplets = 5 × predicted_hops
    """
    def __init__(self, query_dim: int, hidden_dim: int = 256, max_hops: int = 4):
        super().__init__()
        self.query_dim = query_dim
        self.hidden_dim = hidden_dim
        self.max_hops = max_hops
        
        self.mlp = nn.Sequential(
            nn.Linear(query_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, max_hops + 1)
        )
    
    def forward(self, query_emb: torch.Tensor) -> torch.Tensor:
        """Predict complexity (number of hops)."""
        return self.mlp(query_emb)
    
    def predict_hops(self, query_emb: torch.Tensor) -> int:
        """Predict number of hops for a query."""
        with torch.no_grad():
            logits = self.forward(query_emb)
            if logits.dim() == 1:
                predicted = torch.argmax(logits, dim=0).item()
            else:
                predicted = torch.argmax(logits, dim=1).item()
        return predicted
    
    def get_retrieval_budget(self, query_emb: torch.Tensor) -> int:
        """Get number of triplets to retrieve based on predicted complexity."""
        hops = self.predict_hops(query_emb)
        return 5 * hops

# Initialize CAM (will be trained separately or jointly)
# cam = ComplexityAssessmentModule(query_dim=query_encoder.embedding_dim).to(device)
print("✅ ComplexityAssessmentModule class defined")


✅ ComplexityAssessmentModule class defined


## Joint Training Framework


In [54]:
# =========================================
# JOINT TRAINING ALGORITHM
# =========================================
class JointTrainingLoss(nn.Module):
    """
    Joint loss for LLM and retriever optimization.
    Implements Algorithm 4 from GRIL paper, Section 4.3.
    """
    def __init__(self, alpha: float = 1.0, beta: float = 1.0, gamma: float = 0.1, 
                 use_graph_supervision: bool = False):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.use_graph_supervision = use_graph_supervision
    
    def forward(self, llm_logits: torch.Tensor, ground_truth: torch.Tensor,
                triplet_probabilities: Optional[torch.Tensor] = None,
                shortest_path_entities: Optional[torch.Tensor] = None,
                retrieved_entities: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Dict[str, float]]:
        """Compute joint training loss."""
        # 1. LLM Accuracy Loss
        if llm_logits.dim() == 2 and llm_logits.size(1) > 1:
            if ground_truth.dtype == torch.long:
                accuracy_loss = F.cross_entropy(llm_logits, ground_truth)
            else:
                accuracy_loss = F.binary_cross_entropy_with_logits(llm_logits, ground_truth)
        else:
            accuracy_loss = F.mse_loss(llm_logits, ground_truth.float())
        
        # 2. Retriever Feedback Loss
        retriever_feedback_loss = torch.tensor(0.0, device=llm_logits.device)
        if triplet_probabilities is not None:
            probs = F.softmax(triplet_probabilities, dim=-1)
            entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean()
            retriever_feedback_loss = -entropy
        
        # 3. Graph Supervision Loss
        graph_supervision_loss = torch.tensor(0.0, device=llm_logits.device)
        if self.use_graph_supervision and shortest_path_entities is not None:
            if retrieved_entities is not None:
                graph_supervision_loss = F.binary_cross_entropy_with_logits(
                    retrieved_entities.float(), shortest_path_entities.float())
        
        total_loss = (self.alpha * accuracy_loss + self.beta * retriever_feedback_loss + 
                     self.gamma * graph_supervision_loss)
        
        loss_dict = {
            'total_loss': total_loss.item(),
            'accuracy_loss': accuracy_loss.item(),
            'retriever_feedback_loss': retriever_feedback_loss.item(),
            'graph_supervision_loss': graph_supervision_loss.item()
        }
        return total_loss, loss_dict

print("✅ JointTrainingLoss class defined")


✅ JointTrainingLoss class defined


## Placeholder: SAG (Self-Attention Graph) Pooling Layer
**TODO: Sam's Implementation**


In [55]:
print("SAG Pooling layer - TODO: Sam's implementation")


SAG Pooling layer - TODO: Sam's implementation


## Placeholder: LLM Integration
**TODO: Sam & Ujjwal's Implementation**


In [56]:
print("LLM Integration - TODO: Sam & Ujjwal's implementation")


LLM Integration - TODO: Sam & Ujjwal's implementation


## Test: Algorithm 1 Verification

**What this test does:**
- Verifies Algorithm 1 can retrieve a relevant subgraph from the full graph
- Tests with a real query paper (ID: 17396995)
- Shows subgraph statistics (nodes, edges, reduction from full graph)

**Dependencies needed (all available after Part 1 & Part 2 initialization):**
- ✅ `model`: Trained HeteroGNN (loaded from checkpoint)
- ✅ `data`: Full HeteroData graph
- ✅ `paper2idx`: Paper ID to index mapping
- ✅ `all_df`: DataFrame with paper metadata
- ✅ `query_encoder`: QueryEncoder instance
- ✅ `relevance_scorer`: RelevanceScorer instance
- ✅ `device`: torch.device

**What we're NOT waiting for:**
- Algorithm 1 can be tested immediately after Part 1 & Part 2 setup
- No need to wait for SAG Pooling or LLM integration
- This test verifies the core retrieval mechanism works


In [57]:
# =========================================
# TEST: Algorithm 1 - retrieve a subgraph for a query paper

print("="*80)
print("TESTING ALGORITHM 1: Attention-based Graph Retriever")
print("="*80)

# Check dependencies
required = {
    'model': model,
    'data': data,
    'paper2idx': paper2idx,
    'all_df': all_df,
    'query_encoder': query_encoder,
    'relevance_scorer': relevance_scorer,
    'device': device
}

missing = [k for k, v in required.items() if v is None or k not in globals()]
if missing:
    print(f"❌ Missing dependencies: {missing}")
    print("Please run Part 1 and Part 2 initialization cells first.")
else:
    print("✅ All dependencies available")
    
    # Test with a known paper ID
    query_paper_id = "17396995"  # Example paper from our dataset
    
    try:
        # Check if paper exists
        if query_paper_id not in paper2idx:
            print(f"❌ Paper {query_paper_id} not found in paper2idx")
        else:
            query_paper_idx = paper2idx[query_paper_id]
            print(f"✓ Found query paper: {query_paper_id} (index: {query_paper_idx})")
            
            # Get query text
            query_row = all_df[all_df['publication_ID'] == int(query_paper_id)]
            if len(query_row) == 0:
                print(f"❌ Paper {query_paper_id} not found in all_df")
            else:
                query_text = query_row['text'].values[0]
                print(f"✓ Query text length: {len(query_text)} characters")
                
                # Define seed entities
                query_seed_entities = {'paper': [query_paper_idx]}
                print(f"✓ Seed entities: {query_seed_entities}")
                
                # Run Algorithm 1
                print("\n" + "-"*80)
                print("Running Algorithm 1...")
                print("-"*80)
                
                retrieved_subgraph = attention_based_graph_retriever(
                    query_text=query_text,
                    query_seed_entities=query_seed_entities,
                    gnn_model=model,
                    full_data=data,
                    query_encoder=query_encoder,
                    relevance_scorer=relevance_scorer,
                    max_hops=2,
                    relevance_threshold=0.1,
                    max_nodes_per_hop=100,
                    device=device
                )
                
                # Verify results
                print("\n" + "="*80)
                print("✅ ALGORITHM 1 TEST PASSED!")
                print("="*80)
                print(f"Retrieved subgraph statistics:")
                print(f"  Node types: {len(retrieved_subgraph.node_types)}")
                for node_type in retrieved_subgraph.node_types:
                    num_nodes = retrieved_subgraph[node_type].num_nodes
                    print(f"    {node_type}: {num_nodes} nodes")
                
                print(f"\n  Edge types: {len(retrieved_subgraph.edge_types)}")
                for etype in retrieved_subgraph.edge_types:
                    num_edges = retrieved_subgraph[etype].edge_index.size(1) if hasattr(retrieved_subgraph[etype], 'edge_index') else 0
                    print(f"    {etype}: {num_edges} edges")
                
                # Compare with full graph
                print(f"\n  Comparison with full graph:")
                print(f"    Full graph papers: {data['paper'].num_nodes:,}")
                print(f"    Retrieved papers: {retrieved_subgraph['paper'].num_nodes:,}")
                print(f"    Reduction: {(1 - retrieved_subgraph['paper'].num_nodes / data['paper'].num_nodes) * 100:.2f}%")
                
                print("\n✅ Algorithm 1 is working correctly!")
                
    except Exception as e:
        print(f"\n❌ ERROR during Algorithm 1 test:")
        print(f"   {type(e).__name__}: {str(e)}")
        import traceback
        print("\nFull traceback:")
        traceback.print_exc()


TESTING ALGORITHM 1: Attention-based Graph Retriever
✅ All dependencies available
✓ Found query paper: 17396995 (index: 18104)
✓ Query text length: 645 characters
✓ Seed entities: {'paper': [18104]}

--------------------------------------------------------------------------------
Running Algorithm 1...
--------------------------------------------------------------------------------

✅ ALGORITHM 1 TEST PASSED!
Retrieved subgraph statistics:
  Node types: 5
    paper: 185 nodes
    author: 16 nodes
    keyword: 28 nodes
    venue: 2 nodes
    pubDate: 1 nodes

  Edge types: 8
    ('paper', 'cites', 'paper'): 15 edges
    ('paper', 'written_by', 'author'): 21 edges
    ('paper', 'mentions', 'keyword'): 73 edges
    ('paper', 'published_in', 'venue'): 3 edges
    ('paper', 'publication_date', 'pubDate'): 3 edges
    ('author', 'authored', 'paper'): 6 edges
    ('keyword', 'appears_in', 'paper'): 100 edges
    ('venue', 'published', 'paper'): 100 edges

  Comparison with full graph:
    Ful