In [1]:
import os
import sys
from dotenv import load_dotenv

# Load the .env file
load_dotenv()
sys.path.insert(1, os.getenv("PROJECT_ROOT"))
os.environ['HF_HOME'] = os.getenv("HF_CACHE")

from transformers import AutoTokenizer, AutoModel
import torch
import pickle
from tqdm import tqdm
import torch.nn as nn
import dgl
from dgl import heterograph
from collections import defaultdict
import numpy as np
import networkx as nx
from pathlib import Path

# Load SciBERT model
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", local_files_only=True)
model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", local_files_only=True).to(device)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_scibert_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    cls_embedding = outputs.last_hidden_state[:, 0, :]  # CLS token
    return cls_embedding  # shape: (1, hidden_size)


paper_records = pickle.load(open('data/teacher_graph/records/paper_records.pkl', "rb"))
author_records = pickle.load(open('data/teacher_graph/records/author_records.pkl', "rb"))
fields_of_study = set()
affiliations = set()

all_ids = set(paper_records.keys())
student_paper_ids = pickle.load(open('data/teacher_graph/paper_ids.pkl', "rb"))
eval_ids = pickle.load(open('data/teacher_graph/eval_ids.pickle', "rb"))
test_ids = pickle.load(open('data/teacher_graph/test_ids.pickle', "rb"))
extra_train_ids = pickle.load(open('data/teacher_graph/extra_train_ids.pickle', "rb"))

# create a combined set of student paper ids and extra train ids
train_ids = set(student_paper_ids).union(extra_train_ids)

# combined_ids = set(train_ids).union(eval_ids)

train_ids_nums = {i:idx for idx, i in enumerate(train_ids)}
train_nodes = [v for k,v in train_ids_nums.items()]

eval_ids_nums = {i:idx for idx, i in enumerate(eval_ids)}
eval_nodes = [v for k,v in eval_ids_nums.items()]

# for author_id, author in tqdm(author_records.items()):
#     affiliations.update(author['affiliations'])

# for s2_id, paper in tqdm(paper_records.items()):
#     fields_of_study.update(paper['field_ids'])



train_graph = nx.DiGraph()

# add nodes for all papers
train_graph.add_nodes_from(train_nodes)

# add edges for all papers
edges = []

for s2_id in tqdm(train_ids):
    paper = paper_records[s2_id]
    for ref_id in paper['reference_ids']:
        if ref_id in train_ids:
            if ref_id == None:
                continue
            edges.append((train_ids_nums[s2_id], train_ids_nums[ref_id])) # paper -> reference
   
train_graph.add_edges_from(edges)
train_feats = torch.zeros((len(train_ids), 768))


eval_edges = []
for s2_id in tqdm(eval_ids):
    paper = paper_records[s2_id]
    for ref_id in paper['reference_ids']:
        if ref_id in train_ids:
            if ref_id == None:
                continue
            eval_edges.append((eval_ids_nums[s2_id], train_ids_nums[ref_id])) # paper -> reference


# Compute and assign SciBERT embeddings for paper nodes
for local_idx, global_id in tqdm(enumerate(train_ids)):
    pdata = paper_records[global_id]
    if pdata['abstract'] is None:
        abstract = ""
    else:
        abstract = pdata['abstract']
    text = pdata['title'] + "\n" + abstract + "\n"
    # text += f"This paper was published in {pdata['venue']} in {pdata['year']}. It has {len(pdata["author_ids"])} authors and {pdata['referenceCount']} references. It has {pdata["citationCount"]} citations and {pdata["influentialCitationCount"]} influential citations."
    emb = get_scibert_embedding(text)
    train_feats[local_idx] = emb

# 1. Prepare eval node SciBERT embeddings (once)
eval_feats = torch.zeros(len(eval_ids), 768)
for local_idx, global_id in tqdm(enumerate(eval_ids)):
    pdata = paper_records[global_id]
    if pdata['abstract'] is None:
        abstract = ""
    else:
        abstract = pdata['abstract']
    text = pdata['title'] + "\n" + abstract + "\n"
    emb = get_scibert_embedding(text)
    eval_feats[local_idx] = emb

# train the model using the graph attention network
import torch.nn.functional as F
from torch_geometric.nn import GATConv

train_feats = train_feats.to(device)
eval_feats = eval_feats.to(device)

combined_feats = torch.cat((train_feats, eval_feats), dim=0)
combined_feats = combined_feats.to(device)

100%|██████████| 11528/11528 [00:00<00:00, 49234.71it/s]
100%|██████████| 1500/1500 [00:00<00:00, 26119.17it/s]
11528it [01:37, 118.50it/s]
1500it [00:10, 142.27it/s]


In [3]:
eval_graph = nx.DiGraph()
# add nodes for all papers
eval_graph.add_nodes_from(eval_nodes)
# add edges for all papers
eval_graph.add_edges_from([])

In [4]:
def compute_loss(pos_score, neg_score):
    pos_loss = -F.logsigmoid(pos_score).mean()
    neg_loss = -F.logsigmoid(-neg_score).mean()
    return pos_loss + neg_loss

In [5]:
def calculate_recall_at_k_fixed(eval_ids, predicted_indices, k):
    """
    Calculate recall@k for evaluation nodes, mapping indices back to global IDs.
    Args:
        eval_ids: List of evaluation node IDs (global IDs)
        predicted_indices: Tensor of predicted train node indices for each eval node
        k: Number of top predictions to consider
        train_ids_list: List of train node IDs (global IDs) in the same order as used in the model
    Returns:
        recall: Recall@k value
    """
    relevant_count = 0
    total_relevant = 0
    
    for eval_idx, eval_id in enumerate(eval_ids):
        paper = paper_records[eval_id]
        true_references = set(ref for ref in paper['reference_ids'] if ref in train_ids)
        true_references = set([train_ids_nums[i] for i in true_references])
        total_relevant += len(true_references)
        
        # Map predicted indices back to global IDs
        predicted_refs = predicted_indices[eval_idx, :k].cpu().tolist()
        # predicted_refs = set(train_ids_list[idx] for idx in pred_indices)
        
        relevant_count += len(true_references.intersection(predicted_refs))
    
    if total_relevant == 0:
        return 0.0
    
    return relevant_count / total_relevant

in_dim = train_feats.shape[1]
hidden_dim = 128
num_heads = 4

class DotLinkPredictor(nn.Module):
    def forward(self, h, src_idx, dst_idx):
        return (h[src_idx] * h[dst_idx]).sum(dim=-1)
    
    # Alternative implementation that avoids the in-place operation issue
class GATLinkPredictorFixed(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_heads):
        super(GATLinkPredictorFixed, self).__init__()
        self.gat1 = GATConv(in_dim, hidden_dim, num_heads)
        self.gat2 = GATConv(hidden_dim * num_heads, hidden_dim, 1)
        
    # def encode_without_graph(self, features):
    #     """Process nodes without graph structure (for evaluation)"""
    #     h = self.gat1(features, edge_index=None)
    #     h = F.elu(h.flatten(1))
    #     h = self.gat2(h, edge_index=None).squeeze(1)
    #     return h
    
    def forward(self, g, features):
        # For NetworkX graph compatibility
        if isinstance(g, nx.DiGraph):
            edge_index = torch.tensor(list(g.edges())).t().to(device)
            # create empty tensor of size (2, num_edges)
            if edge_index.shape[0] == 0:
                edge_index = torch.zeros((2, len(eval_edges)), dtype=torch.long).to(device)
            h = self.gat1(features, edge_index)
            h = F.elu(h.flatten(1))
            h = self.gat2(h, edge_index).squeeze(1)
        else:
            # Original implementation for other graph types
            h = self.gat1(g, features)
            h = F.elu(h.flatten(1))
            h = self.gat2(g, h).squeeze(1)
        return h

