In [None]:
EXPERIMENTS = [{
        "config": "DexYCB_HandMvNet.yaml",
        "weights":  "../weights/dexycb/HandMvNet/lightning_logs/version_1191880/checkpoints/epoch=150-step=60702-val_mpjpe=5.974.ckpt"
    },{
        "config": "HO3D_HandMvNet_HR_wo_cam.yaml",
        "weights": "../weights/ho3d/HandMvNet-HR_wo_cam/lightning_logs/version_1256682/checkpoints/epoch=20-step=5061-val_mpjpe=14.263.ckpt",
    },{
        "config": "MVHand_HandMvNet_HR_wo_cam.yaml",
        "weights": "../weights/mvhand/HandMvNet-HR_wo_cam/lightning_logs/version_1259582/checkpoints/epoch=98-step=23760-val_mpjpe=1.763.ckpt"
    }
]

SELECTED_EXP = EXPERIMENTS[2]


import os
import sys
import json
from collections import OrderedDict
os.environ["PYOPENGL_PLATFORM"] = "egl"
sys.path.append(os.path.abspath('../src'))
sys.argv = ["config.py", "--config", f"../configs/release/{SELECTED_EXP['config']}"]

import cv2
import torch
import numpy as np
from tqdm import tqdm
import lightning as L
from matplotlib import pyplot as plt

from config import cfg

from vis.visualizer import HandPoseVisualizer
from vis.utils import reverse_transform

from utils.camera import transform_joints_between_cameras, get_2d_joints_from_3d_joints, create_intrinsics_matrix

from datasets.mvhand import MVHandDataModule
from datasets.dexycb import DexYCBDataModule
from datasets.ho3d import HO3DDataModule
from datasets.utils import batch_joints_img_to_cropped_joints

from models.handmvnet import HandMvNet as Model
from models.joints_to_vertices import JointsToVertices


cfg["data"]["batch_size"] = 1
cfg["data"]["num_workers"] = 1
# fix relative paths
if not cfg["data"]["dataset_dir"].startswith("../"):
    cfg["data"]["dataset_dir"] = os.path.join("..", cfg["data"]["dataset_dir"])
    cfg["data"]["mano_models_dir"] = os.path.join("..", cfg["data"]["mano_models_dir"])
    cfg["base_output_dir"] = os.path.join("..", cfg["base_output_dir"])

SKIP_EVERY = 1
NUM_SAMPLES_TO_SAVE = 5
SAVE_INDIVIDUAL = False
SAVE_GT = False
VIEW_ONLY = True
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# DEVICE = "cpu"
OUTPUT_DIR = os.path.join(cfg["base_output_dir"], "infer")

print(json.dumps(cfg, indent=2))

In [2]:
def is_legacy_version(state_dict):
    """
    Check if the model is a legacy version based on mismatch between keys.
    """
    # Define common mismatching keys for legacy models
    legacy_keys = ["pose_net.conv.0.weight", "sample_net.conv.0.weight"]
    for key in legacy_keys:
        if key in state_dict:
            return True
    return False


def load_checkpoint_with_legacy_fix(checkpoint_path, model, device='cpu'):
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # Extract the state_dict from the checkpoint
    state_dict = checkpoint['state_dict']

    # Check if it's a legacy version by inspecting mismatched keys
    if is_legacy_version(state_dict):
        print("[warning] Legacy version detected. Remapping keys...")
        # Create a new state_dict with remapped keys
        new_state_dict = OrderedDict()
        for old_key, value in state_dict.items():
            # Replace old keys with new ones
            new_key = old_key.replace('pose_net.conv.', 'pose_net.') \
                             .replace('sample_net.', 'sample_nets.0.')
            # Add the remapped key-value pair
            new_state_dict[new_key] = value
        # Load the remapped state_dict into the model
        model.load_state_dict(new_state_dict, strict=True)
        print("[info] legacy model loaded successfully.")
    else:
        # Load state_dict as is if no legacy issues
        model.load_state_dict(state_dict, strict=True)

    return model

In [None]:
joints_to_vertices = JointsToVertices(device=DEVICE, mano_dir=cfg["data"]["mano_models_dir"])

# Setting the seed
L.seed_everything(42, workers=True)

print(f"Loading data module...")
dataset_name = cfg["data"].get("name", "dexycb")
if dataset_name == "dexycb":
    dm = DexYCBDataModule(cfg["data"])
    views = [0, 1, 2, 3, 4, 5, 6, 7]
    root_idx = 2
