In [1]:
!pip install torch torch_geometric torch_sparse torch_scatter torch_cluster torch_spline_conv

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_sparse
  Downloading torch_sparse-0.6.18.tar.gz (209 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.0/210.0 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch_scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch_cluster
  Downloading torch_cluster-1.6.3.tar.gz (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.5/54.5 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torc

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.utils import train_test_split_edges
from torch_geometric.nn.models import GAE
from sklearn.metrics import roc_auc_score, average_precision_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# ── dataset: FacebookPagePage ────────────────────────────────────────────────
from torch_geometric.datasets import FacebookPagePage
from torch_geometric.transforms import RandomLinkSplit

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = FacebookPagePage(root="data/FacebookPagePage")
data = dataset[0]

# 85 % train ‑ 5 % val ‑ 10 % test  (+ negative sampling)
transform = RandomLinkSplit(num_val=0.05,
                            num_test=0.10,
                            is_undirected=True,
                            split_labels=True)

train_data, val_data, test_data = transform(data)
train_data, val_data, test_data = [d.to(device) for d in (train_data,
                                                          val_data,
                                                          test_data)]


Downloading https://graphmining.ai/datasets/ptg/facebook.npz
Processing...
Done!


In [4]:
class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        return self.conv2(F.relu(self.conv1(x, edge_index)), edge_index)

model = GAE(Encoder(dataset.num_node_features,
                       hidden_channels=64,
                       out_channels=32)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)


In [5]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)

    loss = model.recon_loss(z,
                            pos_edge_index=train_data.pos_edge_label_index,
                            neg_edge_index=train_data.neg_edge_label_index)

    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(loader_data):
    model.eval()
    z = model.encode(loader_data.x, loader_data.edge_index)
    auc, ap = model.test(z,
                         pos_edge_index=loader_data.pos_edge_label_index,
                         neg_edge_index=loader_data.neg_edge_label_index)
    return auc, ap


In [6]:
import torch
import pandas as pd
from tqdm import trange


def train_until_early_stop(
    train_step,
    eval_step,
    val_data,
    test_data,
    model,
    n_epochs: int = 300,
    log_every: int = 10,
    patience: int = 20,
    ckpt_path: str = "best_gae_facebooks.pt",
    metrics: list = []
):
    metrics = [] if len(metrics) == 0 else metrcis
    best_val_auc, bad = -float("inf"), 0

    for epoch in trange(1, n_epochs + 1):
        loss = train_step()

        if epoch % log_every == 0:
            val_auc, val_ap   = eval_step(val_data)
            test_auc, test_ap = eval_step(test_data)
            print(
                f"Epoch {epoch:03d} | loss {loss:.4f} | "
                f"val AUC/AP {val_auc:.4f}/{val_ap:.4f} | "
                f"test AUC/AP {test_auc:.4f}/{test_ap:.4f}"
            )

            metrics.append(
                {
                    "GCN val AUC": val_auc,
                    "GCN val AP": val_ap,
                    "GCN test AUC": test_auc,
                    "GCN test AP": test_ap,
                }
            )

            if val_auc > best_val_auc:
                best_val_auc, bad = val_auc, 0
                torch.save(model.state_dict(), ckpt_path)
            else:
                bad += 1
                if bad == patience:
                    print("Early stop.")
                    break

    return metrics


metrics = train_until_early_stop(
    train_step=train,
    eval_step=evaluate,
    val_data=val_data,
    test_data=test_data,
    model=model
)


  3%|▎         | 10/300 [00:08<04:09,  1.16it/s]

Epoch 010 | loss 0.9686 | val AUC/AP 0.9548/0.9587 | test AUC/AP 0.9555/0.9600


  7%|▋         | 20/300 [00:17<04:56,  1.06s/it]

Epoch 020 | loss 0.8512 | val AUC/AP 0.9655/0.9696 | test AUC/AP 0.9656/0.9697


 10%|█         | 30/300 [00:27<05:08,  1.14s/it]

Epoch 030 | loss 0.8159 | val AUC/AP 0.9703/0.9740 | test AUC/AP 0.9702/0.9738


 13%|█▎        | 40/300 [00:34<04:02,  1.07it/s]

Epoch 040 | loss 0.7979 | val AUC/AP 0.9735/0.9766 | test AUC/AP 0.9727/0.9761


 17%|█▋        | 50/300 [00:43<03:35,  1.16it/s]

Epoch 050 | loss 0.7874 | val AUC/AP 0.9754/0.9780 | test AUC/AP 0.9746/0.9777


 20%|██        | 60/300 [00:50<03:37,  1.10it/s]

Epoch 060 | loss 0.7804 | val AUC/AP 0.9759/0.9785 | test AUC/AP 0.9753/0.9783


 23%|██▎       | 70/300 [00:57<03:15,  1.17it/s]

Epoch 070 | loss 0.7751 | val AUC/AP 0.9766/0.9791 | test AUC/AP 0.9760/0.9788


 27%|██▋       | 80/300 [01:05<03:09,  1.16it/s]

Epoch 080 | loss 0.7710 | val AUC/AP 0.9769/0.9794 | test AUC/AP 0.9763/0.9791


 30%|███       | 90/300 [01:14<03:43,  1.06s/it]

Epoch 090 | loss 0.7676 | val AUC/AP 0.9770/0.9796 | test AUC/AP 0.9765/0.9793


 33%|███▎      | 100/300 [01:21<02:51,  1.16it/s]

Epoch 100 | loss 0.7646 | val AUC/AP 0.9771/0.9797 | test AUC/AP 0.9766/0.9794


 37%|███▋      | 110/300 [01:29<02:38,  1.20it/s]

Epoch 110 | loss 0.7620 | val AUC/AP 0.9771/0.9797 | test AUC/AP 0.9767/0.9794


 40%|████      | 120/300 [01:36<02:32,  1.18it/s]

Epoch 120 | loss 0.7597 | val AUC/AP 0.9771/0.9798 | test AUC/AP 0.9767/0.9794


 43%|████▎     | 130/300 [01:44<02:23,  1.18it/s]

Epoch 130 | loss 0.7576 | val AUC/AP 0.9771/0.9798 | test AUC/AP 0.9767/0.9795


 47%|████▋     | 140/300 [01:52<02:31,  1.05it/s]

Epoch 140 | loss 0.7557 | val AUC/AP 0.9771/0.9799 | test AUC/AP 0.9768/0.9795


 50%|█████     | 150/300 [01:59<02:04,  1.20it/s]

Epoch 150 | loss 0.7540 | val AUC/AP 0.9771/0.9799 | test AUC/AP 0.9768/0.9795


 53%|█████▎    | 160/300 [02:07<02:02,  1.14it/s]

Epoch 160 | loss 0.7525 | val AUC/AP 0.9771/0.9799 | test AUC/AP 0.9769/0.9796


 57%|█████▋    | 170/300 [02:14<01:51,  1.17it/s]

Epoch 170 | loss 0.7511 | val AUC/AP 0.9771/0.9799 | test AUC/AP 0.9769/0.9796


 60%|██████    | 180/300 [02:22<01:44,  1.15it/s]

Epoch 180 | loss 0.7497 | val AUC/AP 0.9771/0.9799 | test AUC/AP 0.9769/0.9796


 63%|██████▎   | 190/300 [02:31<01:43,  1.07it/s]

Epoch 190 | loss 0.7484 | val AUC/AP 0.9771/0.9799 | test AUC/AP 0.9768/0.9796


 67%|██████▋   | 200/300 [02:38<01:26,  1.16it/s]

Epoch 200 | loss 0.7472 | val AUC/AP 0.9771/0.9799 | test AUC/AP 0.9767/0.9795


 70%|███████   | 210/300 [02:46<01:17,  1.16it/s]

Epoch 210 | loss 0.7460 | val AUC/AP 0.9770/0.9799 | test AUC/AP 0.9767/0.9795


 73%|███████▎  | 220/300 [02:54<01:21,  1.01s/it]

Epoch 220 | loss 0.7449 | val AUC/AP 0.9770/0.9799 | test AUC/AP 0.9766/0.9795


 77%|███████▋  | 230/300 [03:02<01:02,  1.13it/s]

Epoch 230 | loss 0.7439 | val AUC/AP 0.9769/0.9799 | test AUC/AP 0.9766/0.9794


 80%|████████  | 240/300 [03:10<00:52,  1.15it/s]

Epoch 240 | loss 0.7428 | val AUC/AP 0.9770/0.9799 | test AUC/AP 0.9765/0.9794


 83%|████████▎ | 250/300 [03:17<00:43,  1.14it/s]

Epoch 250 | loss 0.7419 | val AUC/AP 0.9769/0.9799 | test AUC/AP 0.9764/0.9794


 87%|████████▋ | 260/300 [03:25<00:33,  1.21it/s]

Epoch 260 | loss 0.7409 | val AUC/AP 0.9769/0.9799 | test AUC/AP 0.9764/0.9793


 90%|█████████ | 270/300 [03:33<00:29,  1.03it/s]

Epoch 270 | loss 0.7400 | val AUC/AP 0.9769/0.9799 | test AUC/AP 0.9763/0.9793


 93%|█████████▎| 280/300 [03:40<00:17,  1.17it/s]

Epoch 280 | loss 0.7391 | val AUC/AP 0.9769/0.9799 | test AUC/AP 0.9763/0.9793


 97%|█████████▋| 290/300 [03:49<00:09,  1.04it/s]

Epoch 290 | loss 0.7382 | val AUC/AP 0.9769/0.9798 | test AUC/AP 0.9762/0.9792


100%|██████████| 300/300 [03:56<00:00,  1.27it/s]

Epoch 300 | loss 0.7373 | val AUC/AP 0.9768/0.9798 | test AUC/AP 0.9762/0.9792





In [7]:
model.load_state_dict(torch.load("best_gae_facebooks.pt"))
final_auc, final_ap = evaluate(test_data)
print(f"\nBest model ‑‑ Test AUC: {final_auc:.4f} | AP: {final_ap:.4f}")



Best model ‑‑ Test AUC: 0.9768 | AP: 0.9795


In [16]:
class GraphSAGE(nn.Module):
    def __init__(self, in_channels: int, hidden: int = 128):
        super().__init__()
        self.sage1 = SAGEConv(in_channels, hidden)
        self.sage2 = SAGEConv(hidden,   hidden)

    def encode(self, x, edge_index):
        x = F.relu(self.sage1(x, edge_index))
        return self.sage2(x, edge_index)

    @staticmethod
    def decode(z, edge_index):
        return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

model = GraphSAGE(dataset.num_node_features, hidden=128).to(device)
opt   = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)


