# Intro to GNNs using Pytorch Geometric

Adapted from the Pytorch Geometric documentation (https://pytorch-geometric.readthedocs.io/)

For this notebook, we'll work with a new package designed for GNNs: Pytorch Geometric. You may choose to install these packages into the same environment you already made, or maybe you'll want to make a new environment -- up to you!

<a name='section_0'></a>
<h2 style="border:1px; border-style:solid; padding: 0.25em; color: #FFFFFF; background-color: #1f77b4">0. Imports & installs</h2>

In [None]:
%pip install torch_geometric deepsnap torch-sparse torch-scatter networkx

In [None]:
import pandas as pd
import numpy as np
import glob
import time
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch_geometric.data import Data
from torch_geometric.utils.convert import to_networkx
import networkx as nx # for visualizing graphs
from sklearn.preprocessing import StandardScaler
from IPython.display import Javascript

import torch_geometric.nn as pyg_nn
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
import torch.optim as optim
from copy import deepcopy
from torch_geometric.nn import GINConv
from sklearn.metrics import *
from torch.nn import Sequential, Linear, ReLU
from deepsnap.dataset import GraphDataset
from deepsnap.batch import Batch

<h2 style="border:1px; border-style:solid; padding: 0.25em; color: #FFFFFF; background-color: #1f77b4">1. Graph visualization</h2>

Let's start very simply... a graph with 3 nodes (0,1, and 2) and two edges connecting nodes 0&harr;1 and 1&harr;2. 
Note that the `edge_index` object has two arrays: one of starting nodes and one of ending nodes for each edge.

In [None]:
### Define node features 
x = torch.tensor([[-1], # node 0
                  [0],  # node 1
                  [1]], # node 2
                 dtype=torch.float)

### Define two (bidirectional) edges between nodes [0,1] and between nodes [1,2]:
edge_index = torch.tensor([[0, 1, 1, 2], # starting node of each edge
                           [1, 0, 2, 1]  # ending node of each edge
                          ], dtype=torch.long)

### Put it all together into a graph structure
data = Data(x=x, edge_index=edge_index)

### Visualize the graph
plt.figure() 
nx.draw(to_networkx(data), 
        cmap='plasma', 
        node_color = np.arange(data.num_nodes),
        with_labels=True,
        font_weight='bold',
        font_color='white',
        node_size=400, linewidths=6)

In [None]:
print("Graph nodes have {} node feature(s) each.".format(data.num_node_features))
print("Graph has {} nodes and {} edges.".format(data.num_nodes, data.num_edges)) # each direction of the edge counts as an edge
print("Is the graph undirected?: {}".format(data.is_undirected())) # this means the edges are bi-directional (connect both nodes in both directions)

<a name='problem_x'></a> 
### <span style="border:3px; border-style:solid; padding: 0.15em; border-color: #1f77b4; color: #1f77b4;">Exercise: Playing with edges</span>

Re-run the graph visualization cell several times and notice what changes.

Then, modify the code above to create some additional graph structures:
1. A directed graph with connections from 0&rarr;1, 1&rarr;2, 2&rarr;0
2. A graph with 5 nodes, where all nodes connect to node 2.
3. A graph with 4 nodes with only self-connections (0&harr;0, 1&harr;1, 2&harr;2, 3&harr;3) 

<h2 style="border:1px; border-style:solid; padding: 0.25em; color: #FFFFFF; background-color: #1f77b4">2. Load the ENZYMES dataset</h2>

Let's work with a slightly more interesting toy dataset: proteins! 

In [None]:
### Load the dataset
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)

<a name='problem_x'></a> 
### <span style="border:3px; border-style:solid; padding: 0.15em; border-color: #1f77b4; color: #1f77b4;">Exercise: Extract dataset properties</span>
*Hint: you can use `dir(dataset)` to list available properties.*
1. How many entries are in this dataset?
2. How many classes are there?
3. How many node features are there?
4. Look at the first entry -- how many nodes and edges are present?
5. Is the first graph undirected?

<a name='problem_x'></a> 
### <span style="border:3px; border-style:solid; padding: 0.15em; border-color: #1f77b4; color: #1f77b4;">Exercise: Visualize the ENZYMES graphs</span>

1. Now use the `networkx` package, as we did above, to visualize a random graph from this dataset.
2. Add a title indicating which class the graph belongs to.
3. Choose a different colormap.
4. Then make a figure with 6 subplots (in a 2x3 arrangement) with an example of each of the 6 classes. Assign a different single color for all the nodes within each each graph class.