elif dataset_name == "mvhand":
    dm = MVHandDataModule(cfg["data"])
    views = [0, 1, 2, 3]
    root_idx = 3
elif dataset_name == "ho3d":
    dm = HO3DDataModule(cfg["data"])
    views = [0, 1, 2, 3, 4]
    root_idx = 0
else:
    print(f"{dataset_name} dataset not found.")
    exit()

if not VIEW_ONLY:
    os.makedirs(f"{OUTPUT_DIR}/{dataset_name}", exist_ok=True)

CHECKPOINT_PATH = SELECTED_EXP["weights"]
if CHECKPOINT_PATH:
    print("\nLoading model from checkpoint:", CHECKPOINT_PATH)
    model = Model(
        train_params=cfg["train"],
        model_params=cfg["model"],
        data_params=cfg["data"]
    )
    
    # Check and fix mismatches if it's a legacy version
    model = load_checkpoint_with_legacy_fix(CHECKPOINT_PATH, model, device=DEVICE)
    model.eval()
    model.to(DEVICE)

else:
    print("Checkpoint not found at:", CHECKPOINT_PATH)
    print("[Warn] Only drawing groundtruths.")

In [None]:
dataloader = dm.test_dataloader()  # dm.test_loader, dm.val_loader
for i, sample in tqdm(enumerate(dataloader)):
    visualizer = HandPoseVisualizer(sample, mano_dir=cfg["data"]["mano_models_dir"])
    if i % SKIP_EVERY != 0: continue
    batch_size = sample["data"]["rgb"].shape[0]

    inputs = sample["data"]
    x = inputs["rgb"]
    bbox = inputs["bboxes"]
    cam_params = sample["cam_params"]
    
    ########### inference ###############
    if CHECKPOINT_PATH:
        out = model.forward(x.to(DEVICE), bbox.to(DEVICE), cam_params)
        out_joints_cam = out["joints_cam"].detach().cpu() * 1000  # [b, 21, 3] in mm
        pred_joints_crop_2d = out["joints_crop_img"].detach().cpu()  # [b, v, 21, 2]
        out_vertices = torch.stack([torch.from_numpy(joints_to_vertices(j.numpy())).float() for j in out_joints_cam])
    ######################################

    for b in range(batch_size):
        extr = sample["cam_params"]["extrinsic"][b].float()  # [v, 4, 4]
        extr_inv = torch.linalg.inv(extr.clone())  # [v, 4, 4]
        intr = sample["cam_params"]["intrinsic"][b].float()  # [v, 4]
        Ks = create_intrinsics_matrix(intr)  # [v, 3, 3]

        crops = sample["data"]["rgb"][b]  # [v, 3, 256, 256]
        # images = sample["data"]["full_rgb"][b]  # [v, 3, h, w]
        bbox = sample["data"]["bboxes"][b]

        #### ground truths
        gt_joints_2d = sample["data"]["joints_img"][b]  # [v, 21, 2]
        gt_joints_crop_2d = sample["data"]["joints_crop_img"][b]  # [v, 21, 2]
        gt_joints_cam = sample["data"]["joints_cam"][b]  # [21, 3]
        gt_root = sample["data"]["root_joint"][b]  # [1, 3]
        gt_joints_cam_abs = gt_joints_cam + gt_root  # mm
        gt_all_root = sample["data"]["all_root_joints"][b]  # [v, 3]
        gt_all_vertices = sample["data"]["all_vertices"][b]  # [v, 778, 3]
        gt_all_vertices_abs = gt_all_vertices + gt_all_root  # mm
        gt_vertices =  sample["data"]["vertices"][b] # [778, 3]
        gt_vertices_abs = gt_vertices + gt_root        # mm
        ##### Predictions
        if CHECKPOINT_PATH:
            pred_joints_crop_2d = pred_joints_crop_2d[b]  # [v, 21, 2]
            pred_joints_cam = out_joints_cam[b]          # [21, 3]
            pred_joints_cam_abs = pred_joints_cam + gt_root   # [21, 3]
            pred_vertices = out_vertices[b]              # [778, 3]
            pred_vertices_abs = pred_vertices + gt_root   # [778, 3]   
                
            pred_joints_crop_2d_proj = get_2d_joints_from_3d_joints(pred_joints_cam_abs.unsqueeze(0)/1000,
                                                                root_idx,
                                                                intr.unsqueeze(0),
                                                                extr.unsqueeze(0)).squeeze(0)
            # print("here:", out["joints_cam"].shape, pred_joints_cam.shape, pred_joints_crop_2d_proj.shape)
            pred_joints_crop_2d_proj = batch_joints_img_to_cropped_joints(pred_joints_crop_2d_proj.view(-1, 21, 2), bbox.view(-1, 4))
            # print(pred_joints_crop_2d.shape)

        gt_combine, pred_combine = [], []
        for v in views:
            # print(v)
            bb = bbox[v].numpy()
            extrinsic = np.eye(4)
            intrinsic = intr[v].numpy()
            
            bb_w, bb_h = bb[2] - bb[0], bb[3] - bb[1]
            intrinsic[2], intrinsic[3] = intrinsic[2] - bb[0], intrinsic[3] - bb[1]
            img = reverse_transform(crops[v], denormalize=True, IMAGENET_TRANSFORM=True)
            img_orig = cv2.resize(img, (bb_w, bb_h))
            
            ############## Groundtruth Vis
            gt_vis_imgs = []
            # vertices = gt_all_vertices_abs.numpy()[v]
            joints_crop = gt_joints_crop_2d.numpy()[v]

            vertices = transform_joints_between_cameras(gt_vertices_abs/1000, extr[inputs["root_idx"][0]], extr[v]).numpy()
            vertices *= 1000

            mesh_on_img, depth = visualizer.generate_mesh_from_verts(
                                    vertices,
                                    extrinsic,
                                    intrinsic,
                                    img_orig
                                )
            mesh_on_img = cv2.resize(mesh_on_img, (256, 256))
            gt_vis_imgs.append(mesh_on_img)

            # draw joints on image
            joints_on_img = visualizer._draw_joints_on_image(img.copy(), joints_crop, point_size=6, edge_width=3)
            gt_vis_imgs.append(joints_on_img)

            if SAVE_GT:
                if SAVE_INDIVIDUAL and not VIEW_ONLY:
                    cv2.imwrite(f'{OUTPUT_DIR}/{dataset_name}/gt_{i}_{b}_{v}.png', np.hstack([img, joints_on_img, mesh_on_img])[:,:,::-1])
            
            gt_combine.append(np.vstack(gt_vis_imgs))

            ############## Prediction Vis
            if CHECKPOINT_PATH:
                pred_vis_imgs = []
                vertices = transform_joints_between_cameras(pred_vertices_abs/1000, extr[inputs["root_idx"][0]], extr[v]).numpy()
                vertices *= 1000
                
                # joints_crop = pred_joints_crop_2d.numpy()[v]
                joints_crop = pred_joints_crop_2d_proj.numpy()[v]

                mesh_on_img, depth = visualizer.generate_mesh_from_verts(
                                        vertices,
                                        extrinsic,
                                        intrinsic,
                                        img_orig
                                    )
                mesh_on_img = cv2.resize(mesh_on_img, (256, 256))
                pred_vis_imgs.append(mesh_on_img)

                # draw joints on image
                joints_on_img = visualizer._draw_joints_on_image(img.copy(), joints_crop, point_size=6, edge_width=3)
                pred_vis_imgs.append(joints_on_img)

                if SAVE_INDIVIDUAL and not VIEW_ONLY:
                    cv2.imwrite(f'{OUTPUT_DIR}/{dataset_name}/pred_{i}_{b}_{v}.png', np.hstack([img, joints_on_img, mesh_on_img])[:,:,::-1])
                
                pred_combine.append(np.vstack(pred_vis_imgs))

        gt_vis = np.hstack(gt_combine)

        if CHECKPOINT_PATH:
            pred_vis = np.hstack(pred_combine)

        if VIEW_ONLY:
            if CHECKPOINT_PATH:
                fig, axes = plt.subplots(2, 1, figsize=(5, 5))
                for ax, vis_img, title in zip(axes, [pred_vis, gt_vis], ["Prediction", "Groundtruth"]):
                    ax.imshow(vis_img)
                    ax.set_title(title)
                    ax.axis('off')
            else:
                plt.imshow(gt_vis)
                plt.title("Groundtruth")
                plt.axis('off')

            plt.tight_layout()
            plt.show()
        else:
            if CHECKPOINT_PATH:
                cv2.imwrite(f'{OUTPUT_DIR}/{dataset_name}/pred_{i}_{b}.png', pred_vis[:,:,::-1])
            if SAVE_GT:
                cv2.imwrite(f'{OUTPUT_DIR}/{dataset_name}/gt_{i}_{b}.png', gt_vis[:,:,::-1])
        # break
    
    NUM_SAMPLES_TO_SAVE -= 1
    if NUM_SAMPLES_TO_SAVE <= 0:
        break