Run this with GPU Kernel

In [None]:
!pip install WandB
!pip install pytorch_lightning

import torch

!pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster  --y
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git


In [2]:
import os

import torch
import torch_scatter
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

from torch import Tensor
from typing import Union, Tuple, Optional
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
                                    OptTensor)

from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax

## Custom Graph Layer Implementation

References:
- http://web.stanford.edu/class/cs224w/

The job of a message passing layer is to update the current feature representation or embedding of each node in a graph by propagating and transforming information within the graph. Overall, the general paradigm of a message passing layers is: 1) pre-processing -> 2) **message passing** / propagation -> 3) post-processing.

The `forward` fuction that we will implement for our message passing layer captures this execution logic. Namely, the `forward` function handles the pre and post-processing of node features / embeddings, as well as initiates message passing by calling the `propagate` function.

The `propagate` function encapsulates the message passing process! It does so by calling three important functions: 1) `message`, 2) `aggregate`, and 3) `update`. Our implementation will vary slightly from this, as we will not explicitly implement `update`, but instead place the logic for updating node embeddings after message passing and within the `forward` function. To be more specific, after information is propagated (message passing), we can further transform the node embeddings outputed by `propagate`. Therefore, the output of `forward` is exactly the node embeddings after one GNN layer.

Lastly, before starting to implement our own layer, let us dig a bit deeper into each of the functions described above:

1.

```
def propagate(edge_index, x=(x_i, x_j), extra=(extra_i, extra_j), size=size):
```
Calling `propagate` initiates the message passing process. Looking at the function parameters, we highlight a couple of key parameters.

  - `edge_index` is passed to the forward function and captures the edge structure of the graph.
  - `x=(x_i, x_j)` represents the node features that will be used in message passing. In order to explain why we pass the tuple `(x_i, x_j)`, we first look at how our edges are represented. For every edge $(i, j) \in {E}$, we can differentiate $i$ as the source or central node ($x_{central}$) and j as the neighboring node ($x_{neighbor}$).
  
    Taking the example of message passing above, for a central node $u$ we will aggregate and transform all of the messages associated with the nodes $v$ s.t. $(u, v) \in {E}$ (i.e. $v \in \mathscr{N}_{u}$). Thus we see, the subscripts `_i` and `_j` allow us to specifcally differenciate features associated with central nodes (i.e. nodes  recieving message information) and neighboring nodes (i.e. nodes passing messages).

    This is definitely a somewhat confusing concept; however, one key thing to remember / wrap your head around is that depending on the perspective, a node $x$ acts as a central node or a neighboring node. In fact, in undirected graphs we store both edge directions (i.e. $(i, j)$ and $(j, i)$). From the central node perspective, `x_i`, x is collecting neighboring information to update its embedding. From a neighboring node perspective, `x_j`, x is passing its message information along the edge connecting it to a different central node.

  - `extra=(extra_i, extra_j)` represents additional information that we can associate with each node beyond its current feature embedding. In fact, we can include as many additional parameters of the form `param=(param_i, param_j)` as we would like. Again, we highlight that indexing with `_i` and `_j` allows us to differentiate central and neighboring nodes.

  The output of the `propagate` function is a matrix of node embeddings after the message passing process and has shape $[N, d]$.

2.
```
def message(x_j, ...):
```
The `message` function is called by propagate and constructs the messages from
neighboring nodes $j$ to central nodes $i$ for each edge $(i, j)$ in *edge_index*. This function can take any argument that was initially passed to `propagate`. Furthermore, we can again differentiate central nodes and neighboring nodes by appending `_i` or `_j` to the variable name, .e.g. `x_i` and `x_j`. Looking more specifically at the variables, we have:

  - `x_j` represents a matrix of feature embeddings for all neighboring nodes passing their messages along their respective edge (i.e. all nodes $j$ for edges $(i, j) \in {E}$). Thus, its shape is $[|{E}|, d]$!

  Critically, we see that the output of the `message` function is a matrix of neighboring node embeddings ready to be aggregated, having shape $[|{E}|, d]$.

3.
```
def aggregate(self, inputs, index, dim_size = None):
```
Lastly, the `aggregate` function is used to aggregate the messages from neighboring nodes. Looking at the parameters we highlight:

  - `inputs` represents a matrix of the messages passed from neighboring nodes (i.e. the output of the `message` function).
  - `index` has the same shape as `inputs` and tells us the central node that corresponding to each of the rows / messages $j$ in the `inputs` matrix. Thus, `index` tells us which rows / messages to aggregate for each central node.

  The output of `aggregate` is of shape $[N, d]$.


