In [None]:
import os, sys
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'
from pathlib import Path
import pickle as pkl
import json
from tqdm import tqdm
import numpy as np
from einops import repeat
import pandas as pd
import cv2
from PIL import Image
import lovely_tensors as lt
lt.monkey_patch()
import imageio.v3 as iio
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, default_collate
from timm.models.vision_transformer import Mlp

from hmr4d.dataset.h36m.h36m import H36mSmplDataset #6,196 samples
from hmr4d.network.gvhmr.relative_transformer import NetworkEncoderRoPE
from hmr4d.model.gvhmr.utils.endecoder import EnDecoder
from hmr4d.model.gvhmr.utils import stats_compose
from hmr4d.utils.body_model.smplx_lite import SmplxLiteV437Coco17
from hmr4d.utils.geo.hmr_cam import compute_bbox_info_bedlam, compute_transl_full_cam

device = 'cuda:0'

In [15]:
# Main : model.gvhmr.gvhmr_pl.GvhmrPL
    # Pipeline : model.gvhmr.pipeline.gvhmr_pipeline.Pipeline
        # denoiser3d : network.gvhmr.relative_transformer.NetworkEncoderRoPE
        # endecoder : model.gvhmr.utils.endecoder.EnDecoder
    # Optimizer : adamw_2e-4

denoiser3d = NetworkEncoderRoPE().eval().to(device)
weights = torch.load("inputs/checkpoints/gvhmr/gvhmr_siga24_release.ckpt")
denoiser3d.load_state_dict({k.replace('pipeline.denoiser3d.', ''): v for k, v in weights['state_dict'].items() if 'denoiser3d.' in k})

endecoder = EnDecoder(stats_name="MM_V1_AMASS_LOCAL_BEDLAM_CAM").eval().to(device)

cam_angvel_stats = stats_compose.cam_angvel["manual"]
cam_angvel_mean = torch.tensor(cam_angvel_stats["mean"], device=device)
cam_angvel_std = torch.tensor(cam_angvel_stats["std"], device=device)

params = []
for k, v in denoiser3d.named_parameters():
    if v.requires_grad:
        params.append(v)
optimizer = torch.optim.AdamW(params=params, lr=2e-4)
scehedule_lr = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[200, 350], gamma=0.5)

smplx = SmplxLiteV437Coco17().to(device)

