# Data Visualization
This notebook contains visualizations for data samples of Muscles in Time (MinT) as well as Muscles in Action (MIA).

### MIA Samples
The first half of the notebook visualizes MIA data samples together with predictions of a network which was pretrained on MinT and then partially fine-tuned on the first and last transformer block on MIA.

### MinT Samples
The second half of the notebook visualizes data samples together with motions from AMASS, as well as prediction of the same network trained and tested on MinT itself.

In [6]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook
%matplotlib inline

import os
from os.path import join as opj
import sys
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn.functional as F
from tqdm import tqdm

# Simple fix for smplx numpy compatibility in recent numpy versions.
np.bool = np.bool_
np.int = np.int_
np.float = np.float_
np.complex = np.complex_
np.object = np.object_
np.unicode = np.unicode_
np.str = np.str_

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.gridspec import GridSpec

import seaborn as sns

from IPython.display import HTML, display, Video

from scipy.spatial.transform import Rotation as R
from scipy.spatial.transform import Slerp
from scipy.interpolate import interp1d

from smplx.body_models import SMPL, SMPLH

os.environ['PYOPENGL_PLATFORM'] = 'egl'

from musint.benchmarks.muscle_sets import MIA_MUSCLES
from musint.benchmarks.muscle_sets import MUSCLE_SUBSETS

import mia_utils as miau
import amass_utils as amau
from plotting_utils import visualize_pose, plot_emg_data

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
# Set global plot parameters for consistent styling
plt.rcParams.update(
    {
        "text.usetex": False,
        "font.family": "serif",
        "font.serif": ["DejaVu Serif"],
        "font.sans-serif": ["Helvetica"],
        "figure.figsize": (6, 4.8),  # Figure size
        "font.size": 12,  # Global font size
        "axes.titlesize": 12,  # Title font size
        "axes.labelsize": 12,  # Axes labels font size
        "xtick.labelsize": 12,  # X-tick labels font size
        "ytick.labelsize": 12,  # Y-tick labels font size
        "legend.fontsize": 12,  # Legend font size
        "figure.titlesize": 12,  # Figure title font size
    }
)

In [8]:
data_dir_mia = "notebook_resources/mia_root"

In [9]:
mia_metadata = miau.index_mia_dataset(data_dir_mia)
mia_metadata["id"] = mia_metadata["path"].apply(lambda p: "/".join(p.split(os.path.sep)[-4:]))
mia_metadata.head()

24663it [01:04, 380.35it/s]


notebook_resources/mia_root/train/Subject3/HighKick/848/emgvalues.npy
24340


Unnamed: 0,path,split,subject,activity,repetition,id
0,notebook_resources/mia_root/train/Subject3/Hig...,train,Subject3,HighKick,848,train/Subject3/HighKick/848
1,notebook_resources/mia_root/train/Subject3/Hig...,train,Subject3,HighKick,886,train/Subject3/HighKick/886
2,notebook_resources/mia_root/train/Subject3/Hig...,train,Subject3,HighKick,836,train/Subject3/HighKick/836
3,notebook_resources/mia_root/train/Subject3/Hig...,train,Subject3,HighKick,878,train/Subject3/HighKick/878
4,notebook_resources/mia_root/train/Subject3/Hig...,train,Subject3,HighKick,874,train/Subject3/HighKick/874


In [10]:
mia_emg_data = miau.load_emg_data(map(lambda p: opj(p, "emgvalues.npy"), mia_metadata["path"]))
print(f"MIA data shape: {mia_emg_data.shape}")

100%|██████████| 24340/24340 [00:20<00:00, 1214.96it/s]


MIA data shape: torch.Size([24340, 30, 8])


In [11]:
# Compute the 99th percentile of the data along the first dimension, scaling without outliers.
pct99_norm = np.percentile(mia_emg_data.reshape((-1, 8)), 99, axis=0)

# Calulcate and save normalization statistics for the dataset.
save_stats = False

