In [1]:
import jax
import jraph
import numpy as np
import jax.numpy as jnp

`jraph` stores graphs in the data type of `GraphsTuple`, which we can initialize by providing a lot of information:
1. number of nodes, V
2. number of edges, E
3. node features, V * d_V
4. edge features, E * d_E
5. global features, G * d_G

For directed graphs, we additionally have 

6. sender nodes, 1 * d_V
7. receiver nodes, 1 * d_V

In [2]:
n_node = jnp.array([3])
n_edge = jnp.array([2])
n_global = jnp.array([1])
d_node = 4
d_edge = 5
d_global = 6
node_feats = jnp.ones((n_node[0], d_node)) # nodes feature matrix
edge_feats = jnp.ones((n_edge[0], d_edge)) # edges feature matrix
global_feats = jnp.ones((n_global[0], d_global)) # global feature matrix
sender_nodes = jnp.array([0, 1])
receiver_nodes = jnp.array([2, 2]) # 0 -> 2 <- 1


In [3]:
single_graph = jraph.GraphsTuple(n_node=n_node, n_edge=n_edge, nodes=node_feats, edges=edge_feats,
                                 globals=global_feats, senders=sender_nodes, receivers=receiver_nodes)
single_graph

GraphsTuple(nodes=DeviceArray([[1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.]], dtype=float32), edges=DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32), receivers=DeviceArray([2, 2], dtype=int32), senders=DeviceArray([0, 1], dtype=int32), globals=DeviceArray([[1., 1., 1., 1., 1., 1.]], dtype=float32), n_node=DeviceArray([3], dtype=int32), n_edge=DeviceArray([2], dtype=int32))

In [4]:
single_graph_ = jraph.GraphsTuple(n_node=None, n_edge=None, nodes=node_feats, edges=edge_feats,
                                 globals=global_feats, senders=sender_nodes, receivers=receiver_nodes)
single_graph_

GraphsTuple(nodes=DeviceArray([[1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.]], dtype=float32), edges=DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32), receivers=DeviceArray([2, 2], dtype=int32), senders=DeviceArray([0, 1], dtype=int32), globals=DeviceArray([[1., 1., 1., 1., 1., 1.]], dtype=float32), n_node=None, n_edge=None)

Curiously, it seems that we can omit providing some information when initializing and get away with it, and there's no automatic assignment for `n_node` and `n_edge` attributes?!

A graph neural network (GNN) is usually composed of iterative whole-graph updates (propagations). We now proceed to test define update functions for nodes, edges, and globals to be identity maps and then use them to initialize a GNN using pre-defined structures from the library.

API for Jraph models can be found here: https://github.com/deepmind/jraph/blob/master/jraph/_src/models.py

In [5]:
def update_edge_fn(
      edge_features,
      sender_node_features,
      receiver_node_features,
      globals_):
    return edge_features # identity!

def update_node_fn(
      node_features,
      aggregated_sender_edge_features,
      aggregated_receiver_edge_features,
      globals_):
    return node_features # identity!

def update_globals_fn(
      aggregated_node_features,
      aggregated_edge_features,
      globals_):
    return globals_ # identity

In [6]:
test_gnn = jraph.GraphNetwork(update_edge_fn=update_edge_fn, 
                              update_node_fn=update_node_fn, 
                              update_global_fn=update_globals_fn)
test_gnn

<function jraph._src.models.GraphNetwork.<locals>._ApplyGraphNet(graph)>

In [7]:
updated_graph = test_gnn(single_graph)

In [9]:
print(updated_graph)

GraphsTuple(nodes=DeviceArray([[1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.]], dtype=float32), edges=DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32), receivers=DeviceArray([2, 2], dtype=int32), senders=DeviceArray([0, 1], dtype=int32), globals=DeviceArray([[1., 1., 1., 1., 1., 1.]], dtype=float32), n_node=DeviceArray([3], dtype=int32), n_edge=DeviceArray([2], dtype=int32))