For additional resources refer to the PyG documentation for implementing custom message passing layers: https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html

In [3]:
class CustomGraphSage(MessagePassing):

    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):
        super(CustomGraphSage, self).__init__(**kwargs)

        self.in_channels = in_channels # 8 in first layer
        self.out_channels = out_channels # 2 in first layer
        self.normalize = normalize

        self.lin_l = Linear(in_channels, out_channels, bias=False)
        self.use_bias = bias
        self.bias = Parameter(torch.empty(out_channels))
        self.lin_r = Linear(in_channels, out_channels, bias=False)

        ############################################################################
        # TODO: Your code here!
        # Define the layers needed for the message and update functions below.
        # self.lin_l is the linear transformation that you apply to embedding
        #            for central node.
        # self.lin_r is the linear transformation that you apply to aggregated
        #            message from neighbors.
        # Don't forget the bias!
        # Our implementation is ~2 lines, but don't worry if you deviate from this.

        ############################################################################

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        self.bias.data.zero_()
        pass

    def forward(self, x, edge_index, size = None):
        out = self.propagate(edge_index, x=(x,x), size=size)
        out = self.lin_r(out)

        out = out + self.lin_l(x)
        if self.use_bias:
          out = out + self.bias

        out = F.normalize(out, p=2., dim=-1)

        ############################################################################
        # TODO: Your code here!
        # Implement message passing, as well as any post-processing (our update rule).
        # 1. Call the propagate function to conduct the message passing.
        #    1.1 See the description of propagate above or the following link for more information:
        #        https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html
        #    1.2 We will only use the representation for neighbor nodes (x_j), so by default
        #        we pass the same representation for central and neighbor nodes as x=(x, x).
        # 2. Update our node embedding with skip connection from the previous layer.
        # 3. If normalize is set, do L-2 normalization (defined in
        #    torch.nn.functional)

        ############################################################################

        return out


    def message(self, x_j):
        return x_j

        ############################################################################
        # TODO: Your code here!
        # Implement your message function here.
        # Hint: Look at the formulation of the mean aggregation function, focusing on
        # what message each neighboring node passes.

        ############################################################################

        return out

    def aggregate(self, inputs, index, dim_size = None):
        # The axis along which to index number of nodes.
        node_dim = self.node_dim
        out = torch_scatter.scatter(inputs, index, node_dim, dim_size=dim_size, reduce="mean")

        ############################################################################
        # TODO: Your code here!
        # Implement your aggregate function here.
        # See here as how to use torch_scatter.scatter:
        # https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html#torch_scatter.scatter

        ############################################################################

        return out

In [4]:
node_embedings = torch.ones(4, 8) # A graph with 4 nodes and 8 dimensional node features
edge_index = torch.tensor([[0, 1, 2, 0, 3],
                            [1, 0, 1, 3, 2]], dtype=torch.long) # Example edge index

In [5]:
custom_layer = CustomGraphSage(in_channels = 8, out_channels = 2, normalize = True)

In [6]:
transformed_embeddings = custom_layer(node_embedings, edge_index)
assert transformed_embeddings.shape == torch.Size([4, 2]), "Incorrect shape returned."

In [7]:
layer = torch_geometric.nn.SAGEConv(8, 2, aggr='mean', normalize=True, bias = False)

In [8]:
transformed_embeddings = layer(node_embedings, edge_index)
assert transformed_embeddings.shape == torch.Size([4, 2]), "Incorrect shape returned."

## Train and Evaluate

Build a training loop and evaluate your custom graph layer implementation

In [9]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

import numpy as np
from sklearn.metrics import accuracy_score
import random

from torch_geometric.data import DataLoader
import wandb
from torch_geometric.nn import global_mean_pool, global_max_pool


In [10]:
from torch_geometric.datasets import TUDataset

dataset = TUDataset(name='ENZYMES', cleaned=False, root='')