In [17]:
def train():
    model.train()
    opt.zero_grad()

    z = model.encode(train_data.x, train_data.edge_index)

    pos_logits = model.decode(z, train_data.pos_edge_label_index)
    neg_logits = model.decode(z, train_data.neg_edge_label_index)

    y = torch.cat([torch.ones_like(pos_logits), torch.zeros_like(neg_logits)])
    preds = torch.cat([pos_logits, neg_logits])

    loss = F.binary_cross_entropy_with_logits(preds, y)
    loss.backward()
    opt.step()
    return loss.item()


In [10]:
@torch.no_grad()
def evaluate(split_data):
    model.eval()
    z = model.encode(split_data.x, split_data.edge_index)

    pos = model.decode(z, split_data.pos_edge_label_index).sigmoid().cpu()
    neg = model.decode(z, split_data.neg_edge_label_index).sigmoid().cpu()

    y_true = torch.cat([torch.ones_like(pos), torch.zeros_like(neg)])
    y_pred = torch.cat([pos,               neg])

    auc = roc_auc_score(y_true, y_pred)
    ap  = average_precision_score(y_true, y_pred)
    return auc, ap


In [11]:
def train_sage(
    train_step,
    eval_step,
    val_data,
    test_data,
    model,
    metrics: list =[],
    metric_prefix: str = "GraphSAGE",
    n_epochs: int = 300,
    log_every: int = 10,
    patience: int = 20,
    ckpt_path: str = "best_sage_facebook.pt",
):
    best_auc, bad, idx = -float("inf"), 0, 0

    for epoch in range(1, n_epochs + 1):
        loss = train_step()

        if epoch % log_every == 0:
            val_auc, val_ap   = eval_step(val_data)
            test_auc, test_ap = eval_step(test_data)

            print(
                f"Ep{epoch:03d}  loss={loss:.4f}  "
                f"val AUC/AP={val_auc:.4f}/{val_ap:.4f}  "
                f"test AUC/AP={test_auc:.4f}/{test_ap:.4f}"
            )

            entry = {
                f"{metric_prefix} val AUC":  val_auc,
                f"{metric_prefix} val AP":   val_ap,
                f"{metric_prefix} test AUC": test_auc,
                f"{metric_prefix} test AP":  test_ap,
            }

            if len(metrics) <= idx:
                metrics.append(entry)
            else:
                metrics[idx].update(entry)

            if val_auc > best_auc:
                best_auc, bad = val_auc, 0
                torch.save(model.state_dict(), ckpt_path)
            else:
                bad += 1
                if bad == patience:
                    print("Early stop.")
                    break

            idx += 1

    model.load_state_dict(torch.load(ckpt_path))
    final_auc, final_ap = eval_step(test_data)
    print(
        f"\nBest model → Test AUC={final_auc:.4f}  AP={final_ap:.4f}"
    )

    return metrics, (final_auc, final_ap)

