In [1]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
from tqdm.auto import tqdm
import warnings
import json
import csv
warnings.filterwarnings('ignore')

# Paths
BASE_DIR = Path.cwd().parent.parent.parent.parent
DATA_DIR = BASE_DIR / "oc_mini"

# Add gnn package to path (parent directory)
sys.path.insert(0, str(Path.cwd().parent))

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Configuration

You can easily swap the transformer model and data source by modifying these variables below.

# GNN Baseline 1: Induced Subgraph with Frozen Transformer

**Experiment Goal**: Demonstrate that when a GNN is trained on an induced subgraph of training nodes (with frozen transformer), test nodes with no edges degrade to transformer-only performance.

**Setup**:
- 90% train nodes, 10% test nodes
- Graph contains ONLY edges between training nodes
- Transformer (SciBERT) is frozen
- Only GNN layers are trainable
- Input: Title + Abstract concatenated

**Expected Result**: Test nodes will have identical embeddings from transformer-only vs GNN, since they have no graph connectivity.

---

In [3]:
# Import GNN modules
from model import TransformerGNN
from graph_utils import (
    create_induced_subgraph,
    analyze_graph_statistics,
    get_node_texts
)
from split_utils import create_node_based_split

# Data paths
edgelist_path = DATA_DIR / "network" / "oc_mini_edgelist.csv"
metadata_path = DATA_DIR / "metadata" / "oc_mini_node_metadata.csv"

In [4]:
# Load metadata
metadata_df = pd.read_csv(metadata_path)
print(f"Metadata loaded: {len(metadata_df)} entries")

# Create train/test split (90/10 split)
all_node_ids = [str(node_id) for node_id in metadata_df['id'].values]
train_nodes, test_nodes = create_node_based_split(all_node_ids, test_ratio=0.1, seed=42)

print(f"\nTrain nodes: {len(train_nodes)} ({len(train_nodes)/len(all_node_ids)*100:.1f}%)")
print(f"Test nodes: {len(test_nodes)} ({len(test_nodes)/len(all_node_ids)*100:.1f}%)")

metadata_df.head()

Metadata loaded: 14442 entries

Train nodes: 12998 (90.0%)
Test nodes: 1444 (10.0%)


Unnamed: 0,id,doi,title,abstract
0,128,10.1101/2021.05.10.443415,Improved protein contact prediction using dime...,AbstractDeep residual learning has shown great...
1,163,10.1101/2021.05.07.443114,Following the Trail of One Million Genomes: Fo...,AbstractSevere acute respiratory syndrome coro...
2,200,10.1101/2021.05.11.443555,Mechanism of molnupiravir-induced SARS-CoV-2 m...,Molnupiravir is an orally available antiviral ...
3,941,10.3390/ijms20020449,Bactericidal and Cytotoxic Properties of Silve...,Silver nanoparticles (AgNPs) can be synthesize...
4,1141,10.3390/ijms20040865,Silver Nanoparticles: Synthesis and Applicatio...,"Over the past few decades, metal nanoparticles..."


In [5]:
# Create induced subgraph - KEY STEP FOR BASELINE
# This graph will ONLY contain edges between training nodes
# Test nodes will have NO edges (degree = 0)

edge_index, node_to_idx, idx_to_node = create_induced_subgraph(
    edgelist_path,
    train_nodes,
    metadata_df
)

print(f"\nInduced subgraph created:")
print(f"  Nodes in mapping: {len(node_to_idx)}")
print(f"  Edges: {edge_index.shape[1]}")

# Analyze statistics
analyze_graph_statistics(edge_index, train_nodes, node_to_idx, metadata_df)

Loading edgelist from /home/vikramr2/oc_mini/network/oc_mini_edgelist.csv...
  Full graph: 111873 edges

Filtering to induced subgraph of 12998 training nodes...
  Induced subgraph: 91586 edges
  Removed 20287 edges involving test nodes

Node mapping:
  Total nodes: 14442
  Train nodes: 12998
  Test nodes: 1444

Final edge_index shape: torch.Size([2, 183172])
  Directed edges: 183172

✓ Verification: 0 edges involve test nodes (should be 0)

Induced subgraph created:
  Nodes in mapping: 14442
  Edges: 183172

GRAPH STATISTICS

