# Install Library =====================

In [1]:
!pip install -q torch-scatter torch-sparse torch-cluster torch-spline-conv pyg-lib \
  -f https://data.pyg.org/whl/torch-2.1.0+cu118.html
!pip install -q torch-geometric


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m47.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m31.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m35.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m891.8/891.8 kB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.6/2.6 MB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
[?25h

# Import Lib =======================

In [2]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import negative_sampling, train_test_split_edges
from sklearn.metrics import roc_auc_score, average_precision_score



# Load data from Cit-HepTh.txt =========

In [3]:
def load_edge_index_from_txt(path):
    edge_list = []
    with open(path, 'r') as f:
        for line in f:
            if line.startswith('#'):
                continue
            src, dst = map(int, line.strip().split())
            edge_list.append((src, dst))
    
    node_set = set([n for edge in edge_list for n in edge])
    node_id_map = {nid: i for i, nid in enumerate(sorted(node_set))}
    
    edges = [(node_id_map[src], node_id_map[dst]) for src, dst in edge_list]
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edge_index, node_id_map

In [4]:
edge_index, node_id_map = load_edge_index_from_txt("/kaggle/input/cit-hep/Cit-HepTh.txt")
num_nodes = len(node_id_map)
x = torch.randn((num_nodes, 64))  # Random features

In [5]:
# Tạo graph & chia train/val/test
data = Data(x=x, edge_index=edge_index)
data.train_mask = data.val_mask = data.test_mask = None
data = train_test_split_edges(data)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



# GraphSAGE ===================== 

In [6]:
class UnsupervisedGraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

In [7]:
class EarlyStopping:
    def __init__(self, patience=10, delta=0.0):
        self.patience = patience
        self.delta = delta
        self.best_loss = float('inf')
        self.counter = 0
        self.early_stop = False

    def __call__(self, current_loss):
        if current_loss < self.best_loss - self.delta:
            self.best_loss = current_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True


In [8]:
def unsupervised_loss(z, pos_edge_index, neg_edge_index):
    pos_loss = -F.logsigmoid((z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=1)).mean()
    neg_loss = -F.logsigmoid(-(z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=1)).mean()
    return pos_loss + neg_loss

In [9]:
def contrastive_loss(z, pos_edge_index, neg_edge_index, margin=1.0):
    def cosine_dist(a, b):
        return 1 - F.cosine_similarity(a, b)
    pos_dist = cosine_dist(z[pos_edge_index[0]], z[pos_edge_index[1]])
    neg_dist = cosine_dist(z[neg_edge_index[0]], z[neg_edge_index[1]])
    return (pos_dist**2).mean() + (F.relu(margin - neg_dist)**2).mean()


In [10]:
def info_nce_loss(z, pos_edge_index, temperature=0.5):
    sim = torch.mm(z, z.t()) / temperature
    sim_exp = torch.exp(sim)

    pos_sim = torch.exp((z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=1) / temperature)
    denom = sim_exp.sum(dim=1)[pos_edge_index[0]]
    return -torch.log(pos_sim / denom).mean()


# Training =======================

In [11]:
print(data)
data = data.to(device)


Data(x=[27770, 64], val_pos_edge_index=[2, 5221], test_pos_edge_index=[2, 10443], train_pos_edge_index=[2, 177542], train_neg_adj_mask=[27770, 27770], val_neg_edge_index=[2, 5221], test_neg_edge_index=[2, 10443])


In [12]:
model = UnsupervisedGraphSAGE(64, 128, 128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

from torch.optim.lr_scheduler import StepLR

# Khởi tạo
model = UnsupervisedGraphSAGE(64, 128, 128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=50, gamma=0.5)  # giảm lr mỗi 50 epochs
early_stopper = EarlyStopping(patience=10)

for epoch in range(1, 301):
    model.train()
    optimizer.zero_grad()
    
    z = model(data.x, data.train_pos_edge_index)  # Sử dụng train_pos_edge_index

    # Sinh negative edges
    neg_edge_index = negative_sampling(
        edge_index=data.train_pos_edge_index, 
        num_nodes=data.num_nodes
    ).to(device)

    loss = unsupervised_loss(z, data.train_pos_edge_index, neg_edge_index)

    loss.backward()
    optimizer.step()
    scheduler.step()

    print(f"[Epoch {epoch:03d}] Loss: {loss.item():.4f} | LR: {scheduler.get_last_lr()[0]:.6f}")

    early_stopper(loss.item())
    if early_stopper.early_stop:
        print(f"Early stopping at epoch {epoch}")
        break



[Epoch 001] Loss: 3.9492 | LR: 0.001000
[Epoch 002] Loss: 2.4804 | LR: 0.001000
[Epoch 003] Loss: 1.7942 | LR: 0.001000
[Epoch 004] Loss: 1.5641 | LR: 0.001000
[Epoch 005] Loss: 1.4935 | LR: 0.001000
[Epoch 006] Loss: 1.4713 | LR: 0.001000
[Epoch 007] Loss: 1.4852 | LR: 0.001000
[Epoch 008] Loss: 1.5008 | LR: 0.001000
[Epoch 009] Loss: 1.5143 | LR: 0.001000
[Epoch 010] Loss: 1.5047 | LR: 0.001000
[Epoch 011] Loss: 1.4894 | LR: 0.001000
[Epoch 012] Loss: 1.4697 | LR: 0.001000
[Epoch 013] Loss: 1.4472 | LR: 0.001000
[Epoch 014] Loss: 1.4269 | LR: 0.001000
[Epoch 015] Loss: 1.4070 | LR: 0.001000
[Epoch 016] Loss: 1.3883 | LR: 0.001000
[Epoch 017] Loss: 1.3676 | LR: 0.001000
[Epoch 018] Loss: 1.3489 | LR: 0.001000
[Epoch 019] Loss: 1.3286 | LR: 0.001000
[Epoch 020] Loss: 1.3166 | LR: 0.001000
[Epoch 021] Loss: 1.3008 | LR: 0.001000
[Epoch 022] Loss: 1.2883 | LR: 0.001000
[Epoch 023] Loss: 1.2738 | LR: 0.001000
[Epoch 024] Loss: 1.2630 | LR: 0.001000
[Epoch 025] Loss: 1.2472 | LR: 0.001000


# Evaluation =======================

In [13]:
@torch.no_grad()
def evaluate_link_prediction(z, pos_edge_index, neg_edge_index):
    pos_scores = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=1).cpu().numpy()
    neg_scores = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=1).cpu().numpy()
    y_true = np.hstack([np.ones(pos_scores.shape[0]), np.zeros(neg_scores.shape[0])])
    y_scores = np.hstack([pos_scores, neg_scores])
    auc = roc_auc_score(y_true, y_scores)
    ap = average_precision_score(y_true, y_scores)
    return auc, ap

model.eval()
with torch.no_grad():
    z = model(data.x.to(device), data.train_pos_edge_index.to(device))

auc, ap = evaluate_link_prediction(z, data.test_pos_edge_index, data.test_neg_edge_index)
print(f"[Evaluation] ROC AUC: {auc:.4f} | Average Precision: {ap:.4f}")

[Evaluation] ROC AUC: 0.8121 | Average Precision: 0.7908


# Save Model =======================

In [14]:
z_np = z.cpu().numpy()

try:
    df = pd.DataFrame(z_np, index=list(node_id_map.keys()))
except NameError:
    df = pd.DataFrame(z_np)
df.index.name = "node_id"
df.to_csv("graphsage_embeddings.csv")
print("Embedding saved to graphsage_embeddings.csv")

# ==== Save Model ====
torch.save(model.state_dict(), "graphsage_model.pt")
print("Model saved to graphsage_model.pt")

Embedding saved to graphsage_embeddings.csv
Model saved to graphsage_model.pt
