# ColorHandPose3D Demo

This notebook demos the ColorHandPose3d network as implemented in "Learning to Estimate 3D Hand Pose from Single RGB Images" by Zimmerman et al. Their project is available at [https://github.com/lmb-freiburg/hand3d].

In [None]:
import os
import sys
import math

import torch
import torchvision
import torch.nn.functional as F
import pickle
import tensorflow as tf
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image

model_path = os.path.abspath(os.path.join('..'))
if model_path not in sys.path:
    sys.path.append(model_path)
    
from colorhandpose3d.model.ColorHandPose3D import ColorHandPose3D
from colorhandpose3d.utils.general import *
from colorhandpose3d.utils.transforms import *

use_cuda = torch.cuda.is_available()
use_cuda = False
print(use_cuda)

## Initialize models and load weights

ColorHandPose3d consists of 4 networks:
- HandSegNet
- PoseNet
- PosePrior
- ViewPointNet

In [None]:
weight_path = '../saved/'

chp3d = ColorHandPose3D(weight_path, 224)
if use_cuda is True:
    chp3d.cuda()

## Load and run sample

The network is trained on the RHD dataset.

In [None]:
transform0 = torchvision.transforms.ToPILImage()
transform1 = torchvision.transforms.ToTensor()
transform2 = torchvision.transforms.Resize((224, 224))
# img = Image.open('/data/RHD_v1-1/RHD_published_v2/training/color/07689.png') # bad image
    
img = Image.open('../outputs/in_video/0172.png')
img = Image.open('../outputs/in_video/0883.png')
img = Image.open('../outputs/in_video/0194.png')
img = Image.open('../outputs/in_video/0290.png')
img = Image.open('../outputs/in_video/0435.png')
img = Image.open('../outputs/in_video/0764.png')
img = Image.open('../outputs/in_video/0788.png')
print(img)
sample_original = transform1(transform2(img)).unsqueeze(0)
print(sample_original.shape)
sample = sample_original - 0.5
hand_side = torch.tensor([[1.0, 0.0]])

# Cuda
if use_cuda is True:
    sample = sample.cuda()
    hand_side = hand_side.cuda()

# Run through network
import time
inps = [None, sample]
s = time.time()
coords_xyz_rel_normed, keypoint_scoremap, image_crop, centers, scale_crop, hand_mask = chp3d(inps, hand_side)
e = time.time()
print('Total forward time: {}'.format(e-s))

# Back to CPU
if use_cuda is True:
    coords_xyz_rel_normed = coords_xyz_rel_normed.cpu()
    keypoint_scoremap = keypoint_scoremap.cpu()
    image_crop = image_crop.cpu()
    centers = centers.cpu()
    scale_crop = scale_crop.cpu()
    hand_mask = hand_mask.cpu()

keypoint_coords3d = coords_xyz_rel_normed.detach().numpy()
keypoint_coords3d = keypoint_coords3d.squeeze()

In [None]:
hand_mask_img = transform0(hand_mask.squeeze(0))
fig = plt.figure()
ax1 = fig.add_subplot()
ax1.imshow(hand_mask_img)
plt.show()
plt.close()

In [None]:
def fig2img(fig):
    """Convert a Matplotlib figure to a PIL Image and return it"""
    import io
    buf = io.BytesIO()
    fig.savefig(buf, bbox_inches='tight', pad_inches=0)
    buf.seek(0)
    img = Image.open(buf)
    return img

## Visualizing the output

Display the crop, heatmaps, and estimated pose.

In [None]:
# Get other things to visualize
keypoint_coords_crop = detect_keypoints(keypoint_scoremap[0].detach().numpy())
keypoint_coords = transform_cropped_coords(keypoint_coords_crop, centers, scale_crop, 224)


img = transform0(sample_original.squeeze())
print(img.size)

# visualize
# fig = plt.figure(1, figsize=(16, 16))
# ax1 = fig.add_subplot(221)
# ax2 = fig.add_subplot(222)
# ax3 = fig.add_subplot(223)
# ax4 = fig.add_subplot(224, projection='3d')
# ax1.imshow(img)
# plot_hand(keypoint_coords, ax1)
# ax2.imshow(transform0(image_crop[0] + 0.5))
# plot_hand(keypoint_coords_crop, ax2)
# ax3.imshow(np.argmax(keypoint_scoremap[0].detach().numpy(), 0))
# plot_hand_3d(keypoint_coords3d, ax4)
# ax4.view_init(azim=-90.0, elev=-90.0)  # aligns the 3d coord with the camera view
# ax4.set_xlim([-5, 5])
# ax4.set_ylim([-5, 5])
# ax4.set_zlim([-5, 5])
# plt.show()
transform2 = torchvision.transforms.Resize((240, 320))


fig = plt.figure()
plt.imshow(img)
plot_hand(keypoint_coords, plt)
plt.axis("off")
plt.tight_layout()
plt.subplots_adjust(left = 0, bottom = 0, right = 1, top = 1, hspace = 0, wspace = 0)

pil = fig2img(fig)
pil = transform2(pil)

pil.save('test.png')


In [None]:
 img = transform0(sample_original.squeeze())
            fig = plt.figure()
            ax1 = fig.add_subplot()
            plt.imshow(img)
            plot_hand(keypoint_coords, ax1)
            plt.axis("off")
            pil = fig2img(fig)
            pil = transform4(pil)
            pil.save("../outputs/out_video/ver0_3/{0:04d}.png".format(i+1))
#             plt.savefig("../outputs/out_video/ver0_3/{0:04d}.png".format(i+1))
            plt.close()