In [None]:
# !pip install torch-geometric sentence-transformers tqdm

In [2]:
# =========================================
# 1. SETUP
# =========================================
import torch
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, Linear
from torch_geometric.utils import train_test_split_edges, negative_sampling
from tqdm import tqdm
import json
import numpy as np


In [11]:
%cd /content

/content


In [None]:
!mkdir -p data
%cd data

# Example: downloading all parts of train.txt.zip
!wget -c https://github.com/QianWangWPI/Released-Microsoft-dataset/raw/main/train.txt.zip.001
!wget -c https://github.com/QianWangWPI/Released-Microsoft-dataset/raw/main/train.txt.zip.002
!wget -c https://raw.githubusercontent.com/QianWangWPI/Released-Microsoft-dataset/refs/heads/main/test.txt
!wget -c https://raw.githubusercontent.com/QianWangWPI/Released-Microsoft-dataset/refs/heads/main/val.txt



In [14]:
!cat train.txt.zip.* > train.txt.zip
!unzip train.txt.zip

Archive:  train.txt.zip
  inflating: train.txt               


In [15]:
import json
import os

with open("train.txt", encoding="utf-8") as f:
    train_data = json.load(f)

with open("val.txt", encoding="utf-8") as f:
    val_data = json.load(f)

with open("test.txt", encoding="utf-8") as f:
    test_data = json.load(f)

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")
print(train_data[0])


