In [1]:
## Standard libraries
import os
import json
import math
import numpy as np
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## To run JAX on TPU in Google Colab, uncomment the two lines below
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()

## JAX
import jax
import jax.numpy as jnp
from jax import random
# Seeding for random operations
main_rng = random.PRNGKey(42)

## Flax (NN in JAX)
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax
from flax import linen as nn
from flax.training import train_state, checkpoints

## Optax (Optimizers in JAX)
try:
    import optax
except ModuleNotFoundError: # Install optax if missing
    !pip install --quiet optax
    import optax

## PyTorch
import torch
import torch.utils.data as data
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../../saved_models/tutorial7_jax"

print("Device:", jax.devices()[0])

  set_matplotlib_formats('svg', 'pdf') # For export


Device: TFRT_CPU_0


In [2]:
import urllib.request
from urllib.error import HTTPError
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/JAX/tutorial7/"
# Files to download
pretrained_files = []

# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("Something went wrong. Please contact the author with the full output including the following error:\n", e)

In [3]:
class GCNLayer(nn.Module):
    c_out : int  # Output feature size

    @nn.compact
    def __call__(self, node_feats, adj_matrix):
        """
        Inputs:
            node_feats - Array with node features of shape [batch_size, num_nodes, c_in]
            adj_matrix - Batch of adjacency matrices of the graph. If there is an edge from i to j, adj_matrix[b,i,j]=1 else 0.
                         Supports directed edges by non-symmetric matrices. Assumes to already have added the identity connections.
                         Shape: [batch_size, num_nodes, num_nodes]
        """
        # Num neighbours = number of incoming edges
        num_neighbours = adj_matrix.sum(axis=-1, keepdims=True)
        node_feats = nn.Dense(features=self.c_out, name='projection')(node_feats)
        node_feats = jax.lax.batch_matmul(adj_matrix, node_feats)
        node_feats = node_feats / num_neighbours
        return node_feats

In [4]:
node_feats = jnp.arange(8, dtype=jnp.float32).reshape((1, 4, 2))
adj_matrix = jnp.array([[[1, 1, 0, 0],
                            [1, 1, 1, 1],
                            [0, 1, 1, 1],
                            [0, 1, 1, 1]]]).astype(jnp.float32)

print("Node features:\n", node_feats)
print("\nAdjacency matrix:\n", adj_matrix)

Node features:
 [[[0. 1.]
  [2. 3.]
  [4. 5.]
  [6. 7.]]]

Adjacency matrix:
 [[[1. 1. 0. 0.]
  [1. 1. 1. 1.]
  [0. 1. 1. 1.]
  [0. 1. 1. 1.]]]


In [5]:
layer = GCNLayer(c_out=2)
# We define our own parameters here instead of using random initialization
params = {'projection': {
    'kernel': jnp.array([[1., 0.], [0., 1.]]),
    'bias': jnp.array([0., 0.])
}}
out_feats = layer.apply({'params': params}, node_feats, adj_matrix)

print("Adjacency matrix", adj_matrix)
print("Input features", node_feats)
print("Output features", out_feats)

Adjacency matrix [[[1. 1. 0. 0.]
  [1. 1. 1. 1.]
  [0. 1. 1. 1.]
  [0. 1. 1. 1.]]]
Input features [[[0. 1.]
  [2. 3.]
  [4. 5.]
  [6. 7.]]]
Output features [[[1. 2.]
  [3. 4.]
  [4. 5.]
  [4. 5.]]]


#### Attention based Graphical Neural Network