[[36m01/01 20:03:27[0m][[32mINFO[0m] [EnDecoder] Use MM_V1_AMASS_LOCAL_BEDLAM_CAM for statistics![0m


In [3]:
def collate_fn(batch):
    # Assume all keys in the batch are the same
    return_dict = {}
    for k in batch[0].keys():
        if k.startswith("meta"):  # data information, do not batch
            return_dict[k] = [d[k] for d in batch]
        else:
            return_dict[k] = default_collate([d[k] for d in batch])
    return_dict["B"] = len(batch)
    return return_dict

dataset = H36mSmplDataset()
dataloader = DataLoader(
    dataset, shuffle=True, num_workers=8,
    persistent_workers=True, batch_size=32,
    drop_last=True, collate_fn=collate_fn, pin_memory=True
)
np.random.seed(4)

batch = next(iter(dataloader))
for k in batch.keys():
    if k != 'meta':
        if type(batch[k]) == dict:
            for kk in batch[k].keys():
                batch[k][kk] = batch[k][kk].to(device)
        elif type(batch[k]) != int:
            batch[k] = batch[k].to(device)
print(batch['meta'])

[[36m01/01 19:55:46[0m][[32mINFO[0m] [H36M] Loading from inputs/H36M/hmr4d_support/smplxpose_v1.pt ...[0m
[[36m01/01 19:55:46[0m][[32mINFO[0m] [H36M] 600 sequences. Elapsed: 0.59s[0m
[[36m01/01 19:55:46[0m][[32mINFO[0m] [H36M] Fully Loading to RAM ViT-Feat: inputs/H36M/hmr4d_support/vitfeat_h36m.pt[0m
[[36m01/01 19:55:47[0m][[32mINFO[0m] [H36M] Finished. Elapsed: 1.62s[0m
[[36m01/01 19:55:47[0m][[32mINFO[0m] [H36M] has 8.7 hours motion -> Resampled to 6196 samples.[0m


[{'data_name': 'h36m', 'idx': 3803, 'vid': 'S6@Sitting_1@58860488', 'start_end': [370, 490]}, {'data_name': 'h36m', 'idx': 3535, 'vid': 'S5@Directions_1@58860488', 'start_end': [2183, 2303]}, {'data_name': 'h36m', 'idx': 3780, 'vid': 'S6@Walking@58860488', 'start_end': [1591, 1711]}, {'data_name': 'h36m', 'idx': 5336, 'vid': 'S6@Walking@60457274', 'start_end': [1580, 1700]}, {'data_name': 'h36m', 'idx': 4341, 'vid': 'S7@Discussion@58860488', 'start_end': [1813, 1933]}, {'data_name': 'h36m', 'idx': 512, 'vid': 'S5@Posing_1@54138969', 'start_end': [703, 823]}, {'data_name': 'h36m', 'idx': 944, 'vid': 'S7@Photo@54138969', 'start_end': [833, 953]}, {'data_name': 'h36m', 'idx': 4452, 'vid': 'S8@Walking@58860488', 'start_end': [18, 138]}, {'data_name': 'h36m', 'idx': 5610, 'vid': 'S7@Purchases_1@60457274', 'start_end': [12, 132]}, {'data_name': 'h36m', 'idx': 5627, 'vid': 'S7@Sitting_1@60457274', 'start_end': [838, 958]}, {'data_name': 'h36m', 'idx': 856, 'vid': 'S6@Discussion@54138969', 'st

In [4]:
from hmr4d.utils.geo.hmr_cam import perspective_projection, normalize_kp2d, safely_render_x3d_K, get_bbx_xys
from hmr4d.utils.geo.augment_noisy_pose import (
    get_wham_aug_kp3d,
    get_visible_mask,
    get_invisible_legs_mask,
    randomly_modify_hands_legs,
)

In [23]:
B, _ = batch["smpl_params_c"]["body_pose"].shape[:2]

with torch.no_grad():
    gt_verts437, gt_j3d = smplx(**batch["smpl_params_c"])
    root_ = gt_j3d[:, :, [11, 12], :].mean(-2, keepdim=True)
    
    batch["gt_j3d"] = gt_j3d
    batch["gt_cr_coco17"] = gt_j3d - root_
    batch["gt_c_verts437"] = gt_verts437
    batch["gt_cr_verts437"] = gt_verts437 - root_
    
i_x2d = safely_render_x3d_K(gt_verts437, batch["K_fullimg"], thr=0.3)
bbx_xys = get_bbx_xys(i_x2d, do_augment=True)
mask_bbx_xys = batch["mask"]["bbx_xys"]
batch["bbx_xys"][~mask_bbx_xys] = bbx_xys[~mask_bbx_xys]

noisy_j3d = gt_j3d + get_wham_aug_kp3d(gt_j3d.shape[:2])
noisy_j3d = randomly_modify_hands_legs(noisy_j3d)
obs_i_j2d = perspective_projection(noisy_j3d, batch["K_fullimg"])  # (B, L, J, 2)

j2d_visible_mask = get_visible_mask(gt_j3d.shape[:2]).cuda()  # (B, L, J)
j2d_visible_mask[noisy_j3d[..., 2] < 0.3] = False  # Set close-to-image-plane points as invisible
legs_invisible_mask = get_invisible_legs_mask(gt_j3d.shape[:2]).cuda()  # (B, L, J)
j2d_visible_mask[legs_invisible_mask] = False

obs_kp2d = torch.cat([obs_i_j2d, j2d_visible_mask[:, :, :, None].float()], dim=-1)  # (B, L, J, 3)
obs = normalize_kp2d(obs_kp2d, batch["bbx_xys"])  # (B, L, J, 3)
obs[~j2d_visible_mask] = 0  # if not visible, set to (0,0,0)
batch["obs"] = obs
batch["obs"][~batch["mask"]["valid"]] = 0

In [None]:
length = batch["length"]
# *. Conditions
cliff_cam = compute_bbox_info_bedlam(batch["bbx_xys"], batch["K_fullimg"])  # (B, L, 3)
f_cam_angvel = (batch["cam_angvel"] - cam_angvel_mean) / cam_angvel_std
f_condition = {
    "obs": batch["obs"],  # (B, L, 17, 3)
    "f_cliffcam": cliff_cam,  # (B, L, 3)
    "f_cam_angvel": f_cam_angvel,  # (B, L, C=6)
    "f_imgseq": batch["f_imgseq"],  # (B, L, C=1024)
}

model_output = denoiser3d(length=length, **f_condition)  # pred_x, pred_cam, static_conf_logits
decode_dict = endecoder.decode(model_output["pred_x"])

outputs = {
    "model_output": model_output, "decode_dict": decode_dict, 
    "pred_smpl_params_incam" : {
        "body_pose": decode_dict["body_pose"],  # (B, L, 63)
        "betas": decode_dict["betas"],  # (B, L, 10)
        "global_orient": decode_dict["global_orient"],  # (B, L, 3)
        "transl": compute_transl_full_cam(model_output["pred_cam"], batch["bbx_xys"], batch["K_fullimg"]),
    }
}

In [42]:
from hmr4d.utils.net_utils import length_to_mask

with torch.no_grad():
    obs = f_condition["obs"].clone()
    B, L, J, C = obs.shape
    
    visible_mask = obs[..., [2]] > 0.5
    obs[~visible_mask[..., 0]] = 0 

    f_obs = denoiser3d.learned_pos_linear(obs[..., :2])  # (B, L, J, 32)
    f_obs *= visible_mask 
    f_obs += denoiser3d.learned_pos_params.repeat(B, L, 1, 1) * ~visible_mask
    x = denoiser3d.embed_noisyobs(f_obs.view(B, L, -1))  # (B, L, J*32) -> (B, L, C)

    pmask = ~length_to_mask(length, L)
    for block in denoiser3d.blocks:
        x = block(x, attn_mask=None, tgt_key_padding_mask=pmask)


In [45]:
denoiser3d.final_layer

Mlp(
  (fc1): Linear(in_features=512, out_features=512, bias=True)
  (act): GELU(approximate='none')
  (drop1): Dropout(p=0.0, inplace=False)
  (norm): Identity()
  (fc2): Linear(in_features=512, out_features=151, bias=True)
  (drop2): Dropout(p=0.0, inplace=False)
)

In [26]:
# ========== Compute Loss ========== #
total_loss = 0
mask = batch["mask"]["valid"]  # (B, L)

# 1. Simple loss: MSE
pred_x = model_output["pred_x"]  # (B, L, C)
target_x = endecoder.encode(batch)  # (B, L, C)
simple_loss = F.mse_loss(pred_x, target_x, reduction="none")
mask_simple = mask[:, :, None].expand(-1, -1, pred_x.size(2)).clone()  # (B, L, C)
mask_simple[batch["mask"]["spv_incam_only"], :, 142:] = False  # 3dpw training
simple_loss = (simple_loss * mask_simple).mean()
total_loss += simple_loss
print(simple_loss)

# # 2. Extra loss
# extra_funcs = [
#     compute_extra_incam_loss,
#     compute_extra_global_loss,
# ]
# for extra_func in extra_funcs:
#     extra_loss, extra_loss_dict = extra_func(batch, outputs, self)
#     total_loss += extra_loss
#     outputs.update(extra_loss_dict)

# outputs["loss"] = total_loss

tensor grad MeanBackward0 cuda:0 0.100


In [None]:
## Endecoder : gvhmr/v1_amass_local_bedlam_cam

import torch.nn as nn
from hmr4d.model.gvhmr.utils import stats_compose
from hmr4d.utils.body_model.smplx_lite import SmplxLiteV437Coco17


class EnDecoder(nn.Module):
    def __init__(self, noise_pose_k=10):
        super().__init__()
        # Load mean, std
        # tmp = stats_compose.MM_V1_AMASS_LOCAL_BEDLAM_CAM
        elements = [
            [stats_compose.body_pose_r6d, stats_compose.betas, stats_compose.global_orient_c_r6d, 
            stats_compose.global_orient_gv_r6d, stats_compose.local_transl_vel],
            ["amass", "amass", "bedlam", "bedlam", "amass"]
        ]
        mean = [t[s]["mean"] for t, s in zip(elements[0], elements[1])]
        mean = torch.tensor([x for xs in mean for x in xs]).float()
        std = [t[s]["std"] for t, s in zip(elements[0], elements[1])]
        std = torch.tensor([x for xs in std for x in xs]).float()
        self.register_buffer("mean", mean, False)
        self.register_buffer("std", std, False)

        # option
        self.noise_pose_k = noise_pose_k

        # smpl
        self.smplx_model = SmplxLiteV437Coco17()
        parents = self.smplx_model.parents[:22]
        self.register_buffer("parents_tensor", parents, False)
        self.parents = parents.tolist()

    def get_noisyobs(self, data, return_type="r6d"):
        """
        Noisy observation contains local pose with noise
        Args:
            data (dict):
                body_pose: (B, L, J*3) or (B, L, J, 3)
        Returns:
            noisy_bosy_pose: (B, L, J, 6) or (B, L, J, 3) or (B, L, 3, 3) depends on return_type
        """
        body_pose = data["body_pose"]  # (B, L, 63)
        B, L, _ = body_pose.shape
        body_pose = body_pose.reshape(B, L, -1, 3)

        # (B, L, J, C)
        return_mapping = {"R": 0, "r6d": 1, "aa": 2}
        return_id = return_mapping[return_type]
        noisy_bosy_pose = gaussian_augment(body_pose, self.noise_pose_k, to_R=True)[return_id]
        return noisy_bosy_pose

    def normalize_body_pose_r6d(self, body_pose_r6d):
        """body_pose_r6d: (B, L, {J*6}/{J, 6}) ->  (B, L, J*6)"""
        B, L = body_pose_r6d.shape[:2]
        body_pose_r6d = body_pose_r6d.reshape(B, L, -1)
        if self.mean.shape[-1] == 1:  # no mean, std provided
            return body_pose_r6d
        body_pose_r6d = (body_pose_r6d - self.mean[:126]) / self.std[:126]  # (B, L, C)
        return body_pose_r6d

    def fk_v2(self, body_pose, betas, global_orient=None, transl=None, get_intermediate=False):
        """
        Args:
            body_pose: (B, L, 63)
            betas: (B, L, 10)
            global_orient: (B, L, 3)
        Returns:
            joints: (B, L, 22, 3)
        """
        B, L = body_pose.shape[:2]
        if global_orient is None:
            global_orient = torch.zeros((B, L, 3), device=body_pose.device)
        aa = torch.cat([global_orient, body_pose], dim=-1).reshape(B, L, -1, 3)
        rotmat = axis_angle_to_matrix(aa)  # (B, L, 22, 3, 3)

        skeleton = self.smplx_model.get_skeleton(betas)[..., :22, :]  # (B, L, 22, 3)
        local_skeleton = skeleton - skeleton[:, :, self.parents_tensor]
        local_skeleton = torch.cat([skeleton[:, :, :1], local_skeleton[:, :, 1:]], dim=2)

        if transl is not None:
            local_skeleton[..., 0, :] += transl  # B, L, 22, 3

        mat = matrix.get_TRS(rotmat, local_skeleton)  # B, L, 22, 4, 4
        fk_mat = matrix.forward_kinematics(mat, self.parents)  # B, L, 22, 4, 4
        joints = matrix.get_position(fk_mat)  # B, L, 22, 3
        if not get_intermediate:
            return joints
        else:
            return joints, mat, fk_mat

    def get_local_pos(self, betas):
        skeleton = self.smplx_model.get_skeleton(betas)[..., :22, :]  # (B, L, 22, 3)
        local_skeleton = skeleton - skeleton[:, :, self.parents_tensor]
        local_skeleton = torch.cat([skeleton[:, :, :1], local_skeleton[:, :, 1:]], dim=2)
        return local_skeleton

    def encode(self, inputs):
        """
        definition: {
                body_pose_r6d,  # (B, L, (J-1)*6) -> 0:126
                betas, # (B, L, 10) -> 126:136
                global_orient_r6d,  # (B, L, 6) -> 136:142  incam
                global_orient_gv_r6d: # (B, L, 6) -> 142:148  gv
                local_transl_vel,  # (B, L, 3) -> 148:151, smpl-coord
            }
        """
        B, L = inputs["smpl_params_c"]["body_pose"].shape[:2]
        # cam
        smpl_params_c = inputs["smpl_params_c"]
        body_pose = smpl_params_c["body_pose"].reshape(B, L, 21, 3)
        body_pose_r6d = matrix_to_rotation_6d(axis_angle_to_matrix(body_pose)).flatten(-2)
        betas = smpl_params_c["betas"]
        global_orient_R = axis_angle_to_matrix(smpl_params_c["global_orient"])
        global_orient_r6d = matrix_to_rotation_6d(global_orient_R)

        # global
        R_c2gv = inputs["R_c2gv"]  # (B, L, 3, 3)
        global_orient_gv_r6d = matrix_to_rotation_6d(R_c2gv @ global_orient_R)

        # local_transl_vel
        smpl_params_w = inputs["smpl_params_w"]
        local_transl_vel = get_local_transl_vel(smpl_params_w["transl"], smpl_params_w["global_orient"])
        # returns
        x = torch.cat([body_pose_r6d, betas, global_orient_r6d, global_orient_gv_r6d, local_transl_vel], dim=-1)
        x_norm = (x - self.mean) / self.std
        return x_norm

    def encode_translw(self, inputs):
        """
        definition: {
                body_pose_r6d,  # (B, L, (J-1)*6) -> 0:126
                betas, # (B, L, 10) -> 126:136
                global_orient_r6d,  # (B, L, 6) -> 136:142  incam
                global_orient_gv_r6d: # (B, L, 6) -> 142:148  gv
                local_transl_vel,  # (B, L, 3) -> 148:151, smpl-coord
            }
        """
        # local_transl_vel
        smpl_params_w = inputs["smpl_params_w"]
        local_transl_vel = get_local_transl_vel(smpl_params_w["transl"], smpl_params_w["global_orient"])

        # returns
        x = local_transl_vel
        x_norm = (x - self.mean[-3:]) / self.std[-3:]
        return x_norm

    def decode_translw(self, x_norm):
        return x_norm * self.std[-3:] + self.mean[-3:]

    def decode(self, x_norm):
        """x_norm: (B, L, C)"""
        B, L, C = x_norm.shape
        x = (x_norm * self.std) + self.mean

        body_pose_r6d = x[:, :, :126]
        betas = x[:, :, 126:136]
        global_orient_r6d = x[:, :, 136:142]
        global_orient_gv_r6d = x[:, :, 142:148]
        local_transl_vel = x[:, :, 148:151]

        body_pose = matrix_to_axis_angle(rotation_6d_to_matrix(body_pose_r6d.reshape(B, L, -1, 6)))
        body_pose = body_pose.flatten(-2)
        global_orient_c = matrix_to_axis_angle(rotation_6d_to_matrix(global_orient_r6d))
        global_orient_gv = matrix_to_axis_angle(rotation_6d_to_matrix(global_orient_gv_r6d))

        output = {
            "body_pose": body_pose,
            "betas": betas,
            "global_orient": global_orient_c,
            "global_orient_gv": global_orient_gv,
            "local_transl_vel": local_transl_vel,
        }

        return output

In [None]:
smpl = BodyModelSMPLH(
    model_path="inputs/checkpoints/body_models", model_type="smpl",
    gender="neutral", num_betas=10, create_body_pose=False, 
    create_betas=False, create_global_orient=False, create_transl=False,
).to(device)
smplx = BodyModelSMPLX(
    model_path="inputs/checkpoints/body_models", model_type="smplx",
    gender="neutral", num_pca_comps=12, flat_hand_mean=False,
).to(device)
smplx2smpl = torch.load("hmr4d/utils/body_model/smplx2smpl_sparse.pt").to(device)
faces_smpl = torch.from_numpy((smpl.faces).astype("int")).unsqueeze(0).to(device)
faces_smplx = torch.from_numpy((smplx.faces).astype("int")).unsqueeze(0).to(device)
J_regressor = torch.load("hmr4d/utils/body_model/smpl_neutral_J_regressor.pt").to(device)