# Fixed training loop
model_fixed = GATLinkPredictorFixed(in_dim, hidden_dim, num_heads).to(device)
predictor = DotLinkPredictor()
optimizer = torch.optim.Adam(list(model_fixed.parameters()) + list(predictor.parameters()), lr=1e-3)

for epoch in range(50):
    model_fixed.train()
    
    # Get edges for training
    edge_list = list(train_graph.edges())
    src, dst = zip(*edge_list)
    src = torch.tensor(src, dtype=torch.long).to(device)
    dst = torch.tensor(dst, dtype=torch.long).to(device)
    n_edges = src.shape[0]
    
    # Process all edges in a single batch with multiple epochs
    # This prevents in-place modification issues
    batch_losses = []
    
    # Process in smaller sub-epochs
    for sub_epoch in range(5):  # 5 sub-epochs per epoch
        perm = torch.randperm(n_edges)
        
        for i in tqdm(range(0, n_edges, 1024)):
            # Get embeddings for all nodes
            h = model_fixed(train_graph, train_feats)
            
            batch_src = src[perm[i:i+1024]]
            batch_dst = dst[perm[i:i+1024]]
            
            # Negative sampling
            neg_dst = torch.randint(0, h.shape[0], batch_dst.shape, dtype=torch.long).to(device)
            
            pos_score = predictor(h, batch_src, batch_dst)
            neg_score = predictor(h, batch_src, neg_dst)
            
            loss = compute_loss(pos_score, neg_score)
            batch_losses.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()  # No retain_graph needed
            optimizer.step()
    
    avg_loss = sum(batch_losses) / len(batch_losses)
    print(f"Epoch {epoch} | Average Loss: {avg_loss:.4f}")
    
    # Evaluation
    model_fixed.eval()
    with torch.no_grad():
        # Get train node embeddings
        train_embs = model_fixed(train_graph, train_feats)
        
        # Get eval node embeddings without graph structure
        eval_embs = model_fixed(eval_graph, eval_feats)
        
        # Compute similarity scores
        scores = torch.matmul(eval_embs, train_embs.T)
        
        # Get top-k predictions
        topk = torch.topk(scores, k=20, dim=1)
        predicted_indices = topk.indices
        
        # Calculate recall@k with the fixed function
        recall_at_10 = calculate_recall_at_k_fixed(eval_ids, predicted_indices, 10)
        recall_at_20 = calculate_recall_at_k_fixed(eval_ids, predicted_indices, 20)
        
        print(f"Recall@10: {recall_at_10:.4f} | Recall@20: {recall_at_20:.4f}")

100%|██████████| 63/63 [00:03<00:00, 17.08it/s]
100%|██████████| 63/63 [00:03<00:00, 18.83it/s]
100%|██████████| 63/63 [00:03<00:00, 18.91it/s]
100%|██████████| 63/63 [00:03<00:00, 16.37it/s]
100%|██████████| 63/63 [00:03<00:00, 15.78it/s]


Epoch 0 | Average Loss: 1.8426
Recall@10: 0.0032 | Recall@20: 0.0053


100%|██████████| 63/63 [00:03<00:00, 16.36it/s]
100%|██████████| 63/63 [00:03<00:00, 19.22it/s]
100%|██████████| 63/63 [00:03<00:00, 18.44it/s]
100%|██████████| 63/63 [00:03<00:00, 16.82it/s]
100%|██████████| 63/63 [00:03<00:00, 18.01it/s]


Epoch 1 | Average Loss: 1.0305
Recall@10: 0.0042 | Recall@20: 0.0064


100%|██████████| 63/63 [00:03<00:00, 18.23it/s]
100%|██████████| 63/63 [00:03<00:00, 18.63it/s]
100%|██████████| 63/63 [00:03<00:00, 19.84it/s]
100%|██████████| 63/63 [00:03<00:00, 17.87it/s]
100%|██████████| 63/63 [00:03<00:00, 17.52it/s]


Epoch 2 | Average Loss: 1.0167
Recall@10: 0.0044 | Recall@20: 0.0078


100%|██████████| 63/63 [00:03<00:00, 16.82it/s]
100%|██████████| 63/63 [00:03<00:00, 16.35it/s]
100%|██████████| 63/63 [00:03<00:00, 19.75it/s]
100%|██████████| 63/63 [00:03<00:00, 17.67it/s]
100%|██████████| 63/63 [00:03<00:00, 19.29it/s]


Epoch 3 | Average Loss: 1.0126
Recall@10: 0.0045 | Recall@20: 0.0077


100%|██████████| 63/63 [00:03<00:00, 20.26it/s]
100%|██████████| 63/63 [00:03<00:00, 17.17it/s]
100%|██████████| 63/63 [00:03<00:00, 17.09it/s]
100%|██████████| 63/63 [00:03<00:00, 15.88it/s]
100%|██████████| 63/63 [00:04<00:00, 15.19it/s]


Epoch 4 | Average Loss: 1.0024
Recall@10: 0.0052 | Recall@20: 0.0110


100%|██████████| 63/63 [00:03<00:00, 17.84it/s]
100%|██████████| 63/63 [00:03<00:00, 18.58it/s]
100%|██████████| 63/63 [00:03<00:00, 18.90it/s]
100%|██████████| 63/63 [00:03<00:00, 18.54it/s]
100%|██████████| 63/63 [00:03<00:00, 19.30it/s]


Epoch 5 | Average Loss: 0.9926
Recall@10: 0.0059 | Recall@20: 0.0133


100%|██████████| 63/63 [00:03<00:00, 16.37it/s]
100%|██████████| 63/63 [00:03<00:00, 16.70it/s]
100%|██████████| 63/63 [00:04<00:00, 12.73it/s]
100%|██████████| 63/63 [00:03<00:00, 15.92it/s]
100%|██████████| 63/63 [00:03<00:00, 18.60it/s]


Epoch 6 | Average Loss: 0.9781
Recall@10: 0.0185 | Recall@20: 0.0378


100%|██████████| 63/63 [00:03<00:00, 20.85it/s]
100%|██████████| 63/63 [00:03<00:00, 17.12it/s]
100%|██████████| 63/63 [00:04<00:00, 15.14it/s]
100%|██████████| 63/63 [00:03<00:00, 16.24it/s]
100%|██████████| 63/63 [00:03<00:00, 16.98it/s]


Epoch 7 | Average Loss: 0.9669
Recall@10: 0.0344 | Recall@20: 0.0587


100%|██████████| 63/63 [00:03<00:00, 17.43it/s]
100%|██████████| 63/63 [00:04<00:00, 14.64it/s]
100%|██████████| 63/63 [00:03<00:00, 20.21it/s]
100%|██████████| 63/63 [00:03<00:00, 19.66it/s]
100%|██████████| 63/63 [00:03<00:00, 16.57it/s]


Epoch 8 | Average Loss: 0.9572
Recall@10: 0.0451 | Recall@20: 0.0770


100%|██████████| 63/63 [00:04<00:00, 14.96it/s]
100%|██████████| 63/63 [00:04<00:00, 13.24it/s]
100%|██████████| 63/63 [00:03<00:00, 18.17it/s]
100%|██████████| 63/63 [00:03<00:00, 18.80it/s]
100%|██████████| 63/63 [00:03<00:00, 20.29it/s]


Epoch 9 | Average Loss: 0.9521
Recall@10: 0.0531 | Recall@20: 0.0907


100%|██████████| 63/63 [00:03<00:00, 18.34it/s]
100%|██████████| 63/63 [00:05<00:00, 11.52it/s]
100%|██████████| 63/63 [00:04<00:00, 14.75it/s]
100%|██████████| 63/63 [00:03<00:00, 16.06it/s]
100%|██████████| 63/63 [00:04<00:00, 14.08it/s]


Epoch 10 | Average Loss: 0.9420
Recall@10: 0.0546 | Recall@20: 0.0864


100%|██████████| 63/63 [00:04<00:00, 13.50it/s]
100%|██████████| 63/63 [00:03<00:00, 18.70it/s]
100%|██████████| 63/63 [00:03<00:00, 20.64it/s]
100%|██████████| 63/63 [00:05<00:00, 11.19it/s]
100%|██████████| 63/63 [00:03<00:00, 17.84it/s]


Epoch 11 | Average Loss: 0.9354
Recall@10: 0.0645 | Recall@20: 0.1017


100%|██████████| 63/63 [00:03<00:00, 16.82it/s]
100%|██████████| 63/63 [00:03<00:00, 16.56it/s]
100%|██████████| 63/63 [00:03<00:00, 16.28it/s]
100%|██████████| 63/63 [00:03<00:00, 18.17it/s]
100%|██████████| 63/63 [00:03<00:00, 16.67it/s]


Epoch 12 | Average Loss: 0.9302
Recall@10: 0.0810 | Recall@20: 0.1259


100%|██████████| 63/63 [00:05<00:00, 10.51it/s]
100%|██████████| 63/63 [00:06<00:00, 10.06it/s]
100%|██████████| 63/63 [00:04<00:00, 15.16it/s]
100%|██████████| 63/63 [00:05<00:00, 11.50it/s]
100%|██████████| 63/63 [00:03<00:00, 17.50it/s]


Epoch 13 | Average Loss: 0.9221
Recall@10: 0.0852 | Recall@20: 0.1347


100%|██████████| 63/63 [00:03<00:00, 20.36it/s]
100%|██████████| 63/63 [00:04<00:00, 12.63it/s]
100%|██████████| 63/63 [00:04<00:00, 15.61it/s]
100%|██████████| 63/63 [00:03<00:00, 17.33it/s]
100%|██████████| 63/63 [00:03<00:00, 17.54it/s]


Epoch 14 | Average Loss: 0.9163
Recall@10: 0.0826 | Recall@20: 0.1287


100%|██████████| 63/63 [00:03<00:00, 16.76it/s]
100%|██████████| 63/63 [00:03<00:00, 16.64it/s]
100%|██████████| 63/63 [00:04<00:00, 15.23it/s]
100%|██████████| 63/63 [00:04<00:00, 12.67it/s]
100%|██████████| 63/63 [00:05<00:00, 11.53it/s]


Epoch 15 | Average Loss: 0.9149
Recall@10: 0.0902 | Recall@20: 0.1394


100%|██████████| 63/63 [00:03<00:00, 17.12it/s]
100%|██████████| 63/63 [00:03<00:00, 16.09it/s]
100%|██████████| 63/63 [00:05<00:00, 12.07it/s]
100%|██████████| 63/63 [00:03<00:00, 20.61it/s]
100%|██████████| 63/63 [00:05<00:00, 11.81it/s]


Epoch 16 | Average Loss: 0.9133
Recall@10: 0.0891 | Recall@20: 0.1396


100%|██████████| 63/63 [00:03<00:00, 17.22it/s]
100%|██████████| 63/63 [00:03<00:00, 19.96it/s]
100%|██████████| 63/63 [00:03<00:00, 20.31it/s]
100%|██████████| 63/63 [00:03<00:00, 18.12it/s]
100%|██████████| 63/63 [00:03<00:00, 19.37it/s]


Epoch 17 | Average Loss: 0.9061
Recall@10: 0.0890 | Recall@20: 0.1355


100%|██████████| 63/63 [00:05<00:00, 12.18it/s]
100%|██████████| 63/63 [00:04<00:00, 14.63it/s]
100%|██████████| 63/63 [00:03<00:00, 15.76it/s]
100%|██████████| 63/63 [00:05<00:00, 11.99it/s]
100%|██████████| 63/63 [00:05<00:00, 10.96it/s]


Epoch 18 | Average Loss: 0.9045
Recall@10: 0.0882 | Recall@20: 0.1307


100%|██████████| 63/63 [00:04<00:00, 14.13it/s]
100%|██████████| 63/63 [00:03<00:00, 18.76it/s]
100%|██████████| 63/63 [00:04<00:00, 13.17it/s]
100%|██████████| 63/63 [00:03<00:00, 18.06it/s]
100%|██████████| 63/63 [00:03<00:00, 17.43it/s]


Epoch 19 | Average Loss: 0.9010
Recall@10: 0.0879 | Recall@20: 0.1293


100%|██████████| 63/63 [00:03<00:00, 19.67it/s]
100%|██████████| 63/63 [00:03<00:00, 18.89it/s]
100%|██████████| 63/63 [00:03<00:00, 19.78it/s]
100%|██████████| 63/63 [00:03<00:00, 17.26it/s]
100%|██████████| 63/63 [00:05<00:00, 12.22it/s]


Epoch 20 | Average Loss: 0.9061
Recall@10: 0.0989 | Recall@20: 0.1532


100%|██████████| 63/63 [00:04<00:00, 15.69it/s]
100%|██████████| 63/63 [00:04<00:00, 14.54it/s]
100%|██████████| 63/63 [00:03<00:00, 17.00it/s]
100%|██████████| 63/63 [00:03<00:00, 16.70it/s]
100%|██████████| 63/63 [00:03<00:00, 18.13it/s]


Epoch 21 | Average Loss: 0.8972
Recall@10: 0.0952 | Recall@20: 0.1418


100%|██████████| 63/63 [00:03<00:00, 17.52it/s]
100%|██████████| 63/63 [00:03<00:00, 19.84it/s]
100%|██████████| 63/63 [00:05<00:00, 11.01it/s]
100%|██████████| 63/63 [00:03<00:00, 17.17it/s]
100%|██████████| 63/63 [00:03<00:00, 18.85it/s]


Epoch 22 | Average Loss: 0.8944
Recall@10: 0.1004 | Recall@20: 0.1486


100%|██████████| 63/63 [00:03<00:00, 19.61it/s]
100%|██████████| 63/63 [00:03<00:00, 17.21it/s]
100%|██████████| 63/63 [00:03<00:00, 18.41it/s]
100%|██████████| 63/63 [00:03<00:00, 19.02it/s]
100%|██████████| 63/63 [00:04<00:00, 14.92it/s]


Epoch 23 | Average Loss: 0.8904
Recall@10: 0.0986 | Recall@20: 0.1479


100%|██████████| 63/63 [00:04<00:00, 15.26it/s]
100%|██████████| 63/63 [00:04<00:00, 14.77it/s]
100%|██████████| 63/63 [00:04<00:00, 13.58it/s]
100%|██████████| 63/63 [00:05<00:00, 12.24it/s]
100%|██████████| 63/63 [00:04<00:00, 13.16it/s]


Epoch 24 | Average Loss: 0.8929
Recall@10: 0.1025 | Recall@20: 0.1553


100%|██████████| 63/63 [00:03<00:00, 17.36it/s]
100%|██████████| 63/63 [00:04<00:00, 13.67it/s]
100%|██████████| 63/63 [00:04<00:00, 13.41it/s]
100%|██████████| 63/63 [00:03<00:00, 17.61it/s]
100%|██████████| 63/63 [00:03<00:00, 17.70it/s]


Epoch 25 | Average Loss: 0.8864
Recall@10: 0.0984 | Recall@20: 0.1511


