In [302]:
import sys
sys.path.append("../")

import jax
import jax.numpy as jnp
import jraph
from e3nn_jax import Irreps
from e3nn_jax import IrrepsArray

from models.nequip import NequIP

from models.utils.irreps_utils import balanced_irreps

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [303]:
import numpy as np

n_batch = 2
n_nodes = 1000
k = 20

x = np.load("../data/halos_small.npy")[:n_batch, :n_nodes, :]

# Normalize
x_mean = x.mean((0, 1))
x_std = x.std((0, 1))
x[:, :, 3:] = (x[:, :, 3:] - x_mean[3:]) / x_std[3:]
x[:, :, :3] = x[:, :, :3] / 1000.  # Divide by box size

In [304]:
from models.utils.graph_utils import nearest_neighbors, rotate_representation

# Get nearest neighbors graph
sources, targets = jax.vmap(nearest_neighbors, in_axes=(0, None))(x, k)

In [305]:
x_irreps = IrrepsArray("1o + 1o + 1x0e", x)

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=x_irreps[0], 
          senders=sources[0],
          receivers=targets[0])

segnn = NequIP(num_message_passing_steps=3)

key = jax.random.PRNGKey(0)
out, params = segnn.init_with_output(key, graph)

out

GraphsTuple(nodes=1x1o+1x1o+1x0e
[[-3.477482   -1.0110798   2.6792042  ... -1.6124381   4.2727103
   0.        ]
 [ 0.93758565  0.08349349 -1.1802998  ...  0.13315277 -1.8823047
   0.        ]
 [-2.2603192  -1.0508989   1.2991517  ... -1.67594     2.0718458
   0.        ]
 ...
 [ 1.4059191   4.6133456  -1.1233319  ...  7.3572173  -1.7914541
   0.        ]
 [-3.3451757   2.4346807   1.5755376  ...  3.8827515   2.5126173
   0.        ]
 [ 4.9496922   0.2703495  -4.059554   ...  0.4311448  -6.4740477
   0.        ]], edges=5x0e+5x1o+3x1e+3x2e
[[ 2.0264182  -4.0492897   0.         ...  0.          0.
   0.        ]
 [ 1.9892232  -3.9975798   0.14114611 ...  5.053578   -0.43218103
   7.5584116 ]
 [ 2.0021927  -4.015755    0.14781636 ...  4.2048016   2.8392394
   1.5501559 ]
 ...
 [ 1.4863317  -3.597436   -0.48142502 ...  4.064386    1.6350402
   1.6727726 ]
 [ 1.458281   -3.545297   -0.49156418 ...  4.382875    0.3291561
  -0.6398    ]
 [ 1.491209   -3.6063929  -0.48974892 ...  3.5069554   

In [306]:
sum(x.size for x in jax.tree_util.tree_leaves(params))

28542

In [307]:
axis = np.array([0, 1 / np.sqrt(2), 1 / np.sqrt(2)])
x_irreps = IrrepsArray("1o + 1o + 1x0e", rotate_representation(x[0], 45., axis))

graph = jraph.GraphsTuple(
          n_node=n_nodes,
          n_edge=k,
          edges=None,
          globals=None,
          nodes=x_irreps, 
          senders=sources[0],
          receivers=targets[0])

out_rot, params = segnn.init_with_output(key, graph)

out_rot

GraphsTuple(nodes=1x1o+1x1o+1x0e
[[-4.3040915   1.2680945   0.40003055 ...  2.022317    0.6379561
   0.        ]
 [ 1.2948667  -0.5703817  -0.526434   ... -0.90962666 -0.83954024
   0.        ]
 [-2.7733119   0.4234165  -0.17516556 ...  0.6752512  -0.2793485
   0.        ]
 ...
 [ 3.862483    3.0702646   0.41974172 ...  4.8963604   0.6693908
   0.        ]
 [-1.935802    3.9814637   0.02875094 ...  6.3495116   0.04585109
   0.        ]
 [ 5.6649175  -2.8385983  -0.95061153 ... -4.5269065  -1.516005
   0.        ]], edges=5x0e+5x1o+3x1e+3x2e
[[ 2.026419   -4.0492897   0.         ...  0.          0.
   0.        ]
 [ 1.9892237  -3.9975789   0.14114594 ... -0.30965698  2.4842327
   4.822027  ]
 [ 2.002193   -4.015756    0.14781645 ...  4.273547   -0.5245591
   8.238031  ]
 ...
 [ 1.4863294  -3.5974314  -0.48142612 ...  3.0217714  -0.57612306
   6.697022  ]
 [ 1.4582795  -3.5452893  -0.4915653  ...  2.0587928  -1.8301649
   5.0206404 ]
 [ 1.4912047  -3.6063921  -0.4897497  ...  3.8224385  

In [308]:
out_rot.nodes.array / rotate_representation(out.nodes.array, 45, axis)

Array([[0.9999996 , 1.0000025 , 0.99999374, ..., 1.0000027 , 0.9999929 ,
               nan],
       [0.9999976 , 1.000007  , 1.0000104 , ..., 1.0000072 , 1.0000103 ,
               nan],
       [0.9999998 , 0.9999966 , 1.0000023 , ..., 0.99999636, 1.0000024 ,
               nan],
       ...,
       [1.0000024 , 0.9999987 , 0.9999925 , ..., 0.9999986 , 0.99999225,
               nan],
       [0.9999883 , 1.0000035 , 0.9994077 , ..., 1.0000035 , 0.99940497,
               nan],
       [1.0000008 , 1.0000005 , 1.0000042 , ..., 1.0000006 , 1.000004  ,
               nan]], dtype=float32)

In [280]:
import flax.linen as nn
import e3nn_jax as e3nn

class GraphWrapper(nn.Module):
    @nn.compact
    def __call__(self, x):
        model = jax.vmap(SEGNN(num_message_passing_steps=3))
        return model(x)

model = GraphWrapper()
rng = jax.random.PRNGKey(42)

In [281]:
x_irreps = IrrepsArray("1o + 1o + 1x0e", x)

graph = jraph.GraphsTuple(
          n_node=jnp.array(n_batch * [[n_nodes]]), 
          n_edge=jnp.array(n_batch * [[k]]),
          nodes=x_irreps, 
          edges=None,
          globals=jnp.ones((n_batch, 7)),
          senders=sources,
          receivers=targets)

graph_out, _ = model.init_with_output(rng, graph)
x_out = graph_out  # Output features