In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn

from functools import partial

In [2]:
%load_ext autoreload
%autoreload 2

In [197]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors
from models.graph_utils import add_graphs_tuples

In [4]:
n_nodes = 5000
x = np.load("/n/holyscratch01/iaifi_lab/ccuesta/data_for_sid/halos.npy")[:, :n_nodes,:]

In [5]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:4], 20)

## EGNN

In [7]:
import jax.numpy as jnp
from jax import jit

def rotation_matrix(angle_deg, axis):
    angle_rad = jnp.radians(angle_deg)
    axis = axis / jnp.linalg.norm(axis)
    
    a = jnp.cos(angle_rad / 2)
    b, c, d = -axis * jnp.sin(angle_rad / 2)
    
    return jnp.array([
        [a*a + b*b - c*c - d*d, 2*(b*c - a*d), 2*(b*d + a*c)],
        [2*(b*c + a*d), a*a + c*c - b*b - d*d, 2*(c*d - a*b)],
        [2*(b*d - a*c), 2*(c*d + a*b), a*a + d*d - b*b - c*c]
    ])

@jit
def rotate_representation(data, angle_deg, axis):
    rot_mat = rotation_matrix(angle_deg, axis)
    
    positions = data[:, :3]
    velocities = data[:, 3:6]
    scalars = data[:, 6:]
    
    rotated_positions = jnp.matmul(rot_mat, positions.T).T
    rotated_velocities = jnp.matmul(rot_mat, velocities.T).T
    
    return jnp.concatenate([rotated_positions, rotated_velocities, scalars], axis=1)

In [66]:
x = (x - x.mean((0,1))) / x.std((0, 1))

In [366]:
from typing import Callable
import flax.linen as nn
import jax.numpy as jnp
import jraph

from models.graph_utils import add_graphs_tuples
from models.mlp import MLP

class CoordNorm(nn.Module):
    """ Coordinate normalization, from 
    https://github.com/lucidrains/egnn-pytorch/blob/main/egnn_pytorch/egnn_pytorch.py#LL67C28-L67C28
    """
    eps: float = 1e-8
    scale_init: float = 1.

    def setup(self):
        self.scale = self.param('scale', nn.initializers.ones, (1,))

    def __call__(self, coors):
        norm = jnp.linalg.norm(coors, axis=-1, keepdims=True)
        normed_coors = coors / jax.lax.clamp(self.eps, norm, jnp.inf)
        return normed_coors * self.scale

def get_edge_mlp_updates(d_hidden, n_layers, activation) -> Callable:
    """Get an edge MLP update function

    Args:
        mlp_feature_sizes (int): number of features in the MLP

    Returns:
        Callable: update function
    """

    def update_fn(
        edges: jnp.ndarray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
        globals: jnp.ndarray,
    ) -> jnp.ndarray:
        """update edge features

        Args:
            edges (jnp.ndarray): edge attributes
            senders (jnp.ndarray): senders node attributes
            receivers (jnp.ndarray): receivers node attributes
            globals (jnp.ndarray): global features

        Returns:
            jnp.ndarray: updated edge features
        """
                
        # Split senders and receivers into coordinates, velocities, and scalar attrs
        x_i, v_i, h_i = senders[:, :3], senders[:, 3:6], senders[:, 6:]
        x_j, v_j, h_j = receivers[:, :3], receivers[:, 3:6], receivers[:, 6:]
                
        # Messages from Eqs. (3) and (4)/(7) 
        phi_e = MLP([d_hidden] * (n_layers), activation=activation)
        phi_x = MLP([d_hidden] * (n_layers - 1) + [1], activation=activation)
        
        m_ij =  phi_e(jnp.concatenate([h_i, h_j, jnp.linalg.norm(x_i - x_j, axis=1, keepdims=True) ** 2, globals], axis=-1))
        return (x_i - x_j) * phi_x(m_ij), m_ij

    return update_fn

def get_node_mlp_updates(d_hidden, n_layers, activation) -> Callable:
    """Get an node MLP update function

    Args:
        mlp_feature_sizes (int): number of features in the MLP

    Returns:
        Callable: update function
    """

    def update_fn(
        nodes: jnp.ndarray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
        globals: jnp.ndarray,
    ) -> jnp.ndarray:
        """update edge features

        Args:
            edges (jnp.ndarray): edge attributes
            senders (jnp.ndarray): senders node attributes
            receivers (jnp.ndarray): receivers node attributes
            globals (jnp.ndarray): global features

        Returns:
            jnp.ndarray: updated edge features
        """
        sum_x_ij, m_i = receivers  # Get aggregated messages
        x_i, v_i, h_i = nodes[:, :3], nodes[:, 3:6], nodes[:, 6:]  # Split node attrs
                
        # From Eqs. (6) and (7)
        phi_v = MLP([d_hidden] * (n_layers - 1) + [1], activation=activation)
        phi_h = MLP([d_hidden] * (n_layers - 1) + [h_i.shape[-1]], activation=activation)
                    
        # Apply updates
        v_i_p = sum_x_ij / 20. + phi_v(h_i) * v_i 
        x_i_p = x_i + v_i_p
        h_i_p = phi_h(jnp.concatenate([h_i, m_i], -1)) + h_i  # Skip connection
                
        return jnp.concatenate([x_i_p, v_i_p, h_i_p], -1)

    return update_fn

