In [1]:
import torch
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_add_pool

dataset = TUDataset(root='data/TUDataset', name='MUTAG').shuffle()

In [2]:
# Create training, validation, and test sets
train_dataset = dataset[:int(len(dataset)*0.8)]
val_dataset   = dataset[int(len(dataset)*0.8):int(len(dataset)*0.9)]
test_dataset  = dataset[int(len(dataset)*0.9):]

# Create mini-batches
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [3]:
train_dataset[0]

Data(edge_index=[2, 46], x=[20, 7], edge_attr=[46, 4], y=[1])

In [4]:
class GIN(torch.nn.Module):
    """GIN"""
    def __init__(self, dim_h):
        super(GIN, self).__init__()
        self.conv1 = GINConv(
            Sequential(Linear(dataset.num_node_features, dim_h),
                       BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.conv2 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.conv3 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.lin1 = Linear(dim_h*3, dim_h*3)
        self.lin2 = Linear(dim_h*3, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        h1 = self.conv1(x, edge_index)
        h2 = self.conv2(h1, edge_index)
        h3 = self.conv3(h2, edge_index)

        h1 = global_add_pool(h1, batch)
        h2 = global_add_pool(h2, batch)
        h3 = global_add_pool(h3, batch)

        h = torch.cat((h1, h2, h3), dim=1)

        h = self.lin1(h)
        h = h.relu()
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.lin2(h)
        
        return F.log_softmax(h, dim=1)


In [5]:
model = GIN(dim_h=32)

@torch.no_grad()
def test(model, loader):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()
    loss = 0
    acc = 0

    for data in loader:
        out = model(data.x, data.edge_index, data.batch)
        loss += criterion(out, data.y) / len(loader)
        acc += accuracy(out.argmax(dim=1), data.y) / len(loader)

    return loss, acc

def accuracy(pred_y, y):
    return ((pred_y == y).sum() / len(y)).item()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epochs = 200

model.train()
for epoch in range(epochs+1):
    total_loss = 0
    acc = 0
    val_loss = 0
    val_acc = 0

    # Train on batches
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        total_loss += loss / len(train_loader)
        acc += accuracy(out.argmax(dim=1), data.y) / len(train_loader)
        loss.backward()
        optimizer.step()

        # Validation
        val_loss, val_acc = test(model, val_loader)

    # Print metrics every 20 epochs
    if(epoch % 20 == 0):
        print(f'Epoch {epoch:>3} | Train Loss: {total_loss:.2f} | Train Acc: {acc*100:>5.2f}% | Val Loss: {val_loss:.2f} | Val Acc: {val_acc*100:.2f}%')

test_loss, test_acc = test(model, test_loader)
print(f'Test Loss: {test_loss:.2f} | Test Acc: {test_acc*100:.2f}%')

Epoch   0 | Train Loss: 1.14 | Train Acc: 42.52% | Val Loss: 0.66 | Val Acc: 52.63%
Epoch  20 | Train Loss: 0.44 | Train Acc: 78.31% | Val Loss: 0.44 | Val Acc: 73.68%
Epoch  40 | Train Loss: 0.33 | Train Acc: 83.43% | Val Loss: 0.21 | Val Acc: 94.74%
Epoch  60 | Train Loss: 0.35 | Train Acc: 80.40% | Val Loss: 0.31 | Val Acc: 89.47%
Epoch  80 | Train Loss: 0.30 | Train Acc: 82.58% | Val Loss: 0.19 | Val Acc: 94.74%
Epoch 100 | Train Loss: 0.21 | Train Acc: 92.19% | Val Loss: 0.22 | Val Acc: 94.74%
Epoch 120 | Train Loss: 0.20 | Train Acc: 88.21% | Val Loss: 0.27 | Val Acc: 84.21%
Epoch 140 | Train Loss: 0.18 | Train Acc: 90.62% | Val Loss: 0.23 | Val Acc: 94.74%
Epoch 160 | Train Loss: 0.24 | Train Acc: 87.64% | Val Loss: 0.38 | Val Acc: 73.68%
Epoch 180 | Train Loss: 0.16 | Train Acc: 93.75% | Val Loss: 0.25 | Val Acc: 94.74%
Epoch 200 | Train Loss: 0.16 | Train Acc: 93.28% | Val Loss: 0.33 | Val Acc: 89.47%
Test Loss: 0.56 | Test Acc: 73.68%


In [13]:
from torch_geometric.explain import GNNExplainer, Explainer