100%|██████████| 63/63 [00:03<00:00, 16.18it/s]
100%|██████████| 63/63 [00:03<00:00, 17.22it/s]
100%|██████████| 63/63 [00:04<00:00, 15.67it/s]
100%|██████████| 63/63 [00:04<00:00, 13.23it/s]
100%|██████████| 63/63 [00:04<00:00, 14.61it/s]


Epoch 26 | Average Loss: 0.8853
Recall@10: 0.1051 | Recall@20: 0.1668


100%|██████████| 63/63 [00:03<00:00, 16.24it/s]
100%|██████████| 63/63 [00:05<00:00, 11.62it/s]
100%|██████████| 63/63 [00:03<00:00, 16.12it/s]
100%|██████████| 63/63 [00:03<00:00, 18.52it/s]
100%|██████████| 63/63 [00:05<00:00, 12.48it/s]


Epoch 27 | Average Loss: 0.8833
Recall@10: 0.0972 | Recall@20: 0.1383


100%|██████████| 63/63 [00:03<00:00, 16.20it/s]
100%|██████████| 63/63 [00:03<00:00, 17.83it/s]
100%|██████████| 63/63 [00:03<00:00, 18.80it/s]
100%|██████████| 63/63 [00:03<00:00, 20.44it/s]
100%|██████████| 63/63 [00:03<00:00, 16.98it/s]


Epoch 28 | Average Loss: 0.8797
Recall@10: 0.1057 | Recall@20: 0.1584


100%|██████████| 63/63 [00:04<00:00, 13.14it/s]
100%|██████████| 63/63 [00:04<00:00, 13.67it/s]
100%|██████████| 63/63 [00:04<00:00, 13.38it/s]
100%|██████████| 63/63 [00:04<00:00, 15.18it/s]
100%|██████████| 63/63 [00:03<00:00, 16.99it/s]


Epoch 29 | Average Loss: 0.8810
Recall@10: 0.1030 | Recall@20: 0.1530


100%|██████████| 63/63 [00:03<00:00, 18.99it/s]
100%|██████████| 63/63 [00:03<00:00, 19.36it/s]
100%|██████████| 63/63 [00:05<00:00, 10.99it/s]
100%|██████████| 63/63 [00:04<00:00, 12.75it/s]
100%|██████████| 63/63 [00:03<00:00, 17.11it/s]


Epoch 30 | Average Loss: 0.8794
Recall@10: 0.1016 | Recall@20: 0.1542


100%|██████████| 63/63 [00:03<00:00, 16.52it/s]
100%|██████████| 63/63 [00:03<00:00, 18.08it/s]
100%|██████████| 63/63 [00:04<00:00, 15.51it/s]
100%|██████████| 63/63 [00:04<00:00, 12.78it/s]
100%|██████████| 63/63 [00:04<00:00, 13.88it/s]


Epoch 31 | Average Loss: 0.8767
Recall@10: 0.1065 | Recall@20: 0.1574


100%|██████████| 63/63 [00:03<00:00, 16.10it/s]
100%|██████████| 63/63 [00:05<00:00, 12.15it/s]
100%|██████████| 63/63 [00:03<00:00, 15.78it/s]
100%|██████████| 63/63 [00:03<00:00, 20.34it/s]
100%|██████████| 63/63 [00:05<00:00, 12.12it/s]


Epoch 32 | Average Loss: 0.8788
Recall@10: 0.1059 | Recall@20: 0.1598


100%|██████████| 63/63 [00:03<00:00, 17.06it/s]
100%|██████████| 63/63 [00:03<00:00, 16.81it/s]
100%|██████████| 63/63 [00:03<00:00, 16.96it/s]
100%|██████████| 63/63 [00:04<00:00, 14.35it/s]
100%|██████████| 63/63 [00:03<00:00, 16.25it/s]


Epoch 33 | Average Loss: 0.8772
Recall@10: 0.1054 | Recall@20: 0.1601


100%|██████████| 63/63 [00:05<00:00, 12.01it/s]
100%|██████████| 63/63 [00:04<00:00, 13.45it/s]
100%|██████████| 63/63 [00:04<00:00, 13.48it/s]
100%|██████████| 63/63 [00:03<00:00, 17.77it/s]
100%|██████████| 63/63 [00:03<00:00, 18.56it/s]


Epoch 34 | Average Loss: 0.8723
Recall@10: 0.1140 | Recall@20: 0.1726


100%|██████████| 63/63 [00:03<00:00, 16.53it/s]
100%|██████████| 63/63 [00:03<00:00, 16.35it/s]
100%|██████████| 63/63 [00:03<00:00, 19.24it/s]
100%|██████████| 63/63 [00:04<00:00, 13.48it/s]
100%|██████████| 63/63 [00:03<00:00, 16.01it/s]


Epoch 35 | Average Loss: 0.8729
Recall@10: 0.1076 | Recall@20: 0.1616


100%|██████████| 63/63 [00:03<00:00, 18.30it/s]
100%|██████████| 63/63 [00:03<00:00, 19.41it/s]
100%|██████████| 63/63 [00:03<00:00, 18.81it/s]
100%|██████████| 63/63 [00:05<00:00, 11.90it/s]
100%|██████████| 63/63 [00:04<00:00, 12.80it/s]


Epoch 36 | Average Loss: 0.8719
Recall@10: 0.1146 | Recall@20: 0.1702


100%|██████████| 63/63 [00:04<00:00, 13.97it/s]
100%|██████████| 63/63 [00:03<00:00, 16.02it/s]
100%|██████████| 63/63 [00:03<00:00, 16.84it/s]
100%|██████████| 63/63 [00:04<00:00, 14.81it/s]
100%|██████████| 63/63 [00:03<00:00, 19.16it/s]


Epoch 37 | Average Loss: 0.8704
Recall@10: 0.1072 | Recall@20: 0.1591


100%|██████████| 63/63 [00:04<00:00, 14.20it/s]
100%|██████████| 63/63 [00:04<00:00, 13.08it/s]
100%|██████████| 63/63 [00:03<00:00, 17.62it/s]
100%|██████████| 63/63 [00:03<00:00, 19.68it/s]
100%|██████████| 63/63 [00:03<00:00, 19.83it/s]


Epoch 38 | Average Loss: 0.8675
Recall@10: 0.1078 | Recall@20: 0.1594


100%|██████████| 63/63 [00:03<00:00, 20.18it/s]
100%|██████████| 63/63 [00:03<00:00, 16.62it/s]
100%|██████████| 63/63 [00:06<00:00,  9.48it/s]
100%|██████████| 63/63 [00:05<00:00, 11.70it/s]
100%|██████████| 63/63 [00:04<00:00, 15.09it/s]


Epoch 39 | Average Loss: 0.8708
Recall@10: 0.1068 | Recall@20: 0.1552


100%|██████████| 63/63 [00:05<00:00, 11.64it/s]
100%|██████████| 63/63 [00:04<00:00, 15.44it/s]
100%|██████████| 63/63 [00:03<00:00, 18.25it/s]
100%|██████████| 63/63 [00:03<00:00, 17.20it/s]
100%|██████████| 63/63 [00:04<00:00, 13.88it/s]


Epoch 40 | Average Loss: 0.8693
Recall@10: 0.0973 | Recall@20: 0.1471


100%|██████████| 63/63 [00:04<00:00, 14.07it/s]
100%|██████████| 63/63 [00:03<00:00, 16.44it/s]
100%|██████████| 63/63 [00:03<00:00, 17.42it/s]
100%|██████████| 63/63 [00:03<00:00, 18.23it/s]
100%|██████████| 63/63 [00:03<00:00, 19.85it/s]