metrics = train_sage(
    train_step=train,
    eval_step=evaluate,
    val_data=val_data,
    test_data=test_data,
    model=model,
    metrics=metrics,
)


Ep010  loss=0.6690  val AUC/AP=0.8227/0.8128  test AUC/AP=0.8193/0.8127
Ep020  loss=0.5025  val AUC/AP=0.9255/0.9238  test AUC/AP=0.9249/0.9245
Ep030  loss=0.4506  val AUC/AP=0.9483/0.9482  test AUC/AP=0.9462/0.9459
Ep040  loss=0.4262  val AUC/AP=0.9568/0.9570  test AUC/AP=0.9547/0.9545
Ep050  loss=0.4124  val AUC/AP=0.9612/0.9614  test AUC/AP=0.9597/0.9596
Ep060  loss=0.4032  val AUC/AP=0.9629/0.9629  test AUC/AP=0.9619/0.9618
Ep070  loss=0.3961  val AUC/AP=0.9637/0.9638  test AUC/AP=0.9630/0.9633
Ep080  loss=0.3903  val AUC/AP=0.9640/0.9643  test AUC/AP=0.9636/0.9642
Ep090  loss=0.3853  val AUC/AP=0.9640/0.9646  test AUC/AP=0.9639/0.9649
Ep100  loss=0.3807  val AUC/AP=0.9638/0.9646  test AUC/AP=0.9639/0.9651
Ep110  loss=0.3764  val AUC/AP=0.9636/0.9645  test AUC/AP=0.9637/0.9652
Ep120  loss=0.3724  val AUC/AP=0.9632/0.9644  test AUC/AP=0.9635/0.9651
Ep130  loss=0.3685  val AUC/AP=0.9628/0.9641  test AUC/AP=0.9631/0.9650
Ep140  loss=0.3648  val AUC/AP=0.9625/0.9639  test AUC/AP=0.9626

