## Install MACE

In [None]:
# %%bash
# if test -d mace
# then
#     rm -rfv mace
# fi
# git clone --depth 1 --branch develop https://github.com/ACEsuit/mace.git 
# pip install mace/

In [None]:
# !pip install mace/

## Create Model

We will first create a model that we will dissect afterwards.

In [None]:
import numpy as np
import torch
import torch.nn.functional
from e3nn import o3
from matplotlib import pyplot as plt
%matplotlib inline

from mace import data, modules, tools
from mace.tools import torch_geometric
# from mace.modules.models_hariharr_dipole import *
# from mace.modules.models_hariharr_energy import *
from mace.modules.models_hariharr_energy_ewald import *
import warnings; warnings.simplefilter('ignore')

In [None]:
z_table = tools.AtomicNumberTable([1, 8])
atomic_energies = np.array([-1.0, -3.0], dtype=float)
cutoff = 3

ewald_hyperparams = dict(
      k_cutoff = 0.6,                           # Frequency cutoff [Å^-1]
      delta_k = 0.2,                            # Voxel grid resolution [Å^-1]
      num_k_rbf = 128,                          # Gaussian radial basis size (Fourier filter)
      downprojection_size = 8,                  # Size of linear bottleneck layer
      num_hidden = 0,                           # Number of residuals in update function
    )

model_config = dict(
        num_elements=2,  # number of chemical elements
        atomic_energies=atomic_energies,  # atomic energies used for normalisation
        avg_num_neighbors=8,  # avg number of neighbours of the atoms, used for internal normalisation of messages
        atomic_numbers=z_table.zs,  # atomic numbers, used to specify chemical element embeddings of the model
        r_max=cutoff,  # cutoff
        num_bessel=8,  # number of radial features
        num_polynomial_cutoff=6,  # smoothness of the radial cutoff
        max_ell=2,  # expansion order of spherical harmonic adge attributes
        num_interactions=2,  # number of layers, typically 2
        interaction_cls_first=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],  # interation block of first layer
        interaction_cls=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],  # interaction block of subsequent layers
        hidden_irreps=o3.Irreps("32x0e + 32x1o"),  # 32: number of embedding channels, 0e, 1o is specifying which equivariant messages to use. Here up to L_max=1
        correlation=3,  # correlation order of the messages (body order - 1)
        MLP_irreps=o3.Irreps("16x0e"),  # number of hidden dimensions of last layer readout MLP
        gate=torch.nn.functional.silu,  # nonlinearity used in last layer readout MLP
        ewald_hyperparams = ewald_hyperparams,
    )

model = TestMACE_Ewald(**model_config)

In [None]:
# z_table = tools.AtomicNumberTable([1, 8])
# atomic_energies = np.array([-1.0, -3.0], dtype=float)
# cutoff = 3

# model_config = dict(
#         num_elements=2,  # number of chemical elements
#         atomic_energies=atomic_energies,  # atomic energies used for normalisation
#         avg_num_neighbors=8,  # avg number of neighbours of the atoms, used for internal normalisation of messages
#         atomic_numbers=z_table.zs,  # atomic numbers, used to specify chemical element embeddings of the model
#         r_max=cutoff,  # cutoff
#         num_bessel=8,  # number of radial features
#         num_polynomial_cutoff=6,  # smoothness of the radial cutoff
#         max_ell=2,  # expansion order of spherical harmonic adge attributes
#         num_interactions=2,  # number of layers, typically 2
#         interaction_cls_first=modules.interaction_classes[
#             "RealAgnosticResidualInteractionBlock"
#         ],  # interation block of first layer
#         interaction_cls=modules.interaction_classes[
#             "RealAgnosticResidualInteractionBlock"
#         ],  # interaction block of subsequent layers
#         hidden_irreps=o3.Irreps("32x0e + 32x1o"),  # 32: number of embedding channels, 0e, 1o is specifying which equivariant messages to use. Here up to L_max=1
#         correlation=3,  # correlation order of the messages (body order - 1)
#         MLP_irreps=o3.Irreps("16x0e"),  # number of hidden dimensions of last layer readout MLP
#         gate=torch.nn.functional.silu,  # nonlinearity used in last layer readout MLP
#     )

# model = TestMACE(**model_config)

In [None]:
# cutoff = 3
# num_bessel = 8
# num_polynomial_cutoff = 6
# max_ell = 2
# num_interactions = 4
# num_elements = 2
# MLP_irreps = o3.Irreps("16x0e")
# hidden_irreps = o3.Irreps("32x0e + 32x1o")
# MLP_irreps = o3.Irreps("16x0e")
# avg_num_neighbors = 8
# z_table = tools.AtomicNumberTable([1, 8])
# atomic_energies = np.array([-1.0, -3.0], dtype=float)
# correlation = 3
# gate = torch.nn.functional.silu


# model = TestEnergyDipolesMACE(
#     r_max=cutoff,
#     num_bessel=num_bessel,
#     num_polynomial_cutoff=num_polynomial_cutoff,
#     max_ell=max_ell,
#     interaction_cls=modules.interaction_classes[
#             "RealAgnosticResidualInteractionBlock"
#         ],
#     interaction_cls_first=modules.interaction_classes[
#             "RealAgnosticResidualInteractionBlock"
#         ],
#     num_interactions=num_interactions,
#     num_elements=num_elements,
#     hidden_irreps=hidden_irreps,
#     MLP_irreps=MLP_irreps,
#     avg_num_neighbors=avg_num_neighbors,
#     atomic_numbers=z_table.zs,
#     correlation=correlation,
#     gate=gate,
#     atomic_energies=atomic_energies
# )

In [None]:
print(model)

create water molecule for example:

In [None]:
config = data.Configuration(
    atomic_numbers=np.array([8, 1, 1]),
    positions=np.array(
        [
            [0.0, -2.0, 0.0],
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
        ]
    ),
    forces=np.array(
        [
            [0.0, -1.3, 0.0],
            [1.0, 0.2, 0.0],
            [0.0, 1.1, 0.3],
        ]
    ),
    energy=-1.5,
)

atomic_data = data.AtomicData.from_config(config, z_table=z_table, cutoff=float(model.r_max))
data_loader = torch_geometric.dataloader.DataLoader(
        dataset=[atomic_data],
        batch_size=1,
        shuffle=True,
        drop_last=False,
    )
batch = next(iter(data_loader))
print("The data is stored in batches. Each batch is a single graph, potentially made up of several disjointed sub-graphs corresponding to different chemical structures. ")
print(batch)
print("\nbatch.edge_index contains which atoms are connected within the cutoff. It is the adjacency matrix in sparse format.\n")
print(batch.edge_index)

## MACE Forward      

In [None]:
data = batch
training = False,
compute_force = True,
compute_virials = False,
compute_stress = False,
compute_displacement = False,

outputs = model.forward(batch)

## MACE readout

To create the output of the model we use the node features from all layers $s$:

\begin{equation}
    \mathcal{R}^{(s)} \left( \boldsymbol{h}_i^{(s)} \right) =
    \begin{cases}
      \sum_{k}W^{(s)}_{k}h^{(s)}_{i,k00}     & \text{if} \;\; 1 < s < S \\[13pt]
      {\rm MLP} \left( \left\{ h^{(s)}_{i,k00} \right\}_k \right)  &\text{if} \;\; s = S
    \end{cases}
\end{equation}

The first linear readout is implemented in

```py
class LinearReadoutBlock(torch.nn.Module):
```

In our example case this maps the 32 dimensional $h^{(1)}_{i,k00}$, the invariant part os the node features after the first interaction to the first term in the aotmic site energy:

In [None]:
print(model.readouts[0])

In [None]:
node_energies = model.readouts[0](node_feats).squeeze(-1)

The last layer readout block is a 1 hidden layer Multi Layer Percptron (MLP):

```py
class NonLinearReadoutBlock(torch.nn.Module):
```

In [None]:
print(model.readouts[1])

It is also possible to have equivariant readouts. This can be achieved by using Gated non-linearities. See as an example:

```py
class NonLinearDipoleReadoutBlock(torch.nn.Module):
```

These readouts are formed for each node in the batch. To turn them into a graph level readout we use a scatter sum operation which sums the node energies for each graph (separate chemical strucutre) in the batch. This is followed by summing the atomic energy and 1-st, 2nd etc. layer contributions to form the final model output.

In [None]:
# energy = scatter_sum(
#                 src=node_energies, index=batch["batch"], dim=-1, dim_size=batch.num_graphs
#             )  # [n_graphs,]
# # in the code this step is done for each layer followed by summing the layer-wise output
# print("Energy:",energy)

In [None]:
batch.cell