In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [2]:
from patch_gnn.graph import extract_neighborhood, generate_feature_dataframe
from pyprojroot import here
from proteingraph import read_pdb

TypeError: jit(): incompatible function arguments. The following argument types are supported:
    1. (arg0: function, arg1: function, arg2: function, arg3: List[int]) -> jaxlib.xla_extension.jax_jit.CompiledFunction

Invoked with: <function _rfft_transpose at 0x7fcd1c4f73a0>, <function _cpp_jit.<locals>.cache_miss at 0x7fcd1c4f7430>, <function _cpp_jit.<locals>.get_device_info at 0x7fcd1c4f74c0>, <function _cpp_jit.<locals>.get_jax_enable_x64 at 0x7fcd1c4f7550>, <function _cpp_jit.<locals>.get_jax_disable_jit_flag at 0x7fcd1c4f75e0>, (0, 2)

## Introduction

In this notebook, I will demo how to use the functions provided in `patch_gnn` and other packages
to do graph neural network training.

## Read in example data

This is an example dataset, an HIV Protease, from the PDB.

Firstly, we use `proteingraph.read_pdb` to get back a Graph object.

_Note: This should later be delegated to `graphein`._

In [None]:
hiv_graph = read_pdb(here() / "data/hiv1_homology_model.pdb")
hiv_graph

Quickly inspect what nodes are present.

In [None]:
hiv_graph.nodes()

Now, let's generate all patches from a graph of radius size 3.

In [None]:
from patch_gnn.graph import generate_patches

In [None]:
graph_patches = generate_patches(hiv_graph, 3)

To visualize a few of them, let's look at how they look like.

In [None]:
import networkx as nx

nx.draw(graph_patches[0])

In [None]:
nx.draw(graph_patches[-10])

## Generate the node input data

We are going to now generate the node feature matrices for each of the graphs.

In [None]:
import pandas as pd

In [None]:
aa_feats = pd.read_csv(here() / "data/amino_acid_properties.csv", index_col=0)
aa_feats

In [None]:
def featurize_amino_acid(n, d, aa_feats: pd.DataFrame) -> pd.Series:
    """
    Featurize a single amino acid.
    
    :param n: Graph node.
    :param d: Graph node attributes.
    :param aa_feats: Dataframe containing amino acid features.
    """
    aa = d["residue_name"]
    feats = pd.Series(aa_feats[aa], name=n)
    return feats

Next up, collect the node featurization functions into a list.
We must use `partial` to enure that each function's signature is limited to `n, d`.

In [None]:
from functools import partial
feature_funcs = [partial(featurize_amino_acid, aa_feats=aa_feats)]

Now, we stack the feature tensors for all graphs together.

In [None]:
from patch_gnn.graph import stack_feature_tensors
Fs = stack_feature_tensors(graph_patches, funcs=feature_funcs)
Fs.shape

In [None]:
import numpy as np
from patch_gnn.graph import (
    identity_matrix, 
    adjacency_matrix, 
    laplacian_matrix,
    to_adjacency_xarray
)

### Adjacency tensors

Next up, we stack the adjacency tensors together.
We are going to use 5 adjacency-like matrices, the 1st-3rd power adjacency matrices,
followed by the identity matrix and the graph laplacian matrix.

In [None]:
from functools import partial

adjacency_funcs = []
for i in range(3):
    adjacency_funcs.append(partial(adjacency_matrix, power=i, name=f"adjacency_{i}"))
adjacency_funcs.extend(
    [
        identity_matrix,
        laplacian_matrix,
    ]
)

In [None]:
adjacency_funcs

In [None]:
As = stack_adjacency_tensors(graph_patches, funcs=adjacency_funcs)

In [None]:
As.shape

## Now, we build the neural network layers.

In [None]:
import jax.numpy as np

from jax import lax, vmap, jit, grad
from jax.experimental import stax

from patch_gnn.layers import MessagePassing, GraphAverage, GraphSummation

### Example model that we might write

Firstly, we might want a custom graph embedding.

Here, what we do is stack together a message passing layer,
followed by Dense-Simgoid transformation,
followed by a graph summation op,
then another linear projection to 256 dimensions.