In [12]:
import pandas as pd

metrics_df = pd.DataFrame(metrics[0])

In [18]:
import plotly.graph_objects as go

def plot_link_pred_metrics_plotly(df):
    """Visualise link‑prediction scores with Plotly (two separate figures)."""
    # ── 1. validation metrics ───────────────────────────────────────────────
    fig_val = go.Figure()
    fig_val.add_scatter(x=df.index, y=df["GCN val AUC"],  mode="lines+markers",
                        name="GCN val AUC")
    fig_val.add_scatter(x=df.index, y=df["GCN val AP"],   mode="lines+markers",
                        name="GCN val AP")
    fig_val.add_scatter(x=df.index, y=df["GraphSAGE val AUC"], mode="lines+markers",
                        name="GraphSAGE val AUC")
    fig_val.add_scatter(x=df.index, y=df["GraphSAGE val AP"],  mode="lines+markers",
                        name="GraphSAGE val AP")
    fig_val.update_layout(title="Validation metrics across runs",
                          xaxis_title="Run / seed",
                          yaxis_title="Score")
    fig_val.show()

    # ── 2. test metrics ────────────────────────────────────────────────────
    fig_test = go.Figure()
    fig_test.add_scatter(x=df.index, y=df["GCN test AUC"],  mode="lines+markers",
                         name="GCN test AUC")
    fig_test.add_scatter(x=df.index, y=df["GCN test AP"],   mode="lines+markers",
                         name="GCN test AP")
    fig_test.add_scatter(x=df.index, y=df["GraphSAGE test AUC"], mode="lines+markers",
                         name="GraphSAGE test AUC")
    fig_test.add_scatter(x=df.index, y=df["GraphSAGE test AP"],  mode="lines+markers",
                         name="GraphSAGE test AP")
    fig_test.update_layout(title="Test metrics across runs",
                           xaxis_title="Run / seed",
                           yaxis_title="Score")
    fig_test.show()
