In [1]:
%load_ext autoreload
%autoreload 2

import os
import matplotlib.pyplot as plt

# Change working directory to project root
os.chdir('../../')

import torch
import cv2
import numpy as np
from pathlib import Path
from tqdm import tqdm
import json
import imageio
import decord
import joblib
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines
from hmr4d.dataset.rich.rich_utils import get_w2az_sahmr, parse_seqname_info
from hmr4d.utils.geo_transform import apply_T_on_points
from hmr4d.utils.vis.vis_kpts import draw_kpts_cv2

decord.bridge.set_bridge("torch")

## Usage
这个文件用于得到SuperMotion的Mocap测试案例的初始化

In [2]:
from hmr4d.utils.smplx_utils import make_smplx
from hmr4d.dataset.rich.rich_utils import get_cam2params, get_cam_key_wham_vid

# dataset
labels = joblib.load("inputs/RICH/eval_support/rich_test_vit.pth")

# 3D initialization (WHAM prediction)
wham_outputs = torch.load(Path("inputs/RICH/eval_support/wham_output.pt"))
smpl_model = make_smplx("smpl", gender="neutral").cuda().eval()
smpl_J_regressor = torch.load("hmr4d/utils/body_model/smpl_neutral_J_regressor.pt")

# 2D initialization (HybrIK prediction)
hybrik_outpus = torch.load("inputs/RICH/eval_support/hybrik_pred_kpts.pth")
cam2params = get_cam2params()

In [3]:
from hmr4d.utils.geo_transform import transform_mat, compute_T_ayf2az
from hmr4d.utils.geo_transform import ransac_PnP_batch, transform_mat, homo_points

outputs = {
    "vid_to_c_p2d_cv": {},
    "vid_to_T_ayfz2c": {},
    "vid_to_pred_j3d_ayfz": {},  # every frame is ayfz
    "vid_to_pred_j3d_ayfz1": {},  # only first frame is ayfz
}

for index in tqdm(range(len(labels["vid"]))):
    vid = labels["vid"][index]
    bbox_xys = labels["bbox"][index][1:]  # (F, 3)
    cam_key = get_cam_key_wham_vid(vid)

    # <======= Load Psrediction
    wham_pred = wham_outputs[vid]
    params = {
        "body_pose": wham_pred["poses_body"].reshape(-1, 23, 3, 3),
        "global_orient": wham_pred["poses_root_world"].reshape(-1, 1, 3, 3),
        "betas": wham_pred["betas"].reshape(-1, 10),
        "transl": wham_pred["trans_world"].reshape(-1, 3),
    }
    params = {k: v.cuda() for k, v in params.items()}
    pred_smpl_out = smpl_model.forward(**params, pose2rot=False)
    pred_smpl_verts = pred_smpl_out.vertices.cpu()  # (F, 6890, 3)
    pred_j3d_glob = smpl_J_regressor @ pred_smpl_verts  # (F, 24, 3)

    pred_01_kpts2d = torch.from_numpy(hybrik_outpus[vid])  # (F, 24, 3)
    _, K = cam2params[cam_key]
    pred_i_kpts = (pred_01_kpts2d[:, :, :2] - 0.5) * bbox_xys[:, None, [-1]] + bbox_xys[:, None, :2]

    # <======= Get pred_j3d_ayfz, T_ayfz2c, c_p2d_cv
    # ydown(WHAM) -> zup -> az (align-ground)
    R_ydown2zup = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32)
    lowest_point = (pred_smpl_verts @ R_ydown2zup.T)[:, :, 2].min(dim=-1)[0]  # TODO: maybe use the first frames?
    lowest_point = torch.mean(lowest_point)
    t_zup2az = torch.tensor([0, 0, -lowest_point], dtype=torch.float32)

    T_ydown2az = transform_mat(R_ydown2zup, t_zup2az)
    pred_j3d_az = apply_T_on_points(pred_j3d_glob, T_ydown2az)  # (F, 24, 3)

    # TODO: this is very stupid, since every interval frames will have different T_az2ayfz
    # Wrong, but keep here:
    # T_az2ayfz = compute_T_ayf2az(pred_j3d_az[None, 0], inverse=True)[0]  # (4, 4)
    # pred_j3d_ayfz = apply_T_on_points(pred_j3d_az, T_az2ayfz)  # (F, 24, 3)
    # fit_R, fit_t = ransac_PnP_batch(K.numpy()[None], pred_i_kpts.numpy(), pred_j3d_ayfz.numpy(), err_thr=10)
    # T_ayfz2c = transform_mat(torch.FloatTensor(fit_R), torch.FloatTensor(fit_t))

    # j3d_ayfz is every frame ayfz V.S. j3d_ayfz1 is only first frame ayfz
    T_az2ayfz = compute_T_ayf2az(pred_j3d_az, inverse=True)  # (F, 4, 4)
    pred_j3d_ayfz1 = apply_T_on_points(pred_j3d_az, T_az2ayfz[0])  # (F, 24, 3), every frame is an ayfz
    pred_j3d_ayfz = apply_T_on_points(pred_j3d_az, T_az2ayfz)  # (F, 24, 3), every frame is an ayfz
    Ks = K[None].repeat(pred_j3d_ayfz.shape[0], 1, 1)
    fit_R, fit_t = ransac_PnP_batch(Ks.numpy(), pred_i_kpts.numpy(), pred_j3d_ayfz.numpy(), err_thr=10)
    T_ayfz2c = transform_mat(torch.FloatTensor(fit_R), torch.FloatTensor(fit_t))  # (F, 4, 4)

    # Get c_p2d_cv
    c_p2d_cv = (pred_i_kpts - K[:2, 2]) / torch.tensor([K[0, 0], K[1, 1]])

    # Add to dictionary
    outputs["vid_to_c_p2d_cv"][vid] = c_p2d_cv.cpu()
    outputs["vid_to_T_ayfz2c"][vid] = T_ayfz2c.cpu()
    outputs["vid_to_pred_j3d_ayfz"][vid] = pred_j3d_ayfz.cpu()
    outputs["vid_to_pred_j3d_ayfz1"][vid] = pred_j3d_ayfz1.cpu()

  0%|          | 0/191 [00:00<?, ?it/s]