class GraphConvNet(nn.Module):
    """A simple graph convolutional network"""

    message_passing_steps: int = 4
    skip_connections: bool = True
    norm_layer: bool = True
    d_hidden: int = 64
    n_layers : int = 3
    activation: str = "swish"

    @nn.compact
    def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
        """Do message passing on graph

        Args:
            graphs (jraph.GraphsTuple): graph object

        Returns:
            jraph.GraphsTuple: updated graph object
        """
        in_features = graphs.nodes.shape[-1]
        processed_graphs = graphs
        processed_graphs = processed_graphs._replace(globals=processed_graphs.globals.reshape(processed_graphs.globals.shape[0], -1))
        
        activation = getattr(nn, self.activation)
        
        update_node_fn = get_node_mlp_updates(self.d_hidden, self.n_layers, activation)
        update_edge_fn = get_edge_mlp_updates(self.d_hidden, self.n_layers, activation)

        # Apply message-passing rounds
        for _ in range(self.message_passing_steps):
            
            graph_net = jraph.GraphNetwork(update_node_fn=update_node_fn, update_edge_fn=update_edge_fn)
            
            if self.skip_connections:
                processed_graphs = add_graphs_tuples(graph_net(processed_graphs), processed_graphs)
            else:
                processed_graphs = graph_net(processed_graphs)
            
            if self.norm_layer:
                processed_graphs = self.norm(processed_graphs)

        return processed_graphs
    
    def norm(self, graph):
        x, v, h = graph.nodes[..., :3], graph.nodes[..., 3:6], graph.nodes[..., 6:]
        x, v, h = CoordNorm()(x), CoordNorm()(v), h if h.shape[-1] == 1 else nn.LayerNorm()(h)
        graph = graph._replace(nodes = jnp.concatenate([x, v, h], -1))
        return graph
        

In [367]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(GraphConvNet())
        return model(x)
    
model = GraphWrapper()
rng = jax.random.PRNGKey(42)

## Test equivariance

In [368]:
graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[20]]),
          nodes=x[:4, :, :], 
          edges=None,
          globals=np.ones((4, 7)),
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)

x_out = graph_out.nodes

angle_deg = 45.
axis = jnp.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])

x_out_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x_out, angle_deg, axis)

graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[20]]),
          nodes=jax.vmap(rotate_representation, in_axes=(0,None,None))(x[:4, :, :], angle_deg, axis),
          edges=None,
          globals=np.ones((4, 7)),
          senders=sources,
          receivers=targets)

graph_out, params = model.init_with_output(rng, graph)
x_out = graph_out.nodes

In [369]:
# Equivariance ratio
eq_ratio = x_out / x_out_rot
print(eq_ratio.max(), eq_ratio.min(), eq_ratio)

153.23671 -53.711906 [[[1.001014   0.99972224 0.9964565  ... 0.9187607  1.0029248  1.0020633 ]
  [0.9998553  0.9999207  1.000664   ... 1.0002093  1.0034618  1.0014204 ]
  [1.0003986  0.99552363 0.9945452  ... 1.0126435  0.97210914 1.0019491 ]
  ...
  [0.9999123  0.9999768  1.0001493  ... 0.9993647  1.0003253  1.0005534 ]
  [1.0010532  0.99940616 1.0015911  ... 0.9988841  1.0018824  1.000044  ]
  [0.9993006  1.0006421  0.9991711  ... 1.001663   0.99832714 1.000007  ]]

 [[1.000759   0.9996155  0.9991985  ... 0.99973816 0.9984365  0.9991454 ]
  [0.99905485 0.9972541  1.0007825  ... 0.9945778  1.0015038  1.0034573 ]
  [0.99137956 1.0010234  1.0025904  ... 1.0010153  1.0107495  1.0036522 ]
  ...
  [1.0001246  0.9994388  0.9999993  ... 1.0006801  0.99917036 1.000143  ]
  [0.998208   0.9997999  1.0004056  ... 1.0003096  1.0016159  0.99994296]
  [0.99504125 0.9977888  1.0017054  ... 0.99480414 1.0023159  1.0002432 ]]

 [[0.99987024 1.0001653  1.0002958  ... 0.9998616  1.0002657  0.99954647]
 

In [371]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

119316