In [1]:
import os
import sys
from dotenv import load_dotenv

# Load the .env file
load_dotenv()
sys.path.insert(1, os.getenv("PROJECT_ROOT"))
os.environ['HF_HOME'] = os.getenv("HF_CACHE")

from transformers import AutoTokenizer, AutoModel
import torch
import pickle
from tqdm import tqdm
import torch.nn as nn
import dgl
from dgl import heterograph
from collections import defaultdict
import numpy as np
import networkx as nx
from pathlib import Path

# Load SciBERT model
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", local_files_only=True)
model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", local_files_only=True).to(device)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_scibert_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    cls_embedding = outputs.last_hidden_state[:, 0, :]  # CLS token
    return cls_embedding  # shape: (1, hidden_size)


paper_records = pickle.load(open('data/teacher_graph/records/paper_records.pkl', "rb"))
author_records = pickle.load(open('data/teacher_graph/records/author_records.pkl', "rb"))
fields_of_study = set()
affiliations = set()

all_ids = set(paper_records.keys())
student_paper_ids = pickle.load(open('data/teacher_graph/paper_ids.pkl', "rb"))
eval_ids = pickle.load(open('data/teacher_graph/eval_ids.pickle', "rb"))
test_ids = pickle.load(open('data/teacher_graph/test_ids.pickle', "rb"))
extra_train_ids = pickle.load(open('data/teacher_graph/extra_train_ids.pickle', "rb"))

# create a combined set of student paper ids and extra train ids
train_ids = set(student_paper_ids).union(extra_train_ids)

# combined_ids = set(train_ids).union(eval_ids)

train_ids_nums = {i:idx for idx, i in enumerate(train_ids)}
train_nodes = [v for k,v in train_ids_nums.items()]

eval_ids_nums = {i:idx for idx, i in enumerate(eval_ids)}
eval_nodes = [v for k,v in eval_ids_nums.items()]

# for author_id, author in tqdm(author_records.items()):
#     affiliations.update(author['affiliations'])

# for s2_id, paper in tqdm(paper_records.items()):
#     fields_of_study.update(paper['field_ids'])



train_graph = nx.DiGraph()

# add nodes for all papers
train_graph.add_nodes_from(train_nodes)

# add edges for all papers
edges = []

for s2_id in tqdm(train_ids):
    paper = paper_records[s2_id]
    for ref_id in paper['reference_ids']:
        if ref_id in train_ids:
            if ref_id == None:
                continue
            edges.append((train_ids_nums[s2_id], train_ids_nums[ref_id])) # paper -> reference
   
train_graph.add_edges_from(edges)
train_feats = torch.zeros((len(train_ids), 768))


eval_edges = []
for s2_id in tqdm(eval_ids):
    paper = paper_records[s2_id]
    for ref_id in paper['reference_ids']:
        if ref_id in train_ids:
            if ref_id == None:
                continue
            eval_edges.append((eval_ids_nums[s2_id], train_ids_nums[ref_id])) # paper -> reference


# Compute and assign SciBERT embeddings for paper nodes
for local_idx, global_id in tqdm(enumerate(train_ids)):
    pdata = paper_records[global_id]
    if pdata['abstract'] is None:
        abstract = ""
    else:
        abstract = pdata['abstract']
    text = pdata['title'] + "\n" + abstract + "\n"
    text += f"This paper was published in {pdata['venue']} in {pdata['year']}. It has {len(pdata["author_ids"])} authors and {pdata['referenceCount']} references. It has {pdata["citationCount"]} citations and {pdata["influentialCitationCount"]} influential citations."
    emb = get_scibert_embedding(text)
    train_feats[local_idx] = emb

# 1. Prepare eval node SciBERT embeddings (once)
eval_feats = torch.zeros(len(eval_ids), 768)
for local_idx, global_id in tqdm(enumerate(eval_ids)):
    pdata = paper_records[global_id]
    if pdata['abstract'] is None:
        abstract = ""
    else:
        abstract = pdata['abstract']
    text = pdata['title'] + "\n" + abstract + "\n"
    emb = get_scibert_embedding(text)
    eval_feats[local_idx] = emb

# train the model using the graph attention network
import torch.nn.functional as F
from torch_geometric.nn import GATConv

train_feats = train_feats.to(device)
eval_feats = eval_feats.to(device)

combined_feats = torch.cat((train_feats, eval_feats), dim=0)
combined_feats = combined_feats.to(device)