def create_data_splits(dataset, test_size=100, val_size=50, seed=None):
    if seed is not None:  # Only set the seed if it's provided
        torch.manual_seed(seed)
    indices = torch.randperm(len(dataset))

    train_indices = indices[:-(test_size+val_size)]
    val_indices = indices[-(test_size+val_size):-test_size]
    test_indices = indices[-test_size:]

    train_dataset = dataset[train_indices]
    val_dataset = dataset[val_indices]
    test_dataset = dataset[test_indices]

    return train_dataset, val_dataset, test_dataset

Downloading https://www.chrsmrrs.com/graphkerneldatasets/ENZYMES.zip
Extracting ./ENZYMES/ENZYMES.zip
Processing...
Done!


In [11]:
class CustomMessagePassingNetworkLightning(pl.LightningModule):
    def __init__(self):
        super(CustomMessagePassingNetworkLightning, self).__init__()
        self.conv1 = CustomGraphSage(3, 32)
        self.conv2 = CustomGraphSage(32, 64)
        self.conv3 = CustomGraphSage(64, 64)
        self.lin1 = nn.Linear(64, 32)
        self.lin2 = nn.Linear(32, 6)

        self.loss = nn.CrossEntropyLoss()

        self.train_loss = []
        self.train_probabilities = []
        self.train_labels = []

        self.val_loss = []
        self.val_probabilities = []
        self.val_labels = []


    def forward(self, data):

        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))

        x = F.relu(self.lin1(x))
        x = self.lin2(x)

        h = global_mean_pool(x, data.batch)

        return

    def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor, validate: bool = False) -> torch.Tensor:
        loss = self.loss(logits, labels)
        return loss

    def training_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.compute_loss(logits, batch.y)
        train_probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
        train_labels = batch.y.detach().cpu().numpy()

        self.train_loss.append(loss.detach().cpu().numpy())
        self.train_probabilities.append(train_probabilities)
        self.train_labels.append(train_labels)

        return loss

    def validation_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.compute_loss(logits, batch.y)
        val_probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
        val_labels = batch.y.detach().cpu().numpy()

        self.val_loss.append(loss.detach().cpu().numpy())
        self.val_probabilities.append(val_probabilities)
        self.val_labels.append(val_labels)

        return loss

    def on_validation_epoch_end(self) -> None:
        val_loss = np.mean(self.val_loss)
        self.log("val/loss", val_loss, prog_bar=True)

        val_proba = np.concatenate(self.val_probabilities)
        val_labels = np.concatenate(self.val_labels)

        val_acc = accuracy_score(val_labels, np.argmax(val_proba, axis=-1))
        self.log("val/accuracy", val_acc, prog_bar=False)

        self.val_loss.clear()
        self.val_probabilities.clear()
        self.val_labels.clear()


    def test_step(self, batch, batch_idx):
        logits = self(batch)
        test_probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
        test_labels = batch.y.detach().cpu().numpy()

        test_acc = accuracy_score(test_labels, np.argmax(test_probabilities, axis=-1))
        self.log("test/accuracy", test_acc, prog_bar=False)


    def on_train_epoch_end(self) -> None:
        train_loss = np.mean(self.train_loss)
        self.log("train/loss", train_loss, prog_bar=True)

        train_proba = np.concatenate(self.train_probabilities)
        train_labels = np.concatenate(self.train_labels)

        train_acc = accuracy_score(train_labels, np.argmax(train_proba, axis=-1))
        self.log("train/accuracy", train_acc, prog_bar=False)

        self.train_loss.clear()
        self.train_probabilities.clear()
        self.train_labels.clear()


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=5e-4)
        return optimizer

In [None]:
for seed in [42,43,44]:
    experiment_name = 'custom_sage_seed_'+str(seed)+"_"+str(random.random()*10000)
    model_wrapper = CustomMessagePassingNetworkLightning()
    logger = WandbLogger(project="adlg-gnn", name = experiment_name)

    trainer = pl.Trainer(
        accelerator="cpu",
        max_epochs=200,
        logger=logger,
        log_every_n_steps=50,
        check_val_every_n_epoch=20,
    )

    train_dataset, val_dataset, test_dataset = create_data_splits(dataset, val_size=0, seed=seed)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64)

    trainer.fit(model_wrapper, train_loader)
    name = experiment_name + '.pth'
    trainer.save_checkpoint(name)
    wandb.save(name)

    trainer.test(dataloaders=test_loader, ckpt_path=name)

    wandb.finish()

## Compare with SAGEConv

