# Training GNN Models for Graph-Based Inference

## Setup and Imports

In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.loader import DataLoader
import json
import pandas as pd
import torch
from torch_geometric.data import Data
import numpy as np
import ast
import re
import random

In [2]:
INIT_DIM = 16
NUM_EPOCHS = 5

In [3]:
edge_path = '/home/student/FinalProject/PaperFeedback/Graphs/author_collaboration_edge_index'
node_path = '/home/student/FinalProject/PaperFeedback/Graphs/author_collaboration_node_index'

In [4]:
with open(node_path, 'r') as f:
    nodes = json.load(f)

with open(edge_path, 'r') as f:
    edges = json.load(f)



In [5]:
labels_df = pd.read_csv('/home/student/FinalProject/PaperFeedback/Datasets/acm_citation_network_v8_labeled.csv')

In [6]:
labels_df

Unnamed: 0.1,Unnamed: 0,index,title,authors,year,venue,references,abstract,id,clustered_labels
0,0,558ac6e0612c41e6b9d39eed,INFORMS Journal on Computing,,2014.0,INFORMS Journal on Computing,,,0,1
1,1,5390879920f70186a0d422b8,Pushout-complements and basic concepts of gram...,Yasuo Kawahara,1990.0,Theoretical Computer Science,,,1,2
2,2,5390879920f70186a0d422b6,Effective constructors the formal series of tr...,Symeon Bozapalidis,1990.0,Theoretical Computer Science,,,2,2
3,3,555aa9a345ce207198fe0ae8,The DataPaper: living in the virtual world,"Mark Green, Chris Shaw",1990.0,Graphics Interface 1990,,Virtual reality user interfaces are a new type...,3,0
4,4,5390879920f70186a0d422ab,Using program slicing in software maintenance,Keith Brian Gallagher,1990.0,Using program slicing in software maintenance,,,4,0
...,...,...,...,...,...,...,...,...,...,...
2381670,2381670,5590cf2f0cf2ce4b6f3a00cf,The QCP File Format and Media Types for Speech...,"R. Gellens, H. Garudadri",2003.0,The QCP File Format and Media Types for Speech...,,RFC 2658 specifies the streaming format for 3G...,2381670,0
2381671,2381671,558c06f4612c2ba45fe93996,Multicast Source Discovery Protocol (MSDP),"B. Fenner, D. Meyer",2003.0,Multicast Source Discovery Protocol (MSDP),,The Multicast Source Discovery Protocol (MSDP)...,2381671,3
2381672,2381672,5590d19b0cf237666fc28f94,RTP Control Protocol Extended Reports (RTCP XR),"T. Friedman, R. Caceres, A. Clark",2003.0,RTP Control Protocol Extended Reports (RTCP XR),,This document defines the Extended Report (XR)...,2381672,3
2381673,2381673,558d109b0cf2a2c70f68ccca,Uniform Resource Identifier (URI) Scheme and A...,E. Lear,2003.0,Uniform Resource Identifier (URI) Scheme and A...,,The Trivial File Transfer Protocol (TFTP) is a...,2381673,3


In [7]:
labels_df['clustered_labels'], unique_topics = pd.factorize(labels_df['clustered_labels'])

In [8]:
with open('id_map.json', 'r') as f:
    id_map = json.load(f)



## Dataset

In [None]:
id_map = dict(zip(labels_df['index'], labels_df['id']))

In [10]:
edge_index_proxy = []
for edge in edges.keys():
    edge_rep = edge.split(';')
    try:
        src = id_map[edge_rep[0]]
        dst = id_map[edge_rep[1]]
        edge_index_proxy.append((src, dst))
    except KeyError:
        continue
edge_index = torch.tensor(edge_index_proxy, dtype=torch.long).t().contiguous()
edge_index.shape

torch.Size([2, 64889322])

In [11]:
labels_dict = dict(zip(labels_df['id'], labels_df['clustered_labels']))
labels = []
for node in nodes:
    try:
        labels.append(labels_dict[id_map[node]])
    except KeyError:
        continue
labels = torch.tensor(labels, dtype=torch.long)
labels.shape

torch.Size([2381675])

In [12]:
reduced_embeddings = pd.read_csv('/home/student/FinalProject/PaperFeedback/Datasets/reduced_embeddings.csv')


In [13]:
reduced_embeddings = reduced_embeddings.drop(columns='Unnamed: 0')

In [14]:
reduced_embeddings

Unnamed: 0,PC1,PC2,PC3,PC4,PC5,PC6,PC7,PC8,PC9,PC10,PC11,PC12,PC13,PC14,PC15,PC16
0,0.244551,0.072160,0.040480,-0.008222,-0.188693,-0.173248,-0.141227,-0.243266,0.228278,-0.163842,0.185187,-0.166343,-0.082616,0.019520,0.002732,-0.056307
1,0.059597,-0.029985,-0.151846,0.005764,-0.181205,-0.143536,0.053723,0.141302,0.159624,-0.003362,-0.174040,-0.080417,0.069381,0.021649,-0.085207,0.057289
2,0.036556,0.048582,-0.309378,0.063445,-0.274076,-0.079415,0.078125,0.037037,-0.071647,-0.111977,-0.102856,0.069368,0.008462,0.129782,0.000145,0.071227
3,0.332041,-0.112594,0.124543,-0.164285,0.223807,-0.230389,0.174200,0.025000,0.010656,-0.050118,0.016155,-0.024879,-0.055583,0.054578,0.155759,0.019210
4,0.097950,0.262777,-0.237347,-0.148805,0.034486,0.053386,-0.088703,0.017467,-0.202946,0.027979,0.045043,0.120020,0.120238,-0.026877,-0.017690,-0.113315
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2381670,-0.085549,0.065734,0.233048,0.051984,0.076737,-0.153134,-0.270285,0.075621,0.097902,-0.040609,-0.217339,-0.117657,0.190760,0.093550,0.015645,-0.115214
2381671,-0.020302,0.188478,0.227094,0.223172,0.094980,0.004704,0.097351,0.057981,-0.004704,-0.114065,-0.002729,-0.074940,-0.075747,0.085331,-0.143944,-0.207621
2381672,0.009575,0.142867,0.119553,0.056473,0.033336,0.008027,-0.145779,0.186886,0.107563,-0.021122,-0.081155,-0.097481,0.132234,0.171594,-0.022116,-0.118640
2381673,0.046728,0.264669,0.181376,0.242812,-0.008225,-0.101865,-0.023464,0.222635,-0.089073,-0.060676,0.048192,-0.047393,-0.087395,-0.063408,0.035975,-0.045400