100%|██████████| 11528/11528 [00:00<00:00, 40314.18it/s]
100%|██████████| 1500/1500 [00:00<00:00, 75110.20it/s]
11528it [01:26, 133.05it/s]
1500it [00:10, 139.35it/s]


In [3]:
eval_graph = nx.DiGraph()
# add nodes for all papers
eval_graph.add_nodes_from(eval_nodes)
# add edges for all papers
eval_graph.add_edges_from([])

In [4]:
def compute_loss(pos_score, neg_score):
    pos_loss = -F.logsigmoid(pos_score).mean()
    neg_loss = -F.logsigmoid(-neg_score).mean()
    return pos_loss + neg_loss

In [5]:
def calculate_recall_at_k_fixed(eval_ids, predicted_indices, k):
    """
    Calculate recall@k for evaluation nodes, mapping indices back to global IDs.
    Args:
        eval_ids: List of evaluation node IDs (global IDs)
        predicted_indices: Tensor of predicted train node indices for each eval node
        k: Number of top predictions to consider
        train_ids_list: List of train node IDs (global IDs) in the same order as used in the model
    Returns:
        recall: Recall@k value
    """
    relevant_count = 0
    total_relevant = 0
    
    for eval_idx, eval_id in enumerate(eval_ids):
        paper = paper_records[eval_id]
        true_references = set(ref for ref in paper['reference_ids'] if ref in train_ids)
        true_references = set([train_ids_nums[i] for i in true_references])
        total_relevant += len(true_references)
        
        # Map predicted indices back to global IDs
        predicted_refs = predicted_indices[eval_idx, :k].cpu().tolist()
        # predicted_refs = set(train_ids_list[idx] for idx in pred_indices)
        
        relevant_count += len(true_references.intersection(predicted_refs))
    
    if total_relevant == 0:
        return 0.0
    
    return relevant_count / total_relevant

in_dim = train_feats.shape[1]
hidden_dim = 128
num_heads = 4

class DotLinkPredictor(nn.Module):
    def forward(self, h, src_idx, dst_idx):
        return (h[src_idx] * h[dst_idx]).sum(dim=-1)
    
    # Alternative implementation that avoids the in-place operation issue
class GATLinkPredictorFixed(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_heads):
        super(GATLinkPredictorFixed, self).__init__()
        self.gat1 = GATConv(in_dim, hidden_dim, num_heads)
        self.gat2 = GATConv(hidden_dim * num_heads, hidden_dim, 1)
        
    # def encode_without_graph(self, features):
    #     """Process nodes without graph structure (for evaluation)"""
    #     h = self.gat1(features, edge_index=None)
    #     h = F.elu(h.flatten(1))
    #     h = self.gat2(h, edge_index=None).squeeze(1)
    #     return h
    
    def forward(self, g, features):
        # For NetworkX graph compatibility
        if isinstance(g, nx.DiGraph):
            edge_index = torch.tensor(list(g.edges())).t().to(device)
            # create empty tensor of size (2, num_edges)
            if edge_index.shape[0] == 0:
                edge_index = torch.zeros((2, len(eval_edges)), dtype=torch.long).to(device)
            h = self.gat1(features, edge_index)
            h = F.elu(h.flatten(1))
            h = self.gat2(h, edge_index).squeeze(1)
        else:
            # Original implementation for other graph types
            h = self.gat1(g, features)
            h = F.elu(h.flatten(1))
            h = self.gat2(g, h).squeeze(1)
        return h

# Fixed training loop
model_fixed = GATLinkPredictorFixed(in_dim, hidden_dim, num_heads).to(device)
predictor = DotLinkPredictor()
optimizer = torch.optim.Adam(list(model_fixed.parameters()) + list(predictor.parameters()), lr=1e-3)

