In [52]:
import sys
sys.path.append('../')

from typing import Dict

import jax
import jax.numpy as jnp

import flax.linen as nn
import e3nn_jax as e3nn
import jraph

from models.utils.equivariant_graph_utils import get_equivariant_graph
from models.utils.graph_utils import build_graph

from models.segnn import SEGNN
from models.gnn import GNN
from models.segnn import SEGNN
from models.egnn import EGNN
from models.nequip import NequIP

## Example graph and configuration

In [53]:
# Make a (2, 1000, 3) point cloud of random points
key = jax.random.PRNGKey(0)
x_points = jax.random.uniform(key, (2, 1000, 3))

In [54]:
k = 10
n_radial = 64
position_features = True
r_max = 0.6
use_3d_distances = False
l_max = 1

## GraphNetwork

In [55]:
GNN_PARAMS = {
    "d_hidden": 128,
    "message_passing_steps": 3,
    "n_layers": 3,
    "activation": "gelu",
    "message_passing_agg": "mean",
    "readout_agg": "mean",
    "mlp_readout_widths": (4, 2, 2),
    "task": "graph",
    "n_outputs": 2,
    "norm": "none",
    "position_features": True,
    "residual": False,
}

class GraphWrapperGNN(nn.Module):
    param_dict: Dict
    @nn.compact
    def __call__(self, x):
        return jax.vmap(GNN(**self.param_dict))(x)

In [56]:
graph = build_graph(x_points, 
                None, 
                k=k, 
                use_edges=True, 
                n_radial_basis=n_radial,
                r_max=r_max,
                use_3d_distances=use_3d_distances,
)

model = GraphWrapperGNN(GNN_PARAMS)
out, params = model.init_with_output(jax.random.PRNGKey(0), graph)

# Number of parameters
print(f"Number of parameters: {sum([p.size for p in jax.tree.leaves(params)])}")

out

Number of parameters: 700674


Array([[ 0.12172805, -0.08840577],
       [ 0.11992944, -0.09075886]], dtype=float32)

## SEGNN

In [57]:
SEGNN_PARAMS = {
    "d_hidden": 128,
    "l_max_hidden": l_max,
    "n_layers": 3,
    "message_passing_steps": 3,
    "task": "graph",
    "output_irreps": e3nn.Irreps("1x0e"),
    "hidden_irreps": None,
    "message_passing_agg": "mean",
    "readout_agg": "mean",
    "n_outputs": 2,
    "scalar_activation": "gelu",
    "gate_activation": "sigmoid",
    "mlp_readout_widths": (4, 2, 2),
    "residual": False,
}

class GraphWrapperSEGNN(nn.Module):
    param_dict: Dict

    @nn.compact
    def __call__(self, x):

        positions = e3nn.IrrepsArray("1o", x.nodes[..., :3])
        
        if x.nodes.shape[-1] == 3:
            nodes = e3nn.IrrepsArray("1o", x.nodes[..., :])
            velocities = None
        else:
            nodes = e3nn.IrrepsArray("1o + 1o", x.nodes[..., :])
            velocities = e3nn.IrrepsArray("1o", x.nodes[..., 3:6])

        
        st_graph = get_equivariant_graph(
            node_features=nodes,
            positions=positions,
            velocities=None,
            steerable_velocities=False,
            senders=x.senders,
            receivers=x.receivers,
            n_node=x.n_node,
            n_edge=x.n_edge,
            globals=x.globals,
            edges=None,
            lmax_attributes=l_max,
            n_radial_basis=n_radial,
            r_max=r_max,
        )
        
        return jax.vmap(SEGNN(**self.param_dict))(st_graph)


In [58]:
model = GraphWrapperSEGNN(SEGNN_PARAMS, )
out, params = model.init_with_output(jax.random.PRNGKey(0), graph)

# Number of parameters
print(f"Number of parameters: {sum([p.size for p in jax.tree.leaves(params)])}")

out

Number of parameters: 573253


Array([[-0.00476563, -0.00017594],
       [-0.00483601, -0.00015454]], dtype=float32)

## NequIP

In [61]:
NEQUIP_PARAMS = {
    "n_outputs": 2,
    "n_radial_basis": n_radial,
    "r_cutoff": r_max,
    "sphharm_norm": "component",
}

class GraphWrapperNequIP(nn.Module):
    param_dict: Dict
    @nn.compact
    def __call__(self, x):
        
        nodes = e3nn.IrrepsArray("1o", x.nodes)  # Assuming positions
        
        graph = jraph.GraphsTuple(
            n_node=x.n_node,
            n_edge=x.n_edge,
            edges=None,
            globals=x.globals,
            nodes=nodes, 
            senders=x.senders,
            receivers=x.receivers)
        
        return jax.vmap(NequIP(**self.param_dict))(graph)

In [62]:
model = GraphWrapperNequIP(NEQUIP_PARAMS)
out, params = model.init_with_output(jax.random.PRNGKey(0), graph)

# Number of parameters
print(f"Number of parameters: {sum([p.size for p in jax.tree.leaves(params)])}")

out

Number of parameters: 391622


Array([[ -6.085656 , -11.367395 ],
       [ -5.1505995,  -9.6540985]], dtype=float32)

## EGNN

In [63]:
class GraphWrapperEGNN(nn.Module):
    param_dict: Dict
    @nn.compact
    def __call__(self, x):
        return jax.vmap(EGNN(positions_only=True, n_outputs=2, n_layers=4, n_radial_basis=n_radial, r_max=r_max, tanh_out=True))(x)

In [64]:
model = GraphWrapperEGNN({})
out, params = model.init_with_output(jax.random.PRNGKey(0), graph)

# Number of parameters
print(f"Number of parameters: {sum([p.size for p in jax.tree.leaves(params)])}")

out

Number of parameters: 441778


Array([[-0.00890504, -0.09374903],
       [-0.0085367 , -0.09540764]], dtype=float32)