In [1]:
import sys
sys.path.append("../")

import jax
import jax.numpy as np
from functools import partial
import jraph
from jax_md import space, partition

from models.graph_utils import nearest_neighbors, nearest_neighbors_ann, RadiusSearch

In [2]:
rng = jax.random.PRNGKey(42)

n_nodes = 5000
k = 10
n_batch = 4
n_pos = 3

x = np.load("/n/holyscratch01/iaifi_lab/ccuesta/data_for_sid/halos.npy")[:n_batch, :n_nodes,:n_pos]

## Brute force kNN

In [3]:
sources, targets = nearest_neighbors(x[0], k=k)
sources_ann, targets_ann = nearest_neighbors_ann(x[0], k=k)

In [4]:
np.array_equal(sources_ann, sources), np.array_equal(targets_ann, targets)

(Array(True, dtype=bool), Array(False, dtype=bool))

In [5]:
(targets_ann == targets).sum()

Array(28243, dtype=int32)

In [6]:
%%timeit
sources_ann, targets_ann = nearest_neighbors_ann(x[0], k=k)

4.88 ms ± 4.28 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
%%timeit
sources, targets = nearest_neighbors(x[0], k=k)

7.04 ms ± 3.89 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Using jax-md

Scale up without having to recalculate pairwise distance matrix. Let's see if it can scale to 50000 nodes.

In [8]:
ns = RadiusSearch(box_size=1000., cutoff=30.)
nbr = ns.init_neighbor_lst(x[0])

In [17]:
%%timeit
nbr_update, _ = ns.update_neighbor_lst(x, nbr)

87.1 ms ± 5.07 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [18]:
nbr_update, nbr = ns.update_neighbor_lst(x, nbr)

In [19]:
nbr_update.idx

Array([[[   0,  331,    1, ..., 5000, 5000, 5000],
        [   0,    0,    1, ..., 5000, 5000, 5000]],

       [[   0, 1379,    1, ..., 5000, 5000, 5000],
        [   0,    0,    1, ..., 5000, 5000, 5000]],

       [[   0, 3839,    1, ..., 5000, 5000, 5000],
        [   0,    0,    1, ..., 5000, 5000, 5000]],

       [[   0, 4977, 4138, ..., 5000, 5000, 5000],
        [   0,    0,    0, ..., 5000, 5000, 5000]]], dtype=int32)

In [20]:
# Shape [batch, (senders, receivers), -1]
nbr_update.idx.shape

(4, 2, 29868)

## Use in graph

In [21]:
from typing import Callable
import flax.linen as nn
import jax.numpy as jnp

from models.mlp import MLP


def get_node_mlp_updates(mlp_feature_sizes: int) -> Callable:
    """Get a node MLP update  function

    Args:
        mlp_feature_sizes (int): number of features in the MLP

    Returns:
        Callable: update function
    """

    def update_fn(
        nodes: jnp.ndarray,
        sent_attributes: jnp.ndarray,
        received_attributes: jnp.ndarray,
        globals: jnp.ndarray,
    ) -> jnp.ndarray:
        """update node features

        Args:
            nodes (jnp.ndarray): node features
            sent_attributes (jnp.ndarray): attributes sent to neighbors
            received_attributes (jnp.ndarray): attributes received from neighbors
            globals (jnp.ndarray): global features

        Returns:
            jnp.ndarray: updated node features
        """
        if received_attributes is not None:
            inputs = jnp.concatenate([nodes, received_attributes, globals], axis=1)
        else:
            inputs = jnp.concatenate([nodes, globals], axis=1)
        return MLP(mlp_feature_sizes)(inputs)

    return update_fn

def get_edge_mlp_updates(mlp_feature_sizes: int) -> Callable:
    """Get an edge MLP update function

    Args:
        mlp_feature_sizes (int): number of features in the MLP

    Returns:
        Callable: update function
    """

    def update_fn(
        edges: jnp.ndarray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
        globals: jnp.ndarray,
    ) -> jnp.ndarray:
        """update edge features

        Args:
            edges (jnp.ndarray): edge attributes
            senders (jnp.ndarray): senders node attributes
            receivers (jnp.ndarray): receivers node attributes
            globals (jnp.ndarray): global features

        Returns:
            jnp.ndarray: updated edge features
        """
        if edges is not None:
            inputs = jnp.concatenate([edges, senders, receivers, globals], axis=1)
        else:
            inputs = jnp.concatenate([senders, receivers, globals], axis=1)
        return MLP(mlp_feature_sizes)(inputs)
    return update_fn

In [22]:
class GraphConvNet(nn.Module):
    """A simple graph convolutional network"""
    latent_size: int = 32
    num_mlp_layers: int = 3
    message_passing_steps: int = 3
    skip_connections: bool = True
    layer_norm: bool = True

    @nn.compact
    def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
        """Do message passing on graph

        Args:
            graphs (jraph.GraphsTuple): graph object

        Returns:
            jraph.GraphsTuple: updated graph object
        """
        in_features = graphs.nodes.shape[-1]
        embedder = jraph.GraphMapFeatures(embed_node_fn=nn.Dense(self.latent_size))
        processed_graphs = embedder(graphs)
        processed_graphs = processed_graphs._replace(
            globals=processed_graphs.globals.reshape(processed_graphs.globals.shape[0], -1),
        )
        
        mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers
        update_node_fn = get_node_mlp_updates(mlp_feature_sizes)
        update_edge_fn = get_edge_mlp_updates(mlp_feature_sizes)

        # Now, we will apply the GCN once for each message-passing round.
        for _ in range(self.message_passing_steps):
            graph_net = jraph.GraphNetwork(
                update_node_fn=update_node_fn,
                update_edge_fn=update_edge_fn,
            )
            if self.skip_connections:
                processed_graphs = add_graphs_tuples(
                    graph_net(processed_graphs), processed_graphs
                )
            else:
                processed_graphs = graph_net(processed_graphs)

            if self.layer_norm:
                processed_graphs = processed_graphs._replace(
                    nodes=nn.LayerNorm()(processed_graphs.nodes),
                )
        return graphs

In [23]:
from models.gnn import add_graphs_tuples

In [24]:
globals = jax.random.normal(rng, (n_batch, 1, 4))
graph = jax.vmap(partition.to_jraph)(nbr_update, mask=None,  nodes=x, edges=None, globals=globals)

In [25]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape

((4, 2), (4, 2), (4, 5001, 3), (4, 2, 4), (4, 29868), (4, 29868))

In [26]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(GraphConvNet())
        return model(x)

In [None]:
model = GraphWrapper()
graph_update, params  = model.init_with_output(rng, graph)
graph_update.nodes.shape  # Output nodes