# Equivariant Neural Network for Predicting Trajectories

In this example, we will train an equivariant neural network to predict the next frame in the trajectory alignment example (10.6). As stated in 10.3.8, for time-dependent trajectories, we do not need to concern ourselves with permutation equivariance because it is implied that the order of the points does not change. Thus, we can treat this example as a simple set of coordinates in 3D space, meaning that any deep learning model that we train on this data should have rotation, mirror/parity, and translation equivariance. In other words, our model should be O(3) equivariant [@wikipedia_2021]. E3NN [@e3nn] is a library built to create equivariant neural networks for the this group, so it's a great choice for this problem.

We will use the trajectory data from that example to train our network.

## Retrieving Data from Trajectory Alignment Example

---



First, let's use the same imports and visualization used in Chapter 10 to download our data and view the first frame. 

In [None]:
# new imports
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import urllib
import numpy as np
import jax
import jax.numpy as jnp
import math

sns.set_context("notebook")
sns.set_style(
    "dark",
    {
        "xtick.bottom": True,
        "ytick.left": True,
        "xtick.color": "#666666",
        "ytick.color": "#666666",
        "axes.edgecolor": "#666666",
        "axes.linewidth": 0.8,
        "figure.dpi": 300,
    },
)
color_cycle = ["#1BBC9B", "#F06060", "#5C4B51", "#F3B562", "#6e5687"]
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=color_cycle)

In [None]:
urllib.request.urlretrieve(
    "https://github.com/whitead/dmol-book/raw/master/data/paths.npz", "paths.npz"
)
paths = np.load("paths.npz")["arr"]
# plot the first point
plt.title("First Frame")
plt.plot(paths[0, :, 0], paths[0, :, 1], "o-")
plt.xticks([])
plt.yticks([])
plt.show()

## Additional Installations

The following cell sets up some additional installations we need. These may take some time. 

In [None]:
pip install torch torchvision

In [None]:
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+${111}.html

In [None]:
pip install torch-cluster

In [None]:
pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+${CUDA}.html

In [None]:
pip install torch-geometric

In [None]:
pip install e3nn

## Baseline Model

In [None]:
import jax
import jax.numpy as jnp

Before we build our E3NN network, it's always a good idea to build a baseline model for comparision. 

First, let's discuss what the input and output should be for this model. The input should be the coordinates of the 12 points: one frame. What should the output be? We want to train a neural network to predict the next trajectory for each point, the next frame, so our output should actually be the same type as our input.

Thus,

**Inputs:**
* 12 sets of coordinates

**Outputs:**
* 12 sets of coordinates

Note: since we are trying to build an O(3) equivariant neural network, which should be equivariant to transformations in 3D space, we need to make these coordinates 3D. This is easy, we will just put zero for the z-coordiantes. We'll do this now.

In [None]:
traj_3d = np.array([])
for i in range(2048):
    for j in range(12):
        TBA = paths[i][j]
        TBA = np.append(
            TBA,
            np.array(
                [
                    0.00,
                ]
            ),
        )
        traj_3d = np.append(traj_3d, TBA)

traj_3d = traj_3d.reshape(2048, 12, 3)

Interestingly, for this example, we want our prediction from one frame to match the following frame. So our features and labels will be nearly identical, offset by one.

For the features, we want to include everything except for the final frame, which has no "next frame" in our data. We can extrapolate with our model to predict this "next frame" as a final step if we want. 

For our lables, we want to include everything except for the first step, which is not the "next frame" of anything in our data. 

In [None]:
features = traj_3d[:-1]
print(features.shape)

labels = traj_3d[1:]
print(labels.shape)

Now we can split our data into training, validation, and testing sets. Let's do approximately an 80:10:10 split here. 

We want to make sure not to shuffle our data, as we are predicting time-series data. 

In [None]:
training_set = features[:1637]
training_labels = labels[:1637]