Epoch 41 | Average Loss: 0.8680
Recall@10: 0.1087 | Recall@20: 0.1560


100%|██████████| 63/63 [00:03<00:00, 17.63it/s]
100%|██████████| 63/63 [00:04<00:00, 12.92it/s]
100%|██████████| 63/63 [00:04<00:00, 14.26it/s]
100%|██████████| 63/63 [00:04<00:00, 15.32it/s]
100%|██████████| 63/63 [00:04<00:00, 13.78it/s]


Epoch 42 | Average Loss: 0.8656
Recall@10: 0.1033 | Recall@20: 0.1569


100%|██████████| 63/63 [00:03<00:00, 15.92it/s]
100%|██████████| 63/63 [00:03<00:00, 18.80it/s]
100%|██████████| 63/63 [00:04<00:00, 14.18it/s]
100%|██████████| 63/63 [00:04<00:00, 14.01it/s]
100%|██████████| 63/63 [00:04<00:00, 14.65it/s]


Epoch 43 | Average Loss: 0.8620
Recall@10: 0.1161 | Recall@20: 0.1664


100%|██████████| 63/63 [00:03<00:00, 18.31it/s]
100%|██████████| 63/63 [00:03<00:00, 19.37it/s]
100%|██████████| 63/63 [00:03<00:00, 17.31it/s]
100%|██████████| 63/63 [00:03<00:00, 19.75it/s]
100%|██████████| 63/63 [00:03<00:00, 20.51it/s]


Epoch 44 | Average Loss: 0.8643
Recall@10: 0.1105 | Recall@20: 0.1562


100%|██████████| 63/63 [00:04<00:00, 15.44it/s]
100%|██████████| 63/63 [00:03<00:00, 16.07it/s]
100%|██████████| 63/63 [00:04<00:00, 15.07it/s]
100%|██████████| 63/63 [00:04<00:00, 15.35it/s]
100%|██████████| 63/63 [00:05<00:00, 12.05it/s]


Epoch 45 | Average Loss: 0.8654
Recall@10: 0.1083 | Recall@20: 0.1572


100%|██████████| 63/63 [00:03<00:00, 18.93it/s]
100%|██████████| 63/63 [00:03<00:00, 18.04it/s]
100%|██████████| 63/63 [00:05<00:00, 11.81it/s]
100%|██████████| 63/63 [00:03<00:00, 17.90it/s]
100%|██████████| 63/63 [00:04<00:00, 14.01it/s]


Epoch 46 | Average Loss: 0.8619
Recall@10: 0.1112 | Recall@20: 0.1587


100%|██████████| 63/63 [00:03<00:00, 16.86it/s]
100%|██████████| 63/63 [00:03<00:00, 17.09it/s]
100%|██████████| 63/63 [00:03<00:00, 17.59it/s]
100%|██████████| 63/63 [00:05<00:00, 11.78it/s]
100%|██████████| 63/63 [00:04<00:00, 14.11it/s]


Epoch 47 | Average Loss: 0.8618
Recall@10: 0.1137 | Recall@20: 0.1661


100%|██████████| 63/63 [00:04<00:00, 15.40it/s]
100%|██████████| 63/63 [00:03<00:00, 17.21it/s]
100%|██████████| 63/63 [00:06<00:00, 10.33it/s]
100%|██████████| 63/63 [00:03<00:00, 17.33it/s]
100%|██████████| 63/63 [00:03<00:00, 17.04it/s]


Epoch 48 | Average Loss: 0.8598
Recall@10: 0.1074 | Recall@20: 0.1538


100%|██████████| 63/63 [00:05<00:00, 12.21it/s]
100%|██████████| 63/63 [00:04<00:00, 15.37it/s]
100%|██████████| 63/63 [00:03<00:00, 19.48it/s]
100%|██████████| 63/63 [00:03<00:00, 18.81it/s]
100%|██████████| 63/63 [00:03<00:00, 16.24it/s]


Epoch 49 | Average Loss: 0.8632
Recall@10: 0.1106 | Recall@20: 0.1645


In [8]:

for epoch in range(50):
    model_fixed.train()
    
    # Get edges for training
    edge_list = list(train_graph.edges())
    src, dst = zip(*edge_list)
    src = torch.tensor(src, dtype=torch.long).to(device)
    dst = torch.tensor(dst, dtype=torch.long).to(device)
    n_edges = src.shape[0]
    
    # Process all edges in a single batch with multiple epochs
    # This prevents in-place modification issues
    batch_losses = []
    
    # Process in smaller sub-epochs
    for sub_epoch in range(5):  # 5 sub-epochs per epoch
        perm = torch.randperm(n_edges)
        
        for i in tqdm(range(0, n_edges, 1024)):
            # Get embeddings for all nodes
            h = model_fixed(train_graph, train_feats)
            
            batch_src = src[perm[i:i+1024]]
            batch_dst = dst[perm[i:i+1024]]
            
            # Negative sampling
            neg_dst = torch.randint(0, h.shape[0], batch_dst.shape, dtype=torch.long).to(device)
            
            pos_score = predictor(h, batch_src, batch_dst)
            neg_score = predictor(h, batch_src, neg_dst)
            
            loss = compute_loss(pos_score, neg_score)
            batch_losses.append(loss.item())
            
            optimizer.zero_grad()
            loss.backward()  # No retain_graph needed
            optimizer.step()
    
    avg_loss = sum(batch_losses) / len(batch_losses)
    print(f"Epoch {epoch} | Average Loss: {avg_loss:.4f}")
    
    # Evaluation
    model_fixed.eval()
    with torch.no_grad():
        # Get train node embeddings
        train_embs = model_fixed(train_graph, train_feats)
        
        # Get eval node embeddings without graph structure
        eval_embs = model_fixed(eval_graph, eval_feats)
        
        # Compute similarity scores
        scores = torch.matmul(eval_embs, train_embs.T)
        
        # Get top-k predictions
        topk = torch.topk(scores, k=20, dim=1)
        predicted_indices = topk.indices
        
        # Calculate recall@k with the fixed function
        recall_at_10 = calculate_recall_at_k_fixed(eval_ids, predicted_indices, 10)
        recall_at_20 = calculate_recall_at_k_fixed(eval_ids, predicted_indices, 20)
        
        print(f"Recall@10: {recall_at_10:.4f} | Recall@20: {recall_at_20:.4f}")

100%|██████████| 63/63 [00:04<00:00, 14.63it/s]
100%|██████████| 63/63 [00:04<00:00, 13.88it/s]
100%|██████████| 63/63 [00:03<00:00, 18.48it/s]
100%|██████████| 63/63 [00:03<00:00, 19.19it/s]
100%|██████████| 63/63 [00:03<00:00, 19.65it/s]


Epoch 0 | Average Loss: 0.8395
Recall@10: 0.1176 | Recall@20: 0.1656


100%|██████████| 63/63 [00:03<00:00, 16.98it/s]
100%|██████████| 63/63 [00:05<00:00, 11.73it/s]
100%|██████████| 63/63 [00:05<00:00, 12.50it/s]
100%|██████████| 63/63 [00:04<00:00, 13.47it/s]
100%|██████████| 63/63 [00:03<00:00, 15.94it/s]


Epoch 1 | Average Loss: 0.8387
Recall@10: 0.1208 | Recall@20: 0.1733


100%|██████████| 63/63 [00:03<00:00, 18.56it/s]
100%|██████████| 63/63 [00:03<00:00, 19.44it/s]
100%|██████████| 63/63 [00:03<00:00, 16.40it/s]
100%|██████████| 63/63 [00:04<00:00, 14.64it/s]
100%|██████████| 63/63 [00:04<00:00, 15.36it/s]


