# Install pytorch
This exercise requires [pytorch](https://pytorch.org). To install it:

1. Activate your virtual environment:
   ```sh
   conda activate flygym
   ```
2. Install pytorch with pip:
   ```sh
   pip install torch
   ```

For details on how to use pytorch, refer to the [pytorch tutorials](https://pytorch.org/tutorials/).

# Exercise 2: Neural network
The basic center of mass tracking method that we used in the [vision tutorial](1_vision.ipynb) might fail for objects that blend in closely with the background. In this exercise, we will introduce a second fly into the arena and deploy a neural network to accurately track its location.

For the neural network to learn, it's essential to first compile a dataset containing images along with the precise locations of this additional fly. This has been implemented in [generate_dataset.py](generate_dataset.py). The script requires a few minutes to execute. To speed up the process, the dataset have already been generated and saved in [data/data.npz](data/data.npz).

Now, we will proceed to load the dataset:

In [None]:
import numpy as np

data = np.load("data/data.npz")
images = data["images"]
# The images are stored in a 4D array (n_images, n_eyes, n_rows, n_cols)
print(f"Shape of images: {images.shape}")
# The positions of the second fly are stored in polar coordinates (r, theta)
r = data["r"]
theta = data["theta"]
data.close()

We will convert the position of the fly from polar coordinates $(r, \theta)$ into a Cartesian coordinate system that has undergone rotation and inversion. The two basis vectors point at 45° to the left and right of the fly, respectively.

<img src="images/coordinates.png" width="600">

Implement the conversion below. What is the advantage of this encoding?

In [None]:
################################################################
# TODO: Convert the polar coordinates to cartesian coordinates
...
coords_lr = ...
################################################################

# The shape should be (# of samples, 2)
assert coords_lr.shape == (len(images), 2)

Next, we can create the data loaders:

In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split

dataset = TensorDataset(torch.tensor(images), torch.tensor(coords_lr))
datasets = dict(
    zip(["train", "val", "test"], random_split(dataset, [8000, 1000, 1000]))
)
batch_size = 32
loaders = {
    name: DataLoader(dataset, batch_size=batch_size, shuffle=name == "train")
    for name, dataset in datasets.items()
}

Construct your neural network model in the following cell. A simple Convolutional Neural Network (CNN) should be enough for achieving good performance on this dataset.

In [None]:
from torch import nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        #########################################
        # TODO: Define your layers here
        # e.g., self.conv1 = nn.Conv2d(...)

        #########################################

    def forward(self, x):
        #########################################
        # TODO: Define the forward pass

        return ...
        #########################################

In [None]:
from copy import deepcopy
import torch.optim as optim

model = Model()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
best_val_loss = float("inf")

for epoch in range(100):
    model.train()
    train_loss = 0

    for inputs, labels in loaders["train"]:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(loaders["train"])

    model.eval()

    with torch.no_grad():
        val_loss = 0
        for inputs, labels in loaders["val"]:
            outputs = model(inputs)
            val_loss += criterion(outputs, labels).item()

    val_loss /= len(loaders["val"])

    print(f"{epoch=}, {train_loss=:0.4f}, {val_loss=:0.4f}")

    if val_loss < best_val_loss:
        best_model_state = deepcopy(model.state_dict())
        best_val_loss = val_loss

model.load_state_dict(best_model_state)

Finally, let's run the model on the testing data and visualize the predictions:

In [None]:
import matplotlib.pyplot as plt

test_indices = datasets["test"].indices
with torch.no_grad():
    coords_lr_pred = model(torch.tensor(images[test_indices])).numpy()

theta_pred = np.angle(coords_lr_pred @ (1, -1j) * np.exp(1j * np.pi / 4))

fig, axs = plt.subplots(1, 3, figsize=(9, 3))
for i in range(2):
    axs[i].scatter(coords_lr[test_indices, i], coords_lr_pred[:, i], alpha=0.1)

axs[2].scatter(theta[test_indices], theta_pred, alpha=0.1)
axs[0].set_ylabel("Prediction")
axs[1].set_xlabel("Ground truth")

for ax, title in zip(axs, ["$x_L$", "$x_R$", "$\\theta$"]):
    ax.set_title(title)

Can you estimate the visual field of the fly based on the plots?

In this exercise, we have learned how to
- Create image datasets with NeuroMechFly simulation
- Use neural networks to estimate object position