100%|██████████| 191/191 [00:07<00:00, 24.30it/s]


In [5]:
# Save pred_j3d_ayfz, T_ayfz2c, c_p2d_cv
torch.save(outputs, "/home/shenzehong/Code/HMR-4D/inputs/RICH/eval_support/sm_rich_mocap_input.pt")

In [None]:
wis3d = make_wis3d(name="mocap-setup")
# wis3d.add_point_cloud(pred_[0], name="zup")
# wis3d.add_point_cloud(pred_j3d_az[0], name='az')
add_motion_as_lines(pred_j3d_az, wis3d, name="az")
add_motion_as_lines(pred_j3d_ayfz, wis3d, name="ayfz")

## Prediction

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from hydra import initialize_config_module, compose
with initialize_config_module(version_base="1.3", config_module=f"hmr4d.configs"):
    overrides = ['exp=motion3d_prior/baseline/net_bertlike', "global/task=supermotion/test_mocap_rich_prior3d", 
                 f'ckpt_path=inputs/checkpoints/trained/m3dp_bertlike_e099-s105300.ckpt']
    cfg = compose(config_name='train', overrides=overrides)

from hydra.utils import instantiate
from hmr4d.utils.net_utils import load_pretrained_model

model = instantiate(cfg.model, _recursive_=False)
load_pretrained_model(model, cfg.ckpt_path, cfg.ckpt_type)
model = model.eval().cuda()


## Dataset to align with WHAM
idx2meta = []
for index, vid in enumerate(labels["vid"]):
    chunk_length = 100
    seq_length = len(labels["frame_id"][index][1:])
    for start in range(0, seq_length - chunk_length, chunk_length):
        end = start + chunk_length
        if start + 2 * chunk_length > seq_length:  # last one
            end = seq_length  # [start, end)
        idx2meta.append((index, vid, start, end))

In [None]:
idx = 1
index, vid, start, end = idx2meta[idx]

c_p2d_cv = outputs["vid_to_c_p2d_cv"][vid][start:end]
T_ayfz2c = outputs["vid_to_T_ayfz2c"][vid]
pred_j3d_ayfz = outputs["vid_to_pred_j3d_ayfz"][vid][start:end]

B = 1
batch = {
    "B": B,
    "gt_ayfz_motion": "",
    "length": torch.tensor([100]),
    "text": "",
    "img_seq": torch.zeros((B, 15, 3, 224, 224)),  
    "img_seq_fid": torch.ones(B, 15).long() * -1,
    "task": "CAP",
    "T_ayfz2c": T_ayfz2c,
    "c_p2d_cv": c_p2d_cv[None],
}

In [None]:
batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in batch.items() }
outputs = model.validation_step(batch, 0)
print(outputs.keys())