In [12]:
class SAGEConvLightning(pl.LightningModule):
    def __init__(self):
        super(SAGEConvLightning, self).__init__()
        self.conv1 = torch_geometric.nn.SAGEConv(3, 32, aggr='mean', normalize=True, bias = False)
        self.conv2 = torch_geometric.nn.SAGEConv(32, 64, aggr='mean', normalize=True, bias = False)
        self.conv3 = torch_geometric.nn.SAGEConv(64, 64, aggr='mean', normalize=True, bias = False)
        self.lin1 = nn.Linear(64, 32)
        self.lin2 = nn.Linear(32, 6)

        self.loss = nn.CrossEntropyLoss()

        self.train_loss = []
        self.train_probabilities = []
        self.train_labels = []

        self.val_loss = []
        self.val_probabilities = []
        self.val_labels = []


    def forward(self, data):

        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))

        x = F.relu(self.lin1(x))
        x = self.lin2(x)

        h = global_mean_pool(x, data.batch)

        return h

    def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor, validate: bool = False) -> torch.Tensor:
        loss = self.loss(logits, labels)
        return loss

    def training_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.compute_loss(logits, batch.y)
        train_probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
        train_labels = batch.y.detach().cpu().numpy()

        self.train_loss.append(loss.detach().cpu().numpy())
        self.train_probabilities.append(train_probabilities)
        self.train_labels.append(train_labels)

        return loss

    def validation_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.compute_loss(logits, batch.y)
        val_probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
        val_labels = batch.y.detach().cpu().numpy()

        self.val_loss.append(loss.detach().cpu().numpy())
        self.val_probabilities.append(val_probabilities)
        self.val_labels.append(val_labels)

        return loss

    def on_validation_epoch_end(self) -> None:
        val_loss = np.mean(self.val_loss)
        self.log("val/loss", val_loss, prog_bar=True)

        val_proba = np.concatenate(self.val_probabilities)
        val_labels = np.concatenate(self.val_labels)

        val_acc = accuracy_score(val_labels, np.argmax(val_proba, axis=-1))
        self.log("val/accuracy", val_acc, prog_bar=False)

        self.val_loss.clear()
        self.val_probabilities.clear()
        self.val_labels.clear()


    def test_step(self, batch, batch_idx):
        logits = self(batch)
        test_probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
        test_labels = batch.y.detach().cpu().numpy()

        test_acc = accuracy_score(test_labels, np.argmax(test_probabilities, axis=-1))
        self.log("test/accuracy", test_acc, prog_bar=False)


    def on_train_epoch_end(self) -> None:
        train_loss = np.mean(self.train_loss)
        self.log("train/loss", train_loss, prog_bar=True)

        train_proba = np.concatenate(self.train_probabilities)
        train_labels = np.concatenate(self.train_labels)

        train_acc = accuracy_score(train_labels, np.argmax(train_proba, axis=-1))
        self.log("train/accuracy", train_acc, prog_bar=False)

        self.train_loss.clear()
        self.train_probabilities.clear()
        self.train_labels.clear()


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=5e-4)
        return optimizer

In [None]:
for seed in [42,43,44]:
    experiment_name = 'sage_seed_'+str(seed)+"_"+str(random.random()*10000)
    model_wrapper = SAGEConvLightning()
    logger = WandbLogger(project="adlg-gnn", name = experiment_name)

    trainer = pl.Trainer(
        accelerator="cpu",
        max_epochs=200,
        logger=logger,
        log_every_n_steps=50,
        check_val_every_n_epoch=20,
    )

    train_dataset, val_dataset, test_dataset = create_data_splits(dataset, val_size=0, seed=seed)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64)

    trainer.fit(model_wrapper, train_loader)
    name = experiment_name + '.pth'
    trainer.save_checkpoint(name)
    wandb.save(name)

    trainer.test(dataloaders=test_loader, ckpt_path=name)

    wandb.finish()

## GRU in aggeration

