# Machine learning electric response

In [None]:
!pip install tensorial@git+https://github.com/camml-lab/tensorial.git@a129b31d2aa1feb0bc8b1eeb867d9207912689b7  e3response@git+https://github.com/camml-lab/e3response.git@92a02afa6923d0efda2f9c2173c6ca27f1980d59

In [None]:
DATA_URL = "https://github.com/camml-lab/e3response/raw/refs/heads/tutorial-zadar/tutorial/data.zip"

import requests
import zipfile
import io

response = requests.get(DATA_URL)
response.raise_for_status()  # Check if the download was successful
zip_file = zipfile.ZipFile(io.BytesIO(response.content))
zip_file.extractall()  # Extract all the contents into the current directory

In [None]:
import functools
import e3response as e3r
import e3nn_jax as e3j
from flax import linen
import jax.numpy as jnp
import jaxtyping as jt
import jraph
from matplotlib import pyplot as plt
import omegaconf
import reax
import tensorial
import tensorial.typing as tt
from tensorial import gcnn

## Learning response properties using autograd

You may have noticed that the response properties are defined as derivatives of the energy with respect to various quantities so this raises the question: Can we make a model that does this too?  The answer is: *yes*, this is a great idea.

To do this, we have to make a few modifications:
1. The model now needs to predict the total energy, and,
2. the model needs to be an explicit function of the external electric field (even if many of the derivatives are taken at zero field), otherwise any derivatives thereof will be zero!

We can do the former by adding the global electric field as a node attribute (in addition to the element which is also stored on the node):

```yaml
  - _target_: tensorial.gcnn.NodewiseEmbedding
    attrs:
      species:
        _target_: tensorial.tensors.OneHot
        num_classes: 3
      globals.external_electric_field:
        _target_: tensorial.tensors.SphericalHarmonic
        irreps: 0e + 1o
        normalise: True
```

the atomic species and global electric field (embedded as a spherical harmonic expansion up to $l=1$) will be concatenated to form the node attributes.

For the latter, we do add the following:

```yaml
  - _target_: tensorial.gcnn.NodewiseLinear
    irreps_out: 1x0e
    out_field: predicted_energy # Per-atom energy

  # Final total energy
  - _target_: tensorial.gcnn.NodewiseReduce
    field: predicted_energy
    out_field: predicted_energy  # Global energy
```

whereby the model will predict energies for each atom, and then sum (reduce) to produce a global predicted energy.

Have a look at `model/nequip_deriv.yaml` to see all the details.

We can calculate the Born effective charges using:

$Z^{*}_{\kappa,\alpha\beta} = \Omega \frac{\partial P_{\alpha}}{\partial u_{\kappa\beta}}$

where $\Omega$ is the unit cell volume, and $u_{\kappa\beta}$ are the atomic positions.

This module is a little complicated but if you take it step-by-step, I hope that it will make sense:

```python
class BornEffectiveCharges(linen.Module):
    polarization_fn: gcnn.GraphFunction

    def setup(self) -> None:
        # Compute the Jacobian of polarization with respect to atomic positions
        self._jacobian_fn = gcnn.jacobian(
            of=f"globals.predicted_polarization",
            wrt=f"nodes.positions",
            has_aux=True,
        )(self.polarization_fn)

    def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
        res = self._jacobian_fn(graph, graph.nodes["positions"])
        born_tensors: jt.Float[tt.ArrayType, "κ α β"] = res[0].swapaxes(0, 1)
        graph: jraph.GraphsTuple = res[1]

        if gcnn.keys.CELL in graph.globals:
            omega = unit_cell_volumes(graph)
            born_tensors = jax.vmap(jnp.multiply)(omega, born_tensors)

        updates = gcnn.utils.UpdateGraphDicts(graph)
        updates.nodes["predicted_born_charges"] = born_tensors
        return updates.get()
```

In `setup` we simply create a new function that takes the entire graph function up (`polarization_fn`) to that point and calculates the Jacobian by taking derivatives of the polarization with respect to positions.

In `__call__` we actually calculate the derivatives and evaluate them at the current positions.  There's a little wrangling to get the indices in the right order, and then we're pretty much done...

But wait...where do we get the polarizations from?

b### 🧪 Task 1: Add a module to calculate the polarizations

The polarizations can be calculated as:

$P_{\alpha} = -\frac{1}{\Omega} \frac{\partial E}{\partial \mathcal{E}_{\alpha}}$

where $E$ is the total energy, and $\mathcal{E}$ is the electric field in the $i$ direction.

✅ **What to do**:
- Complete the code snippet below to write the module that will calculate the polarizations
- Use your code from the previous notebook to calculate the RMSE difference between predictions and labels.
- Perform a basic check on using some different hyperparameters, how low can you get your RMSE?


In code, the module looks like this:

```python

```

In [None]:
from flax import linen
from tensorial import gcnn
import jraph

class Polarization(linen.Module):
    energy_fn: gcnn.GraphFunction

    def setup(self) -> None:
        # Define the gradient of the energy wrt electric field function
        self._grad_fn = gcnn.grad(
            of=, # Add here the value we want to take the gradient with respect to
            wrt=, # Add here the value we want to take the derivative with respect to
            sign=-1,
            has_aux=True,
        )(self.energy_fn)

    def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
        polarizations, graphs = self._grad_fn(
            graph,
            # Add here what the value of the electric field to evaluate polarizations at
        )
        updates = gcnn.utils.UpdateGraphDicts(graph)
        updates.globals["predicted_polarization"] = polarizations
        return updates.get()

Now let's load it and try!

In [None]:
R_MAX = 5.0

cfg = omegaconf.OmegaConf.load("model/nequip_deriv.yaml")
cfg.r_max = R_MAX

# Let our configuration know where it can find our polarization module
cfg.polarization = {"_target_": "__main__.Polarization", "_partial_": True}

module = tensorial.config.instantiate(cfg["module"])

In [None]:
datamodule = e3r.data.BtoDataModule(r_max=R_MAX, batch_size=4)
trainer = reax.Trainer()

In [None]:
trainer.fit(module, datamodule=datamodule, max_epochs=10)

In [None]:
trainer.current_epoch

### 🧪 Task 2: Analyse the performance of the trained model

✅ **What to do**:
- User your code from notebook 1 to plot the parity plots for the polarization.  🔍 **Hint**: You will find Born charges in the  `graph.nodes['born_charges']` and `graph.nodes['predicted_born_charges']` fields.
- Use your code from the previous notebook to calculate the RMSE difference between predictions and labels.
- Perform a basic check on using some different hyperparameters, how low can you get your RMSE?
- How does the performance of this derivative based (physics informed) model differ from that of the direct model in the first notebook?

### Conclusion

That's is!  Well done on getting this far.

There is much more that we can do with `e3response` that we simply don't have time for in this tutorial.  One important step, would be to add temperature dependence.  We can do this by running molecular dynamics and predicting the tensorial quantities along the trajectory.  These can then be averaged to make a comparison with experiment under realistic conditions.

If you're interested in going further, don't hesistate to reach out to me at martin.uhrin@grenoble-inp.fr