Everything is stolen from "https://github.com/pyg-team/pytorch_geometric/blob/master/examples/autoencoder.py"

In [21]:
import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GAE, VGAE, GCNConv

In [22]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [23]:
transform = T.Compose([
    # Row-normalizes the attricutes to sum up to one
    T.NormalizeFeatures(),
    T.ToDevice(device),
    T.RandomLinkSplit(num_val=.05, num_test=.1, is_undirected=True,
                      # if split_labels is to true it will split positive and negative labels,
                      # and save them in distinct attributes
                      split_labels=True, 
                      # Add_negative_train_samples: Whether to add negative training samples for link
                      # prediction for link prediction. negative train samples might e.g. be edges that
                      # are not suposed to be in the graph.
                      add_negative_train_samples=False)
])

In [24]:
dataset = Planetoid("\..", "CiteSeer", transform=transform)
train_data, val_data, test_data = dataset[0]
train_data

Data(x=[3327, 3703], edge_index=[2, 7740], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327], pos_edge_label=[3870], pos_edge_label_index=[2, 3870])

In [25]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv2 = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv_mu = GCNConv(2 * out_channels, out_channels)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        # Mu is the difference between the GCNE and the VGCNE
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

In [26]:
in_channels, out_channels = dataset.num_features, 16
variation = "GAE"
if variation == "GAE":
    model = GAE(GCNEncoder(in_channels, out_channels))
elif variation == "VGAE":
    model = VGAE(VariationalGCNEncoder(in_channels, out_channels))
else:
    raise Exception("Model type not specified")

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [27]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)
    loss = model.recon_loss(z, train_data.pos_edge_label_index)

    # Only relevant if we use the variational graph auto encoder
    if variation == "VGAE":
        loss = loss + (1 / train_data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()

    # Can we remove float?
    return float(loss)

@torch.no_grad()
def test(data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)

In [28]:
import time
times = []
epochs = 100
for epoch in range(1, epochs + 1):
    start = time.time()
    loss = train()
    auc, ap = test(test_data)
    print(f'Epoch: {epoch:03d}, AUC: {auc:.4f}, AP: {ap:.4f}')
    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f})s")



Epoch: 001, AUC: 0.6400, AP: 0.6731
Epoch: 002, AUC: 0.6309, AP: 0.6675
Epoch: 003, AUC: 0.6335, AP: 0.6700
Epoch: 004, AUC: 0.6399, AP: 0.6759
Epoch: 005, AUC: 0.6478, AP: 0.6860
Epoch: 006, AUC: 0.6515, AP: 0.6953
Epoch: 007, AUC: 0.6529, AP: 0.7023
Epoch: 008, AUC: 0.6531, AP: 0.7061
Epoch: 009, AUC: 0.6543, AP: 0.7082
Epoch: 010, AUC: 0.6550, AP: 0.7092
Epoch: 011, AUC: 0.6579, AP: 0.7113
Epoch: 012, AUC: 0.6704, AP: 0.7174
Epoch: 013, AUC: 0.7037, AP: 0.7322
Epoch: 014, AUC: 0.7420, AP: 0.7534
Epoch: 015, AUC: 0.7577, AP: 0.7646
Epoch: 016, AUC: 0.7642, AP: 0.7689
Epoch: 017, AUC: 0.7675, AP: 0.7708
Epoch: 018, AUC: 0.7701, AP: 0.7727
Epoch: 019, AUC: 0.7704, AP: 0.7727
Epoch: 020, AUC: 0.7686, AP: 0.7710
Epoch: 021, AUC: 0.7712, AP: 0.7724
Epoch: 022, AUC: 0.7763, AP: 0.7753
Epoch: 023, AUC: 0.7762, AP: 0.7753
Epoch: 024, AUC: 0.7747, AP: 0.7738
Epoch: 025, AUC: 0.7803, AP: 0.7776
Epoch: 026, AUC: 0.7832, AP: 0.7802
Epoch: 027, AUC: 0.7844, AP: 0.7816
Epoch: 028, AUC: 0.7894, AP: