# Getting Started

This notebook showcases basic functionality of the code base.

Here, we load the metadata, an example dataset, and run inference using a pre-trained model. 

We also show how to visualize the joint angle predictions using a hand mesh (requires the UmeTrack package -- see README.md).

In [None]:
%load_ext autoreload
%autoreload 2

In [6]:
from pathlib import Path
DATA_DOWNLOAD_DIR = "/home/xiziheng/develop/emg2pose/data"

## Download Dataset Metadata

In [None]:
!cd {DATA_DOWNLOAD_DIR} && curl https://fb-ctrl-oss.s3.amazonaws.com/emg2pose/emg2pose_metadata.csv -o emg2pose_metadata.csv

In [7]:
import pandas as pd

metadata_df = pd.read_csv(DATA_DOWNLOAD_DIR / "emg2pose_metadata.csv")
metadata_df.head(5)

TypeError: unsupported operand type(s) for /: 'str' and 'str'

## Download a Smaller (~600 MiB) Version of the Dataset

In [None]:
!cd {DATA_DOWNLOAD_DIR} && curl "https://fb-ctrl-oss.s3.amazonaws.com/emg2pose/emg2pose_dataset_mini.tar" -o emg2pose_dataset_mini.tar

# Unpack the tar to ~/emg2pose_dataset_mini
!tar -xvf emg2pose_dataset_mini.tar

In [8]:
import glob
import os

sessions = sorted(glob.glob(os.path.join(DATA_DOWNLOAD_DIR, "emg2pose_dataset_mini/*.hdf5")))
sessions