In [None]:
def CustomGraphEmbedding(n_output: int):
    """Return an embedding of a graph in n_output dimensions."""
    init_fun, apply_fun = stax.serial(
        MessagePassing(),
        stax.Dense(2048),
        stax.Sigmoid,
        GraphSummation(),
        stax.Dense(n_output),
    )
    return init_fun, apply_fun

embedding_init_fun, embedding_apply_fun = CustomGraphEmbedding(256)

In [None]:
def LinearRegression(num_outputs):
    """Linear regression layer."""
    init_fun, apply_fun = stax.serial(
        stax.Dense(num_outputs),
    )
    return init_fun, apply_fun

def LogisticRegression(num_outputs):
    """Logistic regression layer."""
    init_fun, apply_fun = stax.serial(
        stax.Dense(num_outputs),
        stax.Softmax,
    )
    return init_fun, apply_fun

model_init_fun, model_apply_fun = stax.serial(
    CustomGraphEmbedding(256),
    LinearRegression(1),
)

output_shape, params = model_init_fun(PRNGKey(42), input_shape=(*Fs[0].shape, As[0].shape[-1]))

### Now, we pass the data through the model!

In [None]:
inputs = (As, Fs)
out = vmap(partial(model_apply_fun, params))(inputs)
out.shape

In [None]:
embedding = vmap(partial(embedding_apply_fun, params[0]))(inputs)
embedding.shape

### Now try some really dumb learning task, like learning random numbers.

In [None]:
import numpy as onp

outputs = onp.random.normal(size=(len(graph_patches), 1))
outputs

Now, we try to predict these two numbers!

In [None]:
from patch_gnn.training import mseloss

dloss = grad(mseloss)
mseloss(params, model_apply_fun, inputs, outputs)

Remember the loss - it's pretty high

In [None]:
from jax.experimental.optimizers import adam
from patch_gnn.training import mseloss
from jax import grad

dmseloss = grad(mseloss)

In [None]:
import jax
from typing import Tuple

init, update, get_params = adam(step_size=1e-3)
get_params = jit(get_params)
state = init(params)

random_training_step = partial(step, dloss_fun=dmseloss, apply_fun=model_apply_fun, update_fun=update, get_params=get_params, inputs=inputs, outputs=outputs)
random_training_step = jit(random_training_step)

In [None]:
from tqdm.autonotebook import tqdm
for i in tqdm(range(1000)):
    state = random_training_step(i, state, inputs=inputs, outputs=outputs)

In [None]:
params_final = get_params(state)
mseloss(params_final, model_apply_fun, inputs, outputs)

In [None]:
mseloss(params, model_apply_fun, inputs, outputs)

In [None]:
preds = vmap(partial(model_apply_fun, params_final))(inputs)


In [None]:
original_preds = vmap(partial(model_apply_fun, params))(inputs)

In [None]:
import matplotlib.pyplot as plt
plt.scatter(preds.squeeze(), outputs.squeeze())

In [None]:
plt.scatter(original_preds.squeeze(), outputs.squeeze())

## Are the graphs distinguishable?

In [None]:
# Embedding for first graph, unoptimized params
vmap(partial(embedding_apply_fun, params_final[0]))(inputs)[0]

In [None]:
# Embedding for second graph, unoptimized params
vmap(partial(embedding_apply_fun, params_final[0]))(inputs)[2]

In [None]:
# Embedding for first graph, optimized params
vmap(partial(embedding_apply_fun, params_final[0]))(inputs)[3]

In [None]:
# Embedding for second graph, optimized params
vmap(partial(embedding_apply_fun, params_final[0]))(inputs)[1]

In [None]:
from jax.tree_util import tree_map, tree_flatten, tree_multimap

In [None]:
arr, unflattener = tree_flatten(params)
type(unflattener)

In [None]:
def array_diff(a1, a2):
    return a1 - a2

tree_map(np.mean, tree_multimap(array_diff, params, params_final)), tree_map(np.std, tree_multimap(array_diff, params, params_final))

## GRAVEYARD

In [None]:
raise Exception("You've hit the graveyard!")

In [None]:
G_adj = adjacency_matrix(subG)
G_adj

In [None]:
G_adj.shape, F.shape

In [None]:
# This is message passing in linear algebra form

F1 = np.dot(G_adj, F)
F1.shape

In [None]:
F2 = np.dot(G_adj, F1)
F2

In [None]:
F3 = np.dot(G_adj, F2)
F3

In [None]:
F3.shape