In [1]:
%load_ext autoreload
%autoreload 2

import os
import matplotlib.pyplot as plt

# Change working directory to project root
os.chdir('../../')

import torch
import cv2
import numpy as np
from pathlib import Path
from tqdm import tqdm
import json
import imageio
import decord
import joblib
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
from hmr4d.dataset.rich.rich_utils import get_w2az_sahmr, parse_seqname_info
from hmr4d.utils.geo_transform import apply_T_on_points
from hmr4d.utils.vis.vis_kpts import draw_kpts_cv2

decord.bridge.set_bridge("torch")

In [2]:
# Paths
# rich_test_vit_path = "/home/shenzehong/Code/WHAM/dataset/parsed_data/rich_test_vit.pth"
rich_test_vit_path = "inputs/RICH/eval_support/rich_test_vit.pth"
rich_video_dir = Path("inputs/RICH/hmr4d_support/video")
wham_rich_output = Path("inputs/RICH/eval_support/wham_output.pt")

# Load
labels = joblib.load(rich_test_vit_path)
print(labels.keys())
N_seq = len(labels["vid"])
print("Num sequences in RICH:", N_seq)
wham_outputs = torch.load(wham_rich_output)
print("WHAM output keys in one item:", wham_outputs["test/Gym_010_cooking1/cam_01"].keys())

dict_keys(['bbox', 'res', 'vid', 'pose', 'betas', 'transl', 'kp2d', 'joints3D', 'frame_id', 'cam_poses', 'features', 'gender', 'init_kp3d', 'init_pose'])
Num sequences in RICH: 191
WHAM output keys in one item: dict_keys(['poses_body', 'betas', 'poses_root_world', 'trans_world', 'weak_joints2d'])


### Eval global-metric @ RICH

In [3]:
from hmr4d.utils.smplx_utils import make_smplx

smplx_models = {
    "male": make_smplx("rich-smplx", gender="male").cuda().eval(),
    "female": make_smplx("rich-smplx", gender="female").cuda().eval(),
}
smpl_models = {"neutral": make_smplx("smpl", gender="neutral").cuda().eval()}

smplx2smpl = torch.load("hmr4d/utils/body_model/smplx2smpl_sparse.pt").to_dense()
smpl_J_regressor = torch.load("hmr4d/utils/body_model/smpl_neutral_J_regressor.pt")

In [4]:
from pytorch3d.transforms import matrix_to_axis_angle, axis_angle_to_matrix
pred_j3d_globs, target_j3d_globs = [], []
for index in tqdm(range(N_seq)):
    vid = labels["vid"][index]

    # # <======= Build groundtruth SMPL (from SMPLX)    =======> We can use Joints3D
    # pose = labels["pose"][index][1:]  # (F, 72)
    # betas = labels["betas"][index][1:]  # (F, 10)
    # gender = labels["gender"][index]  # str
    # transl = labels["transl"][index][1:]  # (F, 3), global
    # cam_poses = labels["cam_poses"][index][1:]  # (F, 3, 4)
    # # joints3d = labels["joints3D"][index][1:]  # (F, 24, 3)

    # # forward
    # params = {
    #     "body_pose": pose[:, 3:-6].reshape(-1, 63),
    #     "global_orient": matrix_to_axis_angle(cam_poses[:, :3, :3].transpose(2, 1) @ axis_angle_to_matrix(pose[:, :3])),
    #     "betas": betas,
    #     "transl": transl,
    # }
    # params = {k: v.cuda() for k, v in params.items()}
    # smplx_out = smplx_models[gender](**params)
    # smpl_verts = torch.matmul(smplx2smpl, smplx_out.vertices.cpu())
    # target_j3d_glob = smpl_J_regressor @ smpl_verts  # (F, 24, 3)
    # target_j3d_globs.append(target_j3d_glob)

    # <======= Prediction
    wham_pred = wham_outputs[vid]
    # print(wham_pred.keys())

    params = {
        "body_pose": wham_pred['poses_body'].reshape(-1, 23, 3, 3),
        "global_orient": wham_pred['poses_root_world'].reshape(-1, 1, 3, 3),
        "betas": wham_pred['betas'].reshape(-1, 10),
        "transl": wham_pred['trans_world'].reshape(-1, 3),
    }
    params = {k: v.cuda() for k, v in params.items()}
    pred_smpl_out = smpl_models['neutral'].forward(**params, pose2rot=False)
    pred_j3d_glob = smpl_J_regressor @ pred_smpl_out.vertices.cpu()  # (F, 24, 3)
    pred_j3d_globs.append(pred_j3d_glob)

  0%|          | 0/191 [00:00<?, ?it/s]

100%|██████████| 191/191 [00:06<00:00, 28.57it/s]


In [10]:
from hmr4d.utils.eval.wham.eval_utils import first_align_joints, global_align_joints, compute_jpe

m2mm = 1000
chunk_length = 100

pa_mpjpes, w_mpjpes, wa_mpjpes = [], [], []


# for pred_j3d_glob, target_j3d_glob in zip(pred_j3d_globs, target_j3d_globs):
for pred_j3d_glob, target_j3d_glob in zip(pred_j3d_globs, labels["joints3D"]):
    w_mpjpe, wa_mpjpe = [], []
    seq_length = len(pred_j3d_glob)
    target_j3d_glob = target_j3d_glob[1:]
    for start in range(0, seq_length - chunk_length, chunk_length):
        end = start + chunk_length
        if start + 2 * chunk_length > seq_length:  # last one
            end = seq_length

        target_j3d = target_j3d_glob[start:end].clone().cpu()
        pred_j3d = pred_j3d_glob[start:end].clone().cpu()

        w_j3d = first_align_joints(target_j3d, pred_j3d)
        wa_j3d = global_align_joints(target_j3d, pred_j3d)

        w_jpe = compute_jpe(target_j3d, w_j3d)
        wa_jpe = compute_jpe(target_j3d, wa_j3d)
        w_mpjpe.append(w_jpe)
        wa_mpjpe.append(wa_jpe)

    w_mpjpe = np.concatenate(w_mpjpe) * m2mm
    wa_mpjpe = np.concatenate(wa_mpjpe) * m2mm
    w_mpjpes.append(w_mpjpe.mean())
    wa_mpjpes.append(wa_mpjpe.mean())

print(np.mean(w_mpjpe))
print(np.mean(wa_mpjpe))

221.78987
116.9217


In [7]:
# Visualize
wis3d = make_wis3d(name="wham_global_metric")
add_motion_as_lines(labels['joints3D'][0], wis3d, "gt_c1")
add_motion_as_lines(pred_j3d_globs[0], wis3d, "pred_w")

### 可视化WHAM的输出结果

- joints2d is not perfectly aligned with image
- the first frame of joints2d should be removed (use the second frame to replace it)

#### Joints2D

In [3]:
wham_outputs = torch.load("wham_rich.pt")  # latest version

for index in range(191):
    vid = labels["vid"][index]
    if "ParkingLot2_017_burpeejump1" in vid:
        print(index)

61
77
93


In [7]:
index = 77
vid = labels["vid"][index]
print(vid)
wham_pred = wham_outputs[vid]
print(wham_pred.keys())

# From dataset
# cam_poses = labels["cam_poses"][index][1:]  # (F, 3, 4)
# joints3d = labels["joints3D"][index][1:]  # (F, 24, 3)

frame_id = labels["frame_id"][index][1:]  # (F, )
vr = decord.VideoReader(str(rich_video_dir / vid / "video.mp4"))
frames = vr.get_batch(list(frame_id.numpy()))  # (F, 752, 1024, 3)

test/ParkingLot2_017_burpeejump1/cam_01
dict_keys(['cam_motion3d', 'i_joints2d'])


In [12]:
joints2d = wham_pred["i_joints2d"].numpy()
joints2d = joints2d / 4
joints2d[:1] = joints2d[[1]]  # change 1


decord.bridge.set_bridge("torch")
video_path = rich_video_dir / labels["vid"][index] / "video.mp4"
vr = decord.VideoReader(str(video_path))
frames = vr.get_batch(list(frame_id.numpy()))  # (F, 752, 1024, 3)

imgs = []
for i in range(len(frames)):
    img = frames[i].numpy().copy()
    img = draw_kpts_cv2(img, joints2d[i])
    imgs.append(img)
imageio.mimsave("tmp_weakjoints2d.mp4", imgs, fps=30, quality=6)

#### Joints3D

In [13]:
joints3d = wham_pred["cam_motion3d"]  # (F, 24, 3)

from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
wis3d = make_wis3d(name="wham_cam_motion3d")
add_motion_as_lines(joints3d, wis3d, name="pred_cam_motion3d" )

### Check groundtruth from WHAM-provided-pth (YES)

In [None]:
# for index in range(N_seq):
#     vid = labels["vid"][index]
#     video_path = rich_video_dir / labels["vid"][index] / "video.mp4"

#     meta_txt = rich_video_dir / vid / "meta.txt"  # 表示images_ds4里，第一个和最后一个jpg的文件名，比如[2, 255]一共有254帧，一个seq的不同cam可能不一样
#     sfid, efid = [int(x) for x in meta_txt.open("r").readline().split()]

#     if sfid != 1:
#         print(vid)
#         print(index)
#         break
print(rich_video_dir / labels["vid"][37])  #  jpg: 1, 255
print(labels["frame_id"][37][[0, 1, -1]])  # tensor([  4,   4, 203])
print(len(labels["joints3D"][37]))  # 201， 相机系下的
print(rich_video_dir / labels["vid"][6])  # jpg: 2, 255，jpg就少一张，但是一个jpg-id对应的local-pose肯定是一样的
print(
    labels["frame_id"][6][[0, 1, -1]]
)  # tensor([  3,   3,  202])  # 因为labels["pose"]的local部分是一样的，说明这个frame-id是不是真的frame-id，需要加offset
print(len(labels["joints3D"][6]))  # 201, 相机系下的

# labels["pose"]的local部分是一样的
print((labels["pose"][6][:, 3:] != labels["pose"][37][:, 3:]).sum())
# global_orient是到相机系下的，transl可能是原本的世界系下的，因为这里两个是一样的
print((labels["transl"][6][1:] != labels["transl"][37][1:]).sum())

# # joints3D是cam-coord下的
# wis3d = make_wis3d(name="rich_test")
# add_motion_as_lines(labels['joints3D'][6], wis3d, name='index6', skeleton_type='smpl22')
# add_motion_as_lines(labels['joints3D'][37], wis3d, name='index37', skeleton_type='smpl22')

inputs/RICH/hmr4d_support/video/test/Gym_012_lunge1/cam_04
tensor([  4,   4, 203])
201
inputs/RICH/hmr4d_support/video/test/Gym_012_lunge1/cam_05
tensor([  3,   3, 202])
201
tensor(0)
tensor(0)


In [13]:
from hmr4d.dataset.rich.rich_utils import get_cam2params, get_cam_key_wham_vid

cam2params = get_cam2params()

In [14]:
cam_key = get_cam_key_wham_vid(vid)
T_w2c, K = cam2params[cam_key]

frame_id = labels["frame_id"][index][1:]  # (F, )

decord.bridge.set_bridge("torch")
vr = decord.VideoReader(str(video_path))
frames = vr.get_batch(list(frame_id.numpy()))  # (F, 752, 1024, 3)

In [16]:
from hmr4d.utils.smplx_utils import make_smplx
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle


print(index)
pose = labels["pose"][index][1:]  # (F, 72)
betas = labels["betas"][index][1:]  # (F, 10)
gender = labels["gender"][index]  # str

transl = labels["transl"][index][1:]  # (F, 3), global
cam_poses = labels["cam_poses"][index][1:]  # (F, 3, 4)
joints3d = labels["joints3D"][index][1:]  # (F, 24, 3)

body_pose = pose[:, 3:-6].reshape(-1, 63)
global_orient = matrix_to_axis_angle(cam_poses[:, :3, :3].transpose(2, 1) @ axis_angle_to_matrix(pose[:, :3]))


# <======= Build groundtruth SMPL (from SMPLX)
smplx2smpl = torch.load("hmr4d/utils/body_model/smplx2smpl_sparse.pt").to_dense()
smpl_J_regressor = torch.load("hmr4d/utils/body_model/smpl_neutral_J_regressor.pt")

smplx_models = {
    "male": make_smplx("rich-smplx", gender="male"),
    "female": make_smplx("rich-smplx", gender="female"),
}
smplx_out = smplx_models[gender](
    body_pose=pose[:, 3:-6].reshape(-1, 63),
    global_orient=global_orient,
    transl=transl,
    betas=betas,
)
smpl_verts = torch.matmul(smplx2smpl, smplx_out.vertices)
smpl_verts_c = apply_T_on_points(smpl_verts, T_w2c)
smpl_joints_c = smpl_J_regressor @ smpl_verts_c  # (F, 24, 3)
(smpl_joints_c - joints3d).abs().max()  # 2.8e-6

# from hmr4d.utils.geo_transform import project_p2d
# from hmr4d.utils.vis.vis_kpts import draw_kpts_cv2
# p2d = project_p2d(smpl_joints_c, K, True) / 4  # we use ds4 image as visualization

# imgs = []
# for i in range(len(frames)):
#     img = frames[i].numpy().copy()
#     img = draw_kpts_cv2(img, p2d[i])
#     imgs.append(img)
# imageio.mimsave("tmp2.mp4", imgs, fps=30, quality=6)

6


In [31]:
wis3d = make_wis3d(name="debug_wham_rich_gt")

for i in range(len(smpl_verts)):
    wis3d.set_scene_id(i)
    wis3d.add_point_cloud(smpl_verts[i], name="smpl_verts")

In [8]:
labels["vid"][index]

'test/Gym_012_lunge1/cam_05'

#### Check GT from RICH dataset
可以只用WHAM提供的pth，就能拿到我想要的结果！

In [4]:
RICH_test_body_dir = Path("/mnt/data/Datasets/RICH/bodies/test_body")
seqname = vid.split("/")[1]
sub_id = seqname.split("_")[1]
pkl_paths = list(sorted((RICH_test_body_dir / seqname).glob(f"*/{sub_id}.pkl")))
m_sid = int(pkl_paths[0].parent.name)  # 5
m_eid = int(pkl_paths[-1].parent.name)  # 204

rich_video_dir = Path("inputs/RICH/hmr4d_support/video")
meta_txt = rich_video_dir / vid / "meta.txt"  # 表示images_ds4里，第一个和最后一个jpg的文件名，比如[2, 255]一共有254帧，一个seq的不同cam可能不一样
v_sfid, v_efid = [int(x) for x in meta_txt.open("r").readline().split()]

# NOTE: Understand the frame-id in WHAM, the Motion, the Images
print(f"{vid}: MOTION start/end-id={m_sid}/{m_eid} @ VIDEO(JPEG) start/end-id={v_sfid}/{v_efid}")
frame_id = labels["frame_id"][index][1:]
print(f"{vid}: FRAME_ID start/end = {frame_id[0]}/{frame_id[-1]}")

test/Gym_012_lunge1/cam_05: MOTION start/end-id=5/204 @ VIDEO(JPEG) start/end-id=2/225
test/Gym_012_lunge1/cam_05: FRAME_ID start/end = 3/202


In [8]:
data = joblib.load(pkl_paths[0])
print("-------: the transl in labels is the same as that in original pkl")
print(data["transl"])
print(labels["transl"][index][1])
print("")

# Check pose
print("-------: the local pose in labels is the same as that in original pkl")
print("differences:", (labels["pose"][index][1][3:-6] - torch.from_numpy(data["body_pose"][0])).abs().sum())
print("")

# Global orient
print("-------: the global pose in labels @ cam-pose-R is the same as that in original pkl")
print(data["global_orient"])
# print(labels["pose"][index][1][:3])
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle

cam_pose = labels["cam_poses"][index][1][:3, :3]
gpose_label = axis_angle_to_matrix(labels["pose"][index][1][:3]).squeeze(0).numpy()
print(matrix_to_axis_angle(cam_pose.T @ gpose_label))
print("difference:", matrix_to_axis_angle(cam_pose.T @ gpose_label) - torch.from_numpy(data["global_orient"]))

-------: the transl in labels is the same as that in original pkl
[[ 0.93521154 -0.01478121  4.3353    ]]
tensor([ 0.9352, -0.0148,  4.3353])

-------: the local pose in labels is the same as that in original pkl
differences: tensor(0.)

-------: the global pose in labels @ cam-pose-R is the same as that in original pkl
[[-1.5348296  -0.30412704  2.5799594 ]]
tensor([-1.5348, -0.3041,  2.5800])
difference: tensor([[ 0.0000e+00,  2.3842e-07, -2.3842e-07]])
