In [1]:
from mlff.src.basis_function.spherical import init_sph_fn
from mlff.src.random import set_seeds

import haiku as hk
import jax
import numpy as np
import jax.numpy as jnp
import e3nn_jax as e3nn
jnp.set_printoptions(precision=3, suppress=True)


# Irreducible Representations

In `e3nn` irreducible representations of the SO(3) group are represented using the `e3nn.Irrep` class. This class does not contain any data, it is a structure that describe the representation. It is typically used as argument of other classes of the library to define the input and output representations of functions. It can be used to define irreduciple representations of arbitrary degree $l$ and parity $p \in (+1,-1)$, where $+1$ correpsonds to even parity and $-1$ to odd parity. Parity determines the behavior under sign flip of the input e.g. flipping atomic coordinates as $(r_1, \dots, r_n) \rightarrow (-r_1, \dots, -r_n)$. For example the spherical harmonics have parity $p = (-1)^l$.

A scalar (degree $l=0$) Irreps with odd parity $p=-1$ can be created in two different ways using the `e3nn.Irrep` class. One can either pass $l$ and $p$ as integers or pass what I call an irreps string in the following.

In [2]:
e3nn.Irrep(l=0, p=-1)  # degree = 0 and odd parity

0o

In [3]:
e3nn.Irrep(l=1, p=+1)  # degree = 1 and even parity

1e

Within the irrep string, the number specifies the degree $l$ and the following character the parity. `e` for even parity ($p=+1$), `o` for odd parity ($p=-1$) and `y` for the parity of the spherical harmonics ($p=(-1)^l$).

In [4]:
e3nn.Irrep("3e")  # degree and even parity

3e

In [5]:
e3nn.Irrep("2y")  # degree and parity of spherical harmonics (-1)^2 = +1 = even

2e