# Concatenate all the arrays from the data list into a single array for processing.
Mean = mia_emg_data.mean(
    axis=0
)  # Compute the mean of the concatenated data along the first dimension (across all samples).
Std = mia_emg_data.std(axis=0)  # Compute the standard deviation of the concatenated data along the first dimension.

if save_stats:
    save_dir_mia = "notebook_resources/mia_root"

    np.save(os.path.join(save_dir_mia, "Mean.npy"), Mean)
    np.save(os.path.join(save_dir_mia, "Std.npy"), Std)

    np.save(os.path.join(save_dir_mia, "99pct.npy"), np.percentile(mia_emg_data, 99, axis=0))

In [12]:
mia_emg_dict = {mia_metadata.iloc[i]["id"]: mia_emg_data[i] for i in range(len(mia_metadata))}
next(iter(mia_emg_dict.items()))

('train/Subject3/HighKick/848',
 tensor([[43., 78., 25., 11., 26., 19.,  5., 11.],
         [40., 73., 22., 12., 25., 15.,  4., 11.],
         [49., 53., 20., 11., 25., 13.,  4., 11.],
         [49., 38., 18., 12., 18., 13.,  4., 11.],
         [50., 21., 17., 12., 14., 11.,  3., 12.],
         [47., 21., 16., 10., 14., 12.,  4., 11.],
         [42., 15., 17., 10.,  8., 10.,  4., 10.],
         [37., 12., 18., 10.,  7.,  8.,  4., 11.],
         [31., 10., 15.,  9.,  7.,  7.,  4., 11.],
         [26., 10., 14.,  9.,  5.,  7.,  4., 10.],
         [21.,  7., 12., 10.,  4.,  5.,  5., 10.],
         [16.,  8., 11., 10.,  4.,  5.,  8., 12.],
         [12., 25., 11., 10.,  4., 13.,  6., 12.],
         [10., 49.,  9.,  9., 10., 18.,  5., 11.],
         [12., 49., 11., 10., 10., 18.,  5., 11.],
         [20., 40., 11., 12., 17., 38.,  4., 13.],
         [38., 25., 12., 11., 19., 38.,  4., 12.],
         [49., 25., 13., 11., 19., 34.,  4., 13.],
         [35., 19., 11., 12., 22., 28.,  3., 11.],

In [13]:
mia_metadata["activity"].unique()

array(['HighKick', 'HookPunch', 'LegCross', 'Squat', 'SlowSkater',
       'ElbowPunch', 'KneeKick', 'Shuffle', 'JumpingJack', 'SideLunges',
       'Running', 'FrontPunch', 'LegBack', 'FrontKick', 'RonddeJambe'],
      dtype=object)

**Please download the SMPL-H model to the respective location.**
- https://mano.is.tue.mpg.de/download.php
- https://github.com/vchoutas/smplx/blob/main/tools/README.md

In [14]:
smpl_h_path = "./body_models/smpl/SMPLH_NEUTRAL_AMASS_MERGED.pkl"
# Each VIBE sample in MIA has 30 frames at 10 fps. We convert to 20 fps by interpolating intermediate frames
# Since we only interpolate between two frames we end up with 59 frames as result.
bm = SMPLH(model_path=smpl_h_path, num_betas=10, use_pca=False, batch_size=59).cuda()



In [15]:
preds_root = "notebook_resources/mia_predictions"

mia_preds = pd.read_csv(opj(preds_root, "metadata.csv"))
mia_preds.set_index("name", inplace=True)
mia_preds["gt_name"] = mia_preds["gt_name"].apply(lambda x: x.replace("__I__", "/"))
mia_preds["pred_name"] = mia_preds["pred_name"].apply(lambda x: x.replace("__I__", "/"))
mia_preds.head()

Unnamed: 0_level_0,gt_name,pred_name,time_start
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
val/Subject5/JumpingJack/81,val/Subject5/JumpingJack/81_0_gt.npy,val/Subject5/JumpingJack/81_0_pred.npy,0.0
val/Subject5/JumpingJack/81,val/Subject5/JumpingJack/81_1_gt.npy,val/Subject5/JumpingJack/81_1_pred.npy,1.4
val/Subject5/JumpingJack/98,val/Subject5/JumpingJack/98_2_gt.npy,val/Subject5/JumpingJack/98_2_pred.npy,0.0
val/Subject5/JumpingJack/98,val/Subject5/JumpingJack/98_3_gt.npy,val/Subject5/JumpingJack/98_3_pred.npy,1.4
val/Subject5/JumpingJack/92,val/Subject5/JumpingJack/92_4_gt.npy,val/Subject5/JumpingJack/92_4_pred.npy,0.0


In [16]:
random.seed(42)

for i in range(3):
    # Setting up the following visualization. Selecting a random sample with that activity.
    # activity = "LegBack"

    mia_sub_metadata = mia_metadata[(mia_metadata["split"] == "val")]
    assert len(mia_sub_metadata), f"No samples found.'"

    for j in range(20):
        mia_idx = random.randint(0, len(mia_sub_metadata))
        mia_sample = mia_sub_metadata.iloc[mia_idx]
        if mia_sample["id"] not in mia_preds.index:
            if j < 9:
                print(f"Did not find emg data for {mia_sample['id']}")
                continue
        else:
            assert mia_sample["id"] in mia_preds.index, "Could repeatedly not find emg data for sample"
            break

    print(mia_sample["id"])

    body = miau.mia_to_smpl_body(mia_sample["path"], bm)

    vertices = body.vertices.detach().cpu().numpy()

    gts, preds, start, end = miau.load_and_concat_mia_dat(mia_sample["id"], mia_preds, preds_root, pad=60)

    # gts = mia_emg_data[mia_sample["id"]]
    # gts /= pct99_norm
    # Repeat each frame twice
    # gts = np.repeat(gts, 2, axis=0)

    MINT_MIA_MUS = [1, 5, 0, 4]
    num_frames = min(len(vertices), len(gts), len(preds), 60)

    gts = gts[start:num_frames, MINT_MIA_MUS]

    preds = preds[start:num_frames, MINT_MIA_MUS]

    vertices = vertices[start:num_frames]

    names = ["Quad. Fem. (L)", "Hamstring (L)", "Quad. Fem. (R)", "Hamstring (R)"]

    fig = plt.figure(figsize=(8, 12))

    gs = GridSpec(2, 1, height_ratios=[5, 2])

    ax = fig.add_subplot(gs[0], projection="3d")

    lines = plot_emg_data(gts, preds, names, fig=fig, gridspec=gs)

    def update(frame):
        ax.clear()
        visualize_pose(vertices[frame], frame, ax, title=mia_sample["id"])

        for line in lines:
            line.set_xdata(frame)

        return [ax, lines]

    plt.tight_layout()

    ani = FuncAnimation(fig, update, frames=len(vertices), interval=50)  # 20 fps = 50ms interval

    path_id = mia_sample["id"].replace("/", "__I__")

    ani.save(f"notebook_output/mia/{path_id}.mp4", writer="ffmpeg", fps=20)
    display(Video(f"notebook_output/mia/{path_id}.mp4", embed=True))

    plt.close(fig)  # Prevents the static plot from showing

Did not find emg data for val/Subject0/LegCross/1180
val/Subject3/Running/3


  line.set_xdata(frame)


Did not find emg data for val/Subject0/RonddeJambe/994
Did not find emg data for val/Subject0/LegBack/412
Did not find emg data for val/Subject0/Running/1252
Did not find emg data for val/Subject0/SlowSkater/1442
Did not find emg data for val/Subject0/LegCross/1114
val/Subject7/HighKick/107


  line.set_xdata(frame)


Did not find emg data for val/Subject0/HighKick/1645
val/Subject8/FrontPunch/490


  line.set_xdata(frame)
