<a href="https://colab.research.google.com/github/shu65/pytorch_geometric_examples/blob/main/PyTorch_Geometric_Variational_Graph_AutoEncoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cu113.html

Looking in links: https://data.pyg.org/whl/torch-1.11.0+cu113.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.11.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 2.8 MB/s 
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.11.0%2Bcu113/torch_sparse-0.6.13-cp37-cp37m-linux_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 46.8 MB/s 
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.11.0%2Bcu113/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl (2.5 MB)
[K     |████████████████████████████████| 2.5 MB 56.7 MB/s 
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-1.11.0%2Bcu113/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl (750 kB)
[K     |████████████████████████████████| 750 kB 57.9 MB/s 
[?25hCollecting torch-geometric
  Downloading torch_geometric-2.0.4.tar.gz (407 kB)
[K     |███

In [2]:
!pip list | grep torch

torch                         1.11.0+cu113
torch-cluster                 1.6.0
torch-geometric               2.0.4
torch-scatter                 2.0.9
torch-sparse                  0.6.13
torch-spline-conv             1.2.1
torchaudio                    0.11.0+cu113
torchsummary                  1.5.1
torchtext                     0.12.0
torchvision                   0.12.0+cu113


In [3]:
import argparse
import os

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import VGAE, GCNConv

device = 'cpu'
transform = T.Compose([
    T.NormalizeFeatures(),
    T.ToDevice(device),
    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
                      split_labels=True, add_negative_train_samples=False),
])
path = os.path.join("tmp", "data", "Planetoid")
dataset = Planetoid(path, "PubMed", transform=transform)
train_data, val_data, test_data = dataset[0]
train_data, val_data, test_data

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.test.index
Processing...
Done!


(Data(x=[19717, 500], edge_index=[2, 75352], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717], pos_edge_label=[37676], pos_edge_label_index=[2, 37676]),
 Data(x=[19717, 500], edge_index=[2, 75352], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717], pos_edge_label=[2216], pos_edge_label_index=[2, 2216], neg_edge_label=[2216], neg_edge_label_index=[2, 2216]),
 Data(x=[19717, 500], edge_index=[2, 79784], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717], pos_edge_label=[4432], pos_edge_label_index=[2, 4432], neg_edge_label=[4432], neg_edge_label_index=[2, 4432]))

In [4]:
train_data.x

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0554, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0114, 0.0047,  ..., 0.0000, 0.0000, 0.0000],
        [0.0531, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0145, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

In [5]:
train_data.y

tensor([1, 1, 0,  ..., 2, 0, 2])

In [6]:
class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv_mu = GCNConv(2 * out_channels, out_channels)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)


in_channels = dataset.num_features
out_channels = 16
model = VGAE(VariationalGCNEncoder(in_channels, out_channels))
model.encoder, model.decoder

(VariationalGCNEncoder(
   (conv1): GCNConv(500, 32)
   (conv_mu): GCNConv(32, 16)
   (conv_logstd): GCNConv(32, 16)
 ), InnerProductDecoder())

In [7]:
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(0, 400):
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)
    recon_loss = model.recon_loss(z, train_data.pos_edge_label_index)
    kl_loss = (1 / train_data.num_nodes) * model.kl_loss()
    loss = recon_loss + kl_loss
    loss.backward()
    optimizer.step()
    
    model.eval()
    z = model.encode(test_data.x, test_data.edge_index)
    auc, ap = model.test(z, test_data.pos_edge_label_index, test_data.neg_edge_label_index)

    print(f'Epoch: {epoch:03d}, AUC: {auc:.4f}, AP: {ap:.4f}')

Epoch: 000, AUC: 0.8592, AP: 0.8450
Epoch: 001, AUC: 0.8783, AP: 0.8559
Epoch: 002, AUC: 0.8833, AP: 0.8590
Epoch: 003, AUC: 0.8854, AP: 0.8605
Epoch: 004, AUC: 0.8866, AP: 0.8613
Epoch: 005, AUC: 0.8873, AP: 0.8618
Epoch: 006, AUC: 0.8879, AP: 0.8623
Epoch: 007, AUC: 0.8883, AP: 0.8626
Epoch: 008, AUC: 0.8887, AP: 0.8629
Epoch: 009, AUC: 0.8890, AP: 0.8633
Epoch: 010, AUC: 0.8893, AP: 0.8636
Epoch: 011, AUC: 0.8897, AP: 0.8640
Epoch: 012, AUC: 0.8900, AP: 0.8643
Epoch: 013, AUC: 0.8903, AP: 0.8646
Epoch: 014, AUC: 0.8904, AP: 0.8649
Epoch: 015, AUC: 0.8904, AP: 0.8650
Epoch: 016, AUC: 0.8901, AP: 0.8650
Epoch: 017, AUC: 0.8897, AP: 0.8648
Epoch: 018, AUC: 0.8894, AP: 0.8647
Epoch: 019, AUC: 0.8896, AP: 0.8648
Epoch: 020, AUC: 0.8901, AP: 0.8651
Epoch: 021, AUC: 0.8906, AP: 0.8654
Epoch: 022, AUC: 0.8909, AP: 0.8656
Epoch: 023, AUC: 0.8911, AP: 0.8658
Epoch: 024, AUC: 0.8912, AP: 0.8659
Epoch: 025, AUC: 0.8911, AP: 0.8659
Epoch: 026, AUC: 0.8909, AP: 0.8659
Epoch: 027, AUC: 0.8906, AP: