## Install MACE

In [1]:
# %%bash
# if test -d mace
# then
#     rm -rfv mace
# fi
# git clone --depth 1 --branch develop https://github.com/ACEsuit/mace.git 
# pip install mace/

In [2]:
# !pip install mace/

## Create Model

We will first create a model that we will dissect afterwards.

In [3]:
import numpy as np
import torch
import torch.nn.functional
from e3nn import o3
from matplotlib import pyplot as plt
%matplotlib inline

from mace import data, modules, tools
from mace.tools import torch_geometric
# from mace.modules.models_hariharr_dipole import *
# from mace.modules.models_hariharr_energy import *
from mace.modules.models_hariharr_energy_ewald import *
import warnings; warnings.simplefilter('ignore')

In [4]:
z_table = tools.AtomicNumberTable([1, 8])
atomic_energies = np.array([-1.0, -3.0], dtype=float)
cutoff = 3

ewald_hyperparams = dict(
      k_cutoff = 0.6,                           # Frequency cutoff [Å^-1]
      delta_k = 0.2,                            # Voxel grid resolution [Å^-1]
      num_k_rbf = 128,                          # Gaussian radial basis size (Fourier filter)
      downprojection_size = 8,                  # Size of linear bottleneck layer
      num_hidden = 0,                           # Number of residuals in update function
      num_k_x = 1, #check: what is num_kx, num_ky, num_kz mean
      num_k_y = 1,
      num_k_z = 3,
    )

model_config = dict(
        num_elements=2,  # number of chemical elements
        atomic_energies=atomic_energies,  # atomic energies used for normalisation
        avg_num_neighbors=8,  # check: maybe this should be increased. Avg number of neighbours of the atoms, used for internal normalisation of messages
        atomic_numbers=z_table.zs,  # atomic numbers, used to specify chemical element embeddings of the model
        r_max=cutoff,  # cutoff
        num_bessel=8,  # number of radial features
        num_polynomial_cutoff=6,  # smoothness of the radial cutoff
        max_ell=2,  # expansion order of spherical harmonic adge attributes
        num_interactions=2,  # number of layers, typically 2
        interaction_cls_first=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],  # interation block of first layer
        interaction_cls=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],  # interaction block of subsequent layers
        hidden_irreps=o3.Irreps("32x0e + 32x1o"),  # 32: number of embedding channels, 0e, 1o is specifying which equivariant messages to use. Here up to L_max=1
        correlation=3,  # correlation order of the messages (body order - 1)
        MLP_irreps=o3.Irreps("16x0e"),  # number of hidden dimensions of last layer readout MLP
        gate=torch.nn.functional.silu,  # nonlinearity used in last layer readout MLP
        ewald_hyperparams = ewald_hyperparams,
        use_pbc=True,
    )

model = TestMACE_Ewald(**model_config)

hidden_irreps: 32x0e+32x1o [32]
target_irreps in EquivariantProductBasisBlock: 32x0e+32x1o
interaction: 0 hidden_irreps: 32x0e [32, 32]
target_irreps in EquivariantProductBasisBlock: 32x0e


In [5]:
# z_table = tools.AtomicNumberTable([1, 8])
# atomic_energies = np.array([-1.0, -3.0], dtype=float)
# cutoff = 3

# model_config = dict(
#         num_elements=2,  # number of chemical elements
#         atomic_energies=atomic_energies,  # atomic energies used for normalisation
#         avg_num_neighbors=8,  # avg number of neighbours of the atoms, used for internal normalisation of messages
#         atomic_numbers=z_table.zs,  # atomic numbers, used to specify chemical element embeddings of the model
#         r_max=cutoff,  # cutoff
#         num_bessel=8,  # number of radial features
#         num_polynomial_cutoff=6,  # smoothness of the radial cutoff
#         max_ell=2,  # expansion order of spherical harmonic adge attributes
#         num_interactions=2,  # number of layers, typically 2
#         interaction_cls_first=modules.interaction_classes[
#             "RealAgnosticResidualInteractionBlock"
#         ],  # interation block of first layer
#         interaction_cls=modules.interaction_classes[
#             "RealAgnosticResidualInteractionBlock"
#         ],  # interaction block of subsequent layers
#         hidden_irreps=o3.Irreps("32x0e + 32x1o"),  # 32: number of embedding channels, 0e, 1o is specifying which equivariant messages to use. Here up to L_max=1
#         correlation=3,  # correlation order of the messages (body order - 1)
#         MLP_irreps=o3.Irreps("16x0e"),  # number of hidden dimensions of last layer readout MLP
#         gate=torch.nn.functional.silu,  # nonlinearity used in last layer readout MLP
#     )

