In [1]:
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-cluster -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-geometric

[K     |████████████████████████████████| 2.6MB 7.4MB/s 
[K     |████████████████████████████████| 1.4MB 8.1MB/s 
[K     |████████████████████████████████| 931kB 7.2MB/s 
[K     |████████████████████████████████| 389kB 309kB/s 
[K     |████████████████████████████████| 225kB 8.4MB/s 
[K     |████████████████████████████████| 235kB 48.1MB/s 
[K     |████████████████████████████████| 51kB 8.6MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone


In [17]:
import torch
import torch.nn.functional as F
from torch.nn import ReLU, Module
from torch_geometric.nn import GCNConv, Sequential
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import train_test_split_edges, negative_sampling
from torch_geometric.transforms import NormalizeFeatures
from sklearn.metrics import roc_auc_score


class Net(Module):
    def __init__(self, channels):
        super().__init__()
        convs = []
        for i in range(len(channels) - 1):
            convs.append((
                GCNConv(channels[i], channels[i + 1]),
                "x, edge_index -> x"
            ))
            convs.append(ReLU())
        convs = convs[:-1]
        self.convs = Sequential("x, edge_index", convs)

    def encode(self, x, edge_index):
        return self.convs(x, edge_index)

    def decode(self, x, pos_edge_index, neg_edge_index):
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        return (x[edge_index[0]] * x[edge_index[1]]).sum(dim=-1)

    def decode_all(self, x):
        prob_adj = x @ x.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

In [15]:
def get_link_labels(pos_edge_index, neg_edge_index):
    num_links = pos_edge_index.size(1) + neg_edge_index.size(1)
    link_labels = torch.zeros(num_links, dtype=torch.float)
    link_labels[:pos_edge_index.size(1)] = 1.
    return link_labels

def train(data, model, optimizer):
    model.train()

    neg_edge_index = negative_sampling(
        edge_index=data.train_pos_edge_index,
        num_nodes=data.num_nodes,
        num_neg_samples=data.train_pos_edge_index.size(1)
    )

    optimizer.zero_grad()
    x = model.encode(data.x, data.train_pos_edge_index)
    link_logits = model.decode(x, data.train_pos_edge_index, neg_edge_index)
    link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index).to(data.x.device)
    loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
    loss.backward()
    optimizer.step()

    return loss

@torch.no_grad()
def test(data, model):
    model.eval()

    x = model.encode(data.x, data.train_pos_edge_index)

    results = []
    for prefix in ["val", "test"]:
        pos_edge_index = data[f"{prefix}_pos_edge_index"]
        neg_edge_index = data[f"{prefix}_neg_edge_index"]
        link_logits = model.decode(x, pos_edge_index, neg_edge_index)
        link_probs = link_logits.sigmoid()
        link_labels = get_link_labels(pos_edge_index, neg_edge_index)
        results.append(roc_auc_score(link_labels.cpu(), link_probs.cpu()))

    return results

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = Planetoid("data", "Cora", transform=NormalizeFeatures())
data = dataset[0]

data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
data.to(device)

model = Net([dataset.num_features, 128, 64]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

best_val_auc = test_auc = 0
for epoch in range(100):
    loss = train(data, model, optimizer)
    val_auc, tmp_test_auc = test(data, model)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        test_auc = tmp_test_auc
    print(f"Epoch {epoch + 1} Loss {loss:.4f}\nVal Auc {val_auc} Test Auc {test_auc}")

x = model.encode(data.x, data.train_pos_edge_index)
final_edge_index = model.decode_all(x)

Epoch 1 Loss 0.6930
Val Auc 0.6925212161517442 Test Auc 0.6961066363253388
Epoch 2 Loss 0.6807
Val Auc 0.6894707166505225 Test Auc 0.6961066363253388
Epoch 3 Loss 0.7151
Val Auc 0.6897887782098917 Test Auc 0.6961066363253388
Epoch 4 Loss 0.6766
Val Auc 0.7045208113461232 Test Auc 0.7155752550147806
Epoch 5 Loss 0.6848
Val Auc 0.7396593849846029 Test Auc 0.760855006139078
Epoch 6 Loss 0.6891
Val Auc 0.7714004828752766 Test Auc 0.7903909926583108
Epoch 7 Loss 0.6906
Val Auc 0.7354884413537857 Test Auc 0.7903909926583108
Epoch 8 Loss 0.6906
Val Auc 0.7075279388165218 Test Auc 0.7903909926583108
Epoch 9 Loss 0.6895
Val Auc 0.697407798291142 Test Auc 0.7903909926583108
Epoch 10 Loss 0.6870
Val Auc 0.6940103225433361 Test Auc 0.7903909926583108
Epoch 11 Loss 0.6830
Val Auc 0.6925645881825673 Test Auc 0.7903909926583108
Epoch 12 Loss 0.6798
Val Auc 0.7025401552718704 Test Auc 0.7903909926583108
Epoch 13 Loss 0.6793
Val Auc 0.7138024259422573 Test Auc 0.7903909926583108
Epoch 14 Loss 0.6742
Va

In [19]:
(x, x.shape), (final_edge_index, final_edge_index.shape), best_val_auc

((tensor([[ 0.1374, -0.1213, -0.0143,  ..., -0.2595,  0.0808, -0.1902],
          [ 0.0888, -0.0660,  0.3673,  ..., -0.0973, -0.1418, -0.1412],
          [ 0.1031, -0.0878,  0.2762,  ..., -0.2299, -0.1008, -0.2051],
          ...,
          [ 0.0132, -0.0505, -0.3180,  ..., -0.0432, -0.1375, -0.1339],
          [ 0.1229, -0.1517,  0.0348,  ..., -0.3297, -0.1083, -0.2379],
          [ 0.1277, -0.1434, -0.0026,  ..., -0.3051, -0.0228, -0.2208]],
         device='cuda:0', grad_fn=<AddBackward0>), torch.Size([2708, 64])),
 (tensor([[   0,    0,    0,  ..., 2707, 2707, 2707],
          [   0,    1,    2,  ..., 2705, 2706, 2707]], device='cuda:0'),
  torch.Size([2, 3394548])),
 0.9039309517269298)