# Sweep over `train_frames` and `losses_to_use`

In [1]:
import hydra
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import torch
from omegaconf import DictConfig

from lightning_pose.utils.io import return_absolute_data_paths
from lightning_pose.utils.scripts import get_imgaug_transform, get_dataset, get_data_module, get_loss_factories
from lightning_pose.losses.losses import PCALoss

import sys
sys.path.append('/home/jovyan/tracking-diagnostics')
from diagnostics.handler import ModelHandler
from diagnostics.io import get_base_config, get_keypoint_names

In [2]:
# %% get config
dataset_name = "rick-configs-1"
base_config_dir = "/home/jovyan/rick-configs-1"
base_save_dir = "/home/jovyan/"
cfg = get_base_config(config_dir=base_config_dir, config_name="config")

In [3]:

# load ground truth labels
csv_file = os.path.join(cfg.data.data_dir, cfg.data.csv_file)
csv_data = pd.read_csv(csv_file, header=list(cfg.data.header_rows))
keypoints_gt = csv_data.iloc[:, 1:].to_numpy().reshape(csv_data.shape[0], -1, 2)

keypoint_names = get_keypoint_names(csv_data, cfg.data.header_rows)

Start by looping over `train_frames` and individual `losses_to_use`. Later build in complication.

In [14]:
save_dir = "/home/jovyan/lightning-pose"
loss_types = [[], ["pca_multiview"], ["pca_singleview"], ["temporal"], ["unimodal_mse"]] # TODO: order matters
log_weight_list = [] # TODO: order matters
model_names = ["train_frames_sweep"]*len(loss_types)
supervised_model_name = model_names[0] # they all have the same name
train_frames_list = [50,75,100,125]
model_type = "heatmap"
handlers = []
for tr_fr_idx, train_frames in enumerate(train_frames_list):
    for loss_idx, loss in enumerate(loss_types):
        print("searching for train_frames: {}, loss_type: {}".format(train_frames, loss))
        model_cfg = cfg.copy()
        model_cfg.training.train_frames = train_frames
        model_cfg.model.losses_to_use = loss # assume loss is already a list, [] if supervised
        model_cfg.model.model_name = model_names[loss_idx]
        model_cfg.model.model_type = model_type
        # specific arguments to "train_frames_sweep" models. TODO: change if needed
        model_cfg.training.train_prob=0.2
        model_cfg.training.val_prob=0.2
        model_cfg.training.min_epochs=125
        model_cfg.training.max_epochs=2000
        if len(loss) == 0:
            # support for uniquely-named supervised models
            model_cfg.model.model_name = supervised_model_name
        else:
            # loop over the sub losses
            if len(log_weight_list)>0:
                for sub_loss_idx,sub_loss in enumerate(loss):
                    model_cfg.losses[sub_loss].log_weight = log_weight_list[loss_idx][sub_loss_idx]
        
        try:
            handlers.append(ModelHandler(save_dir, model_cfg, verbose=False))
            print("Found: {}".format(model_cfg.model.model_name))
            print("In: {}".format(handlers[-1].model_dir))
        except FileNotFoundError:
            print('did not find %s model for train_frames=%i' % (loss, train_frames))
            continue
# report on the models found
if len(handlers) == 0:
    print("No models found")
elif len(handlers) == len(train_frames_list)*len(loss_types):
    print("Found all models")
else:
    print("Found {} models out of {}".format(len(handlers), len(train_frames_list)*len(loss_types)))

searching for train_frames: 50, loss_type: []
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/1
searching for train_frames: 50, loss_type: ['pca_multiview']
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/16
searching for train_frames: 50, loss_type: ['pca_singleview']
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/13-32-43/0
searching for train_frames: 50, loss_type: ['temporal']
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/6
searching for train_frames: 50, loss_type: ['unimodal_mse']
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/11
searching for train_frames: 75, loss_type: []
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/2
searching for train_frames: 75, loss_type: ['pca_multiview']
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-

In [17]:
# loop over handlers and compute metrics
to_compute = "rmse" # | "rmse" | "pca_singleview" | "unimodal_mse"
keypoint_names = get_keypoint_names(csv_data, cfg.data.header_rows)
model_cfg.model.losses_to_use = [loss_type] # assume loss is already a list, [] f
error_metric = "reprojection_error" # only for PCA
pca_loss = None
data_module = None
if to_compute == 'rmse':
    y_label = 'RMSE per bodypart'
