# ⚡ Train and visualize a Lightning Pose model ⚡

Using a toy dataset (a.k.a. "mirror-mouse") with 90 labeled images from Warren et al., 2022 (eLife).
* [Environment setup](#Environment-setup)
* [Train (via PyTorch Lightning)](#Training)
* [Monitor optimization in real time (via TensorBoard UI)](#Monitor-training)
* [Video predictions and diagnostics](#Plot-video-predictions-and-unsupervised-losses)


**Make sure to use a GPU runtime!**

To do so, in the upper right corner of this notebook:
* click the "Connect" button (or select "Connect to a hosted runtime" from the drop-down)
* ensure you are connected to a GPU by clicking the down arrow, selecting "View resources" from the menu, and make sure you see "Python 3 Google Compute Engine backend (GPU)"

## Environment setup

In [None]:
# download the lightning-pose repository into /content/lightning-pose
!git clone https://github.com/paninski-lab/lightning-pose.git

In [None]:
# step into that directory
%cd /content/lightning-pose

In [None]:
# check which CUDA driver is installed
!nvcc --version

In [None]:
# install lightning-pose as a package, including all its requirements (specified in setup.py)

# NOTE: you may see the following error:
#     ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.
# This is fine and can be ignored

!pip install -e .


#### RESTART THE RUNTIME
Go to `Runtime > Restart session` to finish package installations.
After restarting, proceed to the next cell.

In [None]:
# step into lightning-pose
%cd /content/lightning-pose

## Training

In [None]:
import os
from omegaconf import OmegaConf

# read hydra configuration file from lightning-pose/scripts/configs/config_mirror-mouse-example.yaml
# this config file contains all the necessary information for training and evaluating a
# Lightning Pose model
# https://lightning-pose.readthedocs.io/en/latest/source/user_guide/config_file.html
cfg = OmegaConf.load("scripts/configs/config_mirror-mouse-example.yaml")

# get absolute data and video directories for toy dataset
cfg.data.data_dir = os.path.join("/content/lightning-pose/data/mirror-mouse-example")
cfg.data.video_dir = os.path.join("/content/lightning-pose/data/mirror-mouse-example/videos")

assert os.path.isdir(cfg.data.data_dir), "data_dir not a valid directory"
assert os.path.isdir(cfg.data.video_dir), "video_dir not a valid directory"

# make training short for a demo (we usually do 300)
cfg.training.min_epochs = 100
cfg.training.max_epochs = 150
cfg.training.batch_size = 32

# directory we'll save the model in
model_dir = '/content/outputs/semi-super-model'


## Monitor training

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [None]:
# Launch tensorboard before launching training (happens in next cell).
# If you receive a 403 error, be sure to enable all cookies for this site in your browser.
# To see the losses during training, select TIME SERIES and hit the refresh button (circle arrow) on the top right.

# The two most important diagnostics are:
# - `train_supervised_rmse`: root mean square error (rmse) of predictions on training data
# - `val_supervised_rmse`: rmse on validation data

%tensorboard --logdir $model_dir

In [None]:
from lightning_pose.train import train

# Save the model artifacts here (logs, weights, predictions)
model_dir = '/content/outputs/semi-super-model'

# Train the model (approx 15-20 mins on this T4 GPU machine)
# This function will also:
# - evaluate the model on train, validation, and test sets
# - evaluate the model on a test video, and compute unsupervised losses
model = train(cfg, model_dir=model_dir)


In [None]:
artifacts = os.listdir(model_dir)
print("Generated the following diagnostic csv files:")
print(artifacts)

### Display the short labeled video
Includes network predictions.
Make sure your video is not too large for this; it may cause memory issues.


In [None]:
from IPython.display import HTML
from base64 import b64encode

labeled_vid_dir = os.path.join(model_dir, "video_preds/labeled_videos")
vids = os.listdir(labeled_vid_dir)
mp4 = open(os.path.join(labeled_vid_dir, vids[0]),'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)


In [None]:
# download vids to your local machine if desired
from google.colab import files
for vid in vids:
    if vid.endswith(".mp4"):
        files.download(os.path.join(labeled_vid_dir, vid))

## Plot video predictions and unsupervised losses

### Load data

In [None]:
from collections import defaultdict
import pandas as pd
from pathlib import Path

from lightning_pose.apps.utils import build_precomputed_metrics_df, get_col_names, concat_dfs
from lightning_pose.apps.utils import update_vid_metric_files_list
from lightning_pose.apps.utils import get_model_folders, get_model_folders_vis
from lightning_pose.apps.plots import plot_precomputed_traces

# select which model(s) to use
model_folders = get_model_folders("/content")

# get the last two levels of each path to be presented to user
model_names = get_model_folders_vis(model_folders)

# get prediction files for each model
prediction_files = update_vid_metric_files_list(video="test_vid", model_preds_folders=model_folders)

# load data
dframes_metrics = defaultdict(dict)
dframes_traces = {}
for p, model_pred_files in enumerate(prediction_files):
    model_name = model_names[p]
    model_folder = model_folders[p]
    for model_pred_file in model_pred_files:
        model_pred_file_path = os.path.join(model_folder, "video_preds", model_pred_file)
        if not isinstance(model_pred_file, Path):
            model_pred_file.seek(0)  # reset buffer after reading
        if "pca" in str(model_pred_file) or "temporal" in str(model_pred_file) or "pixel" in str(model_pred_file):
            dframe = pd.read_csv(model_pred_file_path, index_col=None)
            dframes_metrics[model_name][str(model_pred_file)] = dframe
        else:
            dframe = pd.read_csv(model_pred_file_path, header=[1, 2], index_col=0)
            dframes_traces[model_name] = dframe
            dframes_metrics[model_name]["confidence"] = dframe
        data_types = dframe.iloc[:, -1].unique()

# compute metrics
# concat dataframes, collapsing hierarchy and making df fatter.
df_concat, keypoint_names = concat_dfs(dframes_traces)
df_metrics = build_precomputed_metrics_df(
    dframes=dframes_metrics, keypoint_names=keypoint_names)
metric_options = list(df_metrics.keys())

# print keypoint names; select one of these to plot below
print(keypoint_names)

# NOTE: you can ignore all errors and warnings of the type:
#    No runtime found, using MemoryCacheStorageManager

### Plot video traces

In [None]:
# rerun this cell each time you want to update the keypoint

from IPython.display import display, clear_output
import ipywidgets as widgets

def on_change(change):
    if change["type"] == "change" and change["name"] == "value":
        clear_output()
        cols = get_col_names(change["new"], "x", dframes_metrics.keys())
        fig_traces = plot_precomputed_traces(df_metrics, df_concat, cols)
        fig_traces.show()

# create a Dropdown widget
dropdown = widgets.Dropdown(
    options=keypoint_names,
    value=None,  # Set the default selected value
    description="Select keypoint:",
)

# update plot upon change
dropdown.observe(on_change)

# display widget
display(dropdown)