Epoch 2 | Average Loss: 0.8390
Recall@10: 0.1123 | Recall@20: 0.1596


100%|██████████| 63/63 [00:03<00:00, 17.55it/s]
100%|██████████| 63/63 [00:03<00:00, 19.60it/s]
100%|██████████| 63/63 [00:03<00:00, 19.64it/s]
100%|██████████| 63/63 [00:03<00:00, 19.61it/s]
100%|██████████| 63/63 [00:04<00:00, 13.77it/s]


Epoch 3 | Average Loss: 0.8408
Recall@10: 0.1140 | Recall@20: 0.1590


100%|██████████| 63/63 [00:04<00:00, 14.21it/s]
100%|██████████| 63/63 [00:03<00:00, 16.06it/s]
100%|██████████| 63/63 [00:03<00:00, 16.94it/s]
100%|██████████| 63/63 [00:04<00:00, 13.73it/s]
100%|██████████| 63/63 [00:03<00:00, 16.34it/s]


Epoch 4 | Average Loss: 0.8377
Recall@10: 0.1185 | Recall@20: 0.1670


100%|██████████| 63/63 [00:04<00:00, 13.23it/s]
100%|██████████| 63/63 [00:04<00:00, 12.84it/s]
100%|██████████| 63/63 [00:03<00:00, 17.29it/s]
100%|██████████| 63/63 [00:05<00:00, 12.38it/s]
100%|██████████| 63/63 [00:04<00:00, 13.40it/s]


Epoch 5 | Average Loss: 0.8382
Recall@10: 0.1201 | Recall@20: 0.1662


100%|██████████| 63/63 [00:03<00:00, 17.03it/s]
100%|██████████| 63/63 [00:03<00:00, 18.80it/s]
100%|██████████| 63/63 [00:03<00:00, 19.07it/s]
100%|██████████| 63/63 [00:03<00:00, 18.46it/s]
100%|██████████| 63/63 [00:04<00:00, 13.40it/s]


Epoch 6 | Average Loss: 0.8377
Recall@10: 0.1176 | Recall@20: 0.1713


100%|██████████| 63/63 [00:06<00:00,  9.85it/s]
100%|██████████| 63/63 [00:04<00:00, 15.29it/s]
100%|██████████| 63/63 [00:04<00:00, 15.46it/s]
100%|██████████| 63/63 [00:03<00:00, 18.87it/s]
100%|██████████| 63/63 [00:03<00:00, 17.86it/s]


Epoch 7 | Average Loss: 0.8381
Recall@10: 0.1197 | Recall@20: 0.1710


100%|██████████| 63/63 [00:03<00:00, 16.03it/s]
100%|██████████| 63/63 [00:04<00:00, 15.74it/s]
100%|██████████| 63/63 [00:05<00:00, 11.20it/s]
100%|██████████| 63/63 [00:04<00:00, 15.42it/s]
100%|██████████| 63/63 [00:03<00:00, 18.95it/s]


Epoch 8 | Average Loss: 0.8362
Recall@10: 0.1138 | Recall@20: 0.1706


100%|██████████| 63/63 [00:03<00:00, 19.55it/s]
100%|██████████| 63/63 [00:03<00:00, 19.50it/s]
100%|██████████| 63/63 [00:04<00:00, 15.68it/s]
100%|██████████| 63/63 [00:04<00:00, 13.73it/s]
100%|██████████| 63/63 [00:04<00:00, 13.91it/s]


Epoch 9 | Average Loss: 0.8370
Recall@10: 0.1160 | Recall@20: 0.1737


100%|██████████| 63/63 [00:03<00:00, 16.78it/s]
100%|██████████| 63/63 [00:04<00:00, 12.76it/s]
100%|██████████| 63/63 [00:04<00:00, 13.12it/s]
100%|██████████| 63/63 [00:03<00:00, 18.46it/s]
100%|██████████| 63/63 [00:03<00:00, 16.52it/s]


Epoch 10 | Average Loss: 0.8361
Recall@10: 0.1244 | Recall@20: 0.1765


100%|██████████| 63/63 [00:03<00:00, 15.85it/s]
100%|██████████| 63/63 [00:04<00:00, 14.45it/s]
100%|██████████| 63/63 [00:03<00:00, 17.24it/s]
100%|██████████| 63/63 [00:03<00:00, 18.37it/s]
100%|██████████| 63/63 [00:03<00:00, 19.10it/s]


Epoch 11 | Average Loss: 0.8346
Recall@10: 0.1194 | Recall@20: 0.1714


100%|██████████| 63/63 [00:03<00:00, 19.06it/s]
100%|██████████| 63/63 [00:04<00:00, 14.26it/s]
100%|██████████| 63/63 [00:04<00:00, 14.08it/s]
100%|██████████| 63/63 [00:04<00:00, 13.57it/s]
100%|██████████| 63/63 [00:04<00:00, 13.99it/s]


Epoch 12 | Average Loss: 0.8346
Recall@10: 0.1015 | Recall@20: 0.1438


100%|██████████| 63/63 [00:03<00:00, 16.53it/s]
100%|██████████| 63/63 [00:03<00:00, 19.34it/s]
100%|██████████| 63/63 [00:03<00:00, 20.21it/s]
100%|██████████| 63/63 [00:04<00:00, 13.40it/s]
100%|██████████| 63/63 [00:04<00:00, 13.70it/s]


Epoch 13 | Average Loss: 0.8369
Recall@10: 0.1175 | Recall@20: 0.1661


100%|██████████| 63/63 [00:04<00:00, 15.17it/s]
100%|██████████| 63/63 [00:03<00:00, 19.43it/s]
100%|██████████| 63/63 [00:03<00:00, 19.61it/s]
100%|██████████| 63/63 [00:03<00:00, 20.16it/s]
100%|██████████| 63/63 [00:03<00:00, 19.15it/s]


Epoch 14 | Average Loss: 0.8346
Recall@10: 0.1174 | Recall@20: 0.1630


100%|██████████| 63/63 [00:03<00:00, 16.01it/s]
100%|██████████| 63/63 [00:04<00:00, 15.45it/s]
100%|██████████| 63/63 [00:03<00:00, 16.22it/s]
100%|██████████| 63/63 [00:03<00:00, 18.50it/s]
100%|██████████| 63/63 [00:03<00:00, 17.69it/s]


Epoch 15 | Average Loss: 0.8374
Recall@10: 0.1175 | Recall@20: 0.1652


100%|██████████| 63/63 [00:03<00:00, 18.75it/s]
100%|██████████| 63/63 [00:03<00:00, 19.45it/s]
100%|██████████| 63/63 [00:03<00:00, 16.92it/s]
100%|██████████| 63/63 [00:04<00:00, 14.22it/s]
100%|██████████| 63/63 [00:04<00:00, 15.64it/s]


Epoch 16 | Average Loss: 0.8336
Recall@10: 0.1233 | Recall@20: 0.1755


100%|██████████| 63/63 [00:04<00:00, 15.73it/s]
100%|██████████| 63/63 [00:03<00:00, 17.47it/s]
100%|██████████| 63/63 [00:03<00:00, 17.63it/s]
100%|██████████| 63/63 [00:03<00:00, 18.28it/s]
100%|██████████| 63/63 [00:03<00:00, 18.21it/s]


Epoch 17 | Average Loss: 0.8369
Recall@10: 0.1164 | Recall@20: 0.1672


100%|██████████| 63/63 [00:03<00:00, 16.05it/s]
100%|██████████| 63/63 [00:04<00:00, 15.74it/s]
100%|██████████| 63/63 [00:04<00:00, 14.76it/s]
100%|██████████| 63/63 [00:04<00:00, 15.69it/s]
100%|██████████| 63/63 [00:03<00:00, 18.54it/s]