['/home/xiziheng/develop/emg2pose/data/emg2pose_dataset_mini/2022-12-06-1670313600-e3096-cv-emg-pose-train@2-recording-10_left.hdf5',
 '/home/xiziheng/develop/emg2pose/data/emg2pose_dataset_mini/2022-12-06-1670313600-e3096-cv-emg-pose-train@2-recording-10_right.hdf5',
 '/home/xiziheng/develop/emg2pose/data/emg2pose_dataset_mini/2022-12-06-1670313600-e3096-cv-emg-pose-train@2-recording-11_left.hdf5',
 '/home/xiziheng/develop/emg2pose/data/emg2pose_dataset_mini/2022-12-06-1670313600-e3096-cv-emg-pose-train@2-recording-11_right.hdf5',
 '/home/xiziheng/develop/emg2pose/data/emg2pose_dataset_mini/2022-12-06-1670313600-e3096-cv-emg-pose-train@2-recording-12_left.hdf5',
 '/home/xiziheng/develop/emg2pose/data/emg2pose_dataset_mini/2022-12-06-1670313600-e3096-cv-emg-pose-train@2-recording-12_right.hdf5',
 '/home/xiziheng/develop/emg2pose/data/emg2pose_dataset_mini/2022-12-06-1670313600-e3096-cv-emg-pose-train@2-recording-13_left.hdf5',
 '/home/xiziheng/develop/emg2pose/data/emg2pose_dataset_min

## Let's Look at a Dataset

In [9]:
from emg2pose.data import Emg2PoseSessionData

session = sessions[15]
data = Emg2PoseSessionData(hdf5_path=session)

In [None]:
print(data.fields)
print()

print(f"{'emg shape: ':<20} {data['emg'].shape}")
print(f"{'joint_angles shape: ':<20} {data['joint_angles'].shape}")
print(f"{'time shape: ':<20} {data['time'].shape}")

In [None]:
metadata_df[metadata_df["filename"] == data.metadata["filename"]]

In [None]:
import emg2pose.visualization as visualization

visualization.ik_failure_plot(data)

In [None]:
from emg2pose.utils import downsample
import numpy as np

joint_angles = data["joint_angles"]
joint_angles_30hz = downsample(joint_angles, native_fs=2000, target_fs=30)

assert not np.any(np.isnan(joint_angles_30hz))

visualization.plot_hand_mesh(joint_angles_30hz[100], auto_range=False)

In [None]:
import numpy as np


import matplotlib.pyplot as plt
%config InlineBackend.figure_format='retina'

In [None]:
visualization.get_plotly_animation_for_joint_angles(joint_angles_30hz[0:250])

### Render the Plotly Animation to Video Frames

In [None]:
import mediapy

frames = visualization.joint_angles_to_frames_parallel(joint_angles_30hz[0:250])
frames = visualization.remove_alpha_channel(frames)
mediapy.show_video(frames, width=800, fps=30, downsample=True)

## Let's Load a Checkpoint and Generate some Predictions

In [None]:
!cd {DATA_DOWNLOAD_DIR} \
&& curl "https://fb-ctrl-oss.s3.amazonaws.com/emg2pose/emg2pose_model_checkpoints.tar.gz" -o emg2pose_model_checkpoints.tar.gz && \
tar -xvzf emg2pose_model_checkpoints.tar.gz

In [1]:
from emg2pose.utils import generate_hydra_config_from_overrides
config = generate_hydra_config_from_overrides(
    overrides=[
        "experiment=tracking_vemg2pose",
        f"checkpoint=/home/xiziheng/develop/emg2pose/emg2pose_model_checkpoints/regression_vemg2pose.ckpt"
    ]
)

In [2]:
from emg2pose.lightning import Emg2PoseModule

module = Emg2PoseModule.load_from_checkpoint(
    config.checkpoint,
    network=config.network,
    optimizer=config.optimizer,
    lr_scheduler=config.lr_scheduler,
)

Lightning automatically upgraded your loaded checkpoint from v1.8.6 to v2.5.1.post0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../emg2pose_model_checkpoints/regression_vemg2pose.ckpt`


In [10]:
session = data
start_idx = 0
stop_idx = 10_000

In [None]:
import torch

session_window = session[start_idx:stop_idx]

# no_ik_failure is not a field so we slice separately
no_ik_failure_window = session.no_ik_failure[start_idx:stop_idx]

batch = {
    "emg": torch.Tensor([session_window["emg"].T]).cuda(),  # BCT
    "joint_angles": torch.Tensor([session_window["joint_angles"].T]).cuda(),  # BCT
    "no_ik_failure": torch.Tensor([no_ik_failure_window]).cuda(),  # BT
}

preds, joint_angles, no_ik_failure = module.forward(batch)

# Algorithms that use the initial state for ground truth will do poorly
# when the first joint angles are missing!
if (joint_angles[:, 0] == 0).all():
    print(
        "Warning! Ground truth not available at first time step!"
    )

# BCT --> TC (as numpy)
preds = preds[0].cpu().T.detach().numpy()
joint_angles = joint_angles[0].cpu().T.detach().numpy()

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [None]:
preds.shape

In [None]:
joint_angles.shape

In [None]:
joint_angles_30hz = downsample(joint_angles, native_fs = 2000, target_fs = 30)
visualization.get_plotly_animation_for_joint_angles(joint_angles_30hz[0:250], color="gray")

In [None]:
preds_30hz = downsample(preds, native_fs=2000, target_fs=30)
visualization.get_plotly_animation_for_joint_angles(preds_30hz[0:250], color="lightpink")

### Compare the Ground Truth and Predictions Side-by-Side

In [None]:
gt_frames = visualization.joint_angles_to_frames_parallel(joint_angles_30hz[0:250], color="gray")
pred_frames = visualization.joint_angles_to_frames_parallel(preds_30hz[0:250], color="lightpink")

gt_frames = visualization.remove_alpha_channel(gt_frames)
pred_frames = visualization.remove_alpha_channel(pred_frames)

In [None]:
mediapy.show_videos(dict(gt=gt_frames, pred=pred_frames), width=400, fps=30, downsample=True)

In [None]:
N_COLS = 2
N_ROWS = 10

fig, axs = plt.subplots(N_ROWS, N_COLS, figsize=(4*N_COLS, 2*N_ROWS))

axs_flattened = axs.flatten()
for i, ax in enumerate(axs_flattened):
    ax.set_title(f"Joint Angle {i}")
    ax.plot(joint_angles_30hz[:, i], label="gt")
    ax.plot(preds_30hz[:, i], label="pred")

    ax.legend()

fig.suptitle("Predicted vs. Ground Truth Joint Angles")

plt.tight_layout()
fig.subplots_adjust(top=0.95)

plt.show()