valid_set = features[1637:1842]
valid_labels = labels[1637:1842]

test_set = features[1842:]
test_labels = labels[1842:]

# convert to jnp arrays
training_setbl = jnp.asarray(training_set)
training_labelsbl = jnp.asarray(training_labels)

valid_setbl = jnp.asarray(valid_set)
valid_labelsbl = jnp.asarray(valid_labels)

Let's check to make sure our data matches up. Frame 2 in the features set should be the same as Frame 1 in the labels set.

In [None]:
print("features, frame 2: \n", features[1])
print("labels, frame 1: \n", labels[0])

Great, they match! Now we are ready to build our baseline model!


Perhaps we should first discuss our loss function. We are predicting 3 values (xyz) per point number, so we can just find the Euclidean distance between the two frames. We can The formula is the standard distance formula in 3D:  $$L = \sqrt{(x_{2}-x_{1})^2 + (y_{2}-y_{1})^2 + (z_{2}-z_{1})^2}$$

Remember that we have 12 points per input/output, so we'll compute this loss function iteratively over the points. We will get 12 values for the loss. We can just take the sum to get a value for the loss. This is the total distance between the two frames. When we train our baseline model, our loss function will call the below function, which iteratively computes the distance between each point's initial and final coordinates and sums the distances to produce the total distance.

In [None]:
@jax.jit
def dist(inputs, yhats):
    distance = 0
    for i in range(12):
        distance += jnp.sqrt(
            (inputs[i][0] - yhats[i][0]) ** 2
            + (inputs[i][1] - yhats[i][1]) ** 2
            + (inputs[i][2] - yhats[i][2]) ** 2
        )
    return distance


Note that the shape of our weight matrix is (12,3, 3). This is because we want our output shape to be (12,3). We can't use a shape (3,3) matrix, however, because we want each point to have its own set of parameters.

We can just run the model to calculate the $\hat{y}$ for each point, using the appropriate set of weights. This, too, will be an iterative process.

In [None]:
@jax.jit
def baseline_model(inputs, w, b):
    yhat = inputs @ w + b
    return yhat


def run_blm(inputs, w, b):
    output = jnp.array([])
    for i in range(12):
        xyz = inputs[i]
        yhat = baseline_model(xyz, w[i], b)
        output = jnp.append(output, yhat)
    output1 = jnp.reshape(output, (12, 3))
    return output1


@jax.jit
def dist(inputs, yhats):
    distance = 0
    for i in range(12):
        distance += jnp.sqrt(
            (inputs[i][0] - yhats[i][0]) ** 2
            + (inputs[i][1] - yhats[i][1]) ** 2
            + (inputs[i][2] - yhats[i][2]) ** 2
        )
    return distance


@jax.jit
def baseline_loss(inputs, yhats, w, b):
    return dist(run_blm(inputs, w, b), yhats)


w = np.zeros((12, 3, 3))
w = jnp.asarray(w)
b = 0.0

In [None]:
epochs = 50
eta = 1e-6

baseline_val_loss = [0.0 for _ in range(epochs)]
baseline_tr_loss = [0.0 for _ in range(epochs)]
count = 0
for epoch in range(epochs):
    count += 1
    for d in range(1637):
        inputs = training_setbl[d]
        yhats = training_labelsbl[d]
        baseline_tr_loss[epoch] += baseline_loss(inputs, yhats, w, b)
        grad_bl = jax.grad(baseline_loss, (2, 3))(inputs, yhats, w, b)
        # update w & b
        w -= eta * grad_bl[0]
        b -= eta * grad_bl[1]
    baseline_tr_loss[epoch] = baseline_tr_loss[epoch] / 1637

    for i in range(205):
        inputs_v = valid_setbl[i]
        yhats_v = valid_labelsbl[i]
        baseline_val_loss[epoch] += baseline_loss(inputs_v, yhats_v, w, b)
    baseline_val_loss[epoch] = baseline_val_loss[epoch] / 205