Epoch 18 | Average Loss: 0.8331
Recall@10: 0.1212 | Recall@20: 0.1750


100%|██████████| 63/63 [00:03<00:00, 18.76it/s]
100%|██████████| 63/63 [00:04<00:00, 13.49it/s]
100%|██████████| 63/63 [00:03<00:00, 18.62it/s]
100%|██████████| 63/63 [00:03<00:00, 19.32it/s]
100%|██████████| 63/63 [00:04<00:00, 14.09it/s]


Epoch 19 | Average Loss: 0.8348
Recall@10: 0.1216 | Recall@20: 0.1737


100%|██████████| 63/63 [00:04<00:00, 14.39it/s]
100%|██████████| 63/63 [00:03<00:00, 17.78it/s]
100%|██████████| 63/63 [00:03<00:00, 18.85it/s]
100%|██████████| 63/63 [00:03<00:00, 19.15it/s]
100%|██████████| 63/63 [00:03<00:00, 19.17it/s]


Epoch 20 | Average Loss: 0.8362
Recall@10: 0.1097 | Recall@20: 0.1588


100%|██████████| 63/63 [00:03<00:00, 16.10it/s]
100%|██████████| 63/63 [00:04<00:00, 15.07it/s]
100%|██████████| 63/63 [00:03<00:00, 16.31it/s]
100%|██████████| 63/63 [00:03<00:00, 16.12it/s]
100%|██████████| 63/63 [00:03<00:00, 17.86it/s]


Epoch 21 | Average Loss: 0.8370
Recall@10: 0.1167 | Recall@20: 0.1662


100%|██████████| 63/63 [00:03<00:00, 16.66it/s]
100%|██████████| 63/63 [00:03<00:00, 19.17it/s]
100%|██████████| 63/63 [00:03<00:00, 19.28it/s]
100%|██████████| 63/63 [00:03<00:00, 18.98it/s]
100%|██████████| 63/63 [00:03<00:00, 16.60it/s]


Epoch 22 | Average Loss: 0.8324
Recall@10: 0.1192 | Recall@20: 0.1650


100%|██████████| 63/63 [00:03<00:00, 16.40it/s]
100%|██████████| 63/63 [00:03<00:00, 16.54it/s]
100%|██████████| 63/63 [00:03<00:00, 18.60it/s]
100%|██████████| 63/63 [00:03<00:00, 18.69it/s]
100%|██████████| 63/63 [00:03<00:00, 18.88it/s]


Epoch 23 | Average Loss: 0.8363
Recall@10: 0.1210 | Recall@20: 0.1685


100%|██████████| 63/63 [00:03<00:00, 17.71it/s]
100%|██████████| 63/63 [00:03<00:00, 16.81it/s]
100%|██████████| 63/63 [00:04<00:00, 15.67it/s]
100%|██████████| 63/63 [00:03<00:00, 16.23it/s]
100%|██████████| 63/63 [00:03<00:00, 16.39it/s]


Epoch 24 | Average Loss: 0.8353
Recall@10: 0.1184 | Recall@20: 0.1740


100%|██████████| 63/63 [00:03<00:00, 18.08it/s]
100%|██████████| 63/63 [00:03<00:00, 17.91it/s]
100%|██████████| 63/63 [00:03<00:00, 18.43it/s]
100%|██████████| 63/63 [00:03<00:00, 18.74it/s]
100%|██████████| 63/63 [00:03<00:00, 16.97it/s]


Epoch 25 | Average Loss: 0.8347
Recall@10: 0.1166 | Recall@20: 0.1702


100%|██████████| 63/63 [00:04<00:00, 15.69it/s]
100%|██████████| 63/63 [00:03<00:00, 16.13it/s]
100%|██████████| 63/63 [00:03<00:00, 19.61it/s]
100%|██████████| 63/63 [00:03<00:00, 19.85it/s]
100%|██████████| 63/63 [00:03<00:00, 20.05it/s]


Epoch 26 | Average Loss: 0.8347
Recall@10: 0.1157 | Recall@20: 0.1669


100%|██████████| 63/63 [00:03<00:00, 17.24it/s]
100%|██████████| 63/63 [00:04<00:00, 14.94it/s]
100%|██████████| 63/63 [00:04<00:00, 13.28it/s]
100%|██████████| 63/63 [00:04<00:00, 14.93it/s]
100%|██████████| 63/63 [00:04<00:00, 14.06it/s]


Epoch 27 | Average Loss: 0.8335
Recall@10: 0.1264 | Recall@20: 0.1817


100%|██████████| 63/63 [00:03<00:00, 15.81it/s]
100%|██████████| 63/63 [00:03<00:00, 19.45it/s]
100%|██████████| 63/63 [00:03<00:00, 19.97it/s]
100%|██████████| 63/63 [00:03<00:00, 20.06it/s]
100%|██████████| 63/63 [00:03<00:00, 19.33it/s]


Epoch 28 | Average Loss: 0.8336
Recall@10: 0.1190 | Recall@20: 0.1684


100%|██████████| 63/63 [00:03<00:00, 18.14it/s]
100%|██████████| 63/63 [00:03<00:00, 16.90it/s]
100%|██████████| 63/63 [00:03<00:00, 16.67it/s]
100%|██████████| 63/63 [00:03<00:00, 18.69it/s]
100%|██████████| 63/63 [00:03<00:00, 17.13it/s]


Epoch 29 | Average Loss: 0.8308
Recall@10: 0.1190 | Recall@20: 0.1719


100%|██████████| 63/63 [00:03<00:00, 17.81it/s]
100%|██████████| 63/63 [00:03<00:00, 18.66it/s]
100%|██████████| 63/63 [00:03<00:00, 18.34it/s]
100%|██████████| 63/63 [00:04<00:00, 14.49it/s]
100%|██████████| 63/63 [00:04<00:00, 13.64it/s]


Epoch 30 | Average Loss: 0.8322
Recall@10: 0.1166 | Recall@20: 0.1656


100%|██████████| 63/63 [00:04<00:00, 12.88it/s]
100%|██████████| 63/63 [00:04<00:00, 13.73it/s]
100%|██████████| 63/63 [00:03<00:00, 17.01it/s]
100%|██████████| 63/63 [00:04<00:00, 15.46it/s]
100%|██████████| 63/63 [00:04<00:00, 14.99it/s]


Epoch 31 | Average Loss: 0.8346
Recall@10: 0.1143 | Recall@20: 0.1675


100%|██████████| 63/63 [00:03<00:00, 18.62it/s]
100%|██████████| 63/63 [00:03<00:00, 19.01it/s]
100%|██████████| 63/63 [00:03<00:00, 19.63it/s]
100%|██████████| 63/63 [00:04<00:00, 12.61it/s]
100%|██████████| 63/63 [00:03<00:00, 17.54it/s]


Epoch 32 | Average Loss: 0.8374
Recall@10: 0.1184 | Recall@20: 0.1703


100%|██████████| 63/63 [00:03<00:00, 18.15it/s]
100%|██████████| 63/63 [00:03<00:00, 17.47it/s]
100%|██████████| 63/63 [00:03<00:00, 19.11it/s]
100%|██████████| 63/63 [00:03<00:00, 19.13it/s]
100%|██████████| 63/63 [00:03<00:00, 19.50it/s]


Epoch 33 | Average Loss: 0.8338
Recall@10: 0.1162 | Recall@20: 0.1621


