In [11]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
import torch.nn.functional as F
import torch.optim as optim
import os
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.transforms import ToUndirected

# torch.serialization.add_safe_globals(["ogb.nodeproppred.dataset"])

# Download dataset
# if "ogbn-mag" not in os.listdir("dataset"):
#     dataset = PygNodePropPredDataset(name = "ogbn-mag", root = "dataset/")

# Load the processed dataset
try:
    data, _ = torch.load(r"dataset/ogbn_mag/processed/geometric_data_processed.pt", weights_only=False)
except FileNotFoundError:
    os.chdir(os.path.abspath(os.path.join(os.getcwd(), "../")))
    data, _ = torch.load(r"dataset/ogbn_mag/processed/geometric_data_processed.pt", weights_only=False)

# Extract the paper_cites_edge_index (for the paper -> cites -> paper relation)
paper_cites_edge_index = data.edge_index_dict[('paper', 'cites', 'paper')]

# Extract the paper node features (assuming paper feature is in data.x['paper'])
paper_node_features = data.x_dict['paper']

# Get the number of papers
num_papers = paper_node_features.size(0)

data


Data(
  num_nodes_dict={
    author=1134649,
    field_of_study=59965,
    institution=8740,
    paper=736389,
  },
  edge_index_dict={
    (author, affiliated_with, institution)=[2, 1043998],
    (author, writes, paper)=[2, 7145660],
    (paper, cites, paper)=[2, 5416271],
    (paper, has_topic, field_of_study)=[2, 7505078],
  },
  x_dict={ paper=[736389, 128] },
  node_year={ paper=[736389, 1] },
  edge_reltype={
    (author, affiliated_with, institution)=[1043998, 1],
    (author, writes, paper)=[7145660, 1],
    (paper, cites, paper)=[5416271, 1],
    (paper, has_topic, field_of_study)=[7505078, 1],
  },
  y_dict={ paper=[736389, 1] }
)

In [69]:
import pandas as pd

train_idx = pd.read_csv(r"dataset/ogbn_mag/split/time/paper/train.csv.gz").values.squeeze()
valid_idx = pd.read_csv(r"dataset/ogbn_mag/split/time/paper/valid.csv.gz").values.squeeze()
test_idx = pd.read_csv(r"dataset/ogbn_mag/split/time/paper/test.csv.gz").values.squeeze()

# Define the training, validation, and test sets
train_data = paper_cites_edge_index[0, train_idx], paper_cites_edge_index[1, train_idx]
valid_data = paper_cites_edge_index[0, valid_idx], paper_cites_edge_index[1, valid_idx], paper_cites_edge_index[1, train_idx]
test_data = paper_cites_edge_index[0, test_idx], paper_cites_edge_index[1, test_idx]

train_data

(tensor([     0,      0,      0,  ..., 140773, 140773, 140773]),
 tensor([ 27449, 121051, 151667,  ..., 126967, 150164, 234911]))

tensor([    95,    134,    162,  ..., 140764, 140772, 140773])

In [4]:
import numpy as np

# Get the positive edges for paper_cites_edge_index
edges = paper_cites_edge_index.T.numpy()

edges_sample = edges[:20]
# Shuffle edges and split them into train and test sets (80% train, 20% test)
np.random.shuffle(edges_sample)
train_size = int(0.8 * len(edges_sample))
train_edges = edges_sample[:train_size]
test_edges = edges_sample[train_size:]

# Generate negative edges (pairs of nodes without edges between them)
def generate_negative_edges(num_samples, num_papers, existing_edges):
    neg_edges = set()
    while len(neg_edges) < num_samples:
        u = np.random.randint(0, num_papers)
        v = np.random.randint(0, num_papers)
        if (u, v) not in existing_edges and (v, u) not in existing_edges:
            neg_edges.add((u, v))
    return np.array(list(neg_edges))

# Get negative edges for training and testing
train_neg_edges = generate_negative_edges(len(train_edges), num_papers, set(map(tuple, train_edges)))
test_neg_edges = generate_negative_edges(len(test_edges), num_papers, set(map(tuple, train_edges)))

# Combine positive and negative edges for training
train_edges = np.concatenate([train_edges, train_neg_edges], axis=0)

# Convert the edges to PyTorch tensors
train_edges = torch.tensor(train_edges, dtype=torch.long).T  # Shape (2, num_edges)
test_edges = torch.tensor(test_edges, dtype=torch.long).T  # Shape (2, num_edges)