for epoch in range(40):
    model_fixed.train()
    
    # Get edges for training
    edge_list = list(train_graph.edges())
    src, dst = zip(*edge_list)
    src = torch.tensor(src, dtype=torch.long).to(device)
    dst = torch.tensor(dst, dtype=torch.long).to(device)
    n_edges = src.shape[0]
    
    # Process all edges in a single batch with multiple epochs
    # This prevents in-place modification issues
    batch_losses = []
    
    # Process in smaller sub-epochs
    for sub_epoch in range(5):  # 5 sub-epochs per epoch
        perm = torch.randperm(n_edges)
        
        for i in tqdm(range(0, n_edges, 1024)):
            # Get embeddings for all nodes
            h = model_fixed(train_graph, train_feats)
            
            batch_src = src[perm[i:i+1024]]
            batch_dst = dst[perm[i:i+1024]]
            
            # Negative sampling
            neg_dst = torch.randint(0, h.shape[0], batch_dst.shape, dtype=torch.long).to(device)
            
            pos_score = predictor(h, batch_src, batch_dst)
            neg_score = predictor(h, batch_src, neg_dst)
            
            loss = compute_loss(pos_score, neg_score)
            batch_losses.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()  # No retain_graph needed
            optimizer.step()
    
    avg_loss = sum(batch_losses) / len(batch_losses)
    print(f"Epoch {epoch} | Average Loss: {avg_loss:.4f}")
    
    # Evaluation
    model_fixed.eval()
    with torch.no_grad():
        # Get train node embeddings
        train_embs = model_fixed(train_graph, train_feats)
        
        # Get eval node embeddings without graph structure
        eval_embs = model_fixed(eval_graph, eval_feats)
        
        # Compute similarity scores
        scores = torch.matmul(eval_embs, train_embs.T)
        
        # Get top-k predictions
        topk = torch.topk(scores, k=20, dim=1)
        predicted_indices = topk.indices
        
        # Calculate recall@k with the fixed function
        recall_at_10 = calculate_recall_at_k_fixed(eval_ids, predicted_indices, 10)
        recall_at_20 = calculate_recall_at_k_fixed(eval_ids, predicted_indices, 20)
        
        print(f"Recall@10: {recall_at_10:.4f} | Recall@20: {recall_at_20:.4f}")

100%|██████████| 63/63 [00:05<00:00, 11.86it/s]
100%|██████████| 63/63 [00:04<00:00, 14.67it/s]
100%|██████████| 63/63 [00:04<00:00, 15.52it/s]
100%|██████████| 63/63 [00:03<00:00, 17.57it/s]
100%|██████████| 63/63 [00:03<00:00, 17.87it/s]


Epoch 0 | Average Loss: 1.9430
Recall@10: 0.0012 | Recall@20: 0.0022


100%|██████████| 63/63 [00:03<00:00, 20.44it/s]
100%|██████████| 63/63 [00:03<00:00, 17.86it/s]
100%|██████████| 63/63 [00:04<00:00, 14.79it/s]
100%|██████████| 63/63 [00:03<00:00, 18.12it/s]
100%|██████████| 63/63 [00:03<00:00, 20.56it/s]


Epoch 1 | Average Loss: 1.0357
Recall@10: 0.0016 | Recall@20: 0.0035


100%|██████████| 63/63 [00:03<00:00, 16.44it/s]
100%|██████████| 63/63 [00:03<00:00, 18.79it/s]
100%|██████████| 63/63 [00:03<00:00, 18.82it/s]
100%|██████████| 63/63 [00:03<00:00, 18.38it/s]
100%|██████████| 63/63 [00:04<00:00, 12.74it/s]


Epoch 2 | Average Loss: 1.0149
Recall@10: 0.0036 | Recall@20: 0.0066


100%|██████████| 63/63 [00:05<00:00, 12.15it/s]
100%|██████████| 63/63 [00:04<00:00, 15.51it/s]
100%|██████████| 63/63 [00:04<00:00, 15.18it/s]
100%|██████████| 63/63 [00:03<00:00, 15.89it/s]
100%|██████████| 63/63 [00:03<00:00, 17.47it/s]


Epoch 3 | Average Loss: 1.0028
Recall@10: 0.0041 | Recall@20: 0.0085


100%|██████████| 63/63 [00:03<00:00, 20.65it/s]
100%|██████████| 63/63 [00:04<00:00, 14.06it/s]
100%|██████████| 63/63 [00:03<00:00, 18.17it/s]
100%|██████████| 63/63 [00:03<00:00, 18.26it/s]
100%|██████████| 63/63 [00:03<00:00, 20.04it/s]


Epoch 4 | Average Loss: 0.9905
Recall@10: 0.0050 | Recall@20: 0.0100


100%|██████████| 63/63 [00:03<00:00, 18.40it/s]
100%|██████████| 63/63 [00:03<00:00, 18.83it/s]
100%|██████████| 63/63 [00:03<00:00, 18.86it/s]
100%|██████████| 63/63 [00:03<00:00, 17.08it/s]
100%|██████████| 63/63 [00:05<00:00, 12.33it/s]


