In [1]:
%load_ext autoreload
%autoreload 2

import os
import matplotlib.pyplot as plt

# Change working directory to project root
os.chdir('../../')

import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
import json
import imageio
from hmr4d.utils.wis3d_utils import make_wis3d, add_motion_as_lines, get_const_colors
from hmr4d.dataset.rich.rich_utils import get_w2az_sahmr, parse_seqname_info
from hmr4d.utils.geo_transform import apply_T_on_points

In [2]:
info = torch.load("info.pt")
print(info.keys())
wis3d_dir = "outputs/wis3d_triag"
mid_logged = {}

dict_keys(['pred_ayfz_motion', 'w_cam_p2d_ray', 'T_ayfz2c', 'obs_c_p2d', 'length', 'meta', 'gt_ayfz_motion'])


In [3]:
# Let's draw the same thing as we do in the callback to see the initialization
B = len(info["meta"])
meta = info["meta"]
gt_ayfz_motion = info["gt_ayfz_motion"]
pred_ayfz_motion = info["pred_ayfz_motion"][:, 0]

for b in range(B):
    mid = meta[b][0].replace("/", "-")
    if mid not in mid_logged:
        mid_logged[mid] = make_wis3d(output_dir=wis3d_dir, name=mid)
    wis3d = mid_logged[mid]

    start, end = meta[b][1]
    l = end - start
    pred = pred_ayfz_motion[b, :l]
    gt = gt_ayfz_motion[b, :l]

    add_motion_as_lines(gt, wis3d, name="gt_ayfz_motion", const_color="green", offset=start)
    add_motion_as_lines(pred, wis3d, name="pred", offset=start)

    se_lines = info["w_cam_p2d_ray"][b, :l]
    color = get_const_colors("orange", (se_lines.size(-2),))[:, :3] * 255
    for f in range(l):
        wis3d.set_scene_id(f + start)
        wis3d.add_lines(se_lines[f, 0], se_lines[f, 1], color, name="cam-p2d-ray")

In [20]:
from einops import einsum, rearrange


def R_from_wy(w):
    """Compute R from w, which is the rotation of y axis"""
    zeros = torch.zeros_like(w)
    ones = torch.ones_like(w)
    R = torch.stack([torch.cos(w), zeros, torch.sin(w), zeros, ones, zeros, -torch.sin(w), zeros, torch.cos(w)], -1)
    R = R.reshape(*R.shape[:-1], 3, 3)
    return R


def update_pred_with_wst(pred, w, s, t):
    # Rotation matrix along y axis. because we trust our 3D prior for gravity direction
    R = R_from_wy(w)  # (B, L, 3, 3)

    # Use local coordinate, and we do not touch y-value. because we trust our 3D prior
    offset_x0z = pred[..., [0], :].clone()  # (B, L, 1, 3)
    offset_x0z[..., 1] = 0

    # pred_update (B, L, J, 3)
    pred_update = einsum(pred - offset_x0z, R, "b l n c, b l d c -> b l n d")  # rotate
    pred_update = pred_update * s[:, :, None, None]  # scale
    pred_update = pred_update + t + offset_x0z  # translate
    return pred_update


def find_wst_transform(pred, T_ayfz2c, obs_c_p2d, length, lr=0.01, max_iter=100):
    """(B, L, J, 3), (B, 1/L, 4, 4), (B, L, J, 2), (B)"""
    B, max_L, J = pred.shape[:3]
    assert len(pred.shape) == len(T_ayfz2c.shape) == len(obs_c_p2d.shape) == 4

    # parameters to optimize
    w = torch.zeros(B, max_L).requires_grad_()
    s = torch.ones(B, max_L).requires_grad_()
    tx = torch.zeros(B, max_L).requires_grad_()
    tz = torch.zeros(B, max_L).requires_grad_()
    optimizer = torch.optim.Adam([w, s, tx, tz], lr=lr)
    length_mask = torch.arange(max_L, device=pred.device) < length.reshape(B, 1)

    for i in range(max_iter):
        t = torch.stack([tx, torch.zeros_like(tx), tz], -1)[:, :, None, :]  # (B, L, 1, 3)
        pred_update = update_pred_with_wst(pred, w, s, t)  # (B, L, J, 3)

        # project to image plane
        c_pred_update = apply_T_on_points(pred_update, T_ayfz2c)  # (B, L, J, 3)
        c_pred2d_update = c_pred_update[..., :2] / c_pred_update[..., 2:]  # (B, L, J, 2)

        error = (c_pred2d_update - obs_c_p2d).pow(2).sum((-1, -2))  # (B, L)
        loss = error[length_mask].mean()

        # Add some regularization
        # s should be close to 1
        loss_s = (s - 1).pow(2)[length_mask].mean()

        # w and t should not change too much
        loss_w = (w[:, 1:] - w[:, :-1]).pow(2)[length_mask[:, 1:]].mean()
        loss_tx = (tx[:, 1:] - tx[:, :-1]).pow(2)[length_mask[:, 1:]].mean()
        loss_tz = (tz[:, 1:] - tz[:, :-1]).pow(2)[length_mask[:, 1:]].mean()

        loss += loss_s + loss_w + loss_tx + loss_tz

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # if i % 10 == 0:
        #     print(f"{i} loss: {loss.item()}")

    t = torch.stack([tx, torch.zeros_like(tx), tz], -1)[:, :, None, :]  # (B, L, 1, 3)
    return w.detach(), s.detach(), t.detach()


pred_ayfz_motion = info["pred_ayfz_motion"][:, 0]  # (B, L, J, 3)
T_ayfz2c = info["T_ayfz2c"][:, None]  # (B, 1, 4, 4)
obs_c_p2d = info["obs_c_p2d"]  # (B, L, J, 2)
length = info["length"]  # (B)
w, s, t = find_wst_transform(pred_ayfz_motion, T_ayfz2c, obs_c_p2d, length)

pred_ayfz_motion_update = update_pred_with_wst(pred_ayfz_motion, w, s, t)

0 loss: 0.13491854071617126
10 loss: 0.07375528663396835
20 loss: 0.048857640475034714
30 loss: 0.03808439150452614
40 loss: 0.03326497599482536
50 loss: 0.02977130189538002
60 loss: 0.027103140950202942
70 loss: 0.02517838403582573
80 loss: 0.023773152381181717
90 loss: 0.022707728669047356


In [21]:
# Add  visualization

for b in range(B):
    mid = meta[b][0].replace("/", "-")
    if mid not in mid_logged:
        mid_logged[mid] = make_wis3d(output_dir=wis3d_dir, name=mid)
    wis3d = mid_logged[mid]

    start, end = meta[b][1]
    l = end - start
    # pred = pred_ayfz_motion[b, :l]
    # gt = gt_ayfz_motion[b, :l]
    pred_update = pred_ayfz_motion_update[b, :l]

    add_motion_as_lines(pred_update, wis3d, name="pred_update_wreg", offset=start)

### Focus on 1 data

test-Gym_010_pushup1-cam_03 442

In [5]:
for b in range(B):
    mid = meta[b][0].replace("/", "-")
    if mid == "test-Gym_010_pushup1-cam_03":
        start, end = meta[b][1]
        if start == 300:
            break       