plt.plot(baseline_val_loss)
plt.plot(baseline_tr_loss)
plt.xlabel("Epoch")
plt.ylabel("Val Loss")
plt.show()

Now let's view a parity plot to see if we're learning the right trend here.

In [None]:
ys = jnp.array([])
yhats = jnp.array([])

for i in range(205):
    yhatv = jnp.array([])
    inputs_v = valid_setbl[i]
    yhats_v = valid_labelsbl[i]
    ys = jnp.append(ys, yhats_v)
    yhat_raw = run_blm(inputs_v, w, b)
    yhatv = jnp.append(yhatv, yhat_raw)
    yhatv.reshape(12, 3)
    yhats = jnp.append(yhats, yhatv)
yhats.reshape(205, 12, 3)

plt.plot(ys, ys, "-")
plt.plot(ys, yhats, ".")
plt.xlabel("Trajectory")
plt.ylabel("Predicted Trajectory")
plt.show()

This is difficult to read, since our xyz coordinates are much different in magnitude. Instead, let's look at three plots, one for each coordinate. 

In [None]:
ys_x = jnp.array([])
ys_y = jnp.array([])
ys_z = jnp.array([])
yh_x = jnp.array([])
yh_y = jnp.array([])
yh_z = jnp.array([])

for i in range(205):
    inputs_v = valid_setbl[i]
    y = valid_labelsbl[i]
    yhat_raw = run_blm(inputs_v, w, b)

    ys_x = jnp.append(ys_x, y[i][0])
    ys_y = jnp.append(ys_y, y[i][1])
    ys_z = jnp.append(ys_z, y[i][2])

    yh_x = jnp.append(yh_x, yhat_raw[i][0])
    yh_y = jnp.append(yh_y, yhat_raw[i][1])
    yh_z = jnp.append(yh_z, yhat_raw[i][2])

plt.plot(ys_x, ys_x, "-")
plt.plot(ys_x, yh_x, ".")
plt.xlabel("X-Coordinate of Trajectory")
plt.ylabel("X-Coordinate of Predicted Trajectory")
plt.show()

plt.plot(ys_y, ys_y, "-")
plt.plot(ys_y, yh_y, ".")
plt.xlabel("Y-Coordinate of Trajectory")
plt.ylabel("Y-Coordinate of Predicted Trajectory")
plt.show()

plt.plot(ys_z, ys_z, "-")
plt.plot(ys_z, yh_z, ".")
plt.xlabel("Z-Coordinate of Trajectory")
plt.ylabel("Z-Coordinate of Predicted Trajectory")
plt.show()

It looks like we're starting to get the right trend! This could definitely be improved with some tweeking and more training, but it is sufficient as a baseline. Let's move on to our real model.

## E3NN Basics

In [None]:
import torch
import e3nn
import math

E3NN is a library for creating equivariant neural networks, specifically in O(3). E3NN is built for spatial equvariance in 3-D space. Specifically, this library gives us equviariance with respect to the O(3) group of rotations, inversions, and translations. As discussed before, the time-dependent trajectory's points do not change order, so we do not need to worry about permutation equivariance/invariance in this case; We only need O(3) equivariance. E3NN is a great tool for this problem because we have 3-dimensional points in space, and if we transform them in space, we want the output to transform the same way.

One note about E3NN is that the library is built in PyTorch, so we need to convert our data from numpy arrays to pytorch tensors. We'll do that in the cell below. Note that the shape of the tensor does not change.

In [None]:
train_x = torch.from_numpy(training_set)
train_y = torch.from_numpy(training_labels)

valid_x = torch.from_numpy(valid_set)
valid_y = torch.from_numpy(valid_labels)


test_x = torch.from_numpy(test_set)
test_y = torch.from_numpy(test_labels)