plot_link_pred_metrics_plotly(metrics_df)


In [14]:
pd.DataFrame(metrics)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,"{'GCN val AUC': 0.954808414834996, 'GCN val AP...","{'GCN val AUC': 0.9654533706034828, 'GCN val A...","{'GCN val AUC': 0.9702886287972255, 'GCN val A...","{'GCN val AUC': 0.9734752342253513, 'GCN val A...","{'GCN val AUC': 0.9754467116704663, 'GCN val A...","{'GCN val AUC': 0.975918669912632, 'GCN val AP...","{'GCN val AUC': 0.9766092572751345, 'GCN val A...","{'GCN val AUC': 0.9768861679282683, 'GCN val A...","{'GCN val AUC': 0.9770335405545894, 'GCN val A...","{'GCN val AUC': 0.9771097249402411, 'GCN val A...",...,"{'GCN val AUC': 0.9770376630567336, 'GCN val A...","{'GCN val AUC': 0.976986507173338, 'GCN val AP...","{'GCN val AUC': 0.9769412552044372, 'GCN val A...","{'GCN val AUC': 0.9769573971209456, 'GCN val A...","{'GCN val AUC': 0.9769448658064807, 'GCN val A...","{'GCN val AUC': 0.9769217006205907, 'GCN val A...","{'GCN val AUC': 0.9769024804781437, 'GCN val A...","{'GCN val AUC': 0.9769125273707864, 'GCN val A...","{'GCN val AUC': 0.9768698963304012, 'GCN val A...","{'GCN val AUC': 0.9768120857456967, 'GCN val A..."
1,0.963899,0.964853,,,,,,,,,...,,,,,,,,,,


In [15]:
metrics_df

Unnamed: 0,GCN val AUC,GCN val AP,GCN test AUC,GCN test AP,GraphSAGE val AUC,GraphSAGE val AP,GraphSAGE test AUC,GraphSAGE test AP
0,0.954808,0.958707,0.955522,0.959975,0.82267,0.812806,0.819272,0.812674
1,0.965453,0.96955,0.965556,0.969705,0.925506,0.923788,0.924932,0.924496
2,0.970289,0.974007,0.970187,0.973772,0.948341,0.948247,0.946167,0.945925
3,0.973475,0.97665,0.972697,0.976053,0.9568,0.957038,0.954656,0.954529
4,0.975447,0.978032,0.974635,0.977686,0.961195,0.961395,0.959726,0.959642
5,0.975919,0.978519,0.975343,0.978268,0.962882,0.96286,0.961871,0.961802
6,0.976609,0.979111,0.975975,0.978771,0.963655,0.963842,0.963039,0.963314
7,0.976886,0.979422,0.976343,0.979085,0.963984,0.964303,0.963641,0.964231
8,0.977034,0.979604,0.976481,0.979255,0.964017,0.964557,0.963899,0.964853
9,0.97711,0.979707,0.976594,0.979358,0.963821,0.964561,0.963889,0.965118