In [15]:

embeddings_np = reduced_embeddings.to_numpy()
X = torch.tensor(embeddings_np, dtype=torch.float32)
X.shape

torch.Size([2381675, 16])

In [16]:
data = Data(x=X, edge_index=edge_index, y=labels)

## Model

In [17]:
class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim_1, hidden_dim_2, out_dim, num_heads=3):
        super(GNN, self).__init__()
        self.conv1 = GATConv(input_dim, hidden_dim_1, heads=num_heads)
        self.conv2 = GCNConv(hidden_dim_1*num_heads, hidden_dim_2)
        self.w1 = torch.nn.Linear(hidden_dim_2, hidden_dim_2)
        self.w2 = torch.nn.Linear(hidden_dim_2, out_dim)
        
    def forward(self, data, return_hidden=False):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        hidden = self.w1(x)
        output = self.w2(hidden)
        if return_hidden:
            return output, hidden
        else:
            return output


In [18]:
input_dim = INIT_DIM 
hidden_dim_1 = 16
hidden_dim_2 = 8 
output_dim = labels.max().item() + 1  

# Instantiate the model
model = GNN(input_dim=input_dim, hidden_dim_1=hidden_dim_1, hidden_dim_2=hidden_dim_2, out_dim=output_dim)

# Use Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

# Loss function (cross-entropy loss for classification)
criterion = torch.nn.CrossEntropyLoss()

In [19]:
if torch.cuda.is_available():
    print('using GPU')
    device = 'cuda'
else:
    print('using CPU')
    device = 'cpu'

model = model.to(device=device)    
data = data.to(device=device)

using GPU


In [20]:
from torch_geometric.loader import NeighborLoader


# Assuming `data` is your single large graph
training_loader = NeighborLoader(data, num_neighbors=[50, 20], batch_size=4096, shuffle=True)
inference_loader = NeighborLoader(data, num_neighbors=[50, 20], batch_size=1024, shuffle=False)

In [21]:
from tqdm import tqdm

In [22]:
for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for batch_data in tqdm(training_loader): 
        optimizer.zero_grad()
        output = model(batch_data)
        pred = output[batch_data.batch].argmax(dim=2)
        correct += (pred == batch_data.y[batch_data.batch]).sum().item()
        total += batch_data.y.size(0)
        loss = criterion(output[batch_data.batch].squeeze(), batch_data.y[batch_data.batch].squeeze())
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    accuracy = correct / total
    print(f'Epoch: [{epoch}/{NUM_EPOCHS}], Loss: {total_loss/len(training_loader):.4f}, Accuracy: {accuracy * 100:.2f}%')

  0%|          | 0/582 [00:00<?, ?it/s]

100%|██████████| 582/582 [03:48<00:00,  2.55it/s]


Epoch: [1/5], Loss: 1.4868, Accuracy: 28.84%


100%|██████████| 582/582 [03:47<00:00,  2.56it/s]


Epoch: [2/5], Loss: 1.4862, Accuracy: 28.93%


100%|██████████| 582/582 [03:45<00:00,  2.58it/s]


Epoch: [3/5], Loss: 1.4861, Accuracy: 28.96%


100%|██████████| 582/582 [03:47<00:00,  2.56it/s]


Epoch: [4/5], Loss: 1.4860, Accuracy: 28.97%


100%|██████████| 582/582 [03:47<00:00,  2.56it/s]

Epoch: [5/5], Loss: 1.4860, Accuracy: 28.98%





## Inference

In [23]:
def set_seed(seed: int, deterministic: bool = False):
    """
    Sets the seed for reproducibility across PyTorch, NumPy, and Python's `random` module.
    
    Parameters:
    - seed (int): The seed value to set.
    - deterministic (bool): If True, sets PyTorch to deterministic mode for reproducibility at the cost of performance.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.benchmark = True

In [24]:
set_seed(42)

In [25]:
num_nodes = labels.shape[0]
global_embeddings = torch.zeros((num_nodes, hidden_dim_2))

for batch_data in tqdm(inference_loader):
    model.eval()
    
    with torch.no_grad():
        _, hiddens = model(batch_data, return_hidden=True)
        hiddens = hiddens.cpu()
        global_embeddings[batch_data.n_id] = hiddens



  0%|          | 2/2326 [00:00<02:16, 16.97it/s]

100%|██████████| 2326/2326 [02:49<00:00, 13.74it/s]


In [26]:
global_embeddings = global_embeddings.numpy()

In [27]:
# Step 1: Create a DataFrame with embeddings
df = pd.DataFrame(global_embeddings, columns=[f'dim_{i+1}' for i in range(global_embeddings.shape[1])])

# Step 2: Add the ID column at the beginning
df.insert(0, 'id', list(labels_df['id']))

In [28]:
df.to_csv('/home/student/FinalProject/PaperFeedback/Datasets/author_colab.csv')