Epoch 5 | Average Loss: 0.9835
Recall@10: 0.0074 | Recall@20: 0.0148


100%|██████████| 63/63 [00:04<00:00, 14.75it/s]
100%|██████████| 63/63 [00:04<00:00, 13.50it/s]
100%|██████████| 63/63 [00:04<00:00, 15.46it/s]
100%|██████████| 63/63 [00:03<00:00, 16.32it/s]
100%|██████████| 63/63 [00:03<00:00, 20.26it/s]


Epoch 6 | Average Loss: 0.9695
Recall@10: 0.0109 | Recall@20: 0.0218


100%|██████████| 63/63 [00:03<00:00, 15.81it/s]
100%|██████████| 63/63 [00:03<00:00, 17.25it/s]
100%|██████████| 63/63 [00:03<00:00, 18.68it/s]
100%|██████████| 63/63 [00:03<00:00, 20.96it/s]
100%|██████████| 63/63 [00:03<00:00, 18.79it/s]


Epoch 7 | Average Loss: 0.9668
Recall@10: 0.0101 | Recall@20: 0.0193


100%|██████████| 63/63 [00:03<00:00, 19.16it/s]
100%|██████████| 63/63 [00:03<00:00, 19.62it/s]
100%|██████████| 63/63 [00:03<00:00, 17.14it/s]
100%|██████████| 63/63 [00:03<00:00, 18.38it/s]
100%|██████████| 63/63 [00:03<00:00, 17.42it/s]


Epoch 8 | Average Loss: 0.9558
Recall@10: 0.0163 | Recall@20: 0.0299


100%|██████████| 63/63 [00:04<00:00, 13.29it/s]
100%|██████████| 63/63 [00:04<00:00, 14.38it/s]
100%|██████████| 63/63 [00:04<00:00, 14.18it/s]
100%|██████████| 63/63 [00:03<00:00, 18.48it/s]
100%|██████████| 63/63 [00:03<00:00, 17.70it/s]


Epoch 9 | Average Loss: 0.9526
Recall@10: 0.0498 | Recall@20: 0.0801


100%|██████████| 63/63 [00:05<00:00, 10.77it/s]
100%|██████████| 63/63 [00:04<00:00, 15.42it/s]
100%|██████████| 63/63 [00:03<00:00, 17.63it/s]
100%|██████████| 63/63 [00:03<00:00, 20.14it/s]
100%|██████████| 63/63 [00:03<00:00, 18.19it/s]


Epoch 10 | Average Loss: 0.9458
Recall@10: 0.0539 | Recall@20: 0.0825


100%|██████████| 63/63 [00:03<00:00, 19.90it/s]
100%|██████████| 63/63 [00:05<00:00, 11.39it/s]
100%|██████████| 63/63 [00:05<00:00, 11.84it/s]
100%|██████████| 63/63 [00:03<00:00, 16.12it/s]
100%|██████████| 63/63 [00:03<00:00, 17.49it/s]


Epoch 11 | Average Loss: 0.9389
Recall@10: 0.0470 | Recall@20: 0.0722


100%|██████████| 63/63 [00:04<00:00, 12.99it/s]
100%|██████████| 63/63 [00:04<00:00, 15.39it/s]
100%|██████████| 63/63 [00:03<00:00, 16.70it/s]
100%|██████████| 63/63 [00:04<00:00, 15.35it/s]
100%|██████████| 63/63 [00:04<00:00, 14.60it/s]


Epoch 12 | Average Loss: 0.9306
Recall@10: 0.0385 | Recall@20: 0.0581


100%|██████████| 63/63 [00:03<00:00, 17.01it/s]
100%|██████████| 63/63 [00:03<00:00, 19.49it/s]
100%|██████████| 63/63 [00:03<00:00, 19.39it/s]
100%|██████████| 63/63 [00:03<00:00, 16.90it/s]
100%|██████████| 63/63 [00:03<00:00, 18.46it/s]


Epoch 13 | Average Loss: 0.9325
Recall@10: 0.0536 | Recall@20: 0.0828


100%|██████████| 63/63 [00:03<00:00, 20.15it/s]
100%|██████████| 63/63 [00:04<00:00, 13.49it/s]
100%|██████████| 63/63 [00:04<00:00, 14.94it/s]
100%|██████████| 63/63 [00:04<00:00, 15.01it/s]
100%|██████████| 63/63 [00:04<00:00, 12.83it/s]


