In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn

from functools import partial

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors



In [4]:
x = np.load("../data/halos_small.npy")
n_nodes = 5000
k = 20
sources, targets, dist = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:4, :, :3], k)

In [68]:
graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[k]]),
          nodes=x[:4, :, :6], 
          edges=dist,
          globals=np.ones((4, 7)),
          senders=sources,
          receivers=targets)

In [69]:
graph.n_edge[0]

Array([20], dtype=int32)

In [70]:
graph.n_node.shape, graph.n_edge.shape, graph.nodes.shape, graph.globals.shape, graph.senders.shape, graph.receivers.shape, 

((4, 1), (4, 1), (4, 5000, 6), (4, 7), (4, 100000), (4, 100000))

In [71]:
graph.edges.shape

(4, 100000, 3)

In [72]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(GraphConvNet(
            latent_size=12, 
            hidden_size=32,
            num_mlp_layers=4, 
            attention=False,
            message_passing_steps=3, 
            shared_weights=False,
            skip_connections=True,))
        return model(x)

In [73]:
model = GraphWrapper()
rng = jax.random.PRNGKey(42)
graph_out, params = model.init_with_output(rng, graph)

graph_out

GraphsTuple(nodes=Array([[[-0.1125825 , -0.09283566,  0.11377222],
        [-0.05893194, -0.10181687,  0.07655472],
        [-0.1651399 , -0.19137335,  0.06769908],
        ...,
        [-0.07573469, -0.13577965,  0.09486147],
        [-0.14018403, -0.16820411,  0.08733329],
        [-0.24084674, -0.20037228, -0.04847446]],

       [[-0.11877929, -0.19698136, -0.03018252],
        [-0.17410675, -0.20649745,  0.09587927],
        [-0.07466327, -0.10741072,  0.10991385],
        ...,
        [-0.08268358, -0.10388442,  0.08842175],
        [-0.1477319 , -0.1747698 ,  0.08993898],
        [-0.25986856, -0.2063719 , -0.04190782]],

       [[-0.02476275, -0.0844294 , -0.01236729],
        [-0.12253042, -0.16551678,  0.08586553],
        [-0.13652821, -0.15572608, -0.02954918],
        ...,
        [-0.08189055, -0.11284494,  0.13404271],
        [-0.19673815, -0.22579373,  0.02905469],
        [-0.14372848, -0.16787201, -0.00698224]],

       [[-0.10140395, -0.1546606 ,  0.06181228],
      

In [74]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

38763

In [75]:
graph_out.nodes.shape

(4, 5000, 3)

## DynamicEdgeConv test

In [76]:
from models.dynamic_edge_conv import DynamicEdgeConvNet

In [77]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(DynamicEdgeConvNet(
            latent_size=12, 
            hidden_size=32,
            num_mlp_layers=4, 
            message_passing_steps=4, 
            k=k,
            skip_connections=True,))
        return model(x)

In [15]:
model = GraphWrapper()
rng = jax.random.PRNGKey(42)
graph_out, params = model.init_with_output(rng, graph)

graph_out

GraphsTuple(nodes=Array([[[-0.10018387, -0.11065444, -0.05212637],
        [-0.08652786, -0.09941588, -0.05340162],
        [-0.08675987, -0.09036908, -0.04361029],
        ...,
        [ 0.0063136 ,  0.09881737, -0.03693786],
        [ 0.01601949,  0.10301742, -0.01412692],
        [ 0.0157237 ,  0.11117488, -0.01114986]],

       [[-0.07606154, -0.08917262, -0.04142701],
        [-0.07613653, -0.06745652, -0.04296787],
        [-0.079873  , -0.0667052 , -0.05580064],
        ...,
        [-0.00079158,  0.079138  , -0.02717378],
        [ 0.01088912,  0.11066834, -0.01844661],
        [ 0.01167948,  0.11263683, -0.01662316]],

       [[-0.05494767, -0.03250584, -0.03126454],
        [-0.06130821, -0.03933219, -0.0337631 ],
        [-0.09967669, -0.10617372, -0.05868438],
        ...,
        [ 0.00834937,  0.10542567, -0.02426643],
        [ 0.02006699,  0.11076612, -0.00348963],
        [ 0.00723587,  0.10155639, -0.0291658 ]],

       [[-0.06585068, -0.03040348, -0.0412342 ],
      

## ChebConv

In [78]:
graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[k]]),
          nodes=(x[:4, :, :6] - 500) / 500, 
          edges=dist.sum(-1) / 1000.,
          globals=np.ones((4, 7)),
          senders=sources,
          receivers=targets)

In [103]:
from jax.experimental.sparse import BCOO, eye

In [104]:
from typing import Optional, Tuple
import jax
import jax.numpy as jnp

def get_laplacian(
    edge_index: jnp.ndarray,
    edge_weight: Optional[jnp.ndarray] = None,
    dtype: Optional[jnp.dtype] = None,
    num_nodes: Optional[int] = None
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
    
    if edge_weight is None:
        edge_weight = np.ones_like(edge_index[0])

    A = BCOO((edge_weight, edge_index.T), shape=(num_nodes,num_nodes))
    deg = A.sum(axis=0)
    D = BCOO((deg.todense(), np.array([np.arange(num_nodes), np.arange(num_nodes)]).T), shape=(num_nodes, num_nodes))
    L = D - A
    
    return L.indices.T, L.data
    
def __norm__(edge_index, edge_weight, lambda_max=None, num_nodes=5000):
    
    # Adjusting the get_laplacian function call to the correct format
    edge_index, edge_weight = get_laplacian(edge_index, edge_weight, num_nodes=num_nodes)

    assert edge_weight is not None, "Edge weights cannot be None after getting the Laplacian."

    # If lambda_max is not specified, calculate it as twice the max of the edge weights
    if lambda_max is None:
        lambda_max = 2.0 * edge_weight.max()

    # Normalizing edge weights
    edge_weight = (2.0 * edge_weight) / lambda_max

    return edge_index, edge_weight

edge_index, edge_weight = __norm__(edge_index=np.array([sources[0], targets[0]]), edge_weight=graph.edges[0], lambda_max=None)

In [172]:
from typing import Callable
import jraph
from jraph._src import utils
from models.mlp import MLP

def get_node_mlp_updates() -> Callable:
    """Get a node MLP update  function

    Args:
        mlp_feature_sizes (int): number of features in the MLP
        name (str, optional): name of the update function. Defaults to None.

    Returns:
        Callable: update function
    """

    def update_fn(
        nodes: jnp.ndarray,
        sent_attributes: jnp.ndarray,
        received_attributes: jnp.ndarray,
        globals: jnp.ndarray,
    ) -> jnp.ndarray:
        
        return received_attributes[..., None] * nodes

    return update_fn

# aggregate_edges_for_nodes_fn

class ChebConv(nn.Module):
    
    out_channels: int = 128
    K: int = 6
    
    @nn.compact
    def __call__(self, graph: jraph.GraphsTuple, lambda_max: float = None) -> jraph.GraphsTuple:
        """Do message passing on graph

        Args:
            graphs (jraph.GraphsTuple): graph object

        Returns:
            jraph.GraphsTuple: updated graph object
        """
        
        (senders, receivers), norm = __norm__(edge_index=np.array([graph.senders, graph.receivers]), edge_weight=graph.edges, lambda_max=lambda_max)
        
        Tx_0 = graph.nodes
        Tx_1 = graph.nodes
        out = nn.Dense(self.out_channels)(Tx_0)
        
        if self.K > 1:
            graph_Tx_1 = graph._replace(senders=senders, receivers=receivers, edges=norm)
            graph_Tx_1 = jraph.GraphNetwork(update_node_fn=get_node_mlp_updates(), update_edge_fn=None)(graph_Tx_1)
            Tx_1 = graph_Tx_1.nodes
            
            out = out + nn.Dense(self.out_channels)(Tx_1)
            
        for _ in range(2, self.K):
            graph_Tx_2 = graph._replace(nodes=Tx_1, senders=senders, receivers=receivers, edges=norm)
            Tx_2 = 2. * graph_Tx_2.nodes - Tx_0
            out = out + nn.Dense(self.out_channels)(Tx_2)
            Tx_0, Tx_1 = Tx_1, Tx_2
        
        return graph._replace(nodes=out)            
    
class AdaLayerNorm(nn.Module):
    """Adaptive layer norm; generate scale and shift parameters from conditioning context."""

    @nn.compact
    def __call__(self, x, conditioning):
        # Compute scale and shift parameters from conditioning context
        scale_and_shift = nn.gelu(nn.Dense(2 * x.shape[-1])(conditioning))
        scale, shift = np.split(scale_and_shift, 2, axis=-1)

        # Apply layer norm
        # Don't use bias or scale since these will be learnable through the conditioning context
        x = nn.LayerNorm(use_bias=False, use_scale=False)(x)

        # Apply scale and shift
        # Apple same scale, shift to all elements in sequence
        x = x * (1 + scale[None, :]) + shift[None, :]

        return x
    
class ChebConvNet(nn.Module):
    
    out_channels: int = 128
    K: int = 6
    message_passing_steps: int = 4
    
    @nn.compact
    def __call__(self, graph: jraph.GraphsTuple, lambda_max: float = None) -> jraph.GraphsTuple:
        
        in_channels = graph.nodes.shape[-1]
        
        for _ in range(self.message_passing_steps):
            graph = ChebConv(out_channels=self.out_channels, K=self.K)(graph)
            graph = graph._replace(nodes=AdaLayerNorm()(nn.gelu(graph.nodes), graph.globals))
            
        # Readout
        graph = graph._replace(nodes=nn.Dense(in_channels)(graph.nodes))
            
        return graph
            

In [175]:
from models.chebconv import ChebConvNet

class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(ChebConvNet())
        return model(x)



In [176]:
model = GraphWrapper()
rng = jax.random.PRNGKey(42)
graph_out, params = model.init_with_output(rng, graph)

graph_out.nodes[0]

Array([[-0.12269443,  0.14158587,  1.762475  ,  1.150424  ,  0.42675796,
         1.1578609 ],
       [-0.82985544,  1.5333217 , -0.181583  ,  0.8682801 , -1.1688771 ,
        -0.50984365],
       [-1.1721395 ,  0.7606144 , -0.02651902,  0.12594089, -0.6179613 ,
         2.012218  ],
       ...,
       [-0.8236913 , -0.2686318 , -1.1953758 , -0.18628658, -2.049673  ,
        -1.0728295 ],
       [-1.012834  ,  0.58206487, -0.21994793,  0.08183961, -1.3585956 ,
         1.6713423 ],
       [-0.8718533 , -0.62092274,  1.6060736 , -0.95375323, -0.02815629,
         1.244173  ]], dtype=float32)

In [177]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

311558