Training nodes (12998 nodes):
  Mean degree: 14.09
  Median degree: 10
  Max degree: 1413
  Isolated nodes: 79

Test nodes (1444 nodes):
  Mean degree: 0.00
  Median degree: 0
  Max degree: 0
  Isolated nodes: 1444 (should be ALL)



In [6]:
# Initialize TransformerGNN model with SciBERT
# Key: Transformer is FROZEN, only GNN layers are trainable

model_name = 'allenai/scibert_scivocab_uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = TransformerGNN(
    model_name=model_name,
    gnn_type='gcn',           # Use GCN layers
    hidden_dim=768,           # Match SciBERT output
    num_gnn_layers=2,         # 2 GNN layers
    dropout=0.1,
    pooling='cls',
    freeze_transformer=True   # IMPORTANT: Freeze transformer!
).to(device)

print("\nModel Summary:")
print(f"  Transformer: {model_name}")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Trainable (GNN only): {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"  Frozen (Transformer): {sum(p.numel() for p in model.parameters() if not p.requires_grad):,}")

✓ Transformer weights frozen
✓ Created 2-layer GCN model
  Total GNN parameters: 1,181,184

Model Summary:
  Transformer: allenai/scibert_scivocab_uncased
  Total parameters: 111,102,720
  Trainable (GNN only): 1,184,256
  Frozen (Transformer): 109,918,464


In [None]:
# BASELINE DEMONSTRATION
# Compare transformer-only vs GNN embeddings for test nodes
# NOTE: get_node_texts() concatenates title + abstract for each node

print("="*70)
print("BASELINE DEMONSTRATION: GNN Degradation for Test Nodes")
print("="*70)

# Sample a few test nodes
import random
sample_test_nodes = random.sample(test_nodes, min(5, len(test_nodes)))

model.eval()
with torch.no_grad():
    for node_id in sample_test_nodes:
        # 1. Get transformer-only embedding
        # get_node_texts concatenates title + abstract
        text = get_node_texts([node_id], metadata_df)[0]
        inputs = tokenizer(
            [text],
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        ).to(device)
        
        transformer_emb = model.encode_text(inputs['input_ids'], inputs['attention_mask'])
        transformer_emb = transformer_emb.cpu().numpy()[0]
        
        # 2. Get GNN embedding (full pipeline)
        # First get all transformer embeddings (title + abstract for each)
        all_node_ids = sorted(node_to_idx.keys(), key=lambda x: node_to_idx[x])
        all_texts = get_node_texts(all_node_ids, metadata_df)
        
        # Encode all texts in batches
        all_embs = []
        batch_size = 32
        for i in range(0, len(all_texts), batch_size):
            batch_texts = all_texts[i:i+batch_size]
            batch_inputs = tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            ).to(device)
            batch_embs = model.encode_text(batch_inputs['input_ids'], batch_inputs['attention_mask'])
            all_embs.append(batch_embs)
        
        x = torch.cat(all_embs, dim=0)
        
        # Apply GNN
        gnn_embs = model(x, edge_index.to(device))
        
        # Get embedding for this test node
        node_idx = node_to_idx[node_id]
        gnn_emb = gnn_embs[node_idx].cpu().numpy()
        
        # 3. Compare
        cosine_sim = (transformer_emb * gnn_emb).sum()
        l2_diff = np.sqrt(((transformer_emb - gnn_emb) ** 2).sum())
        
        print(f"\nTest Node {node_id}:")
        print(f"  Cosine similarity: {cosine_sim:.6f} (1.0 = identical)")
        print(f"  L2 difference: {l2_diff:.8f} (0.0 = identical)")
        
        if cosine_sim > 0.999:
            print(f"  ✓ GNN = Transformer (as expected for nodes with no edges)")
        else:
            print(f"  ! Embeddings differ (unexpected)")

print("\n" + "="*70)
print("KEY INSIGHT:")
print("Test nodes have NO edges in the induced subgraph,")
print("so GNN provides NO benefit - it's just the transformer!")
print("="*70)

BASELINE DEMONSTRATION: GNN Degradation for Test Nodes

Test Node 1499677:
  Cosine similarity: 0.593490 (1.0 = identical)
  L2 difference: 0.90167600 (0.0 = identical)
  ! Embeddings differ (unexpected)