Epoch 14 | Average Loss: 0.9239
Recall@10: 0.0589 | Recall@20: 0.0876


100%|██████████| 63/63 [00:04<00:00, 14.63it/s]
100%|██████████| 63/63 [00:04<00:00, 15.45it/s]
100%|██████████| 63/63 [00:03<00:00, 17.76it/s]
100%|██████████| 63/63 [00:03<00:00, 17.57it/s]
100%|██████████| 63/63 [00:05<00:00, 11.55it/s]


Epoch 15 | Average Loss: 0.9250
Recall@10: 0.0655 | Recall@20: 0.0931


100%|██████████| 63/63 [00:03<00:00, 17.54it/s]
100%|██████████| 63/63 [00:03<00:00, 20.19it/s]
100%|██████████| 63/63 [00:03<00:00, 18.88it/s]
100%|██████████| 63/63 [00:03<00:00, 15.78it/s]
100%|██████████| 63/63 [00:04<00:00, 14.25it/s]


Epoch 16 | Average Loss: 0.9223
Recall@10: 0.0547 | Recall@20: 0.0774


100%|██████████| 63/63 [00:04<00:00, 12.65it/s]
100%|██████████| 63/63 [00:05<00:00, 11.99it/s]
100%|██████████| 63/63 [00:04<00:00, 13.32it/s]
100%|██████████| 63/63 [00:04<00:00, 12.84it/s]
100%|██████████| 63/63 [00:04<00:00, 12.90it/s]


Epoch 17 | Average Loss: 0.9171
Recall@10: 0.0522 | Recall@20: 0.0757


100%|██████████| 63/63 [00:03<00:00, 19.56it/s]
100%|██████████| 63/63 [00:05<00:00, 11.00it/s]
100%|██████████| 63/63 [00:04<00:00, 14.77it/s]
100%|██████████| 63/63 [00:03<00:00, 16.31it/s]
100%|██████████| 63/63 [00:03<00:00, 18.44it/s]


Epoch 18 | Average Loss: 0.9106
Recall@10: 0.0664 | Recall@20: 0.0991


100%|██████████| 63/63 [00:03<00:00, 18.47it/s]
100%|██████████| 63/63 [00:04<00:00, 14.87it/s]
100%|██████████| 63/63 [00:06<00:00,  9.92it/s]
100%|██████████| 63/63 [00:05<00:00, 12.27it/s]
100%|██████████| 63/63 [00:05<00:00, 10.80it/s]


Epoch 19 | Average Loss: 0.9124
Recall@10: 0.0671 | Recall@20: 0.0979


100%|██████████| 63/63 [00:04<00:00, 15.54it/s]
100%|██████████| 63/63 [00:03<00:00, 19.29it/s]
100%|██████████| 63/63 [00:05<00:00, 11.08it/s]
100%|██████████| 63/63 [00:04<00:00, 12.94it/s]
100%|██████████| 63/63 [00:03<00:00, 19.23it/s]


Epoch 20 | Average Loss: 0.9042
Recall@10: 0.0585 | Recall@20: 0.0876


100%|██████████| 63/63 [00:04<00:00, 15.45it/s]
100%|██████████| 63/63 [00:03<00:00, 19.19it/s]
100%|██████████| 63/63 [00:03<00:00, 19.66it/s]
100%|██████████| 63/63 [00:04<00:00, 12.61it/s]
100%|██████████| 63/63 [00:05<00:00, 11.78it/s]


Epoch 21 | Average Loss: 0.9047
Recall@10: 0.0541 | Recall@20: 0.0797


100%|██████████| 63/63 [00:05<00:00, 11.70it/s]
100%|██████████| 63/63 [00:04<00:00, 13.92it/s]
100%|██████████| 63/63 [00:03<00:00, 16.08it/s]
100%|██████████| 63/63 [00:03<00:00, 20.68it/s]
100%|██████████| 63/63 [00:03<00:00, 16.93it/s]


Epoch 22 | Average Loss: 0.9044
Recall@10: 0.0732 | Recall@20: 0.1058


