In [1]:
from functools import partial

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import jraph

import sys
sys.path.append('../')

from models.gnn import GNN

In [2]:
n_nodes = 5000
n_features = 3

x_train = np.load("../../hierarchical-encdec/data/set_diffuser_data/train_halos.npy")[:, :n_nodes, :n_features]
x_train = x_train / 1000.

x_train.shape

(1800, 5000, 3)

In [3]:
@partial(jax.jit, static_argnums=(1,))
def nearest_neighbors(
    x: jnp.array,
    k: int,
    mask: jnp.array = None,
):
    """Returns the nearest neighbors of each node in x.

    Args:
        x (jnp.array): positions of nodes
        k (int): number of nearest neighbors to find
        boxsize (float, optional): size of box if perdioc boundary conditions. Defaults to None.
        unit_cell (jnp.array, optional): unit cell for applying periodic boundary conditions. Defaults to None.
        mask (jnp.array, optional): node mask. Defaults to None.

    Returns:
        sources, targets: pairs of neighbors
    """
    if mask is None:
        mask = jnp.ones((x.shape[0],), dtype=np.int32)

    n_nodes = x.shape[0]

    # Compute the vector difference between positions
    dr = x[:, None, :] - x[None, :, :]

    # Calculate the distance matrix
    distance_matrix = jnp.linalg.norm(dr, axis=-1)

    distance_matrix = jnp.where(mask[:, None], distance_matrix, jnp.inf)
    distance_matrix = jnp.where(mask[None, :], distance_matrix, jnp.inf)

    indices = jnp.argsort(distance_matrix, axis=-1)[:, :k]

    sources = indices[:, 0].repeat(k)
    targets = indices.reshape(n_nodes * (k))

    return sources, targets, dr[sources, targets]

In [4]:
from jax.experimental.sparse import BCOO
import dataclasses

class DiffPool(nn.Module):
    n_downsamples: int = 2  # Number of downsample layers
    d_downsampling_factor: int = 4  # Downsampling factor at each layer
    k: int = 10  # Number of nearest neighbors to consider after downsampling
    gnn_kwargs: dict = dataclasses.field(default_factory=lambda: {"d_hidden":64, "n_layers":3})
    symmetric: bool = True  # Symmetrize the adjacency matrix
    task: str = "node"  # Node or graph task
    combine_hierarchies_method: str = "mean"  # How to aggregate hierarchical embeddings; TODO: impl attention

    @nn.compact
    def __call__(self, x):
        
        # If graph prediction task, collect pooled embeddings at each hierarchy level
        if self.task == "graph":
            x_pool = jnp.zeros((self.n_downsamples, self.gnn_kwargs['d_hidden']))

        for i in range(self.n_downsamples):
            
            # Original and downsampled number of nodes
            n_nodes = x.nodes.shape[0]
            n_nodes_downsampled = n_nodes // self.d_downsampling_factor

             # Eq. (5), graph embedding layer
            z = GNN(task='node', **self.gnn_kwargs)(x) 

            # Eq. (6), generate assignment matrix
            # Remove d_hidden from gnn_kwargs and replace it with n_nodes_downsampled
            gnn_kwargs = dict(self.gnn_kwargs.copy())
            gnn_kwargs['d_hidden'] = n_nodes_downsampled

            s = GNN(task='node', **gnn_kwargs,)(x).nodes  
            s = jax.nn.softmax(s, axis=1)  # Row-wise softmax
            
            # Sparse adjacency matrix
            edge_index = jnp.array([x.senders, x.receivers])
            edge_weight = nn.Dense(1)(x.edges)[..., 0]  # Edges might have more than one feature; project down
            a = BCOO((edge_weight, edge_index.T), shape=(n_nodes, n_nodes))
            
            # Eq. (3), coarsened node features
            x = s.T @ z.nodes  
            
            # Eq. (4), coarsened adjacency matrix)
            # Sparse matmul S^T @ A @ S
            a = s.T @ a @ s  

            # Make adj symmetric
            if self.symmetric:
                a = (a + a.T) / 2

            # Take the coarsened adjacency matrix and make a KNN graph of it
            indices = np.argsort(a, axis=-1)[:, :self.k]

            sources = indices[:, 0].repeat(self.k)
            targets = indices.reshape(n_nodes_downsampled * (self.k))

            # Create new graph
            x = jraph.GraphsTuple(
                nodes=x,
                edges=a[sources, targets][..., None],
                senders=sources,
                receivers=targets,
                globals=None,
                n_node=n_nodes_downsampled,
                n_edge=self.k,
            )

            # If graph prediction task, get hierarchical embeddings
            if self.task == "graph":
                x_pool = x_pool.at[i].set(jnp.mean(x.nodes, axis=0))

            
        if self.task == "graph":
            if self.combine_hierarchies_method == "mean":  # Mean over hierarchy levels
                x_pool = jnp.mean(x_pool, axis=0)
            else:
                raise ValueError(f"Unknown combine_hierarchies_method: {self.combine_hierarchies_method}")

            return (x, x_pool)
        
        return x