In [6]:
class GATLayer(nn.Module):
    c_out : int  # Dimensionality of output features
    num_heads : int  # Number of heads, i.e. attention mechanisms to apply in parallel.
    concat_heads : bool = True  # If True, the output of the different heads is concatenated instead of averaged.
    alpha : float = 0.2  # Negative slope of the LeakyReLU activation.

    def setup(self):
        if self.concat_heads:
            assert self.c_out % self.num_heads == 0, "Number of output features must be a multiple of the count of heads."
            c_out_per_head = self.c_out // self.num_heads
        else:
            c_out_per_head = self.c_out

        # Sub-modules and parameters needed in the layer
        self.projection = nn.Dense(c_out_per_head * self.num_heads,
                                   kernel_init=nn.initializers.glorot_uniform())
        self.a = self.param('a',
                            nn.initializers.glorot_uniform(),
                            (self.num_heads, 2 * c_out_per_head))  # One per head


    def __call__(self, node_feats, adj_matrix, print_attn_probs=False):
        """
        Inputs:
            node_feats - Input features of the node. Shape: [batch_size, c_in]
            adj_matrix - Adjacency matrix including self-connections. Shape: [batch_size, num_nodes, num_nodes]
            print_attn_probs - If True, the attention weights are printed during the forward pass (for debugging purposes)
        """
        batch_size, num_nodes = node_feats.shape[0], node_feats.shape[1]

        # Apply linear layer and sort nodes by head
        node_feats = self.projection(node_feats)
        node_feats = node_feats.reshape((batch_size, num_nodes, self.num_heads, -1))

        # We need to calculate the attention logits for every edge in the adjacency matrix
        # In order to take advantage of JAX's just-in-time compilation, we should not use
        # arrays with shapes that depend on e.g. the number of edges. Hence, we calculate
        # the logit for every possible combination of nodes. For efficiency, we can split
        # a[Wh_i||Wh_j] = a_:d/2 * Wh_i + a_d/2: * Wh_j.
        logit_parent = (node_feats * self.a[None,None,:,:self.a.shape[0]//2]).sum(axis=-1)
        logit_child = (node_feats * self.a[None,None,:,self.a.shape[0]//2:]).sum(axis=-1)
        attn_logits = logit_parent[:,:,None,:] + logit_child[:,None,:,:]
        attn_logits = nn.leaky_relu(attn_logits, self.alpha)

        # Mask out nodes that do not have an edge between them
        attn_logits = jnp.where(adj_matrix[...,None] == 1.,
                                attn_logits,
                                jnp.ones_like(attn_logits) * (-9e15))

        # Weighted average of attention
        attn_probs = nn.softmax(attn_logits, axis=2)
        if print_attn_probs:
            print("Attention probs\n", attn_probs.transpose(0, 3, 1, 2))
        node_feats = jnp.einsum('bijh,bjhc->bihc', attn_probs, node_feats)

        # If heads should be concatenated, we can do this by reshaping. Otherwise, take mean
        if self.concat_heads:
            node_feats = node_feats.reshape(batch_size, num_nodes, -1)
        else:
            node_feats = node_feats.mean(axis=2)

        return node_feats

In [7]:
layer = GATLayer(2, num_heads=2)
params = {
    'projection': {
        'kernel': jnp.array([[1., 0.], [0., 1.]]),
        'bias': jnp.array([0., 0.])
    },
    'a': jnp.array([[-0.2, 0.3], [0.1, -0.1]])
}
out_feats = layer.apply({'params': params}, node_feats, adj_matrix, print_attn_probs=True)

print("Adjacency matrix", adj_matrix)
print("Input features", node_feats)
print("Output features", out_feats)

Attention probs
 [[[[0.35434368 0.6456563  0.         0.        ]
   [0.10956531 0.14496915 0.264151   0.48131454]
   [0.         0.18580715 0.28850412 0.5256887 ]
   [0.         0.23912403 0.26961157 0.49126434]]

  [[0.5099987  0.49000132 0.         0.        ]
   [0.2975179  0.24358703 0.23403586 0.22485918]
   [0.         0.38382432 0.31424877 0.3019269 ]
   [0.         0.40175956 0.3289329  0.2693075 ]]]]
Adjacency matrix [[[1. 1. 0. 0.]
  [1. 1. 1. 1.]
  [0. 1. 1. 1.]
  [0. 1. 1. 1.]]]
Input features [[[0. 1.]
  [2. 3.]
  [4. 5.]
  [6. 7.]]]
Output features [[[1.2913126 1.9800026]
  [4.2344294 3.7724724]
  [4.679763  4.836205 ]
  [4.50428   4.7350955]]]