100%|██████████| 63/63 [00:04<00:00, 14.72it/s]
100%|██████████| 63/63 [00:03<00:00, 16.36it/s]
100%|██████████| 63/63 [00:03<00:00, 19.78it/s]
100%|██████████| 63/63 [00:03<00:00, 18.74it/s]
100%|██████████| 63/63 [00:03<00:00, 16.39it/s]


Epoch 23 | Average Loss: 0.9040
Recall@10: 0.0532 | Recall@20: 0.0841


100%|██████████| 63/63 [00:04<00:00, 15.26it/s]
100%|██████████| 63/63 [00:04<00:00, 12.93it/s]
100%|██████████| 63/63 [00:06<00:00, 10.44it/s]
100%|██████████| 63/63 [00:06<00:00,  9.96it/s]
100%|██████████| 63/63 [00:03<00:00, 16.33it/s]


Epoch 24 | Average Loss: 0.8971
Recall@10: 0.0544 | Recall@20: 0.0782


100%|██████████| 63/63 [00:03<00:00, 16.24it/s]
100%|██████████| 63/63 [00:03<00:00, 19.05it/s]
100%|██████████| 63/63 [00:03<00:00, 18.26it/s]
100%|██████████| 63/63 [00:04<00:00, 14.96it/s]
100%|██████████| 63/63 [00:04<00:00, 15.70it/s]


Epoch 25 | Average Loss: 0.8952
Recall@10: 0.0673 | Recall@20: 0.0984


100%|██████████| 63/63 [00:03<00:00, 19.41it/s]
100%|██████████| 63/63 [00:03<00:00, 19.14it/s]
100%|██████████| 63/63 [00:03<00:00, 19.64it/s]
100%|██████████| 63/63 [00:04<00:00, 13.31it/s]
100%|██████████| 63/63 [00:04<00:00, 14.28it/s]


Epoch 26 | Average Loss: 0.8944
Recall@10: 0.0552 | Recall@20: 0.0782


100%|██████████| 63/63 [00:04<00:00, 12.86it/s]
100%|██████████| 63/63 [00:04<00:00, 13.80it/s]
100%|██████████| 63/63 [00:05<00:00, 11.82it/s]
100%|██████████| 63/63 [00:04<00:00, 13.92it/s]
100%|██████████| 63/63 [00:03<00:00, 19.73it/s]


Epoch 27 | Average Loss: 0.8942
Recall@10: 0.0692 | Recall@20: 0.1024


100%|██████████| 63/63 [00:05<00:00, 11.38it/s]
100%|██████████| 63/63 [00:03<00:00, 16.49it/s]
100%|██████████| 63/63 [00:03<00:00, 18.79it/s]
100%|██████████| 63/63 [00:03<00:00, 18.91it/s]
100%|██████████| 63/63 [00:03<00:00, 19.77it/s]


Epoch 28 | Average Loss: 0.8934
Recall@10: 0.0607 | Recall@20: 0.0852


100%|██████████| 63/63 [00:03<00:00, 19.19it/s]
100%|██████████| 63/63 [00:05<00:00, 12.31it/s]
100%|██████████| 63/63 [00:04<00:00, 12.79it/s]
100%|██████████| 63/63 [00:03<00:00, 17.46it/s]
100%|██████████| 63/63 [00:04<00:00, 14.90it/s]


Epoch 29 | Average Loss: 0.8879
Recall@10: 0.0601 | Recall@20: 0.0859


100%|██████████| 63/63 [00:04<00:00, 13.29it/s]
100%|██████████| 63/63 [00:03<00:00, 17.48it/s]
100%|██████████| 63/63 [00:03<00:00, 18.85it/s]
100%|██████████| 63/63 [00:03<00:00, 19.63it/s]
100%|██████████| 63/63 [00:03<00:00, 19.24it/s]


Epoch 30 | Average Loss: 0.8878
Recall@10: 0.0555 | Recall@20: 0.0824


100%|██████████| 63/63 [00:05<00:00, 10.60it/s]
100%|██████████| 63/63 [00:03<00:00, 15.95it/s]
100%|██████████| 63/63 [00:03<00:00, 18.36it/s]
100%|██████████| 63/63 [00:03<00:00, 19.61it/s]
100%|██████████| 63/63 [00:03<00:00, 17.91it/s]


Epoch 31 | Average Loss: 0.8870
Recall@10: 0.0637 | Recall@20: 0.0895