# model = TestMACE(**model_config)

In [6]:
# cutoff = 3
# num_bessel = 8
# num_polynomial_cutoff = 6
# max_ell = 2
# num_interactions = 4
# num_elements = 2
# MLP_irreps = o3.Irreps("16x0e")
# hidden_irreps = o3.Irreps("32x0e + 32x1o")
# MLP_irreps = o3.Irreps("16x0e")
# avg_num_neighbors = 8
# z_table = tools.AtomicNumberTable([1, 8])
# atomic_energies = np.array([-1.0, -3.0], dtype=float)
# correlation = 3
# gate = torch.nn.functional.silu


# model = TestEnergyDipolesMACE(
#     r_max=cutoff,
#     num_bessel=num_bessel,
#     num_polynomial_cutoff=num_polynomial_cutoff,
#     max_ell=max_ell,
#     interaction_cls=modules.interaction_classes[
#             "RealAgnosticResidualInteractionBlock"
#         ],
#     interaction_cls_first=modules.interaction_classes[
#             "RealAgnosticResidualInteractionBlock"
#         ],
#     num_interactions=num_interactions,
#     num_elements=num_elements,
#     hidden_irreps=hidden_irreps,
#     MLP_irreps=MLP_irreps,
#     avg_num_neighbors=avg_num_neighbors,
#     atomic_numbers=z_table.zs,
#     correlation=correlation,
#     gate=gate,
#     atomic_energies=atomic_energies
# )

In [7]:
print(model)

TestMACE_Ewald(
  (down): Dense(
    (linear): Linear(in_features=31, out_features=8, bias=False)
    (_activation): Identity()
  )
  (ewald_blocks): ModuleList(
    (0-1): 2 x EwaldBlock(
      (down): Dense(
        (linear): Linear(in_features=31, out_features=8, bias=False)
        (_activation): Identity()
      )
      (up): Dense(
        (linear): Linear(in_features=8, out_features=32, bias=False)
        (_activation): Identity()
      )
      (pre_residual): ResidualLayer(
        (dense_mlp): Sequential(
          (0): Dense(
            (linear): Linear(in_features=32, out_features=32, bias=False)
            (_activation): ScaledSiLU(
              (_activation): SiLU()
            )
          )
          (1): Dense(
            (linear): Linear(in_features=32, out_features=32, bias=False)
            (_activation): ScaledSiLU(
              (_activation): SiLU()
            )
          )
        )
      )
      (ewald_layers): ModuleList(
        (0): Dense(
          (li

create water molecule for example:

In [8]:
config = data.Configuration(
    atomic_numbers=np.array([8, 1, 1]),
    positions=np.array(
        [
            [0.0, -2.0, 0.0],
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
        ]
    ),
    forces=np.array(
        [
            [0.0, -1.3, 0.0],
            [1.0, 0.2, 0.0],
            [0.0, 1.1, 0.3],
        ]
    ),
    energy=-1.5,
    pbc=(True, True, True),
)

atomic_data = data.AtomicData.from_config(config, z_table=z_table, cutoff=float(model.r_max))
data_loader = torch_geometric.dataloader.DataLoader(
        dataset=[atomic_data],
        batch_size=3,
        shuffle=True,
        drop_last=False,
    )
batch = next(iter(data_loader))
print("The data is stored in batches. Each batch is a single graph, potentially made up of several disjointed sub-graphs corresponding to different chemical structures. ")
print(batch)
print("\nbatch.edge_index contains which atoms are connected within the cutoff. It is the adjacency matrix in sparse format.\n")
print(batch.edge_index)

pbc: (True, True, True)
The data is stored in batches. Each batch is a single graph, potentially made up of several disjointed sub-graphs corresponding to different chemical structures. 
Batch(batch=[3], cell=[3, 3], edge_index=[2, 834], energy=[1], energy_weight=[1], forces=[3, 3], forces_weight=[1], node_attrs=[3, 2], positions=[3, 3], ptr=[2], shifts=[834, 3], stress_weight=[1], unit_shifts=[834, 3], virials_weight=[1], weight=[1])

batch.edge_index contains which atoms are connected within the cutoff. It is the adjacency matrix in sparse format.

tensor([[0, 0, 0,  ..., 2, 2, 2],
        [0, 1, 2,  ..., 0, 1, 2]])


In [9]:
config = data.Configuration(
    atomic_numbers=np.array([8, 1, 1]),
    positions=np.array(
        [
            [0.0, -2.0, 0.0],
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
        ]
    ),
    forces=np.array(
        [
            [0.0, -1.3, 0.0],
            [1.0, 0.2, 0.0],
            [0.0, 1.1, 0.3],
        ]
    ),
    energy=-1.5,
    pbc=(True, True, True),
)

atomic_data = data.AtomicData.from_config(config, z_table=z_table, cutoff=float(model.r_max))
data_loader = torch_geometric.dataloader.DataLoader(
        dataset=[atomic_data],
        batch_size=3,
        shuffle=True,
        drop_last=False,
    )
batch = next(iter(data_loader))
print("The data is stored in batches. Each batch is a single graph, potentially made up of several disjointed sub-graphs corresponding to different chemical structures. ")
print(batch)
print("\nbatch.edge_index contains which atoms are connected within the cutoff. It is the adjacency matrix in sparse format.\n")
print(batch.edge_index)

pbc: (True, True, True)
The data is stored in batches. Each batch is a single graph, potentially made up of several disjointed sub-graphs corresponding to different chemical structures. 
Batch(batch=[3], cell=[3, 3], edge_index=[2, 834], energy=[1], energy_weight=[1], forces=[3, 3], forces_weight=[1], node_attrs=[3, 2], positions=[3, 3], ptr=[2], shifts=[834, 3], stress_weight=[1], unit_shifts=[834, 3], virials_weight=[1], weight=[1])

batch.edge_index contains which atoms are connected within the cutoff. It is the adjacency matrix in sparse format.

tensor([[0, 0, 0,  ..., 2, 2, 2],
        [0, 1, 2,  ..., 0, 1, 2]])


In [10]:
# batch.cell
cells = batch.cell
cells[:, 2] # torch.cross(cells[:, 1], cells[:, 2], dim=-1)

tensor([0., 0., 0.])

## MACE Forward      

In [11]:
data = batch
training = False,
compute_force = True,
compute_virials = False,
compute_stress = False,
compute_displacement = False,

outputs = model.forward(batch)

data type: <class 'mace.tools.torch_geometric.batch.Batch'>
cross_a2a3: torch.Size([3])
cross_a3a1: torch.Size([3])
cross_a1a2: torch.Size([3])
bcells: torch.Size([3, 3]) vol: torch.Size([1])
k_grid shape if periodic: torch.Size([1, 31, 3])
k_index_product_set shape if periodic: torch.Size([31, 3])
k_cell shape if periodic: torch.Size([3, 3])
self.slice_indices: [32, 32]
node_feats: torch.Size([3, 32])
interaction layer: 0
from ewald block b: torch.Size([3, 31, 3]) from ewald block k: torch.Size([1, 31, 3])
interaction layer: 0 node feats after ewald: torch.Size([3, 32])
interaction layer: 0 node feats after MACE: torch.Size([3, 128])
interaction layer: 0 node_feats inside interaction: torch.Size([3, 128])
interaction layer: 1
interaction layer: 1 node feats after ewald: torch.Size([3, 32])
interaction layer: 1 node feats after MACE: torch.Size([3, 32])
interaction layer: 1 node_feats inside interaction: torch.Size([3, 32])


In [12]:
# energy = scatter_sum(
#                 src=node_energies, index=batch["batch"], dim=-1, dim_size=batch.num_graphs
#             )  # [n_graphs,]
# # in the code this step is done for each layer followed by summing the layer-wise output
# print("Energy:",energy)

In [13]:
batch

Batch(batch=[3], cell=[3, 3], edge_index=[2, 834], energy=[1], energy_weight=[1], forces=[3, 3], forces_weight=[1], node_attrs=[3, 2], positions=[3, 3], ptr=[2], shifts=[834, 3], stress_weight=[1], unit_shifts=[834, 3], virials_weight=[1], weight=[1])