# Equivariant Neural Network for Predicting Trajectories

```{admonition} Authors:
[Sam Cox](https://github.com/SamCox822)
```

In this example, we will train an equivariant neural network to predict the next frame in the trajectory alignment example in 10.6. As stated in 10.3.8, for time-dependent trajectories, we do not need to concern ourselves with permutation equivariance because it is implied that the order of the points does not change. Thus, we can treat this example as point cloud, meaning that any deep learning model that we train on this data should have rotation and translation equivariance. In other words, our model should be E(3) equivariant. [E3NN](https://e3nn.org) is a library built to create equivariant neural networks for the this group, so it's a great choice for this problem.


## Retrieving Data from Trajectory Alignment Example

---



First, let's borrow a few cells from Chapter 10 to download our data and view the first frame. 

In [None]:
import matplotlib.pyplot as plt
import urllib
import urllib.request
import numpy as np
import jax
import jax.numpy as jnp

In [None]:
urllib.request.urlretrieve(
    "https://github.com/whitead/dmol-book/raw/master/data/paths.npz", "paths.npz"
)
paths = np.load("paths.npz")["arr"]
# plot the first point
plt.title("First Frame")
plt.plot(paths[0, :, 0], paths[0, :, 1], "o-")
plt.xticks([])
plt.yticks([])
plt.show()

## Additional Installations and Imports
## Running This Notebook


Click the &nbsp;<i aria-label="Launch interactive content" class="fas fa-rocket"></i>&nbsp; above to launch this page as an interactive Google Colab. See details below on installing packages.

````{tip} My title
:class: dropdown
To install packages, execute this code in a new cell. 

```
!pip install dmol-book

```

In [None]:
# additional imports
import torch_geometric
from torch_geometric.data import Data, Dataset, DataLoader
import torch
import e3nn
import math
from e3nn.nn.models.gate_points_2101 import Network
from e3nn import o3
import dmol

## Baseline Model

Before we build our E3NN network, it's always a good idea to build a baseline model for comparision. 

First, let's discuss what the input and output should be for this model. The input should be the coordinates of the 12 points: one frame. What should the output be? We want to train a neural network to predict the next trajectory for each point, the next frame, so our output should actually be the same type and size as our input.

Thus,

**Inputs:**
* 12 sets of coordinates

**Outputs:**
* 12 sets of coordinates

Note: since we are trying to build an E(3)-equivariant neural network, which should be equivariant to transformations in 3D space, we need to make these coordinates 3D. This is easy, we will just put zero for the z-coordiantes. We'll do this now.

In [None]:
traj_3d = np.array([])
for i in range(2048):
    for j in range(12):
        TBA = paths[i][j]
        TBA = np.append(
            TBA,
            np.array(
                [
                    0.00,
                ]
            ),
        )
        traj_3d = np.append(traj_3d, TBA)

traj_3d = traj_3d.reshape(2048, 12, 3)

Interestingly, for this example, we want our prediction from one frame to match the following frame. So our features and labels will be nearly identical, offset by one. For the features, we want to include everything except for the final frame, which has no "next frame" in our data. We can extrapolate with our model to predict this "next frame" as a final step if we want.  For our lables, we want to include everything except for the first step, which is not the "next frame" of anything in our data.  We can also go ahead and split our data into training and testing sets. Let's do approximately an 80:20 split here. We want to make sure not to shuffle our data, as we are predicting order-sensitive data.

In [None]:
features = traj_3d[:-1]
labels = traj_3d[1:]

# split data 80:20
training_set = features[:1637]
training_labels = labels[:1637]
valid_set = features[1637:]
valid_labels = labels[1637:]

# convert to jnp arrays
training_setbl = jnp.asarray(training_set)
training_labelsbl = jnp.asarray(training_labels)
valid_setbl = jnp.asarray(valid_set)
valid_labelsbl = jnp.asarray(valid_labels)

Let's check to make sure our data matches up. Frame 2 in the features set should be the same as Frame 1 in the labels set.

In [None]:
def mse(y, yhat):
    return np.mean((yhat - y) ** 2)

In [None]:
print("features, frame 2: \n", features[1])
print("labels, frame 1: \n", labels[0])
if mse(features[1], labels[0]) == 0:
    print("success! they match!")
else:
    print(mse(features[1], labels[0]))

Great, they match! Now we are ready to build our baseline model!


In [None]:
@jax.jit
def baseline_model(inputs, w, b):
    yhat = inputs @ w + b
    return yhat


def baseline_loss(inputs, y, w, b):
    return mse(y, baseline_model(inputs, w, b))


bl_loss_grad = jax.grad(baseline_loss, (2, 3))

w = np.zeros((3, 3))
w = jnp.asarray(w)
b = 0.0

In [None]:
epochs = 16
eta = 1e-6

baseline_val_loss = [0.0 for _ in range(epochs)]

for epoch in range(epochs):
    for d in range(1637):
        inputs = training_setbl[d]
        y = training_labelsbl[d]
        grad_bl = bl_loss_grad(inputs, y, w, b)
        # update w & b
        w -= eta * grad_bl[0]
        b -= eta * grad_bl[1]

    for i in range(410):
        inputs_v = valid_setbl[i]
        y_v = valid_labelsbl[i]
        baseline_val_loss[epoch] += baseline_loss(inputs_v, y_v, w, b)
    baseline_val_loss[epoch] = jnp.sqrt(baseline_val_loss[epoch] / 410)


plt.plot(baseline_val_loss)
plt.xlabel("Epoch")
plt.ylabel("Val Loss")
plt.show()

In [None]:
print("Final loss value: ", baseline_val_loss[-1])

Now let's view a parity plot to see if we're learning the right trend here.

In [None]:
ys = jnp.array([])
yhats = jnp.array([])

for i in range(205):
    inputs_v = valid_setbl[i]
    ys = jnp.append(ys, valid_labelsbl[i])
    yhat = baseline_model(inputs_v, w, b)
    yhats = jnp.append(yhats, yhat)

plt.plot(ys, ys, "-")
plt.plot(ys, yhats, ".")
plt.xlabel("Trajectory")
plt.ylabel("Predicted Trajectory")
plt.show()

This is difficult to read, since our xyz coordinates are very different in magnitude. Instead, let's look at three plots, one for each coordinate. 

In [None]:
ys_x = jnp.array([])
ys_y = jnp.array([])
ys_z = jnp.array([])
yh_x = jnp.array([])
yh_y = jnp.array([])
yh_z = jnp.array([])

for i in range(205):
    inputs_v = valid_setbl[i]
    y = valid_labelsbl[i]
    yhat = baseline_model(inputs_v, w, b)

    ys_x = jnp.append(ys_x, y[i][0])
    ys_y = jnp.append(ys_y, y[i][1])
    ys_z = jnp.append(ys_z, y[i][2])

    yh_x = jnp.append(yh_x, yhat[i][0])
    yh_y = jnp.append(yh_y, yhat[i][1])
    yh_z = jnp.append(yh_z, yhat[i][2])

plt.plot(ys_x, ys_x, "-")
plt.plot(ys_x, yh_x, ".")
plt.xlabel("X-Coordinate of Trajectory")
plt.ylabel("X-Coordinate of Predicted Trajectory")
plt.show()

plt.plot(ys_y, ys_y, "-")
plt.plot(ys_y, yh_y, ".")
plt.xlabel("Y-Coordinate of Trajectory")
plt.ylabel("Y-Coordinate of Predicted Trajectory")
plt.show()

plt.plot(ys_z, ys_z, "-")
plt.plot(ys_z, yh_z, ".")
plt.xlabel("Z-Coordinate of Trajectory")
plt.ylabel("Z-Coordinate of Predicted Trajectory")
plt.show()

It looks like we are starting to get the right trend for some of the coordinates, but more training is definitely needed, especially for the z-coordinate, which should always be zero. However, as stated, we want any model that uses this data to be equivariant in 3D space. Let's check the equivariances now.

In [None]:
# checking for rotation equivariance
import scipy.spatial.transform as trans

# rotate around x coordinate by 80 degrees
rot = trans.Rotation.from_euler("x", 80, degrees=True)

input_point = jnp.asarray(np.random.normal(size=(12, 3)))
w_test1 = jnp.asarray(np.random.normal(size=(3, 3)))

input_rot = rot.apply(input_point)
output_1 = baseline_model(input_rot, w_test1, b)
output_prerot = baseline_model(input_point, w_test1, b)
output_rot = []
for xyz in output_prerot:
    coord = rot.apply(xyz)
    output_rot.append(coord)
output_rot = jnp.array(output_rot)

print("rotated first: \n", output_1)
print("rotated last: \n", output_rot)
print("\033[1m" + "difference: " + "\033[0m", mse(output_1, output_rot))

So it doesn't look like our baseline model is rotation-equivariant. This is important, because we if we give our model coordinates that are rotated, we expect the output should be rotated by the same degree. Likewise, we need translation equivariance. Let's check that now.

In [None]:
# checking for translation equivariance
random_trans = jnp.asarray(np.random.normal(size=(12, 3)))

input_trans = input_point + random_trans
output_2 = baseline_model(input_trans, w_test1, b)
output_trans = random_trans + baseline_model(input_point, w_test1, b)

print("translated first: ", output_2)
print("translated last: ", output_trans)
print("\033[1m" + "difference: " + "\033[0m", mse(output_2, output_trans))

As expected, our model isn't translation equviariant either. We can solve this problem a few ways. One way is to augment our data in order to teach our model equivariance. This requires more training and data storage, so let's look at a more compact approach.

## E3NN Basics

E3NN is a library for creating equivariant neural networks, specifically in E(3). E3NN is built for spatial equvariance in 3-D space, giving us equivariance with respect to the E(3) group of rotations, inversions, and translations. As discussed before, the time-dependent trajectory points do not change order, so we do not need to worry about permutation equivariance/invariance in this case; we only need E(3)-equivariance. E3NN is a great tool for this problem because we have 3-dimensional points in space, and if we transform them in space, we want the output to transform the same way.

E3NN works through the use of irreducible representations (irreps). In general, representations tell you how to interact with the data with repect to the group, and irreducible representations are the smallest and complete representations {cite}`geiger2022e3nn`. When creating a model, we give the model the irreps so that it knows how to handle the data we will give it during trianing. It's not necessary to understand what the irreps are; instead, just know that they are the smallest representations, which are similar to, and transform the same way as, the spherical harmonics. Any (reducible) representation can be decomposed into irreducible representations. If you want to know more, you can check out more on the E3NN documentation website [@e3nn]. Let's take a look at how the irreps are used in this context. 

For this group (O(3), which includes parity), we need to find the L and d for each piece of data, where $d = 2L + 1$ (d = dimension). Look at the table below. 

| **parity** | **L** | **d** | **name**      |
|------------|-------|-------|---------------|
| even       | 0     | 1     | scalar        |
| odd        | 0     | 1     | pseudo scalar |
| even       | 1     | 3     | pseudo vector |
| odd        | 1     | 3     | vector        |
| even       | 2     | 5     |       -       |
| odd        | 2     | 5     |       -       |

The general notation is **MxLp**, where M is the number, L is the L from the table above, and p corresponds to the parity (e: even, o: odd). 

For example, if you wanted to portray "12 scalars, 4 vectors" in this format, you would write **"12x0e + 4x1o"**. Take a minute to make sure you understand how to use this notation, as it's essential for E3NN. E3NN deals with equivariance by receiving the irreps as a model parameter. This allows the E3NN framework to know how each input feature/output transforms under symmetry, so that it can treat each piece appropriately. As a side note, the output of an E3NN model must always be of equal or higher symmetry than your input.

Because E3NN is built to handle 3D spatial data, we do not need to tell the model that we are going to give it 3D coordinates; it's implicit and **required**. The irreps_in, instead, correspond to the input node features. In this example, we don't have input features, but as an example, you can imagine we could want our model to predict the next set of coordinates, given the intitial coordinates and the corresponding atom types. In that case, our irreps_in would be the atom types (one scalar per input if we have one-hot vectors). 

Since we don't have input features, we'll put "None" for that parameter, and we want our output to be the same shape as the input: 12 vectors. However, since we are trying to predict 12 vectors out for 12 vectors in, we only need to tell the model to predict 1 vector per input **"1x1o"**. Take a minute to make sure you understand why this is the case. You can think of the model recognizing 12 input vectors and predicting a vector for each. Again, E3NN expects coordinate inputs, so we don't specify this for the input.

## E3NN Model

E3NN has several models within their library, which can be found [on the E3NN github page here](https://github.com/e3nn/e3nn/tree/main/e3nn/nn/models). For this example, we will use one of these models. 

To use the E3NN model, we need to turn our data into a torch_geometric dataset. We'll do that now. Then we can split our data into training and testing sets.

Also, instead of directly computing the next frame, we'll change it here to predict the distance to the next frame. This is a small change, but having data centered nearer zero can be better for training. We'll need to undo this when we look at the frames later.

In [None]:
feat = torch.from_numpy(features)  # convert to pytorch tensors
ys = torch.from_numpy(labels)  # convert to pytorch tensors
traj_data = []
distances = ys - feat  # compute distances to next frame


# make torch_geometric dataset
# we want this to be an iterable list
# x = None because we have no input features
for frame, label in zip(feat, distances):
    traj_data += [
        torch_geometric.data.Data(
            x=None, pos=frame.to(torch.float32), y=label.to(torch.float32)
        )
    ]

train_split = 1637
train_loader = torch_geometric.data.DataLoader(
    traj_data[:train_split], batch_size=1, shuffle=False
)

test_split = 1842
test_loader = torch_geometric.data.DataLoader(
    traj_data[train_split:test_split], batch_size=1, shuffle=False
)

Great! Now we're ready to define our model. Since this is a pre-built model in E3NN, so we just need to import it and define the model parameters. Note that the state of this model will save automatically, so you will need to reinitialize the model every time you want to start training. To see how these models work you can look at [this preprint](https://arxiv.org/abs/2207.09453) or [this video series](https://www.youtube.com/watch?v=q9EwZsHY1sk&list=PLx3xbphkO3qIlBoESkbafXaDtr0tq5iRd).

In [None]:
from e3nn.nn.models.gate_points_2101 import Network
from e3nn import o3

model_kwargs = {
    "irreps_in": None,  # no input features
    "irreps_hidden": o3.Irreps("5x0e + 5x0o + 5x1e + 5x1o"),  # hyperparameter
    "irreps_out": "1x1o",  # 12 vectors out, but only 1 vector out per input
    "irreps_node_attr": None,
    "irreps_edge_attr": o3.Irreps.spherical_harmonics(3),
    "layers": 3,  # hyperparameter
    "max_radius": 3.5,
    "number_of_basis": 10,
    "radial_layers": 1,
    "radial_neurons": 128,
    "num_neighbors": 11,  # average number of neighbors w/in max_radius
    "num_nodes": 12,  # not important unless reduce_output is True
    "reduce_output": False,  # setting this to true would give us one scalar as an output.
}

model = Network(**model_kwargs)  # initializing model with parameters above

In [None]:
import torch.optim as optim

eta = 1e-6
optimizer = torch.optim.Adam(model.parameters(), lr=eta)
optimizer.zero_grad()

In [None]:
# this will print an outline of the model architecture!
print(model)

In [None]:
epochs = 16

val_loss = [0.0 for _ in range(epochs)]

for epoch in range(epochs):
    for step, data in enumerate(train_loader):
        yhat = model(data)
        loss_1 = torch.mean((yhat - data.y) ** 2)
        loss_1.backward()
        optimizer.step()
        optimizer.zero_grad()

    with torch.no_grad():
        for step, data in enumerate(test_loader):
            yhat = model(data)
            loss2 = torch.mean((yhat - data.y) ** 2)
            val_loss[epoch] += (loss2).detach()
    val_loss[epoch] = val_loss[epoch] / 205

In [None]:
v_loss = torch.tensor(val_loss)

plt.plot(v_loss, label="Validation Loss")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

In [None]:
print("final loss value: ", val_loss[-1])

Let's run each model on the last frame to get an extrapolated frame. Remember that for the E3NN model, we predicting displacements, so we'll just need to add our final displacement back to our final coordinates.

In [None]:
last_frame_bl = y_v
extrp_bl = baseline_model(last_frame_bl, w, b)

last_frame_e3nn = yhat + data.pos
lf_e3nn = []

# format as torch geometric dataset (dummy y values)
lf_e3nn += [
    torch_geometric.data.Data(
        x=None,
        pos=last_frame_e3nn.to(torch.float32),
        y=last_frame_e3nn.to(torch.float32),
    )
]
lf_loader = torch_geometric.data.DataLoader(lf_e3nn, batch_size=1, shuffle=False)

# run through model
for i in lf_loader:
    extrp_e3nn = model(i)

# add extrapolated displacements back into last frame
extrp_e3nn += last_frame_e3nn

In [None]:
extrp_bl = np.array(extrp_bl)
extrp_e3nn = extrp_e3nn.detach().numpy()

np.set_printoptions(suppress=True)
print(extrp_bl)
print(extrp_e3nn)

These look pretty good! note that the baseline model has still not learned that the z-coordinate is always zero, however.

Now we can check for equivariance in the same way that we did before with the baseline model. Let's just take the final frame and translate by 10 in the x-direction, then compare to just the output translated. 

In [None]:
# checking for rotation equivariance
import scipy.spatial.transform as trans

# rotate around x coordinate by 80 degrees
rot = trans.Rotation.from_euler("x", 80, degrees=True)

input_point = np.asarray(np.random.normal(size=(12, 3)))
input_rot = rot.apply(input_point)
input_point = torch.from_numpy(input_point)
input_rot = torch.from_numpy(input_rot)

# format as torch geometric dataset (dummy y values)
rot_first = []
rot_first += [
    torch_geometric.data.Data(
        x=None, pos=input_rot.to(torch.float32), y=input_rot.to(torch.float32)
    )
]
rf_loader = torch_geometric.data.DataLoader(rot_first, batch_size=1, shuffle=False)
# run through model
for i in rf_loader:
    output_1 = model(i)


# format as torch geometric dataset (dummy y values)
rot_last = []
rot_last += [
    torch_geometric.data.Data(
        x=None, pos=input_point.to(torch.float32), y=input_point.to(torch.float32)
    )
]
rl_loader = torch_geometric.data.DataLoader(rot_last, batch_size=1, shuffle=False)
# run through model
for i in rl_loader:
    output_2 = model(i)

output_2 = output_2.detach().numpy()
output_1 = output_1.detach().numpy()

output_rot = []
for xyz in output_2:
    coord = rot.apply(xyz)
    output_rot.append(coord)
output_rot = np.array(output_rot)

print("rotated first: \n", output_1)
print("rotated last: \n", output_rot)

Perfect! Our random array, when rotated first, gives the same results as when we rotated last! Now we know we have rotational equvivariances. I won't go further to test translational equivariances; I will leave that as an exercise. The E3NN model outperforms the baseline significantly, and it is E(3)-equivariant, unlike our baseline model!

In [None]:
## Cited References

```{bibliography}
:style: unsrtalpha
:filter: docname in docnames
```