In [53]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn

from functools import partial

In [54]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors

In [56]:
n_nodes = 5000
x = np.load("/n/holyscratch01/iaifi_lab/ccuesta/data_for_sid/halos.npy")[:, :n_nodes,:]

In [57]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:4], 20)

In [58]:
graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[20]]),
          nodes=x[:4], 
          edges=None,#np.zeros(sources.shape),
          globals=np.ones((4, 7)),
          senders=sources,
          receivers=targets)

## EGNN

In [59]:
import jax.numpy as jnp
from jax import jit

def rotation_matrix(angle_deg, axis):
    angle_rad = jnp.radians(angle_deg)
    axis = axis / jnp.linalg.norm(axis)
    
    a = jnp.cos(angle_rad / 2)
    b, c, d = -axis * jnp.sin(angle_rad / 2)
    
    return jnp.array([
        [a*a + b*b - c*c - d*d, 2*(b*c - a*d), 2*(b*d + a*c)],
        [2*(b*c + a*d), a*a + c*c - b*b - d*d, 2*(c*d - a*b)],
        [2*(b*d - a*c), 2*(c*d + a*b), a*a + d*d - b*b - c*c]
    ])

@jit
def rotate_representation(data, angle_deg, axis):
    rot_mat = rotation_matrix(angle_deg, axis)
    
    positions = data[:, :3]
    velocities = data[:, 3:6]
    scalars = data[:, 6:]
    
    rotated_positions = jnp.matmul(rot_mat, positions.T).T
    rotated_velocities = jnp.matmul(rot_mat, velocities.T).T
    
    return jnp.concatenate([rotated_positions, rotated_velocities, scalars], axis=1)

In [60]:
from typing import Callable
import flax.linen as nn
import jax.numpy as jnp
import jraph

from models.graph_utils import add_graphs_tuples
from models.mlp import MLP

def get_edge_mlp_updates() -> Callable:
    """Get an edge MLP update function

    Args:
        mlp_feature_sizes (int): number of features in the MLP

    Returns:
        Callable: update function
    """

    def update_fn(
        edges: jnp.ndarray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
        globals: jnp.ndarray,
    ) -> jnp.ndarray:
        """update edge features

        Args:
            edges (jnp.ndarray): edge attributes
            senders (jnp.ndarray): senders node attributes
            receivers (jnp.ndarray): receivers node attributes
            globals (jnp.ndarray): global features

        Returns:
            jnp.ndarray: updated edge features
        """
        
        # Split senders and receivers into coordinates, velocities, and scalar attrs
        x_i, v_i, h_i = senders[:, :3], senders[:, 3:6], senders[:, 6:]
        x_j, v_j, h_j = receivers[:, :3], receivers[:, 3:6], receivers[:, 6:]
        
        # Messages from Eqs. (3) and (4)/(7) 
        phi_e = MLP([64, 64])
        phi_x = MLP([64, 1])
        
        m_ij =  phi_e(jnp.concatenate([h_i, h_j, jnp.linalg.norm(x_i - x_j, axis=1, keepdims=True) ** 2], axis=-1))
        return (x_i - x_j) * phi_x(m_ij), m_ij

    return update_fn

def get_node_mlp_updates() -> Callable:
    """Get an node MLP update function

    Args:
        mlp_feature_sizes (int): number of features in the MLP

    Returns:
        Callable: update function
    """

    def update_fn(
        nodes: jnp.ndarray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
        globals: jnp.ndarray,
    ) -> jnp.ndarray:
        """update edge features

        Args:
            edges (jnp.ndarray): edge attributes
            senders (jnp.ndarray): senders node attributes
            receivers (jnp.ndarray): receivers node attributes
            globals (jnp.ndarray): global features

        Returns:
            jnp.ndarray: updated edge features
        """
        sum_x_ij, m_i = receivers  # Get aggregated messages
        x_i, v_i, h_i = nodes[:, :3], nodes[:, 3:6], nodes[:, 6:]  # Split node attrs
                
        # From Eqs. (6) and (7)
        phi_v = MLP([64, 1])
        phi_h = MLP([64, h_i.shape[-1]])
        
        # Apply updates
        v_i_p = sum_x_ij / 20. + phi_v(h_i) * v_i 
        x_i_p = x_i + v_i_p
        h_i_p = phi_h(jnp.concatenate([h_i, m_i], -1))
        
        return jnp.concatenate([x_i_p, v_i_p, h_i_p], -1)

    return update_fn

class GraphConvNet(nn.Module):
    """A simple graph convolutional network"""

    message_passing_steps: int
    skip_connections: bool = True
    layer_norm: bool = True

    @nn.compact
    def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
        """Do message passing on graph

        Args:
            graphs (jraph.GraphsTuple): graph object

        Returns:
            jraph.GraphsTuple: updated graph object
        """
        in_features = graphs.nodes.shape[-1]
        processed_graphs = graphs

        update_node_fn = get_node_mlp_updates()
        update_edge_fn = get_edge_mlp_updates()

        # Now, we will apply the GCN once for each message-passing round.
        for _ in range(self.message_passing_steps):
            graph_net = jraph.GraphNetwork(update_node_fn=update_node_fn, update_edge_fn=update_edge_fn)
            processed_graphs = graph_net(processed_graphs)

        return processed_graphs

In [61]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(GraphConvNet(
            message_passing_steps=1, 
            skip_connections=True,))
        return model(x)
    
model = GraphWrapper()
rng = jax.random.PRNGKey(312)

## Test equivariance

In [62]:
graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[20]]),
          nodes=x[:4, :, :], 
          edges=None,
          globals=None,
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)

x_out = graph_out.nodes

angle_deg = 30
axis = jnp.array([0, 1, 0])

x_out_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x_out, angle_deg, axis)

graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[20]]),
          nodes=jax.vmap(rotate_representation, in_axes=(0,None,None))(x[:4, :, :], angle_deg, axis),
          edges=None,
          globals=None,
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)
x_out = graph_out.nodes

In [63]:
# Equivariance ratio
x_out / x_out_rot

Array([[[0.99968636, 0.99905145, 1.0019348 , ..., 0.99905145,
         1.0019348 , 1.0000005 ],
        [0.99963415, 1.0244254 , 0.9853004 , ..., 1.0244254 ,
         0.9853004 , 1.        ],
        [1.0003184 , 0.9998195 , 1.0005709 , ..., 0.9998195 ,
         1.0005709 , 0.99999994],
        ...,
        [1.0000124 , 1.000107  , 0.9994209 , ..., 1.000107  ,
         0.9994209 , 1.0000002 ],
        [0.9995397 , 0.99955285, 0.9989698 , ..., 0.99955285,
         0.9989698 , 1.0000012 ],
        [1.0004483 , 0.99985904, 1.0009383 , ..., 0.99985904,
         1.0009383 , 1.0000004 ]],

       [[0.99998456, 1.0002854 , 1.0000325 , ..., 1.0002854 ,
         1.0000325 , 1.        ],
        [1.0005915 , 1.0005159 , 1.0001292 , ..., 1.0005159 ,
         1.0001292 , 1.        ],
        [0.9998185 , 0.99903196, 1.0007113 , ..., 0.99903196,
         1.0007113 , 1.        ],
        ...,
        [1.0001528 , 1.0001528 , 0.99961823, ..., 1.0001528 ,
         0.99961823, 0.99999946],
        [0.9