In [1]:
from functools import partial

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import jraph

import sys
sys.path.append('../')

from models.gnn import GNN

In [83]:
n_nodes = 4000
n_features = 3

x_train = np.load("../../hierarchical-encdec/data/set_diffuser_data/train_halos.npy")[:, :n_nodes, :n_features]
x_train = x_train / 1000.

x_train.shape

(1800, 4000, 3)

In [84]:
@partial(jax.jit, static_argnums=(1,))
def nearest_neighbors(
    x: jnp.array,
    k: int,
    mask: jnp.array = None,
):
    """Returns the nearest neighbors of each node in x.

    Args:
        x (jnp.array): positions of nodes
        k (int): number of nearest neighbors to find
        boxsize (float, optional): size of box if perdioc boundary conditions. Defaults to None.
        unit_cell (jnp.array, optional): unit cell for applying periodic boundary conditions. Defaults to None.
        mask (jnp.array, optional): node mask. Defaults to None.

    Returns:
        sources, targets: pairs of neighbors
    """
    if mask is None:
        mask = jnp.ones((x.shape[0],), dtype=np.int32)

    n_nodes = x.shape[0]

    # Compute the vector difference between positions
    dr = x[:, None, :] - x[None, :, :]

    # Calculate the distance matrix
    distance_matrix = jnp.linalg.norm(dr, axis=-1)

    distance_matrix = jnp.where(mask[:, None], distance_matrix, jnp.inf)
    distance_matrix = jnp.where(mask[None, :], distance_matrix, jnp.inf)

    indices = jnp.argsort(distance_matrix, axis=-1)[:, :k]

    sources = indices[:, 0].repeat(k)
    targets = indices.reshape(n_nodes * (k))

    return sources, targets, dr[sources, targets]

In [151]:
from jax.experimental.sparse import BCOO
import dataclasses

class DiffPool(nn.Module):
    n_downsamples: int = 2  # Number of downsample layers
    d_downsampling_factor: int = 4  # Downsampling factor at each layer
    k: int = 10  # Number of nearest neighbors to consider after downsampling
    gnn_kwargs: dict = dataclasses.field(default_factory=lambda: {"d_hidden":64, "n_layers":3})
    symmetric: bool = True  # Symmetrize the adjacency matrix
    task: str = "node"  # Node or graph task
    combine_hierarchies_method: str = "mean"  # How to aggregate hierarchical embeddings; TODO: impl attention
    use_edge_features: bool = False  # Whether to use edge features in adjacency matrix

    @nn.compact
    def __call__(self, x):
        
        # If graph prediction task, collect pooled embeddings at each hierarchy level
        if self.task == "graph":
            x_pool = jnp.zeros((self.n_downsamples, self.gnn_kwargs['d_output']))

        for i in range(self.n_downsamples):
            
            # Original and downsampled number of nodes
            n_nodes = x.nodes.shape[0]
            n_nodes_downsampled = n_nodes // self.d_downsampling_factor

             # Eq. (5), graph embedding layer
            z = GNN(task='node', **self.gnn_kwargs)(x) 

            # Eq. (6), generate assignment matrix
            # Remove d_hidden from gnn_kwargs and replace it with n_nodes_downsampled
            gnn_kwargs = dict(self.gnn_kwargs.copy())
            gnn_kwargs['d_output'] = n_nodes_downsampled

            s = GNN(task='node', **gnn_kwargs,)(x).nodes  
            s = jax.nn.softmax(s, axis=1)  # Row-wise softmax
            
            # Sparse adjacency matrix
            # If edge features, use them as weights, otherwise use 1 to indicate connectivity
            edge_index = jnp.array([x.senders, x.receivers])
            if self.use_edge_features:  
                edge_weight = nn.Dense(1)(x.edges)[..., 0]  # Edges might have more than one feature; project down
            else:
                edge_weight = jnp.ones((x.edges.shape[0],))

            a = BCOO((edge_weight, edge_index.T), shape=(n_nodes, n_nodes))
            
            # Eq. (3), coarsened node features
            x = s.T @ z.nodes  
            
            # Eq. (4), coarsened adjacency matrix)
            # Sparse matmul S^T @ A @ S
            a = s.T @ a @ s  

            # Make adj symmetric
            if self.symmetric:
                a = (a + a.T) / 2

            # Take the coarsened adjacency matrix and make a KNN graph of it
            indices = np.argsort(a, axis=-1)[:, :self.k]

            sources = indices[:, 0].repeat(self.k)
            targets = indices.reshape(n_nodes_downsampled * (self.k))

            # Create new graph
            x = jraph.GraphsTuple(
                nodes=x,
                edges=a[sources, targets][..., None],
                senders=sources,
                receivers=targets,
                globals=None,
                n_node=n_nodes_downsampled,
                n_edge=self.k,
            )

            # If graph prediction task, get hierarchical embeddings
            if self.task == "graph":
                x_pool = x_pool.at[i].set(jnp.mean(x.nodes, axis=0))

        if self.task == "graph":
            if self.combine_hierarchies_method == "mean":  # Mean over hierarchy levels
                x_pool = jnp.mean(x_pool, axis=0)
            elif self.combine_hierarchies_method == "concat":  # Max over hierarchy levels
                x_pool = jnp.concatenate(x_pool, axis=0)
            else:
                raise ValueError(f"Unknown combine_hierarchies_method: {self.combine_hierarchies_method}")

            return (x, x_pool)
        
        return x
    