In [23]:
class CustomGRU(MessagePassing):

    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):
        super(CustomGRU, self).__init__(**kwargs)

        self.in_channels = in_channels # 8 in first layer
        self.out_channels = out_channels # 2 in first layer
        self.normalize = normalize

        self.lin_l = Linear(in_channels, out_channels, bias=False)
        self.use_bias = bias
        self.bias = Parameter(torch.empty(out_channels))
        self.lin_r = Linear(in_channels, out_channels, bias=False)

        self.aggregator = torch.nn.GRU(in_channels, in_channels, 1)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        self.bias.data.zero_()
        pass

    def forward(self, x, edge_index, size = None):
        out = self.propagate(edge_index, x=(x,x), size=size)
        out = self.lin_r(out)

        out = out + self.lin_l(x)
        if self.use_bias:
          out = out + self.bias

        out = F.normalize(out, p=2., dim=-1)

        return out


    def message(self, x_j):
        return x_j

    def aggregate(self, inputs, index, dim_size = None):

        # order sequence according to nodes
        sequence = inputs[torch.argsort(index)]
        output, h_n = self.aggregator(sequence)

        # return hidden state as aggregation
        return h_n

In [14]:
custom_layer = CustomGRU(in_channels = 8, out_channels = 2, normalize = True)
transformed_embeddings = custom_layer(node_embedings, edge_index)
assert transformed_embeddings.shape == torch.Size([4, 2]), "Incorrect shape returned."

### Train

In [24]:
class GruLightning(pl.LightningModule):
    def __init__(self):
        super(GruLightning, self).__init__()
        self.conv1 = CustomGRU(3, 32)
        self.conv2 = CustomGRU(32, 64)
        self.conv3 = CustomGRU(64, 64)
        self.lin1 = nn.Linear(64, 32)
        self.lin2 = nn.Linear(32, 6)

        self.loss = nn.CrossEntropyLoss()

        self.train_loss = []
        self.train_probabilities = []
        self.train_labels = []

        self.val_loss = []
        self.val_probabilities = []
        self.val_labels = []


    def forward(self, data):

        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))

        x = F.relu(self.lin1(x))
        x = self.lin2(x)

        h = global_mean_pool(x, data.batch)

        return h

    def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor, validate: bool = False) -> torch.Tensor:
        loss = self.loss(logits, labels)
        return loss

    def training_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.compute_loss(logits, batch.y)
        train_probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
        train_labels = batch.y.detach().cpu().numpy()

        self.train_loss.append(loss.detach().cpu().numpy())
        self.train_probabilities.append(train_probabilities)
        self.train_labels.append(train_labels)

        return loss

    def validation_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.compute_loss(logits, batch.y)
        val_probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
        val_labels = batch.y.detach().cpu().numpy()

        self.val_loss.append(loss.detach().cpu().numpy())
        self.val_probabilities.append(val_probabilities)
        self.val_labels.append(val_labels)

        return loss

    def on_validation_epoch_end(self) -> None:
        val_loss = np.mean(self.val_loss)
        self.log("val/loss", val_loss, prog_bar=True)

        val_proba = np.concatenate(self.val_probabilities)
        val_labels = np.concatenate(self.val_labels)

        val_acc = accuracy_score(val_labels, np.argmax(val_proba, axis=-1))
        self.log("val/accuracy", val_acc, prog_bar=False)

        self.val_loss.clear()
        self.val_probabilities.clear()
        self.val_labels.clear()


    def test_step(self, batch, batch_idx):
        logits = self(batch)
        test_probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
        test_labels = batch.y.detach().cpu().numpy()

        test_acc = accuracy_score(test_labels, np.argmax(test_probabilities, axis=-1))
        self.log("test/accuracy", test_acc, prog_bar=False)


    def on_train_epoch_end(self) -> None:
        train_loss = np.mean(self.train_loss)
        self.log("train/loss", train_loss, prog_bar=True)

        train_proba = np.concatenate(self.train_probabilities)
        train_labels = np.concatenate(self.train_labels)

        train_acc = accuracy_score(train_labels, np.argmax(train_proba, axis=-1))
        self.log("train/accuracy", train_acc, prog_bar=False)

        self.train_loss.clear()
        self.train_probabilities.clear()
        self.train_labels.clear()


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=5e-4)
        return optimizer

In [None]:
## NOTE: Use T4 runtime here as GRU trains very slowly on CPU

for seed in [42,43,44]:
    experiment_name = 'gru_'+str(seed)+"_"+str(random.random()*10000)
    model_wrapper = GruLightning()
    logger = WandbLogger(project="adlg-gnn", name = experiment_name)

    # We instantiate a Pytorch Lightning trainer
    trainer = pl.Trainer(
        accelerator="gpu",
        max_epochs=200,
        logger=logger,
        log_every_n_steps=50,
        check_val_every_n_epoch=20,
    )

    train_dataset, val_dataset, test_dataset = create_data_splits(dataset, val_size=0, seed=seed)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64)

    trainer.fit(model_wrapper, train_loader)
    name = experiment_name + '.pth'
    trainer.save_checkpoint(name)
    wandb.save(name)

    trainer.test(dataloaders=test_loader, ckpt_path=name)

    wandb.finish()