100%|██████████| 63/63 [00:04<00:00, 14.50it/s]
100%|██████████| 63/63 [00:04<00:00, 15.54it/s]
100%|██████████| 63/63 [00:03<00:00, 16.39it/s]
100%|██████████| 63/63 [00:03<00:00, 16.53it/s]
100%|██████████| 63/63 [00:04<00:00, 15.72it/s]


Epoch 34 | Average Loss: 0.8327
Recall@10: 0.1145 | Recall@20: 0.1591


100%|██████████| 63/63 [00:03<00:00, 18.99it/s]
100%|██████████| 63/63 [00:03<00:00, 18.95it/s]
100%|██████████| 63/63 [00:03<00:00, 19.10it/s]
100%|██████████| 63/63 [00:05<00:00, 11.44it/s]
100%|██████████| 63/63 [00:03<00:00, 16.74it/s]


Epoch 35 | Average Loss: 0.8332
Recall@10: 0.1149 | Recall@20: 0.1604


100%|██████████| 63/63 [00:03<00:00, 16.35it/s]
100%|██████████| 63/63 [00:03<00:00, 18.18it/s]
100%|██████████| 63/63 [00:03<00:00, 17.98it/s]
100%|██████████| 63/63 [00:03<00:00, 19.01it/s]
100%|██████████| 63/63 [00:03<00:00, 18.77it/s]


Epoch 36 | Average Loss: 0.8342
Recall@10: 0.1212 | Recall@20: 0.1748


100%|██████████| 63/63 [00:03<00:00, 18.41it/s]
100%|██████████| 63/63 [00:03<00:00, 16.07it/s]
100%|██████████| 63/63 [00:04<00:00, 14.31it/s]
100%|██████████| 63/63 [00:04<00:00, 15.46it/s]
100%|██████████| 63/63 [00:03<00:00, 17.02it/s]


Epoch 37 | Average Loss: 0.8357
Recall@10: 0.1171 | Recall@20: 0.1665


100%|██████████| 63/63 [00:03<00:00, 17.17it/s]
100%|██████████| 63/63 [00:03<00:00, 16.54it/s]
100%|██████████| 63/63 [00:04<00:00, 14.01it/s]
100%|██████████| 63/63 [00:03<00:00, 18.36it/s]
100%|██████████| 63/63 [00:03<00:00, 18.72it/s]


Epoch 38 | Average Loss: 0.8332
Recall@10: 0.1241 | Recall@20: 0.1782


100%|██████████| 63/63 [00:03<00:00, 19.24it/s]
100%|██████████| 63/63 [00:03<00:00, 18.56it/s]
100%|██████████| 63/63 [00:03<00:00, 18.53it/s]
100%|██████████| 63/63 [00:05<00:00, 11.50it/s]
100%|██████████| 63/63 [00:03<00:00, 18.20it/s]


Epoch 39 | Average Loss: 0.8324
Recall@10: 0.1124 | Recall@20: 0.1596


100%|██████████| 63/63 [00:05<00:00, 12.59it/s]
100%|██████████| 63/63 [00:03<00:00, 19.34it/s]
100%|██████████| 63/63 [00:03<00:00, 19.21it/s]
100%|██████████| 63/63 [00:03<00:00, 19.17it/s]
100%|██████████| 63/63 [00:03<00:00, 19.32it/s]


Epoch 40 | Average Loss: 0.8327
Recall@10: 0.1171 | Recall@20: 0.1665


100%|██████████| 63/63 [00:04<00:00, 13.91it/s]
100%|██████████| 63/63 [00:04<00:00, 14.91it/s]
100%|██████████| 63/63 [00:04<00:00, 14.06it/s]
100%|██████████| 63/63 [00:03<00:00, 16.19it/s]
100%|██████████| 63/63 [00:03<00:00, 16.63it/s]


Epoch 41 | Average Loss: 0.8305
Recall@10: 0.1190 | Recall@20: 0.1640


100%|██████████| 63/63 [00:03<00:00, 17.44it/s]
100%|██████████| 63/63 [00:04<00:00, 13.68it/s]
100%|██████████| 63/63 [00:03<00:00, 18.05it/s]
100%|██████████| 63/63 [00:03<00:00, 17.36it/s]
100%|██████████| 63/63 [00:03<00:00, 18.79it/s]


Epoch 42 | Average Loss: 0.8312
Recall@10: 0.1165 | Recall@20: 0.1665


100%|██████████| 63/63 [00:05<00:00, 12.31it/s]
100%|██████████| 63/63 [00:03<00:00, 16.41it/s]
100%|██████████| 63/63 [00:04<00:00, 15.33it/s]
100%|██████████| 63/63 [00:03<00:00, 19.06it/s]
100%|██████████| 63/63 [00:03<00:00, 19.48it/s]


Epoch 43 | Average Loss: 0.8296
Recall@10: 0.1171 | Recall@20: 0.1702


100%|██████████| 63/63 [00:03<00:00, 19.19it/s]
100%|██████████| 63/63 [00:03<00:00, 19.44it/s]
100%|██████████| 63/63 [00:04<00:00, 13.13it/s]
100%|██████████| 63/63 [00:04<00:00, 14.78it/s]
100%|██████████| 63/63 [00:04<00:00, 15.13it/s]


Epoch 44 | Average Loss: 0.8323
Recall@10: 0.1196 | Recall@20: 0.1694


100%|██████████| 63/63 [00:03<00:00, 16.73it/s]
100%|██████████| 63/63 [00:03<00:00, 16.84it/s]
100%|██████████| 63/63 [00:03<00:00, 17.12it/s]
100%|██████████| 63/63 [00:04<00:00, 13.79it/s]
100%|██████████| 63/63 [00:04<00:00, 14.53it/s]


Epoch 45 | Average Loss: 0.8337
Recall@10: 0.1091 | Recall@20: 0.1620


100%|██████████| 63/63 [00:03<00:00, 18.25it/s]
100%|██████████| 63/63 [00:03<00:00, 19.18it/s]
100%|██████████| 63/63 [00:04<00:00, 14.12it/s]
100%|██████████| 63/63 [00:04<00:00, 14.27it/s]
100%|██████████| 63/63 [00:03<00:00, 18.39it/s]


Epoch 46 | Average Loss: 0.8330
Recall@10: 0.1134 | Recall@20: 0.1603


100%|██████████| 63/63 [00:03<00:00, 17.63it/s]
100%|██████████| 63/63 [00:03<00:00, 16.85it/s]
100%|██████████| 63/63 [00:03<00:00, 16.46it/s]
100%|██████████| 63/63 [00:04<00:00, 13.43it/s]
100%|██████████| 63/63 [00:04<00:00, 15.24it/s]


Epoch 47 | Average Loss: 0.8295
Recall@10: 0.1192 | Recall@20: 0.1655


100%|██████████| 63/63 [00:03<00:00, 16.37it/s]
100%|██████████| 63/63 [00:04<00:00, 15.28it/s]
100%|██████████| 63/63 [00:03<00:00, 16.65it/s]
100%|██████████| 63/63 [00:03<00:00, 16.01it/s]
100%|██████████| 63/63 [00:04<00:00, 12.75it/s]


Epoch 48 | Average Loss: 0.8295
Recall@10: 0.1135 | Recall@20: 0.1667


100%|██████████| 63/63 [00:04<00:00, 15.72it/s]
100%|██████████| 63/63 [00:03<00:00, 17.34it/s]
100%|██████████| 63/63 [00:04<00:00, 14.50it/s]
100%|██████████| 63/63 [00:04<00:00, 13.46it/s]
100%|██████████| 63/63 [00:03<00:00, 15.95it/s]


Epoch 49 | Average Loss: 0.8308
Recall@10: 0.1200 | Recall@20: 0.1736


In [10]:
torch.save(model_fixed.state_dict(), 'release/gat_model.pt')