So we successfully performed an update and we see that the data was kept the same because we used identities for all component updates. Viola! Next, we work on a practical binary classification example using GNN.

For a context.... copied from jraph's official repo:
```
The ogbg-molhiv dataset is a molecular property prediction dataset.
It is adopted from the MoleculeNet [1]. All the molecules are pre-processed
using RDKit [2].

Each graph represents a molecule, where nodes are atoms, and edges are chemical
bonds. Input node features are 9-dimensional, containing atomic number and
chirality, as well as other additional atom features such as formal charge and
whether the atom is in the ring or not.
The goal is to predict whether a molecule inhibits HIV virus replication or not.
Performance is measured in ROC-AUC.

This script uses a GraphNet to learn the prediction task.

Refs:
[1] Zhenqin Wu, Bharath Ramsundar, Evan N Feinberg, Joseph Gomes,
Caleb Geniesse, Aneesh SPappu, Karl Leswing, and Vijay Pande.
Moleculenet: a benchmark for molecular machine learning.
Chemical Science, 9(2):513–530, 2018.
[2] Greg Landrum et al. RDKit: Open-source cheminformatics, 2006.
```

In [10]:
@jraph.concatenated_args
def edge_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
    net = hk.Sequential(
      [hk.Linear(128), jax.nn.relu,
       hk.Linear(128)])
    return net(feats)


@jraph.concatenated_args
def node_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
    """Node update function for graph net."""
    net = hk.Sequential(
      [hk.Linear(128), jax.nn.relu,
       hk.Linear(128)])
    return net(feats)


@jraph.concatenated_args
def update_global_fn(feats: jnp.ndarray) -> jnp.ndarray:
    """Global update function for graph net."""
    # Molhiv is a binary classification task, so output pos neg logits.
    net = hk.Sequential(
      [hk.Linear(128), jax.nn.relu,
       hk.Linear(2)])
    return net(feats)

In [11]:
def net_fn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Graph net function."""
    # Add a global paramater for graph classification.
    graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1]))
    embedder = jraph.GraphMapFeatures(
      hk.Linear(128), hk.Linear(128), hk.Linear(128))
    net = jraph.GraphNetwork(
      update_node_fn=node_update_fn,
      update_edge_fn=edge_update_fn,
      update_global_fn=update_global_fn)
    return net(embedder(graph))

In [13]:
# some util functions to lower amortized costs compiling the graph network for different input sizes
def _nearest_bigger_power_of_two(x: int) -> int:
    """Computes the nearest power of two greater than x for padding."""
    y = 2
    while y < x:
        y *= 2
    return y


def pad_graph_to_nearest_power_of_two(
    graphs_tuple: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Pads a batched `GraphsTuple` to the nearest power of two.
    For example, if a `GraphsTuple` has 7 nodes, 5 edges and 3 graphs, this method
    would pad the `GraphsTuple` nodes and edges:
    7 nodes --> 8 nodes (2^3)
    5 edges --> 8 edges (2^3)
    And since padding is accomplished using `jraph.pad_with_graphs`, an extra
    graph and node is added:
    8 nodes --> 9 nodes
    3 graphs --> 4 graphs
    Args:
    graphs_tuple: a batched `GraphsTuple` (can be batch size 1).
    Returns:
    A graphs_tuple batched to the nearest power of two.
    """
    # Add 1 since we need at least one padding node for pad_with_graphs.
    pad_nodes_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_node)) + 1
    pad_edges_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_edge))
    # Add 1 since we need at least one padding graph for pad_with_graphs.
    # We do not pad to nearest power of two because the batch size is fixed.
    pad_graphs_to = graphs_tuple.n_node.shape[0] + 1
    return jraph.pad_with_graphs(graphs_tuple, pad_nodes_to, pad_edges_to,
                               pad_graphs_to)