## Attention

In [126]:
class Attention(MessagePassing):

    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):
        super(Attention, self).__init__(**kwargs)

        self.in_channels = in_channels # 8 in first layer
        self.out_channels = out_channels # 2 in first layer
        self.normalize = normalize

        self.lin_l = Linear(in_channels, out_channels, bias=False)
        self.use_bias = bias
        self.bias = Parameter(torch.empty(out_channels))
        self.lin_r = Linear(in_channels, out_channels, bias=False)

        self.aggregator = nn.Linear(4,1)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        self.bias.data.zero_()
        pass

    def forward(self, x, edge_index, size = None):
        out = self.propagate(edge_index, x=(x,x), size=size)
        out = self.lin_r(out)

        out = out + self.lin_l(x)
        if self.use_bias:
          out = out + self.bias

        out = F.normalize(out, p=2., dim=-1)

        return out


    def message(self, x_j):
        return x_j

    def aggregate(self, inputs, index, dim_size = None):

        node_dim = self.node_dim
        agg_mean = torch_scatter.scatter(inputs, index, node_dim, dim_size=dim_size, reduce="mean").to(inputs.device)
        agg_max = torch_scatter.scatter(inputs, index, node_dim, dim_size=dim_size, reduce="max").to(inputs.device)
        agg_min = torch_scatter.scatter(inputs, index, node_dim, dim_size=dim_size, reduce="min").to(inputs.device)
        agg_sum = torch_scatter.scatter(inputs, index, node_dim, dim_size=dim_size, reduce="sum").to(inputs.device)

        all = torch.stack([agg_mean,agg_max,agg_min,agg_sum],dim=-1).to(inputs.device)

        agg = self.aggregator(all).squeeze()

        assert agg.shape == agg_mean.shape, "Shape" + str(agg.shape) + " != " + str(agg_mean.shape)

        # return hidden state as aggregation
        return agg

In [116]:
class AttentionLightning(pl.LightningModule):
    def __init__(self):
        super(AttentionLightning, self).__init__()
        self.conv1 = Attention(3, 32)
        self.conv2 = Attention(32, 64)
        self.conv3 = Attention(64, 64)
        self.lin1 = nn.Linear(64, 32)
        self.lin2 = nn.Linear(32, 6)

        self.loss = nn.CrossEntropyLoss()

        self.train_loss = []
        self.train_probabilities = []
        self.train_labels = []

        self.val_loss = []
        self.val_probabilities = []
        self.val_labels = []


    def forward(self, data):

        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))

        x = F.relu(self.lin1(x))
        x = self.lin2(x)

        h = global_mean_pool(x, data.batch)

        return h

    def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor, validate: bool = False) -> torch.Tensor:
        loss = self.loss(logits, labels)
        return loss

    def training_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.compute_loss(logits, batch.y)
        train_probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
        train_labels = batch.y.detach().cpu().numpy()

        self.train_loss.append(loss.detach().cpu().numpy())
        self.train_probabilities.append(train_probabilities)
        self.train_labels.append(train_labels)

        return loss

    def validation_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.compute_loss(logits, batch.y)
        val_probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
        val_labels = batch.y.detach().cpu().numpy()

        self.val_loss.append(loss.detach().cpu().numpy())
        self.val_probabilities.append(val_probabilities)
        self.val_labels.append(val_labels)

        return loss

    def on_validation_epoch_end(self) -> None:
        val_loss = np.mean(self.val_loss)
        self.log("val/loss", val_loss, prog_bar=True)

        val_proba = np.concatenate(self.val_probabilities)
        val_labels = np.concatenate(self.val_labels)

        val_acc = accuracy_score(val_labels, np.argmax(val_proba, axis=-1))
        self.log("val/accuracy", val_acc, prog_bar=False)

        self.val_loss.clear()
        self.val_probabilities.clear()
        self.val_labels.clear()


    def test_step(self, batch, batch_idx):
        logits = self(batch)
        test_probabilities = F.softmax(logits, dim=-1).detach().cpu().numpy()
        test_labels = batch.y.detach().cpu().numpy()

        test_acc = accuracy_score(test_labels, np.argmax(test_probabilities, axis=-1))
        self.log("test/accuracy", test_acc, prog_bar=False)


    def on_train_epoch_end(self) -> None:
        train_loss = np.mean(self.train_loss)
        self.log("train/loss", train_loss, prog_bar=True)

        train_proba = np.concatenate(self.train_probabilities)
        train_labels = np.concatenate(self.train_labels)

        train_acc = accuracy_score(train_labels, np.argmax(train_proba, axis=-1))
        self.log("train/accuracy", train_acc, prog_bar=False)

        self.train_loss.clear()
        self.train_probabilities.clear()
        self.train_labels.clear()


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=5e-4)
        return optimizer

