# ColorHandPose3D Backward Test

Configure and test the training behavior for the CHP3D model.

In [None]:
import os
import sys
import math

import torch
import torchvision
import torch.nn.functional as F
import pickle
import tensorflow as tf
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image

model_path = os.path.abspath(os.path.join('..'))
if model_path not in sys.path:
    sys.path.append(model_path)
    
from colorhandpose3d.model.ColorHandPose3D import ColorHandPose3D
from colorhandpose3d.model.HandSegNet import HandSegNet
from colorhandpose3d.model.PoseNet import PoseNet
from colorhandpose3d.model.PosePrior import PosePrior
from colorhandpose3d.model.ViewPoint import ViewPoint
from colorhandpose3d.utils.general import *
from colorhandpose3d.utils.transforms import *

# HandSegNet

Start with the first module - hand segmentation network.

## Load the weights

In [None]:
handsegnet = HandSegNet()

file_name = '../weights/handsegnet-rhd.pickle'
session = tf.Session()
exclude_var_list = list()

# read from pickle file
with open(file_name, 'rb') as fi:
    weight_dict = pickle.load(fi)
    weight_dict = {k: v for k, v in weight_dict.items() if not any([x in k for x in exclude_var_list])}
    
keys = [k for k, v in weight_dict.items() if 'HandSegNet' in k]
keys.sort()
    
for name, module in handsegnet.named_children():
    key = 'HandSegNet/{0}/'.format(name)
    if key + 'biases' in weight_dict:
        b = torch.Tensor(weight_dict[key + 'biases'])
        w = torch.Tensor(weight_dict[key + 'weights'])
        w = w.permute((3, 2, 0, 1))
        module.weight.data = w
        module.bias.data = b
        
# torch.save(handsegnet.state_dict(), '/home/ajdillhoff/dev/projects/colorhandpose3d-pytorch/saved/handsegnet.pth.tar')

## Load and run an example through the network

In [None]:
transform0 = torchvision.transforms.ToPILImage()
transform1 = torchvision.transforms.ToTensor()
transform2 = torchvision.transforms.Resize(256)
img = Image.open('../data/RHD_v1-1/RHD_published_v2/training/color/00007.png')

sample_original = transform1(img).unsqueeze(0)
sample = sample_original - 0.5
hand_scoremap = handsegnet.forward(sample)
mask = hand_scoremap.argmax(1).to(torch.float)
mask_img = transform0(mask.to(torch.float))

## Test the backward pass

# PoseNet

Next, the data moves through PoseNet.

## Load the weights

In [None]:
posenet = PoseNet()

file_name = '../weights/posenet3d-rhd-stb-slr-finetuned.pickle'
session = tf.Session()
exclude_var_list = list()

# read from pickle file
with open(file_name, 'rb') as fi:
    weight_dict = pickle.load(fi)
    weight_dict = {k: v for k, v in weight_dict.items() if not any([x in k for x in exclude_var_list])}
    
keys = [k for k, v in weight_dict.items() if 'PoseNet2D' in k]
keys.sort()
    
for name, module in posenet.named_children():
    key = 'PoseNet2D/{0}/'.format(name)
    if key + 'biases' in weight_dict:
        b = torch.Tensor(weight_dict[key + 'biases'])
        w = torch.Tensor(weight_dict[key + 'weights'])
        w = w.permute((3, 2, 0, 1))
        module.weight.data = w
        module.bias.data = b
        
# torch.save(posenet.state_dict(), '/home/ajdillhoff/dev/projects/colorhandpose3d-pytorch/saved/posenet.pth.tar')

In [None]:
hand_mask = single_obj_scoremap(hand_scoremap)
centers, _, crops = calc_center_bb(hand_mask)
crops = crops.to(torch.float32)
crops *= 1.25
scale_crop = min(max(256. / crops, 0.25), 5.0)
image_crop = crop_image_from_xy(sample, centers, 256, scale_crop)