elif to_compute == 'pca_multiview' or to_compute == 'pca_singleview':
    y_label = 'PCA reprojection error'
    from lightning_pose.utils.pca import KeypointPCA
    data_dir, video_dir = return_absolute_data_paths(data_cfg=model_cfg.data)
    imgaug_transform = get_imgaug_transform(cfg=model_cfg)
    dataset = get_dataset(cfg=model_cfg, data_dir=data_dir, imgaug_transform=imgaug_transform)
    data_module = get_data_module(cfg=model_cfg, dataset=dataset, video_dir=video_dir)
    data_module.setup()
    # compute pca params
    loss_factories = get_loss_factories(cfg=model_cfg, data_module=data_module)
    pca_loss = loss_factories["unsupervised"].loss_instance_dict[to_compute]

name_strs_to_plot = ['+'.join([l[:5] for l in loss]) if len(loss)>0 else 's' for loss in loss_types]
# store results here
if to_compute == "pca_singleview":
    # remove obstacle keypoints
    keypoint_names = [kp for kp in keypoint_names if kp not in ['obs_top','obsHigh_bot','obsLow_bot']]
    print(keypoint_names)

metrics_collected = {bp: [] for bp in keypoint_names}
    
for hand_idx, handler in enumerate(handlers):
    print(hand_idx)
    print("name: {}".format(handler.cfg.model.model_name))
    print("losses_to_use: {}".format(handler.cfg.model.losses_to_use))
    print(handler.model_dir)
    # compute metric
    try:
        result = handler.compute_metric(
            to_compute, 'predictions.csv',
            keypoints_true=keypoints_gt, pca_loss_obj=pca_loss, datamodule=data_module)
        print(result.shape)
    except FileNotFoundError:
        print('could not find model predictions')
        continue
    for b, bodypart in enumerate(keypoint_names):
        metrics_collected[bodypart].append(result[:, b])

0
name: train_frames_sweep
losses_to_use: []
/home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/1
Metric: rmse
Computing RMSE...
(1045, 17)
1
name: train_frames_sweep
losses_to_use: ['pca_multiview']
/home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/16
Metric: rmse
Computing RMSE...
(1045, 17)
2
name: train_frames_sweep
losses_to_use: ['pca_singleview']
/home/jovyan/lightning-pose/multirun/2022-04-02/13-32-43/0
Metric: rmse
Computing RMSE...
(1045, 17)
3
name: train_frames_sweep
losses_to_use: ['temporal']
/home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/6
Metric: rmse
Computing RMSE...
(1045, 17)
4
name: train_frames_sweep
losses_to_use: ['unimodal_mse']
/home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/11
Metric: rmse
Computing RMSE...
(1045, 17)
5
name: train_frames_sweep
losses_to_use: []
/home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/2
Metric: rmse
Computing RMSE...
(1045, 17)
6
name: train_frames_sweep
losses_to_use: ['pca_multiview']
/home/j

In [18]:
# collect results
# TODO: currently ignorant of the train_frames. either fix here or above.               
results_df = []
for bodypart in keypoint_names:
    dict_tmp = {
        'bodypart': bodypart,
        #'rng_seed': rng_seed,
        'eval_mode': handlers[-1].pred_df.iloc[:, -1].to_numpy(),
        'img_file': csv_data.iloc[:, 0], # TODO: fix, this is wrong. should be a str not a float
    }
    for col_name, metric in zip(name_strs_to_plot, metrics_collected[bodypart]):
        dict_tmp[col_name] = metric
    results_df.append(pd.DataFrame(dict_tmp))

results_df = pd.concat(results_df)

In [19]:
results_df.head()

Unnamed: 0,bodypart,eval_mode,img_file,s,pca_m,pca_s,tempo,unimo
0,paw1LH_top,unused,barObstacleScaling1/img1.png,87.153238,93.340255,39.708454,106.189319,71.652569
1,paw1LH_top,validation,barObstacleScaling1/img2.png,111.731015,115.790357,28.46511,60.46525,39.128012
2,paw1LH_top,test,barObstacleScaling1/img3.png,81.099913,8.681526,78.112809,18.369521,74.446077
3,paw1LH_top,test,barObstacleScaling1/img4.png,186.998488,2.664029,3.857219,2.859897,13.102761
4,paw1LH_top,test,barObstacleScaling1/img5.png,7.186884,7.548655,9.943737,6.150783,8.909237