In [127]:
for seed in [42,43,44]:
    experiment_name = 'att_'+str(seed)+"_"+str(random.random()*10000)
    model_wrapper = AttentionLightning()
    logger = WandbLogger(project="adlg-gnn", name = experiment_name)

    # We instantiate a Pytorch Lightning trainer
    trainer = pl.Trainer(
        accelerator="cpu",
        max_epochs=200,
        logger=logger,
        log_every_n_steps=50,
        check_val_every_n_epoch=20,
    )

    train_dataset, val_dataset, test_dataset = create_data_splits(dataset, val_size=0, seed=seed)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64)

    trainer.fit(model_wrapper, train_loader)
    name = experiment_name + '.pth'
    trainer.save_checkpoint(name)
    wandb.save(name)

    trainer.test(dataloaders=test_loader, ckpt_path=name)

    wandb.finish()

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type  

Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=200` reached.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at att_42_8288.225557177662.pth
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at att_42_8288.225557177662.pth
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


VBox(children=(Label(value='0.191 MB of 0.191 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test/accuracy,▁
train/accuracy,▁▁▂▂▃▃▃▃▃▃▄▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▆▇▇▇▇▇▇████
train/loss,█████▇▇▇▇▇▇▇▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
epoch,200.0
test/accuracy,0.51649
train/accuracy,0.764
train/loss,0.74951
trainer/global_step,1600.0


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
[34m[1mwandb[0m: Currently logged in as: [33mmdueck[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type             | Params
-------------------------------------------
0 | conv1 | Attention        | 229   
1 | conv2 | Attention        | 4.2 K 
2 | conv3 | Attention        | 8.3 K 
3 | lin1  | Linear           | 2.1 K 
4 | lin2  | Linear           | 198   
5 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
14.9 K    Trainable params
0         Non-trainable params
14.9 K    Total params
0.060     Total estimated model params size (MB)
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (8) is smaller than the logging interval Trainer(log_ev

Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=200` reached.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at att_43_9513.734404915314.pth
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at att_43_9513.734404915314.pth
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test/accuracy,▁
train/accuracy,▁▁▁▁▂▂▂▂▃▃▃▃▃▄▃▄▄▄▅▅▅▅▅▆▅▆▆▆▆▆▇▇▇▇▇▇████
train/loss,████████▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
epoch,200.0
test/accuracy,0.3151
train/accuracy,0.776
train/loss,0.73845
trainer/global_step,1600.0


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.


INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type             | Params
-------------------------------------------
0 | conv1 | Attention        | 229   
1 | conv2 | Attention        | 4.2 K 
2 | conv3 | Attention        | 8.3 K 
3 | lin1  | Linear           | 2.1 K 
4 | lin2  | Linear           | 198   
5 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
14.9 K    Trainable params
0         Non-trainable params
14.9 K    Total params
0.060     Total estimated model params size (MB)
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (8) is smaller than the logging interval Trainer(log_ev

Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=200` reached.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at att_44_2925.481188496739.pth
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at att_44_2925.481188496739.pth
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

VBox(children=(Label(value='0.191 MB of 0.191 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test/accuracy,▁
train/accuracy,▁▁▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▄▅▅▆▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇█
train/loss,█████▇▇▇▇▇▇▇▇▆▆▆▆▅▅▅▅▅▄▄▄▄▃▄▃▃▃▃▂▂▂▂▂▂▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
epoch,200.0
test/accuracy,0.3151
train/accuracy,0.614
train/loss,1.10803
trainer/global_step,1600.0