print(
    "original data type: ", type(training_set), ", original shape: ", training_set.shape
)
print("converted data type: ", type(train_x), ", converted shape: ", train_x.shape)

E3NN works through the use of irreducible representations (irreps). In general, representations tell you how to interact with the data with repect to the group. When creating a model, we give the model the irreps so that it knows how to handle the data we will give it during trianing. It's not necessary to understand what the irreps are; instead, just know that they are the smallest representations, which are similar to, and transform the same way, as the spherical harmonics. Any (reducible) representation can be decomposed into irreducible representations. If you want to know more, you can check out more on the E3NN documentation website [@e3nn]. Let's take a look at how the irreps are used in this context. 

For this group, we need to find the L and d for each piece of data, where $d = 2L + 1$. Look at the table below. 

| **parity** | **L** | **d** | **name**      |
|------------|-------|-------|---------------|
| even       | 0     | 1     | scalar        |
| odd        | 0     | 1     | pseudo scalar |
| even       | 1     | 3     | pseudo vector |
| odd        | 1     | 3     | vector        |
| even       | 2     | 5     |       -       |
| odd        | 2     | 5     |       -       |
|            |       |       |               |

The general notation is **MxLp**, where M is the number, L is the L from the table above, and p corresponds to the parity (e: even, o: odd). 

For example, if you wanted to portray "12 scalars, 4 vectors" in this format, you would write **"12x0e + 4x1o"**. Take a minute to make sure you understand how to use this notation, as it's essential for E3NN. E3NN deals with equivariance by recieving the irreps as a model parameter. This allows the E3NN framework to know how each input feature/output transform under symmetry, so that it can treat each piece appropriately. As a side note, the output of an E3NN model must always be of equal or higher symmetry than your input.

Again, E3NN is built to handle 3D spatial data, so we do not need to tell the model that we are going to give it 3D coordinates; it's implicit. The irreps_in, instead, correspond to the input features. In this example, we don't have input features, but as an example, you can imagine we could want our model to predict the next set of coordinates, given the intitial coordinates and the velocity. In this case, our irreps_in would be the velocity. If we gave our velocity as vectors, we would have **"12x1o""** as our input features. If we just gave our model the magnitude of the velocity, we would represent our input features as **"12x0e"**. 

Since we don't have input features, we'll put None for that parameter, and we want our output to be the same shape as the input: **"12x1o"**. Take a minute to make sure you understand why this is the case. 

Again, E3NN expects coordinate inputs, so we don't specify this for the input.

## E3NN Model

E3NN has several models within their library, which can be found on the E3NN github under e3nn/e3nn/nn/models/. For this example, we will use one of these models. To use this E3NN model, we need to turn our data into a torch_geometric dataset. We'll do that now.

In [None]:
import torch_geometric
from torch_geometric.data import Data, Dataset

feat = torch.from_numpy(features)
ys = torch.from_numpy(labels)

traj_data = torch_geometric.data.Data(
    pos=feat.to(torch.float32), x=ys.to(torch.float32)
)

In [None]:
from e3nn.nn.models.gate_points_2101 import Network
from e3nn import o3

model_kwargs = {
    "irreps_in": None,  # no input features
    "irreps_hidden": o3.Irreps("5x0e + 5x0o + 5x1e + 5x1o"),  # hyperparameter
    "irreps_out": "12x1o",  # 12 vectors out
    "irreps_node_attr": None,
    "irreps_edge_attr": o3.Irreps.spherical_harmonics(3),
    "layers": 3,  # hyperparameter
    "max_radius": 3.5,
    "number_of_basis": 10,
    "radial_layers": 1,
    "radial_neurons": 128,
    "num_neighbors": 12,  # average number of neighbors w/in max_radius
    "num_nodes": 12,  # not important unless reduce_output is True
    "reduce_output": False,
}

model = Network(**model_kwargs)  # initializing model with parameters above
model.to("cuda")

In [None]:
output = model(traj_data.pos)