Everything is stolen from "https://github.com/pyg-team/pytorch_geometric/blob/master/examples/autoencoder.py"

In [2]:
import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GAE, VGAE, GCNConv
import os

In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cpu


In [4]:
transform = T.Compose([
    # Row-normalizes the attricutes to sum up to one
    T.NormalizeFeatures(),
    T.ToDevice(device),
    T.RandomLinkSplit(num_val=.05, num_test=.1, is_undirected=True,
                      # if split_labels is to true it will split positive and negative labels,
                      # and save them in distinct attributes
                      split_labels=True, 
                      # Add_negative_train_samples: Whether to add negative training samples for link
                      # prediction for link prediction. negative train samples might e.g. be edges that
                      # are not suposed to be in the graph.
                      add_negative_train_samples=False)
])

In [5]:
"""Preprocessed data """
dataset_dir = os.path.join(os.path.pardir, 'data/preprocessed/')
print(dataset_dir)
file_path=os.path.join(dataset_dir, 'meshgraphnets_miniset30traj5ts_vis.pt')
print(file_path)
dataset_full_timesteps = torch.load(file_path)
dataset = torch.load(file_path)[:1]
data2 = torch.load(file_path)[1:2]
print(len(dataset_full_timesteps)/5)

..\data/preprocessed/
..\data/preprocessed/meshgraphnets_miniset30traj5ts_vis.pt
30.0


In [6]:
# dataset = Planetoid("\..", "CiteSeer", transform=transform)
pls = transform(dataset)
train_data, val_data, test_data = pls[0]
train_data

Data(x=[1923, 11], edge_index=[2, 9412], edge_attr=[9412, 3], y=[1923, 2], p=[1923, 1], pos_edge_label=[4706], pos_edge_label_index=[2, 4706])

In [9]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv2 = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv_mu = GCNConv(2 * out_channels, out_channels)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        # Mu is the difference between the GCNE and the VGCNE
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

In [10]:
in_channels, out_channels = dataset.num_features, 16
variation = "GAE"
if variation == "GAE":
    model = GAE(GCNEncoder(in_channels, out_channels))
elif variation == "VGAE":
    model = VGAE(VariationalGCNEncoder(in_channels, out_channels))
else:
    raise Exception("Model type not specified")

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [11]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)
    loss = model.recon_loss(z, train_data.pos_edge_label_index)

    # Only relevant if we use the variational graph auto encoder
    if variation == "VGAE":
        loss = loss + (1 / train_data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()

    # Can we remove float?
    return float(loss)

@torch.no_grad()
def test(data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)

In [12]:
import time
times = []
epochs = 100
for epoch in range(1, epochs + 1):
    start = time.time()
    loss = train()
    auc, ap = test(test_data)
    print(f'Epoch: {epoch:03d}, AUC: {auc:.4f}, AP: {ap:.4f}')
    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f})s")



Epoch: 001, AUC: 0.6576, AP: 0.6994
Epoch: 002, AUC: 0.6525, AP: 0.6956
Epoch: 003, AUC: 0.6515, AP: 0.6952
Epoch: 004, AUC: 0.6535, AP: 0.6983
Epoch: 005, AUC: 0.6549, AP: 0.7052
Epoch: 006, AUC: 0.6577, AP: 0.7144
Epoch: 007, AUC: 0.6613, AP: 0.7224
Epoch: 008, AUC: 0.6622, AP: 0.7253
Epoch: 009, AUC: 0.6624, AP: 0.7255
Epoch: 010, AUC: 0.6618, AP: 0.7242
Epoch: 011, AUC: 0.6635, AP: 0.7243
Epoch: 012, AUC: 0.6731, AP: 0.7280
Epoch: 013, AUC: 0.7053, AP: 0.7424
Epoch: 014, AUC: 0.7548, AP: 0.7716
Epoch: 015, AUC: 0.7820, AP: 0.7919
Epoch: 016, AUC: 0.7910, AP: 0.7984
Epoch: 017, AUC: 0.7922, AP: 0.7990
Epoch: 018, AUC: 0.7919, AP: 0.7977
Epoch: 019, AUC: 0.7963, AP: 0.8007
Epoch: 020, AUC: 0.7967, AP: 0.7999
Epoch: 021, AUC: 0.7949, AP: 0.7971
Epoch: 022, AUC: 0.8001, AP: 0.8007
Epoch: 023, AUC: 0.8053, AP: 0.8046
Epoch: 024, AUC: 0.8086, AP: 0.8081
Epoch: 025, AUC: 0.8111, AP: 0.8120
Epoch: 026, AUC: 0.8195, AP: 0.8203
Epoch: 027, AUC: 0.8300, AP: 0.8294
Epoch: 028, AUC: 0.8400, AP: