# Machine learning electric response

Install dependencies and import

In [None]:
!pip install tensorial@git+https://github.com/camml-lab/tensorial.git@a129b31d2aa1feb0bc8b1eeb867d9207912689b7  e3response@git+https://github.com/camml-lab/e3response.git@92a02afa6923d0efda2f9c2173c6ca27f1980d59

In [None]:
DATA_URL = "https://github.com/camml-lab/e3response/raw/refs/heads/tutorial-zadar/tutorial/data.zip"

import requests
import zipfile
import io

response = requests.get(DATA_URL)
response.raise_for_status()  # Check if the download was successful
zip_file = zipfile.ZipFile(io.BytesIO(response.content))
zip_file.extractall()  # Extract all the contents into the current directory

In [None]:
import e3response as e3r
import e3nn_jax as e3j
from matplotlib import pyplot as plt
import omegaconf
import reax
import tensorial


## Direct learning of tensorial quantities

To start the tutorial we will train a model that simply takes the atomic structure (represented as a graph) and learns to predict Born effective charges directly (without the need to predict or even consider the energy of the system).

As a reminder, one way to calcualte the Born effective charge tensors is:

$Z^{*}_{\kappa, \alpha \beta} = \frac{\partial F_{\kappa, \alpha}}{\partial \mathcal{E}_\beta}$

Where $\kappa$ is the atom index, $\alpha$, $\beta$ are the Cartesian dimension indices, $F$ is the force and $\mathcal{E}$ is the electric field.

In [None]:
R_MAX = 5.0  # The cutoff sphere radius

Here we will load the definition of our model from a YAML file, this makes reproducibility easier and gives us a central place where we can see how the model is build.  Have a look at the `model/nequip_tensors.yaml` file to see what's going on.

In [None]:
cfg = omegaconf.OmegaConf.load("model/nequip_tensors.yaml")
cfg["r_max"] = R_MAX

Let's look at the model's readout block, it is defined like this:

```yaml
# Per-atom Born effective charge tensor
- _target_: tensorial.gcnn.NodewiseLinear
  irreps_out: 1x0e + 1x1e + 1x2e
  out_field: predicted_born_charges
```

The `NodewiseLinear` module will take the current features (stored in `nodes["features"]`) and perform a learnable linear operation from the current feature irreps (in this case "4x0e + 4x0o + 4x1e + 4x1o + 4x2e + 4x2o + 2x3o + 2x3e") down to the irreps of the Born effective charge tensor and stores it in the `nodes.predicted_born_charges` field.

In order to form the change-of-basis operation from a spherical harmonic basis to the famliar Cartesian basis we can use the `NodewiseDecoding` module as follows:


```yaml
- _target_: tensorial.gcnn.NodewiseDecoding
  in_field: predicted_born_charges
  attrs:
    predicted_born_charges:
      _target_: tensorial.CartesianTensor
      formula: ij
      i: 1o
      j: 1o
```

The `formula` tells the module that the Born effective charge tensor is rank 2 (`ij`) and can be calculated as a Cartesian tensor product of between two vectors:

\begin{equation*}
  \left[\begin{array}{>{\columncolor{mpink!20}}c}
      x_1 \\  y_1 \\  z_1
    \end{array}\right]
  \otimes %
  \left[\begin{array}{c}
      x_2 \\  y_2 \\  z_2
    \end{array}\right]  = %
  \left[ \begin{array}{ccc}
      x_1 x_2 & x_1 y_2 & x_1 z_2 \\
      y_1 x_2 & y_1 y_2 & y_1 z_2 \\
      z_1 x_2 & z_1 y_2 & z_1 z_2 \\
    \end{array} \right]
\end{equation*}

using this the `e3nn-jax` can be used to find the change-of-basis matrix.  The module will store the result in the `nodes.predicted_born_charges` field.

In [None]:
module = tensorial.config.instantiate(cfg["module"])

We now instantiate the `DataModule` class containing that will split the data into train/test/validation sets and create batches of a fixed size that fits into the memory that we have available.

In [None]:
datamodule = e3r.data.BtoDataModule(r_max=R_MAX, batch_size=16)

We will be using the REAX library (similar to PyTorch Lightning, but for JAX) for training out model, this simplifies some things by providing a lot of the boilerplate code for us.

In [None]:
trainer = reax.Trainer()

Now we will use the Trainer's fit function to fit the EGNN.

Notice that it takes some time before we see the progress bar start to move.  This is because JAX is compiling and optimising our code before executing it.  Once compiled the code will run _much_ faster, however if we change the shape of any of our arrays this will trigger a re-compilation so the `DataModule` uses a batcher to ensure they are all padded to the right size.

In [None]:
trainer.fit(module, datamodule=datamodule, max_epochs=100)

Now that our module is trained, we can use the `trainer.predict()` method to perform inference and get the predictions for the test set.

In [None]:
predict = trainer.predict(module, dataloaders=datamodule.test_dataloader(), keep_predictions=True)

In [None]:
len(predict.predictions)

In [None]:
first_batch = predict.predictions[0]
print(first_batch._asdict().keys())

In [None]:
first_batch.nodes.keys()

In [None]:
fig = plt.figure()
ax = fig.add_subplot()
ax.scatter(first_batch.nodes["born_charges"], first_batch.nodes["predicted_born_charges"])
ax.axis("equal")
ax.set_xlabel("Born charges")
ax.set_ylabel("Predicted Born charges");
# plt.show()

### 🧪 Task 1: Analyse the performance of the trained model

We want a way to understand how well our model is performing on the 'unseen' test set.

✅ **What to do**:
- Extend the code above to make the parity plot show the data for _all_ batches.  🔍 **Hint**: You will need to loop over the results list in `predict.predictions`.
- Write a function that calculates the mean absolute error (MAE) and root mean square error (RMSE) between the labels and the predicted Born charge tensors.  This will give us a way to quickly understand how well the model is performing.

### 🧪 Task 2: Perform a hyperparameter optimisation

Now that we can measure the performance of our model, let's see if we can improve the results.

✅ **What to try**:
- Try changing the cutoff radius `R_MAX`.  You could start by looking at values like 3Å, 4Å, etc.
- Change the irreps used in the hidden layer, currently it is "16x0o + 16x0e + 16x1o + 16x1e" but you can change the multiplicities (16) or even introduce higher-degree irreps e.g. " + 8x2o + 8x2e"

### 🧪 Task 3: Add the ability to predict the Raman tensors

For now, our model only predicts the Born effective charges, but we could also predict the Raman tensors. These can be calculated as:

$R_{\mu\nu}^{(\kappa\lambda)} = \frac{\partial^3 E}{\partial \mathcal{E}_\mu \, \partial \mathcal{E}_\nu \, \partial u_{\kappa\lambda}}$

where $E$ is the energy of the system and $u$ are the atomic displacements.  We immediately see the following:

* $R_{\mu\nu}^{(\kappa\lambda)}$ is a rank 3 tensor ($\mu$, $\nu$ and $\lambda$),
* it is calculated by taking derivatives with respect to vectors (as opposed to pseudo-vectors) $u$ and $E$,
* and it has permutational symmetry in $\mu$ and $\nu$.

✅ **What to try**:
- Edit the `nequip_tensors.yaml` file to include a new `NodewiseLinear` readout module based on the Born effective charges example.  If you're not sure of the irreps that this tensor should have, you can always calculate them as follows:

In [None]:
e3j.reduced_tensor_product_basis("ijk=jik", i="1o", j="1o", k="1o").irreps

Where the formula "ijk=jik" is used to express the permutational symmetry of the Cartesian indices.

✅ **What to try (ctnd.)**:
- Add a `NodewiseDecoding` block to convert the spherical harmonic representation back to a Cartesian one.
- Add a term in your loss function to take into account the Raman tensors. 🔍 **Hint**: `loss_fns` contains a list where each entry starts with `-` and so you can add a new loss term based on the Born effective charges.  The target tensors are already there for you in `nodes.raman_tensors` you just need to add the field where you store predictions.  Furthermore, you will need to add a term to the `weights` list.  The values of the Raman tensors are about 1000 times lower than the BECs.
- Now, train you model, do you get the performance you expected?  Has the result gotten worse for the Born effective chrarges?