keypoints_scoremap = posenet(image_crop)

## Test the backward pass

# PosePrior

The third module lifts the 2D predictions to 3D.

## Load the weights

In [None]:
poseprior = PosePrior()

file_name = '../weights/posenet3d-rhd-stb-slr-finetuned.pickle'
session = tf.Session()
exclude_var_list = list()

# read from pickle file
with open(file_name, 'rb') as fi:
    weight_dict = pickle.load(fi)
    weight_dict = {k: v for k, v in weight_dict.items() if not any([x in k for x in exclude_var_list])}
    
keys = [k for k, v in weight_dict.items() if 'PosePrior' in k]
keys.sort()

for name, module in poseprior.named_children():
    key = 'PosePrior/{0}/'.format(name)
    if key + 'biases' in weight_dict:
        b = torch.Tensor(weight_dict[key + 'biases'])
        w = torch.Tensor(weight_dict[key + 'weights'])
        if len(w.shape) == 4:
            w = w.permute((3, 2, 0, 1))
        else:
            w = w.permute(1, 0)
        module.weight.data = w
        module.bias.data = b
        
# torch.save(poseprior.state_dict(), '/home/ajdillhoff/dev/projects/colorhandpose3d-pytorch/saved/poseprior.pth.tar')

In [None]:
coord_can = poseprior(keypoints_scoremap, torch.tensor([[0.0, 1.0]]))

## Test the backward pass

# ViewPointNet

The last network predicts the viewpoint.

## Load the weights

In [None]:
viewpoint = ViewPoint()

file_name = '../weights/posenet3d-rhd-stb-slr-finetuned.pickle'
session = tf.Session()
exclude_var_list = list()

# read from pickle file
with open(file_name, 'rb') as fi:
    weight_dict = pickle.load(fi)
    weight_dict = {k: v for k, v in weight_dict.items() if not any([x in k for x in exclude_var_list])}
    
keys = [k for k, v in weight_dict.items() if 'ViewpointNet' in k]
keys.sort()

for name, module in viewpoint.named_children():
    key = 'ViewpointNet/{0}/'.format(name)
    if key + 'biases' in weight_dict:
        b = torch.Tensor(weight_dict[key + 'biases'])
        w = torch.Tensor(weight_dict[key + 'weights'])
        
        # tf conv2d is [kH x kW x inputC x outputC]
        # pytorch conv2d is [outputC x inputC x kH x KW]
        # tf fully connected is [inputC x outputC]
        # pytorch fully connected is [outputC x inputC]
        if len(w.shape) == 4:
            w = w.permute((3, 2, 0, 1))
        else:
            w = w.t()
        module.weight.data = w
        module.bias.data = b
        
# torch.save(viewpoint.state_dict(), '/home/ajdillhoff/dev/projects/colorhandpose3d-pytorch/saved/viewpoint.pth.tar')

In [None]:
rot_params = viewpoint(keypoints_scoremap, torch.tensor([[0.0, 1.0]]))
rot_matrix = get_rotation_matrix(rot_params)
cond_right = torch.eq(torch.argmax(torch.tensor([[0.0, 1.0]]), 1), 1)
cond_right_all = torch.reshape(cond_right, [-1, 1, 1]).repeat(1, 21, 3)
coords_xyz_can_flip = flip_right_hand(coord_can, cond_right_all)
coords_xyz_rel_normed = coords_xyz_can_flip @ rot_matrix

## Test the backward pass

In [None]:
loss_fn = torch.nn.MSELoss(reduction='elementwise_mean')
loss = loss_fn(coords_xyz_rel_normed, 0.5 * coords_xyz_rel_normed)
print('loss = {0}'.format(loss.item()))
coords_xyz_rel_normed.retain_grad()
loss.retain_grad()
loss.backward()
print(coords_xyz_rel_normed.grad.mean())