In [1]:
%load_ext autoreload
%autoreload 2

import os
import matplotlib.pyplot as plt

# Change working directory to project root
os.chdir('../../')

import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
import json
import imageio
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
from hmr4d.dataset.rich.rich_utils import get_w2az_sahmr, parse_seqname_info
from hmr4d.utils.geo_transform import apply_T_on_points
from hmr4d.utils.eval.wham.eval_utils import convert_joints22_to_24

In [2]:
# Load the dataset
wham_rich = torch.load("inputs/RICH/eval_support/wham_rich.pt")
labels = np.load("inputs/RICH/eval_support/rich_test_vit.pth", allow_pickle=True)
vid2labelindex = {vid: i for i, vid in enumerate(labels["vid"])}
wham_rich_world = torch.load("inputs/RICH/eval_support/sm_rich_mocap_input.pt")["vid_to_pred_j3d_ayfz1"]

In [3]:
from hmr4d.utils.geo_transform import compute_T_ayfz2ay, compute_T_ayf2az
from hmr4d.dataset.rich.rich_utils import RichVid2Tc2az

compute_Tc2az_from_vid = RichVid2Tc2az()

all_dump = torch.load("dump.pt")
# vid_to_indices = defaultdict(list)
# for i, dump in enumerate(all_dump):
#     vid = dump["meta"][0]
#     vid_to_indices[vid].append(i)
# print("Number of clips: ", len(all_dump))
# print("Number of videos: ", len(vid_to_indices))
# print(all_dump[0].keys())

gt_vidse2motion = {}
wham_vidse2motion = {}
ours_vidse2motion = {}
for dump in all_dump:
    vid, (s, e) = dump["meta"]
    vidse = f"{vid}_{int(s)}_{int(e)}"

    # gt
    j = vid2labelindex[vid]
    gt_c = labels["joints3D"][j][1:][s:e]  # (L, 24, 3)
    T_c2az = compute_Tc2az_from_vid(vid)
    gt_az = apply_T_on_points(gt_c, T_c2az)
    T_az2ayfz = compute_T_ayf2az(gt_az[:1], inverse=True)[0]  # (4, 4)
    gt_ayfz = apply_T_on_points(gt_az, T_az2ayfz)  # (L, 24, 3)
    gt_vidse2motion[vidse] = {
        "c": gt_c,
        "w": gt_ayfz,
    }

    # wham
    wham_ay = wham_rich_world[vid][s:e]
    T_ay2ayfz = compute_T_ayfz2ay(wham_ay[:1], inverse=True)[0]  # (4, 4)
    wham_w = apply_T_on_points(wham_ay, T_ay2ayfz)  # (F, 22, 3)
    wham_vidse2motion[vidse] = {
        "c": wham_rich[vid]["cam_motion3d"][s:e],
        "w": wham_w,
    }

    # ours
    pred_ayfz = convert_joints22_to_24(dump["pred_ayfz_motion"])
    pred_T_ayfz2c = dump["pred_T_ayfz2c"]
    pred_c = apply_T_on_points(pred_ayfz, pred_T_ayfz2c)
    ours_vidse2motion[vidse] = {
        "c": pred_c,
        "w": pred_ayfz,
    }

print(len(ours_vidse2motion))

# wis3d = make_wis3d() 
# add_motion_as_lines(gt_vidse2motion[vidse]['w'], wis3d, name='gt_w')
# add_motion_as_lines(wham_vidse2motion[vidse]['w'], wis3d, name='wham_w')
# add_motion_as_lines(ours_vidse2motion[vidse]['w'], wis3d, name='ours_w')

996


In [36]:
import torch.functional as F 
from pytorch3d.transforms import axis_angle_to_matrix
#  ======= TRACE ======== #
trace_dumps = torch.load("/nas/share/hmr4d/comp_exp/TRACE@RICH.pt")
trace_vidse2motion = {}
yup2ydown = axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float()

for vid in trace_dumps:
    for j in range(len(trace_dumps[vid])):
        s = trace_dumps[vid][j]['start'] 
        e = trace_dumps[vid][j]['end']

        vidse = f"{vid}_{int(s)}_{int(e)}"
        motion3d = torch.from_numpy(trace_dumps[vid][j]['motion3d']).float()  # (F, 24, 3)
        motion3d = torch.einsum("...jk,...lk->...jl", motion3d, yup2ydown)
        trace_vidse2motion[vidse] = {"w": motion3d}  # Cam coordinate
print("Number of clips: ", len(trace_vidse2motion))

Number of clips:  236


### Local Metric

In [38]:
from hmr4d.utils.eval.wham.eval_utils import batch_compute_similarity_transform_torch, compute_error_accel
from collections import defaultdict

m2mm = 1000
accumulator = defaultdict(list)

keys_to_evaluate = list(gt_vidse2motion.keys())
for vidse in keys_to_evaluate:
    gt_c = gt_vidse2motion[vidse]["c"]
    pred_c = wham_vidse2motion[vidse]["c"]
    # pred_c = ours_vidse2motion[vidse]["c"]

    # MPJPE
    gt_cr = gt_c - gt_c[:, [1, 2]].mean(-2, keepdim=True)
    pred_cr = pred_c - pred_c[:, [1, 2]].mean(-2, keepdim=True)
    mpjpe = torch.sqrt(((pred_cr - gt_cr) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm

    # PAMPJPE
    S1_hat = batch_compute_similarity_transform_torch(pred_cr, gt_cr)
    pa_mpjpe = torch.sqrt(((S1_hat - gt_cr) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm

    accel = compute_error_accel(joints_pred=pred_cr, joints_gt=gt_cr)[1:-1]
    accel = accel * (30**2)  # per frame^s to per s^2

    # Accumulate
    accumulator["pa_mpjpe"].append(pa_mpjpe)
    accumulator["mpjpe"].append(mpjpe)
    accumulator["accel"].append(accel)

for k, v in accumulator.items():
    accumulator[k] = np.concatenate(v).mean()

log_str = "Local Evaluation on RICH, "
log_str += " ".join([f"{k.upper()}: {v:.4f}," for k, v in accumulator.items()])
print(log_str)

Local Evaluation on RICH, PA_MPJPE: 43.4920, MPJPE: 81.6215, ACCEL: 5.4840,


### Global Metric

In [37]:
from hmr4d.utils.eval.wham.eval_utils import batch_compute_similarity_transform_torch, compute_error_accel
from hmr4d.utils.eval.wham.eval_utils import first_align_joints, global_align_joints, compute_jpe
from collections import defaultdict

m2mm = 1000
accumulator = defaultdict(list)

# keys_to_evaluate = list(gt_vidse2motion.keys())
for vidse in keys_to_evaluate:
    gt_w = gt_vidse2motion[vidse]["w"]
    # pred_w = wham_vidse2motion[vidse]["w"]
    # pred_w = ours_vidse2motion[vidse]["w"]
    pred_w = trace_vidse2motion[vidse]["w"]

    # MPJPE
    gt_wr = gt_w - gt_w[:, [1, 2]].mean(-2, keepdim=True)
    pred_wr = pred_w - pred_w[:, [1, 2]].mean(-2, keepdim=True)
    mpjpe = compute_jpe(gt_wr, pred_wr) * m2mm

    # PA-MPJPE
    S1_hat = batch_compute_similarity_transform_torch(pred_wr, gt_wr)
    pa_mpjpe = compute_jpe(S1_hat, gt_wr) * m2mm

    # WA2-MPJPE and WAA-MPJPE
    pred_w_a2 = first_align_joints(gt_w, pred_w)  # align first 2
    pred_w_aa = global_align_joints(gt_w, pred_w)  # align all
    wa2_jpe = compute_jpe(gt_w, pred_w_a2) * m2mm
    waa_jpe = compute_jpe(gt_w, pred_w_aa) * m2mm

    # Accumulate
    accumulator["mpjpe"].append(mpjpe)
    accumulator["pa-mpjpe"].append(pa_mpjpe)
    accumulator["wa2_mpjpe"].append(wa2_jpe)
    accumulator["waa_mpjpe"].append(waa_jpe)

for k, v in accumulator.items():
    accumulator[k] = np.concatenate(v).mean()

log_str = "Global Evaluation on RICH, "
log_str += " ".join([f"{k.upper()}: {v:.4f}," for k, v in accumulator.items()])
print(log_str)

Global Evaluation on RICH, MPJPE: 343.5977, PA-MPJPE: 77.2557, WA2_MPJPE: 1489.9246, WAA_MPJPE: 254.7058,


### Add WHAM output

In [None]:
from hmr4d.utils.geo_transform import compute_T_ayfz2ay
from hmr4d.utils.eval.wham.eval_utils import first_align_joints, global_align_joints, compute_jpe

prepared_conds = torch.load("inputs/RICH/eval_support/sm_rich_mocap_input.pt")

COMPUTE_METRIC = True
gmpjpe = []
mpjpe = []
wmpjpe = []
wampjpe = []

for vid in vids:
    vid_ = "-".join(vid.split("/"))
    # wis3d = make_wis3d(output_dir=wis3d_dir, name=f"{vid_}")

    seq_length = len(vid2gt[vid])
    chunk_length = 100
    wham_motion_clips = []
    for start in range(0, seq_length, chunk_length):
        end = start + chunk_length
        if seq_length - end < chunk_length:
            end = seq_length
        if end - start < chunk_length:
            break

        # Load motion
        init_motion_ay = prepared_conds["vid_to_pred_j3d_ayfz1"][vid][start:end, :22]  # not AYFZ when start!=0
        T_ay2ayfz = compute_T_ayfz2ay(init_motion_ay[:1], inverse=True)[0]  # (4, 4)
        init_motion_ayfz = apply_T_on_points(init_motion_ay, T_ay2ayfz)  # (F, 22, 3)
        wham_motion_clips.append(init_motion_ayfz)

        # Let's compute metric here
        if COMPUTE_METRIC:
            gt_ayfz_motion = vid2gt[vid][start:end]  # (F, 22, 3)
            # gmpjpe
            error = (init_motion_ayfz - gt_ayfz_motion).pow(2).sum(-1).sqrt()  # (F, 22)
            gmpjpe.append(error.mean())
            # mpjpe
            gt_ = gt_ayfz_motion - gt_ayfz_motion[:, :1]
            pred_ = init_motion_ayfz - init_motion_ayfz[:, :1]
            error = (pred_ - gt_).pow(2).sum(-1).sqrt()  # (F, 22)
            mpjpe.append(error.mean())

            # wmpjpe and wampjpe
            w_j3d = first_align_joints(gt_ayfz_motion, init_motion_ayfz)
            wa_j3d = global_align_joints(gt_ayfz_motion, init_motion_ayfz)
            w_jpe = compute_jpe(gt_ayfz_motion, w_j3d)
            wa_jpe = compute_jpe(gt_ayfz_motion, wa_j3d)
            wmpjpe.append(w_jpe.mean())
            wampjpe.append(wa_jpe.mean())

    # wham_motion = torch.cat(wham_motion_clips, dim=0)  # (F, 22, 3)
    # add_motion_as_lines(wham_motion, wis3d, name="pred_wham")

if COMPUTE_METRIC:
    print(f"gmpjpe: {np.mean(gmpjpe) * 1000}")
    print(f"mpjpe: {np.mean(mpjpe) * 1000}")
    print(f"wmpjpe: {np.mean(wmpjpe) * 1000}")
    print(f"wampjpe: {np.mean(wampjpe) * 1000}")

## Compare

In [None]:
vid2pred = torch.load(pt_dir / "vid2pred_ppp.pt")

In [None]:
from hmr4d.utils.geo_transform import compute_T_ayfz2ay
from hmr4d.utils.eval.wham.eval_utils import first_align_joints, global_align_joints, compute_jpe

prepared_conds = torch.load("inputs/RICH/eval_support/sm_rich_mocap_input.pt")

vid2wham = {}
for vid in vids:
    seq_length = len(vid2gt[vid])
    chunk_length = 100
    wham_motion_clips = []
    for start in range(0, seq_length, chunk_length):
        end = start + chunk_length
        if seq_length - end < chunk_length:
            end = seq_length
        if end - start < chunk_length:
            break

        # Load motion
        init_motion_ay = prepared_conds["vid_to_pred_j3d_ayfz1"][vid][start:end, :22]  # not AYFZ when start!=0
        T_ay2ayfz = compute_T_ayfz2ay(init_motion_ay[:1], inverse=True)[0]  # (4, 4)
        init_motion_ayfz = apply_T_on_points(init_motion_ay, T_ay2ayfz)  # (F, 22, 3)
        wham_motion_clips.append(init_motion_ayfz)
    wham_motion = torch.cat(wham_motion_clips, dim=0)  # (F, 22, 3)
    vid2wham[vid] = wham_motion

In [None]:
def get_metrics(vid2gt, vid2pred):
    gmpjpe = []
    mpjpe = []
    wmpjpe = []
    wampjpe = []
    clip2vid = []

    vids = list(vid2gt.keys())
    for vid in vids:
        seq_length = len(vid2gt[vid])
        chunk_length = 100
        for start in range(0, seq_length, chunk_length):
            end = start + chunk_length
            if seq_length - end < chunk_length:
                end = seq_length
            if end - start < chunk_length:
                break
            clip2vid.append(vid)
            gt = vid2gt[vid][start:end]  # (F, 22, 3)
            pred = vid2pred[vid][start:end]

            # gmpjpe
            error = (pred - gt).pow(2).sum(-1).sqrt()  # (F, 22)
            gmpjpe.append(error.mean())
            # mpjpe
            gt_ = gt - gt[:, :1]
            pred_ = pred - pred[:, :1]
            error = (pred_ - gt_).pow(2).sum(-1).sqrt()  # (F, 22)
            mpjpe.append(error.mean())

            # wmpjpe and wampjpe
            w_j3d = first_align_joints(gt, pred)
            wa_j3d = global_align_joints(gt, pred)
            w_jpe = compute_jpe(gt, w_j3d)
            wa_jpe = compute_jpe(gt, wa_j3d)
            wmpjpe.append(w_jpe.mean())
            wampjpe.append(wa_jpe.mean())

    gmpjpe = np.array(gmpjpe) * 1000
    mpjpe = np.array(mpjpe) * 1000
    wmpjpe = np.array(wmpjpe) * 1000
    wampjpe = np.array(wampjpe) * 1000
    clip2vid = np.array(clip2vid) 
    return gmpjpe, mpjpe, wmpjpe, wampjpe, clip2vid

In [None]:
gmpjpe, mpjpe, wmpjpe, wampjpe, clip2vid = get_metrics(vid2gt, vid2pred)
gmpjpe_pred = np.array([gmpjpe[clip2vid == vid].mean() for vid in vids])

gmpjpe, mpjpe, wmpjpe, wampjpe, clip2vid = get_metrics(vid2gt, vid2wham)
gmpjpe_wham = np.array([gmpjpe[clip2vid == vid].mean() for vid in vids])


# print(f"gmpjpe: {np.mean(gmpjpe) * 1000}")
# print(f"mpjpe: {np.mean(mpjpe) * 1000}")
# print(f"wmpjpe: {np.mean(wmpjpe) * 1000}")
# print(f"wampjpe: {np.mean(wampjpe) * 1000}")

In [None]:
pred_indices = np.argsort(gmpjpe_pred)
print((gmpjpe_wham[pred_indices] - gmpjpe_pred[pred_indices]).astype(np.int32))

# wham_indices = np.argsort(gmpjpe_wham)

In [None]:
pred_indices

In [None]:
good_indices = np.argsort((gmpjpe_pred - gmpjpe_wham))[:10]

In [None]:
wis3d_dir = "outputs/wis3d_ws"  # visualization folder
# Draw results
for i in pred_indices[:30]:
    vid = vids[i]
    vid_ = "-".join(vid.split("/"))
    wis3d = make_wis3d(output_dir=wis3d_dir, name=f"{vid_}")
    add_motion_as_lines(vid2gt[vid], wis3d, name="gt", const_color='green')
    add_motion_as_lines(vid2wham[vid], wis3d, name="wham", const_color='red')
    add_motion_as_lines(vid2pred[vid], wis3d, name="ours", const_color='blue')

In [None]:


gmpjpe = [gmpjpe[clip2vid == vid].mean() for vid in vids]
mpjpe = [mpjpe[clip2vid == vid].mean() for vid in vids]
wmpjpe = [wmpjpe[clip2vid == vid].mean() for vid in vids]
wampjpe = [wampjpe[clip2vid == vid].mean() for vid in vids]


In [None]:
gmpjpe

## Archived Utilities
1. Convert Joints22 to Joints24

In [None]:
from smplx import SMPL
bm_kwargs = {
    "model_path": "inputs/checkpoints/body_models/smpl",
    "gender": "neutral",
    "num_betas": 10,
}
with torch.no_grad():
    model = SMPL(**bm_kwargs)
    joints = model().joints[0, :24]


ratio2321 = (joints[23] - joints[21]).norm(2) / (joints[21] - joints[19]).norm(2)
ratio2220 = (joints[22] - joints[20]).norm(2) / (joints[20] - joints[18]).norm(2)
print(ratio2220, ratio2321)

def convert_joints22_to_24(joints22, ratio2220=0.3438, ratio2321=0.3345):
    L = joints22.size(0)
    joints24 = torch.zeros(L, 24, 3)
    joints24[:, :22] = joints22
    joints24[:, 22] = joints22[:, 20] + ratio2220 * (joints22[:, 20] - joints22[:, 18])
    joints24[:, 23] = joints22[:, 21] + ratio2321 * (joints22[:, 21] - joints22[:, 19])
    return joints24
    

tensor(0.3438) tensor(0.3345)