In [5]:
# Original graph

n_batch = 2
k = 5

sources, targets, distances = jax.vmap(nearest_neighbors, in_axes=(0, None))(x_train[:n_batch], k)

graph = jraph.GraphsTuple(
          n_node=np.array(n_batch * [[n_nodes]]), 
          n_edge=np.array(n_batch * [[k]]),
          nodes=x_train[:n_batch, :, :], 
          edges=np.linalg.norm(distances, axis=-1)[..., None],
          globals=None,
          senders=sources,
          receivers=targets)

gnn_kwargs = {"d_hidden": 64, "n_layers": 2}

model = DiffPool(n_downsamples=4, 
                 d_downsampling_factor=4, 
                 k=k,
                 gnn_kwargs=gnn_kwargs,
                 task='graph')
                 
rng = jax.random.PRNGKey(0)
(graph, x_pooled), params = jax.vmap(partial(model.init_with_output, rng))(graph)



In [6]:
print(f"We started with {n_nodes} nodes and downsampled by a factor of {model.d_downsampling_factor} {model.n_downsamples} times, so we should have {n_nodes // model.d_downsampling_factor**model.n_downsamples} nodes now.")

We started with 5000 nodes and downsampled by a factor of 4 4 times, so we should have 19 nodes now.


In [7]:
graph.nodes.shape

(2, 19, 64)

In [8]:
x_pooled

Array([[ 2.3131170e+00,  5.9959066e-01, -8.2782376e-01, -1.3430538e+00,
        -5.9773880e-01, -2.6204677e+00,  2.1586642e+00,  1.2976978e+00,
         1.8912312e+00, -9.0804523e-01,  1.4964399e+00, -2.2673976e-01,
         2.4867387e+00,  2.5094144e+00, -1.6578006e+00, -5.7519598e+00,
         4.2466869e+00, -3.0880690e-02, -1.8448907e+00,  1.4376051e+00,
         2.1066837e-01,  3.3694351e-01, -1.2878375e+00, -4.3846827e+00,
         4.5278823e-01, -9.2205226e-01,  4.5015043e-01, -6.6028333e-01,
         6.0408354e-02,  5.2189142e-02, -2.7403590e-01, -3.3584299e+00,
        -3.7789006e+00, -1.2336316e+00, -2.1005521e+00, -2.0387125e+00,
         2.0132084e+00,  2.4599993e-01,  1.0808258e+00,  1.0420587e+00,
        -1.9011109e+00, -1.0029547e+00, -1.1978803e-01,  1.0380412e+00,
         4.2466733e-01, -3.2987196e+00,  2.4638064e+00,  8.9791083e-01,
         4.9437940e-02, -2.7256694e+00,  1.3670267e-01,  3.0473718e-01,
        -2.0539993e-01,  2.0859385e+00,  2.6627638e+00,  6.63552