The `Irrep` class also offers additional functionality such as constructing it from a given rotation, see [here](https://e3nn-jax.readthedocs.io/en/latest/api/irreps.html#e3nn_jax.Irrep).

# Direct Sum of Irreducible Representations

The `Irreps` class allows to construct direct sums of irreducible representations. In the same fashion as before they can be either initialized using integers or a string. In addition to degree $l$ and parity $p$ there is an additional parameter which we call $F_{lp}$ which specifies the number of summands in the direct sum for a given combination of degree and parity.

Initialize an `Irreps` from integers. The input is a `list` of `tuples`, where each of the `tuples` has the form ($F$, ($l$, $p$)).

In [6]:
e3nn.Irreps([(100, (0, 1)), (50, (1, -1))])

100x0e+50x1o

An identical representation can be created using the irreps string `"100xe0 + 50x1o"`.

In [7]:
x = e3nn.Irreps("100x0e+50x1o")
x

100x0e+50x1o

One can obtain different properties from the class as e.g. the maximal degree, or the number of irreps.

In [8]:
x.lmax

1

In [9]:
x.num_irreps

150

# Irreps Array

`3enn` ususally takes `IrrepsArray` as input. Thus, one has to transform the corresponding inputs, here a scalar 
`s` ($l=0$) and a vector `v` ($l=1$) to `IrrepsArrays` first. `IrrepsArray` have information about the irreps in form of the irreps string as well as the values. They can be obtained independetly by accessing the class properties `.irreps` and `.array`, respectively.

In [10]:
set_seeds(0)
s = np.random.rand(1)
v = np.random.rand(3)

s = e3nn.IrrepsArray("0y", s)
v = e3nn.IrrepsArray("1y", v)

print('Full v: ', v)
print('Irreps of v: ', v.irreps)
print('Values of v: ', v.array)

Full v:  1x1o [0.715 0.603 0.545]
Irreps of v:  1x1o
Values of v:  [0.715 0.603 0.545]


# Spherical Harmonics

Calling the spherical harmonics function with integral normalization, reproduces the spherical harmonics
from the table in wikipedia. Note that they assume the input to be of order `[y,z,x]` (instead of `[x,y,z]`
see [here](https://github.com/e3nn/e3nn-jax/blob/0082115c15a386b9696a885d9899a6ec6d9a347d/e3nn_jax/_src/spherical_harmonics.py#L74). The returned quantity is an IrrepsArray itself, which returns the irreps string and then the `DeviceArray`.

In [11]:
Y0_e3nn = e3nn.spherical_harmonics([0, 1, 2], v, normalize=True, normalization='integral')
print(Y0_e3nn)
print('-'*100)
print('Irreps: \n{}\n'.format(Y0_e3nn.irreps))
print('Values: \n{}'.format(Y0_e3nn.array))

1x0e+1x1o+1x2e [ 0.282  0.323  0.272  0.246  0.363  0.402 -0.022  0.306 -0.1  ]
----------------------------------------------------------------------------------------------------
Irreps: 
1x0e+1x1o+1x2e

Values: 
[ 0.282  0.323  0.272  0.246  0.363  0.402 -0.022  0.306 -0.1  ]


Little hidden, there is also a function that takes np/jnp arrays of coordinates and returns the 
spherical harmonics.

In [12]:
R = np.random.rand(10,3)  # coordinates
e3nn.sh(irreps_out=[0,1,2], input=R, normalize=True, normalization='integral')

DeviceArray([[ 0.282,  0.233,  0.355,  0.241,  0.257,  0.379,  0.185,
               0.392,  0.008],
             [ 0.282,  0.319,  0.344,  0.137,  0.2  ,  0.502,  0.154,
               0.216, -0.189],
             [ 0.282,  0.349,  0.233,  0.25 ,  0.4  ,  0.372, -0.1  ,
               0.267, -0.135],
             [ 0.282,  0.485,  0.037,  0.046,  0.101,  0.083, -0.31 ,
               0.008, -0.534],
             [ 0.282,  0.009,  0.357,  0.334,  0.013,  0.014,  0.189,
               0.545,  0.254],
             [ 0.282,  0.277,  0.312,  0.255,  0.323,  0.395,  0.07 ,
               0.363, -0.027],
             [ 0.282,  0.247,  0.417,  0.063,  0.071,  0.471,  0.374,
               0.121, -0.13 ],
             [ 0.282,  0.272,  0.061,  0.401,  0.499,  0.076, -0.301,
               0.112,  0.199],
             [ 0.282,  0.356,  0.283,  0.18 ,  0.293,  0.46 ,  0.001,
               0.233, -0.215],
             [ 0.282,  0.356,  0.21 ,  0.261,  0.425,  0.341, -0.141,
               0.251,

Alternatively one can create the `IrrepsArray` from a `list` of irreps and `array` values.
The `1x` in front of the irreps string indicates the feature dimension `F`. One then observes the tensor product
structure (flattened) for equivariant features used in most current models.

In [13]:
F = 1
y_F1 = e3nn.IrrepsArray.from_list("1x0y + 1x1y + 1x2y", [np.ones((F,1)), 2*np.ones((F,3)), 3*np.ones((F,5))], ())
print('With feature dimension F = {}'.format(F))
print('Irreps: ', y_F1.irreps)
print('Values: ', y_F1.array)
print('Array shape:', y_F1.shape)
F = 8
y_F8 = e3nn.IrrepsArray.from_list("{}x0y +{}x1y +{}x2y".format(F,F,F), [np.ones((F,1)), 2*np.ones((F,3)), 3*np.ones((F,5))], ())
print('\nWith feature dimension F = {}'.format(F))
print('Irreps: ', y_F8.irreps)
print('Values: ', y_F8.array)
print('Array shape:', y_F8.shape)

With feature dimension F = 1
Irreps:  1x0e+1x1o+1x2e
Values:  [1. 2. 2. 2. 3. 3. 3. 3. 3.]
Array shape: (9,)

With feature dimension F = 8
Irreps:  8x0e+8x1o+8x2e
Values:  [1. 1. 1. 1. 1. 1. 1. 1. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
 2. 2. 2. 2. 2. 2. 2. 2. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.
 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]
Array shape: (72,)


Slicing is possible for IrrepsArray but only with valid slices, which slice full degrees (nice for safe code) 
s.t.

In [14]:
print(y_F1[:1])
# or
print(y_F1[1:4])
# or 
print(y_F8[8:32])
# but not
print(y_F1[:3])

1x0e [1.]
1x1o [2. 2. 2.]
8x1o [2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]


IndexError: Error in IrrepsArray.__getitem__, unable to slice 1x0e+1x1o+1x2e with 0:3.

There are multiple operations one `IrrepsArray` which can be found [here](https://e3nn-jax.readthedocs.io/en/latest/api/irreps_array.html). One example is `e3nn.norm` which computes the norm per irrep, which results in all-scalar irreps or `e3nn.sum` which can e.g. compute the sum along the batch dimension (0th axis) (should not be necessary with `jax.vmap`) or when applied to the axis with all irreps (here 1st axis) computes the sum across features. When no axis is specified it does both.

In [15]:
x = e3nn.IrrepsArray("2x0e + 8x1o + 2x2e", jnp.arange(36))

print('Irreps: ', e3nn.norm(x).irreps)
print('Values: ', e3nn.norm(x).array)

Irreps:  2x0e+8x0e+2x0e
Values:  [ 0.     1.     5.385 10.488 15.652 20.833 26.019 31.209 36.401 41.593
 62.69  73.858]


In [16]:
x = e3nn.IrrepsArray("3x0e + 2x1e", jnp.arange(2*9).reshape(2,9))

print('Along batch dimension')
print(e3nn.sum(x, axis=0))
print('Shape: ', e3nn.sum(x, axis=0).shape)
print('\nAlong irreps dimension')
print(e3nn.sum(x, axis=1))
print('Shape: ', e3nn.sum(x, axis=1).shape)
print('\nAlong both')
print(e3nn.sum(x))
print('Shape: ', e3nn.sum(x).shape)

Along batch dimension
3x0e+2x1e [ 9 11 13 15 17 19 21 23 25]
Shape:  (9,)

Along irreps dimension
1x0e+1x1e
[[ 3  9 11 13]
 [30 27 29 31]]
Shape:  (2, 4)

Along both
1x0e+1x1e [33 36 40 44]
Shape:  (4,)


# Tensor Products

## Standard Tensor Products

Standard tensor product of irreps, where for input degrees $l_1$ and $l_2$ one gets output degrees $l_3 \in (|l_1-l_2|, ..., l_1+l_2)$. Here $l_1 \in (0,1)$ and $l_2 \in (0,2)$, thus all $(l_1,l_2)$ combinations are $(0,0), (1,0), (0,2), (1,2)$. We see irreps with even and odd parity arising from the different $(l_1,l_2)$ combinations.

In [17]:
x = e3nn.IrrepsArray("0y + 1y", jnp.arange(4))
y = e3nn.IrrepsArray("0y + 2y", jnp.arange(6))
e3nn.tensor_product(x, y)

1x0e+1x1o+1x1o+1x2e+1x2o+1x3o
[ 0.     0.     0.     0.     0.147 11.463 10.299  0.     0.     0.
  0.     0.     4.082 -9.63   1.414 -0.328  2.449  5.657  6.928  5.184
 -1.613  8.613 11.547  9.899]

One can also filter the irreps during calculation.

In [18]:
e3nn.tensor_product(x, y, filter_ir_out=[e3nn.Irrep('0o'), e3nn.Irrep('1o'), e3nn.Irrep('2o')])

1x1o+1x1o+1x2o
[ 0.     0.     0.     0.147 11.463 10.299  4.082 -9.63   1.414 -0.328
  2.449]

## Other Tensor Product

`e3nn` also provides an class called `SymmetricTensorProduct` which performes a symmetric tensor product contraction with parameters or a `square_tensor_product` which performs a tensor product of an `IrrepsArray` with itself. See [here](https://e3nn-jax.readthedocs.io/en/latest/api/tensor_products.html) for all tensor products.

# Neural Network Stuff

e3nn uses `haiku` for building NN architectures. To that end, it follows the typical design of first intializing the parameters and then use the modules `__call__` function as forward function with the parameters as one of the arguments.

## Linear Layer
A linear layer is provided where on can use the irreps string to specify the output feature dimension for each irrep.

In [19]:
@hk.without_apply_rng
@hk.transform
def linear(x):
    output_irreps = "32x0e + 16x1o"
    return e3nn.Linear(output_irreps)(x)

x = e3nn.IrrepsArray("16x1o + 2x0e", jnp.ones(int(16*3+2*1)))
params = linear.init(jax.random.PRNGKey(0), x)
y = linear.apply(params, x)
y

32x0e+16x1o
[ 0.619  1.198 -1.09   0.056 -1.627 -0.487  0.172 -1.223 -0.446 -1.429
 -2.295  0.276 -1.155 -2.084 -0.895 -0.701 -1.916 -0.241  1.033  0.328
  1.15  -0.915  0.99   0.693  0.34  -0.069 -0.796 -1.181  0.029  2.111
  0.004 -0.611  1.419  1.419  1.419  1.146  1.146  1.146  1.459  1.459
  1.459  0.831  0.831  0.831 -2.729 -2.729 -2.729 -1.174 -1.174 -1.174
 -0.931 -0.931 -0.931 -0.116 -0.116 -0.116 -1.185 -1.185 -1.185 -1.587
 -1.587 -1.587 -0.669 -0.669 -0.669  0.317  0.317  0.317 -0.67  -0.67
 -0.67   0.896  0.896  0.896 -0.313 -0.313 -0.313 -0.278 -0.278 -0.278]

In [20]:
# as expeted the shape of the output irreps is shape 32*1 (l=0) + 16*3 (l=1) = 80
y.shape

(80,)

Interestingly, the linear layer seems to ignore irreps in the input that are not specified in the `output_irreps` string. Here we provide additionally an irrep of degree $l=2$ as input. The output, however, is the same as above. Thus one has to be careful when using the linear layer, since `e3nn.Linear` might silently remove irreps not specified as output.

In [21]:
@hk.without_apply_rng
@hk.transform
def linear(x):
    output_irreps = "32x0e + 16x1o"
    return e3nn.Linear(output_irreps)(x)

x = e3nn.IrrepsArray("1x2o + 16x1o + 2x0e", jnp.ones(int(1*5+16*3+2*1)))
params = linear.init(jax.random.PRNGKey(0), x)
y1 = linear.apply(params, x)
y1 - y

32x0e+16x1o
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]

## Gated Activation Function
They also provide a module for gated activations. The idea is that the scalars serve as coefficients for the degrees. Thus in order to apply a gated activation to an irrep of the form `2x1e + 1x2e` (three non-scalar irreps in total) one needs at least three scalars. If there are more scalars than non-scalars, the unused scalars are just kept in the scalar irrep. One can also apply activation function to the scalars that are unused and the scalars that are used as gates for the non-scalar irreps, where one can even apply different activation functiond depending on the parity of the irrep. See [here](https://e3nn-jax.readthedocs.io/en/latest/api/nn.html#e3nn_jax.gate). In the following we just give the irreps string as input to obtain what the output would look like.

In [25]:
e3nn.gate("3x0e + 2x1e + 1x2e")  # works

2x1e+1x2e

In [26]:
e3nn.gate("2x0e + 2x1e + 1x2e")  # fails

ValueError: The input must have at least as many scalars as the number of non-scalar irreps

In [27]:
e3nn.gate("12x0e + 2x1e + 1x2e")  # 12 - 2 - 1 = 9 scalars remain

9x0e+2x1e+1x2e

## Other stuff
Also batch norm and dropout are provided, which I guess we can skip.

# Build a Neural Network

Here a little example of a neural network using `n_layers` pairs of linear layers and gated activation.

In [28]:
class SO3MLP(hk.Module):
    def __init__(self, n_layers, out_irreps, activation_fn=jax.nn.sigmoid):
        super().__init__()
        
        self.n_layers = n_layers
        self.out_irreps = out_irreps
        self.activation_fn = activation_fn
    def __call__(self, x: e3nn.IrrepsArray):
        
        for _ in range(self.n_layers):
            x = e3nn.Linear(self.out_irreps)(x)
            x = e3nn.gate(x, 
                          even_act=self.activation_fn,
                          odd_act=self.activation_fn,
                          odd_gate_act=self.activation_fn,
                          even_gate_act=self.activation_fn)
        
        
        return x

In [30]:
r = np.random.rand(3)
spherical_harmonics = e3nn.sh(irreps_out=[0,1,2], input=R, normalize=True, normalization='integral')
x = e3nn.IrrepsArray("1x2y + 1x1y + 1x0y", spherical_harmonics)

@hk.without_apply_rng
@hk.transform
def so3_mlp(x):
    return SO3MLP(n_layers=2, out_irreps="8x2y + 16x1y + 32x0y")(x)
    
params = so3_mlp.init(jax.random.PRNGKey(0), x)
so3_mlp.apply(params, x)

8x0e+8x2e+16x1o
[[ 1.68   0.478  1.004  1.358  1.017  0.993  1.178  0.845 -0.085 -0.071
  -0.108 -0.073 -0.078 -0.257 -0.213 -0.324 -0.22  -0.234  0.318  0.262
   0.4    0.271  0.289  0.112  0.093  0.142  0.096  0.102 -0.071 -0.058
  -0.089 -0.06  -0.064 -0.065 -0.053 -0.081 -0.055 -0.059 -0.09  -0.074
  -0.113 -0.077 -0.082  0.036  0.03   0.045  0.031  0.033 -1.006 -0.492
  -1.039 -0.221 -0.108 -0.229 -0.388 -0.19  -0.401  0.402  0.197  0.416
  -0.104 -0.051 -0.107  0.248  0.121  0.257 -0.024 -0.012 -0.025 -0.996
  -0.487 -1.028 -1.049 -0.513 -1.084  0.538  0.263  0.556  0.733  0.359
   0.758  1.65   0.806  1.704 -0.021 -0.01  -0.021 -0.686 -0.335 -0.709
   0.449  0.22   0.464  0.389  0.19   0.402]
 [ 1.667  0.501  1.041  1.363  1.056  0.999  1.154  0.828 -0.087 -0.098
  -0.106 -0.042 -0.062 -0.253 -0.285 -0.308 -0.123 -0.179  0.316  0.356
   0.385  0.153  0.223  0.102  0.115  0.124  0.05   0.072 -0.075 -0.085
  -0.092 -0.037 -0.053 -0.067 -0.076 -0.082 -0.033 -0.048 -0.102 -0.116
  -