# Welcome to the `track-mjx` Workshop!

In this document, I am going to guide you set up the running environment for the `track-mjx`, both using Docker image, and using conda environment.

> Make sure you have a NVIDIA GPU enabled Linux environment setup for this repo.

# Section 1: Environment Setup

## Option 1: Running `track-mjx` on a Bare-Metal Linux Machine

### Installation

1. Clone the repository with the following command:
    ```bash
    git clone https://github.com/talmolab/track-mjx.git && cd track-mjx
    ```
2. Create a new development environment via `conda`:
    ```bash
    conda env create -f environment.yml
    ```
    This will install the necessary dependencies and install the package in editable mode.
3. Test the environment.
    Active the conda environment that was just installed:
    ```bash
    conda activate track_mjx
    ```
    Then run `jupyter lab` and execute the tests in [`notebooks/test_setup.ipynb`](notebooks/test_setup.ipynb). This will check if MuJoCo, GPU support and JAX appear to be working.

### Using VSCode or other IDE

> Make sure you have downloaded the Python and jupyter notebook plugins in your VSCode.

After successfully installing the `track-mjx` conda environment, open the command palette (<kbd>Shift</kbd> + <kbd>Cmd</kbd> + <kbd>P</kbd> on Mac or <kbd>Ctrl</kbd> + <kbd>Shift</kbd> + <kbd>P</kbd> on Windows), and select `Python: Select Interpreter`. Choose `track-mjx (Python 3.11.11)` to set your Python environment. You may need to restart the terminal to activate the `track-mjx` environment, or you can activate it directly by running `conda activate track-mjx`.

When you open a `.ipynb` file, you also need to select the `track-mjx` environment as the kernel. You can do this by clicking on the kernel name in the top right corner of the notebook and selecting `track-mjx`.

## Option 2: Running `track-mjx` using docker image

> Make sure you have docker daemon running in your computer. Run the following command.

> To allow GPU passthrough to the running docker image, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) if you haven't do so.

> We have not figure out how to use the `egl` render with the docker image running on bare-metal ubuntu linux system. `egl` backend is working on Salk's Manticores clusters with Kubernetes linux based system.

<!-- Need to re-test the docker system locally in Linux/Windows -->

Pull and run the docker image from the DockerHub registry:

```bash
docker run --gpus all -e NVIDIA_DRIVER_CAPABILITIES=all -p 8888:22 scottyang17/track-mjx:vscode-v1
```

The `8888` is the local port that you want to forward to. Choose one that's unoccupied as you'll use this later to connect to the Docker container from VSCode.

[See
here](https://github.com/talmolab/internal-dockerfiles/tree/3245903ec48b633ae205eeab0583d6413c32530b/remote-dev)
for more info on our Remote Dev Docker image.


### Setup VS Code Remote Dev

First, install the `Remote Development` (with id: ms-vscode-remote.vscode-remote-extensionpack) extension on vscode. Bring up the command palette, search and choose `Remote-SSH: Connect to Host` -> `Configure SSH Hosts` -> `<your ssh config path>`, and put following config:

```
Host local-testing
    HostName <ip>
    Port <port>
    User root
```

The `<ip>` will be `localhost` if running on the same machine, or the IP of the remote machine if running on a cluster.

Bring up your command palette, choose `Remote-SSH: Connect to Host` -> `track-mjx-remote-dev`, type in the password `root`, you are now connecting to the image.


## Testing your setup

After you have set up the environment, you can test it by running all the commands in the [`notebooks/test_setup.ipynb`](notebooks/test_setup.ipynb) notebook. This will check if MuJoCo, GPU support and JAX appear to be working.

# Section 2: Download the Datasets

Once the environment is set up, you can proceed to download the datasets containing the target trajectories for the agents to track. These datasets are available via a Google Drive link.

1. install the `gdown` package by running the following command:
    ```bash
    pip install gdown
    ```

2. Download the datasets by running the following command, make sure you are under the `track-mjx/data` directory, the following command will download the dataset for the rodent trajectory
    ```bash
    gdown 1BHKe9agnupdPC8xExx4OaPtPJBKTTI2U
    ```

    TODO: fly trajectory dataset pending


# Section 3: Training the `track-mjx` model

After setting up the environment and downloading the datasets, you can proceed to train the `track-mjx` model. 

### `screen` based terminal

This enables you to use persistent sessions even if you get disconnected from the Docker image. See [this issue](https://github.com/talmolab/track-mjx/issues/8#issuecomment-2469376476) for a workflow description.

This is useful when running stuff over SSH since disconnects would otherwise kill long-running processes like training.

1. Start a new `screen` session:
    ```bash
    screen -S track-mjx-train
    ```
    This will automatically attach you to the new session.
2. Activate the conda environment:
    ```bash
    conda activate track_mjx
    ```
3. Run the training command:
    ```bash
    python -m track_mjx.train data_path="data/ReferenceClip.p"
    ```

You're good!

**To manually detach from the session:** Press <kbd>Ctrl</kbd>+<kbd>A</kbd> → <kbd>D</kbd> (even on Mac).

**To attach to a running session:**
```bash
screen -r track-mjx-train
```

**To see which screen sessions are running:**
```bash
screen -ls
```

**To kill a session and anything running in it:**
```bash
screen -S track-mjx-train -X quit
```

If you get disconnected from SSH, you just need to re-attach to the running session to view it. The running process itself won't be terminated.

To run the training, you can also use the following command:

```bash
python -m track_mjx.train data_path="data/transform_snips.h5" +hydra.job.config_name="rodent-two-clips"
```

You might need to login to wandb to see the training progress. You can do this by running the following command:

```bash
wandb login
```