In [1]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn

from functools import partial

In [2]:
%load_ext autoreload
%autoreload 2

In [197]:
import sys
sys.path.append("../")

from models.gnn import GraphConvNet
from models.graph_utils import nearest_neighbors
from models.graph_utils import add_graphs_tuples

In [4]:
n_nodes = 5000
x = np.load("/n/holyscratch01/iaifi_lab/ccuesta/data_for_sid/halos.npy")[:, :n_nodes,:]

In [5]:
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:4], 20)

## EGNN

In [7]:
import jax.numpy as jnp
from jax import jit

def rotation_matrix(angle_deg, axis):
    angle_rad = jnp.radians(angle_deg)
    axis = axis / jnp.linalg.norm(axis)
    
    a = jnp.cos(angle_rad / 2)
    b, c, d = -axis * jnp.sin(angle_rad / 2)
    
    return jnp.array([
        [a*a + b*b - c*c - d*d, 2*(b*c - a*d), 2*(b*d + a*c)],
        [2*(b*c + a*d), a*a + c*c - b*b - d*d, 2*(c*d - a*b)],
        [2*(b*d - a*c), 2*(c*d + a*b), a*a + d*d - b*b - c*c]
    ])

@jit
def rotate_representation(data, angle_deg, axis):
    rot_mat = rotation_matrix(angle_deg, axis)
    
    positions = data[:, :3]
    velocities = data[:, 3:6]
    scalars = data[:, 6:]
    
    rotated_positions = jnp.matmul(rot_mat, positions.T).T
    rotated_velocities = jnp.matmul(rot_mat, velocities.T).T
    
    return jnp.concatenate([rotated_positions, rotated_velocities, scalars], axis=1)

In [66]:
x = (x - x.mean((0,1))) / x.std((0, 1))

In [476]:
from typing import Callable
import flax.linen as nn
import jax.numpy as jnp
import jraph

from models.graph_utils import add_graphs_tuples
from models.mlp import MLP

class CoordNorm(nn.Module):
    """ Coordinate normalization, from 
    https://github.com/lucidrains/egnn-pytorch/blob/main/egnn_pytorch/egnn_pytorch.py#LL67C28-L67C28
    """
    eps: float = 1e-8
    scale_init: float = 1.

    def setup(self):
        self.scale = self.param('scale', nn.initializers.ones, (1,))

    def __call__(self, coors):
        norm = jnp.linalg.norm(coors, axis=-1, keepdims=True)
        normed_coors = coors / jax.lax.clamp(self.eps, norm, jnp.inf)
        return normed_coors * self.scale

def get_edge_mlp_updates(d_hidden, n_layers, activation, position_only=False) -> Callable:
    """Get an edge MLP update function

    Args:
        mlp_feature_sizes (int): number of features in the MLP

    Returns:
        Callable: update function
    """

    def update_fn(
        edges: jnp.ndarray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
        globals: jnp.ndarray,
    ) -> jnp.ndarray:
        """update edge features

        Args:
            edges (jnp.ndarray): edge attributes
            senders (jnp.ndarray): senders node attributes
            receivers (jnp.ndarray): receivers node attributes
            globals (jnp.ndarray): global features

        Returns:
            jnp.ndarray: updated edge features
        """
                
        # Split senders and receivers into coordinates, velocities, and scalar attrs
        x_i, v_i, h_i = senders[:, :3], senders[:, 3:6], senders[:, 6:]
        x_j, v_j, h_j = receivers[:, :3], receivers[:, 3:6], receivers[:, 6:]
                
        # Messages from Eqs. (3) and (4)/(7) 
        phi_e = MLP([d_hidden] * (n_layers), activation=activation)
        phi_x = MLP([d_hidden] * (n_layers - 1) + [1], activation=activation)
        
        m_ij =  phi_e(jnp.concatenate([h_i, h_j, jnp.linalg.norm(x_i - x_j, axis=1, keepdims=True) ** 2, globals], axis=-1))
        return (x_i - x_j) * phi_x(m_ij), m_ij

    def update_fn_position_only(
        edges: jnp.ndarray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
        globals: jnp.ndarray,
    ) -> jnp.ndarray:
        """update edge features

        Args:
            edges (jnp.ndarray): edge attributes
            senders (jnp.ndarray): senders node attributes
            receivers (jnp.ndarray): receivers node attributes
            globals (jnp.ndarray): global features

        Returns:
            jnp.ndarray: updated edge features
        """
                
        # Split senders and receivers into coordinates, velocities, and scalar attrs
        x_i  = senders
        x_j = receivers
                
        # Messages from Eqs. (3) and (4)/(7) 
        phi_e = MLP([d_hidden] * (n_layers), activation=activation)
        phi_x = MLP([d_hidden] * (n_layers - 1) + [1], activation=activation)
        
        # Get invariants
        message_scalars = jnp.concatenate([jnp.linalg.norm(x_i - x_j, axis=1, keepdims=True) ** 2, globals], axis=-1)
        if edges is not None:
            message_scalars = jnp.concatenate([message_scalars, edges], axis=-1)  # Add edge features if available
        
        m_ij =  phi_e(message_scalars)
        return (x_i - x_j) * phi_x(m_ij), m_ij

    return update_fn if not position_only else update_fn_position_only

def get_node_mlp_updates(d_hidden, n_layers, activation, n_edge, position_only=False) -> Callable:
    """Get an node MLP update function

    Args:
        mlp_feature_sizes (int): number of features in the MLP

    Returns:
        Callable: update function
    """

    def update_fn(
        nodes: jnp.ndarray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
        globals: jnp.ndarray,
    ) -> jnp.ndarray:
        """update edge features

        Args:
            edges (jnp.ndarray): edge attributes
            senders (jnp.ndarray): senders node attributes
            receivers (jnp.ndarray): receivers node attributes
            globals (jnp.ndarray): global features

        Returns:
            jnp.ndarray: updated edge features
        """
        
        sum_x_ij, m_i = receivers  # Get aggregated messages
        x_i, v_i, h_i = nodes[:, :3], nodes[:, 3:6], nodes[:, 6:]  # Split node attrs
                
        # From Eqs. (6) and (7)
        phi_v = MLP([d_hidden] * (n_layers - 1) + [1], activation=activation)
        phi_h = MLP([d_hidden] * (n_layers - 1) + [h_i.shape[-1]], activation=activation)
                    
        # Apply updates
        v_i_p = sum_x_ij / (n_edge - 1) + phi_v(h_i) * v_i
        x_i_p = x_i + v_i_p
        h_i_p = phi_h(jnp.concatenate([h_i, m_i], -1)) + h_i  # Skip connection
                
        return jnp.concatenate([x_i_p, v_i_p, h_i_p], -1)

    def update_fn_position_only(
        nodes: jnp.ndarray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
        globals: jnp.ndarray,
    ) -> jnp.ndarray:
        """update edge features

        Args:
            edges (jnp.ndarray): edge attributes
            senders (jnp.ndarray): senders node attributes
            receivers (jnp.ndarray): receivers node attributes
            globals (jnp.ndarray): global features

        Returns:
            jnp.ndarray: updated edge features
        """
        
        sum_x_ij, m_i = receivers  # Get aggregated messages
        x_i = nodes
                
        # From Eqs. (6) and (7)
        phi_v = MLP([d_hidden] * (n_layers - 1) + [1], activation=activation)
        phi_h = MLP([d_hidden] * (n_layers - 1) + [h_i.shape[-1]], activation=activation)
                    
        # Apply updates
        x_i_p = x_i + sum_x_ij
                
        return x_i_p

    return update_fn if not position_only else update_fn_position_only

class EGNN(nn.Module):
    """A simple graph convolutional network"""

    message_passing_steps: int = 4
    skip_connections: bool = False
    norm_layer: bool = True
    d_hidden: int = 64
    n_layers : int = 3
    activation: str = "gelu"

    @nn.compact
    def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
        """Do message passing on graph

        Args:
            graphs (jraph.GraphsTuple): graph object

        Returns:
            jraph.GraphsTuple: updated graph object
        """
        in_features = graphs.nodes.shape[-1]
        processed_graphs = graphs
        processed_graphs = processed_graphs._replace(globals=processed_graphs.globals.reshape(processed_graphs.globals.shape[0], -1))
        
        activation = getattr(nn, self.activation)
                
        update_node_fn = get_node_mlp_updates(self.d_hidden, self.n_layers, activation, n_edge=processed_graphs.n_edge)
        update_edge_fn = get_edge_mlp_updates(self.d_hidden, self.n_layers, activation)
        
        # Switch for whether to use positions-only version of edge/node updates
        if graphs.nodes.shape[-1] < 6:
            raise NotImplementedError("Number of features should be either 3 (just positions) or >= 6 (positions, velocities, and scalars)")

        positions_only = True if graphs.nodes.shape[-1] == 3 else False

        # Apply message-passing rounds
        for _ in range(self.message_passing_steps):
            
            graph_net = jraph.GraphNetwork(update_node_fn=update_node_fn, update_edge_fn=update_edge_fn)
            
            if self.skip_connections:
                processed_graphs = add_graphs_tuples(graph_net(processed_graphs), processed_graphs)
            else:
                processed_graphs = graph_net(processed_graphs)
            
            if self.norm_layer:
                processed_graphs = self.norm(processed_graphs, positions_only=positions_only)

        return processed_graphs
    
    def norm(self, graph, positions_only=False):
        if not positions_only:
            x, v, h = graph.nodes[..., :3], graph.nodes[..., 3:6], graph.nodes[..., 6:]

            # Only apply LN if scalars have more than one feature
            x, v, h = CoordNorm()(x), CoordNorm()(v), h if h.shape[-1] == 1 else nn.LayerNorm()(h)
            graph = graph._replace(nodes=jnp.concatenate([x, v, h], -1))
        else:
            x = CoordNorm()(graph.nodes)
            graph = graph._replace(nodes=x)
        return graph
        

In [477]:
class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(EGNN())
        return model(x)
    
model = GraphWrapper()
rng = jax.random.PRNGKey(42)

## Test equivariance

In [478]:
graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[20]]),
          nodes=x[:4, :, :], 
          edges=None,
          globals=np.ones((4, 7)),
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)

x_out = graph_out.nodes

angle_deg = 45.
axis = jnp.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])

x_out_rot = jax.vmap(rotate_representation, in_axes=(0,None,None))(x_out, angle_deg, axis)

graph = jraph.GraphsTuple(
          n_node=np.array(4 * [[n_nodes]]), 
          n_edge=np.array(4 * [[20]]),
          nodes=jax.vmap(rotate_representation, in_axes=(0,None,None))(x[:4, :, :], angle_deg, axis),
          edges=None,
          globals=np.ones((4, 7)),
          senders=sources,
          receivers=targets)

graph_out, params = model.init_with_output(rng, graph)
x_out = graph_out.nodes

In [479]:
# Equivariance ratio
eq_ratio = x_out / x_out_rot
print(eq_ratio.max(), eq_ratio.min(), eq_ratio)

20.35442 -116.01335 [[[1.0002252  0.9961086  0.9995838  ... 0.9965355  0.99983704 0.9992793 ]
  [0.99985385 1.0009937  0.9863682  ... 1.0014853  0.98887014 0.99713904]
  [0.9999512  0.99969053 0.9997458  ... 0.9996097  0.99961597 0.9951595 ]
  ...
  [0.9999879  0.9939522  0.9981981  ... 0.9955679  0.99789745 1.0002339 ]
  [1.0008881  0.99951905 1.0003864  ... 0.9990917  1.0002927  0.9990781 ]
  [0.9994261  1.000005   0.99979764 ... 0.99999774 0.9999151  0.99951905]]

 [[0.9989653  1.0004969  0.961327   ... 1.0004354  0.9836156  0.9978476 ]
  [0.9988907  0.99992114 1.0001626  ... 0.999335   0.99997264 0.99622875]
  [1.0002766  0.9996672  0.99966156 ... 1.0004942  0.9988764  1.0002664 ]
  ...
  [0.99984944 1.0027136  1.0010624  ... 1.0038079  0.99937624 0.9998152 ]
  [1.0006504  1.0000986  1.0005808  ... 0.99983186 1.0003997  0.999679  ]
  [1.0000618  0.9999947  0.9998149  ... 1.0003642  0.9999997  1.0001197 ]]

 [[1.000036   0.9974575  1.0004278  ... 0.9963247  1.0001675  1.0005107 ]
  

In [480]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

119316