100%|██████████| 63/63 [00:03<00:00, 20.05it/s]
100%|██████████| 63/63 [00:04<00:00, 14.25it/s]
100%|██████████| 63/63 [00:05<00:00, 12.13it/s]
100%|██████████| 63/63 [00:05<00:00, 10.98it/s]
100%|██████████| 63/63 [00:04<00:00, 13.49it/s]


Epoch 32 | Average Loss: 0.8855
Recall@10: 0.0656 | Recall@20: 0.0907


100%|██████████| 63/63 [00:04<00:00, 12.92it/s]
100%|██████████| 63/63 [00:03<00:00, 18.47it/s]
100%|██████████| 63/63 [00:04<00:00, 13.82it/s]
100%|██████████| 63/63 [00:03<00:00, 17.00it/s]
100%|██████████| 63/63 [00:03<00:00, 15.90it/s]


Epoch 33 | Average Loss: 0.8849
Recall@10: 0.0519 | Recall@20: 0.0757


100%|██████████| 63/63 [00:03<00:00, 17.11it/s]
100%|██████████| 63/63 [00:03<00:00, 19.67it/s]
100%|██████████| 63/63 [00:03<00:00, 18.98it/s]
100%|██████████| 63/63 [00:03<00:00, 18.77it/s]
100%|██████████| 63/63 [00:03<00:00, 16.02it/s]


Epoch 34 | Average Loss: 0.8793
Recall@10: 0.0586 | Recall@20: 0.0826


100%|██████████| 63/63 [00:05<00:00, 12.04it/s]
100%|██████████| 63/63 [00:04<00:00, 15.58it/s]
100%|██████████| 63/63 [00:05<00:00, 11.30it/s]
100%|██████████| 63/63 [00:05<00:00, 11.92it/s]
100%|██████████| 63/63 [00:04<00:00, 14.12it/s]


Epoch 35 | Average Loss: 0.8804
Recall@10: 0.0658 | Recall@20: 0.0924


100%|██████████| 63/63 [00:03<00:00, 16.51it/s]
100%|██████████| 63/63 [00:04<00:00, 12.80it/s]
100%|██████████| 63/63 [00:04<00:00, 13.89it/s]
100%|██████████| 63/63 [00:03<00:00, 16.54it/s]
100%|██████████| 63/63 [00:03<00:00, 16.05it/s]


Epoch 36 | Average Loss: 0.8807
Recall@10: 0.0712 | Recall@20: 0.1024


100%|██████████| 63/63 [00:04<00:00, 15.49it/s]
100%|██████████| 63/63 [00:05<00:00, 12.51it/s]
100%|██████████| 63/63 [00:05<00:00, 11.58it/s]
100%|██████████| 63/63 [00:04<00:00, 15.72it/s]
100%|██████████| 63/63 [00:03<00:00, 16.65it/s]


Epoch 37 | Average Loss: 0.8802
Recall@10: 0.0639 | Recall@20: 0.0869


100%|██████████| 63/63 [00:04<00:00, 13.48it/s]
100%|██████████| 63/63 [00:04<00:00, 15.30it/s]
100%|██████████| 63/63 [00:03<00:00, 19.85it/s]
100%|██████████| 63/63 [00:04<00:00, 14.88it/s]
100%|██████████| 63/63 [00:04<00:00, 13.75it/s]


Epoch 38 | Average Loss: 0.8790
Recall@10: 0.0685 | Recall@20: 0.0965


100%|██████████| 63/63 [00:03<00:00, 16.22it/s]
100%|██████████| 63/63 [00:03<00:00, 20.52it/s]
100%|██████████| 63/63 [00:03<00:00, 20.27it/s]
100%|██████████| 63/63 [00:03<00:00, 16.52it/s]
100%|██████████| 63/63 [00:04<00:00, 14.49it/s]


Epoch 39 | Average Loss: 0.8769
Recall@10: 0.0552 | Recall@20: 0.0772


In [11]:
torch.save(model_fixed.state_dict(), 'model/gat_model_plus.pt')

In [12]:
model2 = GATLinkPredictorFixed(in_dim, hidden_dim, num_heads)
model2.load_state_dict(torch.load('model/gat_model_plus.pt'))
model2.to(device)  # move to GPU if needed
model2.eval()      # set to evaluation mode

  model2.load_state_dict(torch.load('model/gat_model_plus.pt'))


GATLinkPredictorFixed(
  (gat1): GATConv(768, 128, heads=4)
  (gat2): GATConv(512, 128, heads=1)
)