# Tutorial 7: Graph Neural Networks

![Status](https://img.shields.io/static/v1.svg?label=Status&message=Under%20development&color=red)

**Filled notebook:** 
[![View on Github](https://img.shields.io/static/v1.svg?logo=github&label=Repo&message=View%20On%20Github&color=lightgrey)](https://github.com/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial7/GNN_overview.ipynb)
[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial7/GNN_overview.ipynb)  
**Empty notebook:** 
[![View on Github](https://img.shields.io/static/v1.svg?logo=github&label=Repo&message=View%20On%20Github&color=lightgrey)](https://github.com/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial7/GNN_overview.ipynb)
[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial7/GNN_overview.ipynb)  
**Pre-trained models:** 

In [1]:
## Standard libraries
import os
import json
import math
import numpy as np 

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline 
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install pytorch-lightning==1.0.3
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial7"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

## GNN models

### Graph representation

* Adjacency matrix
* List of edges

While expressing a graph as a list of edges is more efficient, using an adjacency matrix is more intuitive and simpler to implement. Thus, we will represent graphs as adjacency matrices in this notebook.

### Graph Convolutions

In [20]:
class GCNLayer(nn.Module):
    
    def __init__(self, c_in, c_out):
        super().__init__()
        self.projection = nn.Linear(c_in, c_out)
        
    def forward(self, node_feats, adj_matrix):
        # Shape: node_feats.shape = [batch_size, num_nodes, c_in], adj_matrix.shape = [batch_size, num_nodes, num_nodes]
        num_neighbours = adj_matrix.sum(dim=-1, keepdims=True)
        node_feats = self.projection(node_feats)
        node_feats = torch.bmm(adj_matrix, node_feats)
        node_feats = node_feats / num_neighbours
        return node_feats

In [22]:
layer = GCNLayer(2, 2)
layer.projection.weight.data = torch.Tensor([[1., 0.], [0., 1.]])
layer.projection.bias.data = torch.Tensor([0., 0.])
node_feats = torch.arange(6, dtype=torch.float32).view(1, 3, 2)
adj_matrix = torch.Tensor([[[1, 1, 0],
                            [0, 1, 0],
                            [0, 0, 1]]])
out_feats = layer(node_feats, adj_matrix)

print("Adjacency matrix", adj_matrix)
print("Input features", node_feats)
print("Output features", out_feats)

Adjacency matrix tensor([[[1., 1., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])
Input features tensor([[[0., 1.],
         [2., 3.],
         [4., 5.]]])
Output features tensor([[[1., 2.],
         [2., 3.],
         [4., 5.]]], grad_fn=<DivBackward0>)


### Graph Attention 

In [54]:
class GATLayer(nn.Module):
    
    def __init__(self, c_in, c_out, num_heads=1, concat_heads=True, alpha=0.2):
        super().__init__()
        self.num_heads = num_heads
        self.concat_heads = concat_heads
        if self.concat_heads:
            assert c_out % num_heads == 0, "Number of output features must be a multiple of the count of heads."
            c_out = c_out // num_heads
        
        self.projection = nn.Linear(c_in, c_out * num_heads)
        self.a = nn.Parameter(torch.Tensor(num_heads, 2 * c_out))
        self.leakyrelu = nn.LeakyReLU(alpha)
        
        nn.init.xavier_uniform_(self.projection.weight.data, gain=1.414)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        
    def forward(self, node_feats, adj_matrix, print_attn_probs=False):
        batch_size, num_nodes = node_feats.size(0), node_feats.size(1)
        node_feats = self.projection(node_feats)
        node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)
        
        a_input = torch.cat([
            node_feats[:,:,None,:,:].repeat(1, 1, num_nodes, 1, 1),
            node_feats[:,None,:,:,:].repeat(1, num_nodes, 1, 1, 1)
        ], dim=-1)
        attn_logits = torch.einsum('bijhc,hc->bijh', a_input, self.a) # Shape: [batch, num_nodes, num_nodes, num_heads]
        attn_logits = self.leakyrelu(attn_logits) 
        
        attn_logits = attn_logits.masked_fill(adj_matrix[...,None] == 0., -9e15)
        attn_probs = F.softmax(attn_logits, dim=2)
        if print_attn_probs:
            print("Attention probs\n", attn_probs.permute(0, 3, 1, 2))
        node_feats = torch.einsum('bijh,bjhc->bihc', attn_probs, node_feats)
        
        if self.concat_heads:
            node_feats = node_feats.view(batch_size, num_nodes, -1)
        else:
            node_feats = node_feats.mean(dim=2)
        
        return node_feats 

In [55]:
layer = GATLayer(2, 2, num_heads=2)
layer.projection.weight.data = torch.Tensor([[1., 0.], [0., 1.]])
layer.projection.bias.data = torch.Tensor([0., 0.])
node_feats = torch.arange(6, dtype=torch.float32).view(1, 3, 2)
adj_matrix = torch.Tensor([[[1, 1, 0],
                            [0, 1, 0],
                            [0, 0, 1]]])
out_feats = layer(node_feats, adj_matrix, print_attn_probs=True)

print("Adjacency matrix", adj_matrix)
print("Input features", node_feats)
print("Output features", out_feats)

Attention probs
 tensor([[[[0.0372, 0.9628, 0.0000],
          [0.0000, 1.0000, 0.0000],
          [0.0000, 0.0000, 1.0000]],

         [[0.5013, 0.4987, 0.0000],
          [0.0000, 1.0000, 0.0000],
          [0.0000, 0.0000, 1.0000]]]], grad_fn=<PermuteBackward>)
Adjacency matrix tensor([[[1., 1., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])
Input features tensor([[[0., 1.],
         [2., 3.],
         [4., 5.]]])
Output features tensor([[[1.9255, 1.9973],
         [2.0000, 3.0000],
         [4.0000, 5.0000]]], grad_fn=<ViewBackward>)


## Graph Tasks

### Node-level tasks: Semi-supervised node classification

### Edge-level tasks: Link prediction

### Graph-level tasks: Graph classification