class DiffPoolWrapper(nn.Module):
    model_kwargs: dict = dataclasses.field(default_factory=lambda: {})

    @nn.compact
    def __call__(self, x):
        return jax.vmap(DiffPool(**self.model_kwargs))(x)

In [162]:
# Original graph

n_batch = 2
k = 15

sources, targets, distances = jax.vmap(nearest_neighbors, in_axes=(0, None))(x_train[:n_batch], k)

graph = jraph.GraphsTuple(
          n_node=np.array(n_batch * [[n_nodes]]), 
          n_edge=np.array(n_batch * [[k]]),
          nodes=x_train[:n_batch, :, :], 
          edges=np.linalg.norm(distances, axis=-1)[..., None],
          globals=None,
          senders=sources,
          receivers=targets)

gnn_kwargs = {"d_hidden": 64, "d_output": 16, "n_layers": 2, "message_passing_steps":2}
                 
model = DiffPoolWrapper(model_kwargs={"n_downsamples": 4, 
                                "d_downsampling_factor": 4, 
                                "k": k,
                                "gnn_kwargs": gnn_kwargs,
                                "combine_hierarchies_method": 'mean',
                                "use_edge_features": False,
                                "task": 'graph'})

rng = jax.random.PRNGKey(0)
# rng = jax.random.split(rng, 1)
(graph, x_pooled), params = model.init_with_output(rng, graph)

In [153]:
print("Number of parameters:", sum(p.size for p in jax.tree_util.tree_flatten(params)[0]))

Number of parameters: 477615


4

In [159]:
print(f"We started with {n_nodes} nodes and downsampled by a factor of {model.model_kwargs['d_downsampling_factor']} {model.model_kwargs['n_downsamples']} times, so we should have {n_nodes // model.model_kwargs['d_downsampling_factor'] ** model.model_kwargs['n_downsamples']} nodes now.")

We started with 4000 nodes and downsampled by a factor of 4 4 times, so we should have 15 nodes now.


In [160]:
graph.nodes.shape

(2, 15, 16)

In [161]:
x_pooled

Array([[ 0.37964916, -1.8670151 ,  0.5420827 , -3.4975605 , -1.0746057 ,
        -0.5992161 ,  0.3162737 , -0.5658225 , -1.6766647 , -1.7225434 ,
        -0.38327152,  2.1413724 ,  0.10378033,  0.41748714, -0.48726815,
        -1.0837824 ],
       [ 0.38148662, -1.8926976 ,  0.52135444, -3.5302162 , -1.0936382 ,
        -0.64446616,  0.28737193, -0.5451107 , -1.6957139 , -1.7556255 ,
        -0.37822193,  2.1825082 ,  0.09021281,  0.39656502, -0.48781922,
        -1.122267  ]], dtype=float32)