In [63]:
import numpy as np

import torch
import torch.nn.functional as F

from torch_geometric.utils import train_test_split_edges, negative_sampling
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.nn import SAGEConv

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score

from tqdm.auto import tqdm

In [64]:
# Load saved graph
data = torch.load('data/embedded-ego-networks/facebook/ego_network_data.pt')
print(data)  

for key in list(data.keys()):                      # iterate over attribute names
    val = data[key]
    if (isinstance(val, list) or isinstance(val, str)) and \
       val and len(val) == data.edge_index.size(1):
        print(f"Removing non-tensor edge attr → {key}")
        delattr(data, key)

print(data)

# --- z-score scalar columns, keep SBERT as-is ---------------------------
scalars = np.stack([data.x[:,i].cpu().numpy() for i,name in
                    enumerate(['degree','indegree','outdegree',
                               'betweenness','pagerank','clustering_coeff',
                               'eigenvector'])], axis=1)
scaled   = StandardScaler().fit_transform(scalars).astype('float32')
data.x   = torch.cat([torch.tensor(scaled),                 # (N,7)
                      data.x[:,7:]                          # (N,768)
                     ],  dim=1)

Data(edge_index=[2, 5038], name=[348], id=[348], referral=[5038], edge__igraph_index=[5038], x=[348, 775], edge_attr=[5038, 768], y=[348])
Removing non-tensor edge attr → referral
Data(edge_index=[2, 5038], name=[348], id=[348], edge__igraph_index=[5038], x=[348, 775], edge_attr=[5038, 768], y=[348])


In [65]:
# Edge split (80/10/10)
split = RandomLinkSplit(num_val  = 0.1,
                        num_test = 0.1,
                        add_negative_train_samples = True,
                        neg_sampling_ratio = 1.0,
                        is_undirected = False)      # keep direction
train_data, val_data, test_data = split(data)       # PyG transform
print(train_data)

Data(edge_index=[2, 4032], name=[348], id=[348], edge__igraph_index=[4032], x=[348, 775], edge_attr=[4032, 768], y=[348], edge_label=[8064], edge_label_index=[2, 8064])


In [66]:
# GraphSAGE Model
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_dim, hid=256):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hid)
        self.conv2 = SAGEConv(hid, hid)
    def encode(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv2(x, edge_index)            # (N, hid)
    def decode(self, z, edge_index):                # no sigmoid here
        src, dst = edge_index
        return (z[src] * z[dst]).sum(dim=-1)        # raw dot-product
    def forward(self, data):
        z = self.encode(data.x, data.edge_index)
        return self.decode(z, data.edge_label_index)
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model  = GraphSAGE(train_data.num_features).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion  = torch.nn.BCEWithLogitsLoss()

In [67]:
# Training Helpers
def train_epoch(data):
    model.train(); optimizer.zero_grad()
    out  = model(data)
    loss = criterion(out, data.edge_label.float())
    loss.backward(); optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(data, k=50):
    model.eval()
    logit = model(data)
    prob  = torch.sigmoid(logit).cpu()
    y     = data.edge_label.cpu()
    auc   = roc_auc_score(y, prob)
    _, idx = prob.topk(k)
    hits = y[idx].sum().item() / k
    return auc, hits

In [71]:
# Training & Evaluation Loops
train_data = train_data.to(device)
val_data   = val_data.to(device)
test_data  = test_data.to(device)

EPOCHS = 1000
for epoch in tqdm(range(EPOCHS)):
    loss = train_epoch(train_data)
    if epoch % 50 == 0:
        auc, hits = evaluate(val_data)
        print(f'E{epoch:02d}  loss {loss:.4f}  val-AUC {auc:.4f}  hits@50 {hits:.3f}')

  1%|          | 6/1000 [00:00<00:43, 23.07it/s]

E00  loss 13.5050  val-AUC 0.7359  hits@50 0.640


  6%|▌         | 56/1000 [00:01<00:28, 33.17it/s]

E50  loss 8.4324  val-AUC 0.7579  hits@50 0.640


 10%|█         | 104/1000 [00:03<00:27, 32.79it/s]

E100  loss 5.8491  val-AUC 0.7679  hits@50 0.780


 16%|█▌        | 155/1000 [00:04<00:26, 31.36it/s]

E150  loss 4.1435  val-AUC 0.7587  hits@50 0.820


 21%|██        | 207/1000 [00:06<00:24, 32.41it/s]

E200  loss 2.9435  val-AUC 0.7656  hits@50 0.860


 26%|██▌       | 255/1000 [00:07<00:22, 33.23it/s]

E250  loss 2.1717  val-AUC 0.7947  hits@50 0.900


 31%|███       | 307/1000 [00:09<00:20, 33.26it/s]

E300  loss 1.7671  val-AUC 0.8003  hits@50 0.840


 36%|███▌      | 355/1000 [00:10<00:19, 33.29it/s]

E350  loss 1.4822  val-AUC 0.8119  hits@50 0.780


 41%|████      | 407/1000 [00:12<00:18, 32.07it/s]

E400  loss 1.2727  val-AUC 0.8182  hits@50 0.720


 46%|████▌     | 455/1000 [00:13<00:16, 33.55it/s]

E450  loss 1.1147  val-AUC 0.8219  hits@50 0.720


 51%|█████     | 507/1000 [00:15<00:14, 33.45it/s]

E500  loss 0.9921  val-AUC 0.8263  hits@50 0.800


 56%|█████▌    | 555/1000 [00:16<00:13, 33.62it/s]

E550  loss 0.8946  val-AUC 0.8299  hits@50 0.820


 61%|██████    | 607/1000 [00:18<00:11, 33.43it/s]

E600  loss 0.8162  val-AUC 0.8336  hits@50 0.860


 66%|██████▌   | 655/1000 [00:19<00:10, 33.45it/s]

E650  loss 0.7552  val-AUC 0.8363  hits@50 0.840


 71%|███████   | 707/1000 [00:21<00:08, 33.40it/s]

E700  loss 0.6985  val-AUC 0.8395  hits@50 0.840


 76%|███████▌  | 755/1000 [00:22<00:07, 33.66it/s]

E750  loss 0.6513  val-AUC 0.8445  hits@50 0.840


 81%|████████  | 806/1000 [00:24<00:05, 32.93it/s]

E800  loss 0.6135  val-AUC 0.8479  hits@50 0.840


 85%|████████▌ | 854/1000 [00:25<00:04, 32.45it/s]

E850  loss 0.5828  val-AUC 0.8484  hits@50 0.840


 91%|█████████ | 906/1000 [00:27<00:02, 33.33it/s]

E900  loss 0.5531  val-AUC 0.8530  hits@50 0.820


 95%|█████████▌| 954/1000 [00:28<00:01, 32.98it/s]

E950  loss 0.5282  val-AUC 0.8553  hits@50 0.820


100%|██████████| 1000/1000 [00:30<00:00, 33.02it/s]


In [72]:
test_auc, test_hits = evaluate(test_data)
print(f'\nTEST  AUC={test_auc:.4f}  Hits@50={test_hits:.3f}')


TEST  AUC=0.8722  Hits@50=0.880