<h2 style="border:1px; border-style:solid; padding: 0.25em; color: #FFFFFF; background-color: #1f77b4">3. Make a dataloader for graphs</h2>

In [None]:
from torch_geometric.loader import DataLoader # use this method!

### <span style="border:3px; border-style:solid; padding: 0.15em; border-color: #1f77b4; color: #1f77b4;">Exercise: Make a DataLoader</span>

1. Using the DataLoader class defined above, make a dataloader to load the dataset in batches of size 32.
2. Make sure to use use `shuffle=True`.
3. Print out the number of graphs in the first batch (should be 32, i.e. your batch size).
4. Print out the first batch of graphs (should be a `DataBatch` object with various properties listed). What does each number mean?

<h2 style="border:1px; border-style:solid; padding: 0.25em; color: #FFFFFF; background-color: #1f77b4">4. Train/Val/Test splits</h2>

Define hyperparameters for training: 

In [None]:
args = {
    "device" : 'cpu', # unless you're using a GPU
    "hidden_size" : 50,
    "epochs" : 10,
    "lr" : 0.001,
    "num_layers": 3,
    "dataset" : "ENZYMES",
    "batch_size": 32,
}

### <span style="border:3px; border-style:solid; padding: 0.15em; border-color: #1f77b4; color: #1f77b4;">Exercise: train/val/test splits</span>

Split the dataset into 80% train, 10% val, and 10% test sets in the cell below.

In [None]:
### Define the train/val/test datasets
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='./data', name=args["dataset"])

dataset = dataset.shuffle()
dataset_train = ???
dataset_val = ???
dataset_test = ???

print(f'Number of training graphs: {len(dataset_train)} ({100*len(dataset_train)/len(dataset):.0f}% of total)')
print(f'Number of val graphs: {len(dataset_val)} ({100*len(dataset_val)/len(dataset):.0f}% of total)')
print(f'Number of test graphs: {len(dataset_test)} ({100*len(dataset_test)/len(dataset):.0f}% of total)')

num_node_features = dataset.num_node_features
num_classes = dataset.num_classes

### Load batches of graphs at a time
train_loader = DataLoader(dataset_train, collate_fn=Batch.collate(), batch_size=args["batch_size"], shuffle=True)
val_loader = DataLoader(dataset_val, collate_fn=Batch.collate(), batch_size=args["batch_size"])
test_loader = DataLoader(dataset_test, collate_fn=Batch.collate(), batch_size=args["batch_size"])

<h2 style="border:1px; border-style:solid; padding: 0.25em; color: #FFFFFF; background-color: #1f77b4">5. Define the model</h2>

Let's build a basic GNN -- a Graph Convolutional Network (GCN):

In [None]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

In [None]:
model = GCN(hidden_channels=64)
print(model)

We want this network to perform a multi-class graph classification, i.e. given an input graph, we want our GNN to output which class of protein it belongs to. Which loss should we pick? Insert it below:

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = ???

def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()
     correct = 0
     loss_ = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)  
         loss = criterion(out, data.y)
         loss_ += loss.item()
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset), loss_ / len(loader.dataset)  # Derive ratio of correct predictions.

<h2 style="border:1px; border-style:solid; padding: 0.25em; color: #FFFFFF; background-color: #1f77b4">6. Train & evaluate</h2>

In [None]:
import wandb
use_wandb = True 

In [None]:
from tqdm.auto import trange

# Initialize W&B run for training
wandb.init(project="intro_to_pyg") # name your project whatever you like

for epoch in trange(args["epochs"]):
    train()
    train_acc, train_loss = test(train_loader)
    val_acc, val_loss = test(val_loader)
    
    # Log metrics to W&B
    if use_wandb:
        wandb.log({
            "train/loss": train_loss,
            "train/acc": train_acc,
            "val/acc": val_acc,
            "val/loss": val_loss,
        })

    torch.save(model, "graph_classification_model.pt")

# Finish the W&B run
if use_wandb:
    wandb.finish()

Now let's evalute the performance of the model on the test set:

### <span style="border:3px; border-style:solid; padding: 0.15em; border-color: #1f77b4; color: #1f77b4;">Exercise: Evaluate & improve the model</span>
1. What accuracy would you expect to see with random guessing?
2. Evaluate on your holdout test set. How does your test accuracy compare?
3. Can you improve the model by changing some of the hyperparameters to get a total test accuracy $\geq$ 40%? 