Train: 42000, Val: 9000, Test: 9000
{'publication_ID': 17396995, 'Citations': '17957262;21818356;24164861;21818356;24164861;28586396;28688377', 'pubDate': '2007 May 1', 'language': 'eng', 'title': 'Herpes simplex virus type 2 infection does not influence viral dynamics during early HIV 1 infection', 'journal': 'The Journal of infectious diseases', 'abstract': 'We sought to compare baseline and longitudinal plasma HIV-1 loads between herpes simplex virus type 2 (HSV-2)-seropositive and -seronegative individuals who are enrolled in a primary HIV-1 infection cohort in San Diego, California.', 'keywords': 'Adult;California;epidemiology;Cohort Studies;HIV Infections;blood;complications;epidemiology;virology;HIV-1;pathogenicity;Herpes Simplex;blood;complications;epidemiology;virology;Herpesvirus 2, Human;pathogenicity;Humans;Longitudinal Studies;Male;Prevalence;Retrospective Studies;Viral Load', 'authors': [{'name': 'Edward R Cachay', 'org': 'University of California, San Diego, La Jolla, CA

In [16]:
train_data[0].keys()

dict_keys(['publication_ID', 'Citations', 'pubDate', 'language', 'title', 'journal', 'abstract', 'keywords', 'authors', 'venue', 'doi'])

In [20]:
import pandas as pd

def preprocess(data):
    df = pd.DataFrame(data)

    # Split citations into list
    df["Citations"] = df["Citations"].apply(lambda x: x.split(";") if isinstance(x, str) else [])

    df["title"] = df["title"].astype(str)
    df["abstract"] = df["abstract"].astype(str)
    df["keywords"] = df["keywords"].astype(str)

    # Combine title + abstract + keywords as paper text
    df["text"] = df["title"] + ". " + df["abstract"] + ". " + df["keywords"]

    # Extract authors as list of IDs
    df["author_ids"] = df["authors"].apply(lambda authors: [a.get('id', a.get('name')) for a in authors] if isinstance(authors, list) else [])


    # Extract keywords as lowercase list
    df["keyword_list"] = df["keywords"].apply(lambda x: [kw.strip().lower() for kw in x.split(';')] if isinstance(x, str) else [])

    return df[["publication_ID", "text", "Citations", "author_ids", "keyword_list"]]

train_df = preprocess(train_data)
val_df = preprocess(val_data)
test_df = preprocess(test_data)

train_df.head(2)

Unnamed: 0,publication_ID,text,Citations,author_ids,keyword_list
0,17396995,Herpes simplex virus type 2 infection does not...,"[17957262, 21818356, 24164861, 21818356, 24164...","[56335e9345cedb339a968f1e, 5628adc145ce1e5965f...","[adult, california, epidemiology, cohort studi..."
1,16779733,Efficacy of the anti Candida rAls3p N or rAls1...,"[19197361, 19399183, 20041174, 20300572, 17311...","[5633adb045cedb339ab453d6, 5631b0d845ceb49c5e2...","[animals, candida, immunology, isolation & pur..."


In [43]:
import itertools

all_df = pd.concat([train_df, val_df, test_df], ignore_index=True)

# All papers (in data or cited)
all_papers = set(all_df['publication_ID'].astype(str))
for cits in all_df['Citations']:
    all_papers.update(cits)
paper2idx = {pid: i for i, pid in enumerate(sorted(all_papers))}

# All authors
all_authors = set(itertools.chain.from_iterable(all_df['author_ids']))
# Filter out None values before sorting
all_authors = {aid for aid in all_authors if aid is not None}
author2idx = {aid: i for i, aid in enumerate(sorted(all_authors))}

# All keywords
all_keywords = set(itertools.chain.from_iterable(all_df['keyword_list']))
keyword2idx = {kw: i for i, kw in enumerate(sorted(all_keywords))}

print(f"Total papers={len(paper2idx)}, authors={len(author2idx)}, keywords={len(keyword2idx)}")

Total papers=424279, authors=307266, keywords=17207


In [44]:
from torch_geometric.data import HeteroData
import torch

data = HeteroData()

# Paper node features
embedding_dim = 128
data['paper'].x = torch.randn(len(paper2idx), embedding_dim)

# Author node features (learnable embeddings)
data['author'].x = torch.randn(len(author2idx), embedding_dim)

# Keyword node features (learnable embeddings)
data['keyword'].x = torch.randn(len(keyword2idx), embedding_dim)


In [45]:
# Paper -> Paper citations
src, dst = [], []
for _, row in train_df.iterrows():
    pid = str(row['publication_ID'])
    for cited in row['Citations']:
        cited = str(cited)
        if cited in paper2idx:
            src.append(paper2idx[pid])
            dst.append(paper2idx[cited])
data['paper', 'cites', 'paper'].edge_index = torch.tensor([src, dst], dtype=torch.long)

# Paper -> Author and Author -> Paper
src_pa, dst_ap = [], []
for _, row in train_df.iterrows():
    pid = str(row['publication_ID'])
    for aid in row['author_ids']:
        if aid in author2idx:
            src_pa.append(paper2idx[pid])
            dst_ap.append(author2idx[aid])
data['paper', 'written_by', 'author'].edge_index = torch.tensor([src_pa, dst_ap], dtype=torch.long)
data['author', 'authored', 'paper'].edge_index = torch.tensor([dst_ap, src_pa], dtype=torch.long)

# Paper -> Keyword and Keyword -> Paper
src_pk, dst_kp = [], []
for _, row in train_df.iterrows():
    pid = str(row['publication_ID'])
    for kw in row['keyword_list']:
        if kw in keyword2idx:
            src_pk.append(paper2idx[pid])
            dst_kp.append(keyword2idx[kw])
data['paper', 'mentions', 'keyword'].edge_index = torch.tensor([src_pk, dst_kp], dtype=torch.long)
data['keyword', 'appears_in', 'paper'].edge_index = torch.tensor([dst_kp, src_pk], dtype=torch.long)


In [46]:
# from torch_geometric.utils import heterogeneous_random_walk

# Build edge_index_dict
edge_index_dict = {
    ('paper', 'cites', 'paper'): data['paper', 'cites', 'paper'].edge_index,
    ('paper', 'written_by', 'author'): data['paper', 'written_by', 'author'].edge_index,
    ('author', 'authored', 'paper'): data['author', 'authored', 'paper'].edge_index,
    ('paper', 'mentions', 'keyword'): data['paper', 'mentions', 'keyword'].edge_index,
    ('keyword', 'appears_in', 'paper'): data['keyword', 'appears_in', 'paper'].edge_index
}

node_type_dict = {
    'paper': torch.arange(data['paper'].num_nodes),
    'author': torch.arange(data['author'].num_nodes),
    'keyword': torch.arange(data['keyword'].num_nodes)
}


In [47]:
from torch_geometric.nn import HeteroConv, GATConv, Linear
import torch.nn.functional as F
import torch.nn as nn

class HeteroGNN(nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()

        self.conv1 = HeteroConv({
            ('paper', 'cites', 'paper'): GATConv(-1, hidden_channels, add_self_loops=True),
            ('paper', 'written_by', 'author'): GATConv(-1, hidden_channels, add_self_loops=False),
            ('author', 'authored', 'paper'): GATConv(-1, hidden_channels, add_self_loops=False),
            ('paper', 'mentions', 'keyword'): GATConv(-1, hidden_channels, add_self_loops=False),
            ('keyword', 'appears_in', 'paper'): GATConv(-1, hidden_channels, add_self_loops=False),
        }, aggr='sum')

        self.lin = nn.Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = {key: self.lin(x) for key, x in x_dict.items()}
        return x_dict

In [48]:
model = HeteroGNN(hidden_channels=64, out_channels=64)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print(model)


HeteroGNN(
  (conv1): HeteroConv(num_relations=5)
  (lin): Linear(in_features=64, out_features=64, bias=True)
)


In [49]:
from torch_geometric.utils import negative_sampling

def loss_fn(emb_src, emb_dst, pos_edges, neg_edges):
    pos_score = (emb_src[pos_edges[0]] * emb_dst[pos_edges[1]]).sum(dim=1)
    neg_score = (emb_src[neg_edges[0]] * emb_dst[neg_edges[1]]).sum(dim=1)
    labels = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
    scores = torch.cat([pos_score, neg_score])
    return F.binary_cross_entropy_with_logits(scores, labels)


In [51]:
edge_types = [
    ('paper', 'cites', 'paper'),
    ('paper', 'written_by', 'author'),
    ('author', 'authored', 'paper'),
    ('paper', 'mentions', 'keyword'),
    ('keyword', 'appears_in', 'paper')
]

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 50

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()

    out_dict = model(data.x_dict, edge_index_dict)
    total_loss = 0

    for etype in edge_types:
        src, rel, dst = etype
        pos_edges = data[etype].edge_index
        neg_edges = negative_sampling(
            edge_index=pos_edges,
            num_nodes=(data[src].num_nodes, data[dst].num_nodes),
            num_neg_samples=pos_edges.size(1)
        )

        emb_src = out_dict[src]
        emb_dst = out_dict[dst]

        total_loss += loss_fn(emb_src, emb_dst, pos_edges, neg_edges)

    total_loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1} | Loss: {total_loss.item():.4f}")


Epoch 1 | Loss: 9.1004
Epoch 2 | Loss: 7.6939
Epoch 3 | Loss: 6.6027
Epoch 4 | Loss: 5.7752
Epoch 5 | Loss: 5.1711
Epoch 6 | Loss: 4.7281
Epoch 7 | Loss: 4.4031
Epoch 8 | Loss: 4.1595
Epoch 9 | Loss: 3.9678
Epoch 10 | Loss: 3.8108
Epoch 11 | Loss: 3.6784
Epoch 12 | Loss: 3.5604
Epoch 13 | Loss: 3.4543
Epoch 14 | Loss: 3.3636
Epoch 15 | Loss: 3.2823
Epoch 16 | Loss: 3.2129
Epoch 17 | Loss: 3.1491
Epoch 18 | Loss: 3.0932
Epoch 19 | Loss: 3.0444
Epoch 20 | Loss: 2.9982
Epoch 21 | Loss: 2.9566
Epoch 22 | Loss: 2.9164
Epoch 23 | Loss: 2.8769
Epoch 24 | Loss: 2.8402
Epoch 25 | Loss: 2.8053
Epoch 26 | Loss: 2.7722
Epoch 27 | Loss: 2.7412
Epoch 28 | Loss: 2.7112
Epoch 29 | Loss: 2.6835
Epoch 30 | Loss: 2.6542
Epoch 31 | Loss: 2.6328
Epoch 32 | Loss: 2.6095
Epoch 33 | Loss: 2.5877
Epoch 34 | Loss: 2.5665
Epoch 35 | Loss: 2.5473
Epoch 36 | Loss: 2.5291
Epoch 37 | Loss: 2.5088
Epoch 38 | Loss: 2.4933
Epoch 39 | Loss: 2.4763
Epoch 40 | Loss: 2.4604
Epoch 41 | Loss: 2.4439
Epoch 42 | Loss: 2.4271
E

In [52]:
torch.save(model.state_dict(), "hetero_gnn_checkpoint.pt")

In [93]:
from google.colab import drive
import torch
import os

# Mount Google Drive
drive.mount('/content/drive')

checkpoint_dir = '/content/drive/MyDrive/mag_citation_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Save model
checkpoint_path = os.path.join(checkpoint_dir, 'hetero_gnn_checkpoint.pt')
torch.save(model.state_dict(), checkpoint_path)
print(f"Model saved at {checkpoint_path}")

# Load:
# model.load_state_dict(torch.load(checkpoint_path))
# model.to(device)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Model saved at /content/drive/MyDrive/mag_citation_checkpoints/hetero_gnn_checkpoint.pt
