In [None]:
!wget https://github.com/iccvsubmission10189/iccvsubmission10189_supplementary/archive/refs/heads/main.zip
!unzip main.zip
!mv iccvsubmission10189_supplementary-main/*.py .

In [None]:
import random
import sys
import ipywidgets as widgets
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import pickle
from model import DSPDH_temporal_future
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import tqdm
from matplotlib import animation, rc
from IPython.display import HTML
import matplotlib.pyplot as plt
from IPython.display import clear_output, display
import io
from PIL import Image
from pathlib import Path
from utils import to_colormap, human_chains_ixs, chains_ixs, get_chains, get_human_chains, subplot_bones, JOINTS_NAMES, unravel_indices
import cv2
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib import gridspec
plt.rcParams['animation.html'] = 'jshtml'

Writer = animation.writers['ffmpeg']
writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800)

seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
else:
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

# set device and workers
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

In [None]:
def plot_skeleton(type, joints, color, figname, viewangle):
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    if type == 'human':
        ax.set_xlim(-1.2, 1.2)
        ax.set_ylim(1.8, 3.7)
        ax.set_zlim(-1.1, 1.1)
    else:
        ax.set_xlim(-0.65, 0.65)
        ax.set_ylim(1.8, 2.3)
        ax.set_zlim(-0.3, 0.5)

    if type == 'human':
        chains = get_human_chains(joints, *human_chains_ixs)
    else:
        chains = get_chains(joints, *chains_ixs)
    ax.scatter3D(joints[:, 0], joints[:, 2], joints[:, 1], c=color, depthshade=True)
    subplot_bones(chains, ax, c=color)

    ax.get_xaxis().set_ticklabels([])
    ax.get_yaxis().set_ticklabels([])
    ax.set_zticklabels([])
    ax.grid(b=None)
    ax.view_init(10, viewangle)
    fig.canvas.draw()
    # grab the pixel buffer and dump it into a numpy array
    im = np.array(fig.canvas.renderer.buffer_rgba())
    plt.close(fig)
    #plt.savefig(figname, dpi=500, bbox_inches="tight")
    #plt.cla()
    return im[:,:,:3]
    #


In [None]:
#Download pretrained model and data
!pip install gdown
!mkdir data

import gdown
output = './data/'


gdrive_urls = [
    'https://drive.google.com/file/d/1JdWVNfT-IrAS0GbSeZYZswmI1NS7ur0W/view?usp=share_link',
    'https://drive.google.com/file/d/1sgKQ9QnUaUhHFDD31QHlsbxaPfNynE9j/view?usp=share_link',
    'https://drive.google.com/file/d/1xURG1YpOWCp33283P3WzFySXIACQG1Rv/view?usp=share_link',
    'https://drive.google.com/file/d/1B3jgMlK0MAsrkOQS8qPghQ5FGSNtqyyj/view?usp=share_link',
    'https://drive.google.com/file/d/1BTge-TKqAVHRGqFz92vSWTu6kUzWUaWJ/view?usp=share_link',
    'https://drive.google.com/file/d/1Bjbd60BGhEjb4aYjVtNPXnSkJndjTXkS/view?usp=share_link',
    'https://drive.google.com/file/d/1U4tShglCG8rhC38RJNIAmt-Q56UjLf5P/view?usp=share_link',
    'https://drive.google.com/file/d/15dSkQzutMvtXror-tekFUI41HS6JIWkG/view?usp=share_link'
]

urls = [u.replace('/file/d/','/uc?id=').replace('/view?usp=share_link','') for u in gdrive_urls]

for n, u in enumerate(urls):
    print(f'downloading file {n+1}/{len(urls)}...')
    gdown.download(u + "&confirm=t", output, quiet=False)
gdown.download('https://www.dropbox.com/s/fao4vfui5ta871c/model_pretrained.pth?dl=1', output + '/model_pretrained.pth')

In [None]:
#load intrinsic parameters and pretrained models
input_K = torch.Tensor(np.load('data/intrinsic.npy')).to(device)
resume = 'data/model_pretrained.pth'
model = DSPDH_temporal_future(c=32, joints_num=16, deltas=False, future_window_size=4)
model = model.to(device)
weights_dir = Path(resume)
checkpoint = torch.load(str(weights_dir), map_location=device)
model.load_state_dict(checkpoint["model"], strict=True)

In [None]:
#choose the example:
#center_cam_left_arm.pkl   left_cam_left_arm.pkl   right_cam_left_arm.pkl
#center_cam_right_arm.pkl  left_cam_right_arm.pkl  right_cam_right_arm.pkl
data = pickle.load(open("data/center_cam_left_arm.pkl", "rb"))

In [None]:
#inputs
depth_image = torch.Tensor(data['xyz'])
prev_joint = torch.Tensor(data['buffer'])
joint_gt = data['joints_3d_gt']

In [None]:
#inference of the model
joints_pred_total = []
heatmap_pred_total = []
for i in tqdm.tqdm(range(len(depth_image))):
    depth_image_current = depth_image[i].to(device)
    prev_joint_current = prev_joint[i].to(device)
    with torch.no_grad():
        heatmap_pred, heatmap_pred_fut = model(depth_image_current, prev_joint_current)
        heatmap_pred = (heatmap_pred + 1) / 2.
        heatmap_pred_fut = (heatmap_pred_fut + 1) / 2.

        real_H, real_W = heatmap_pred.size()[2:]

        # compute 3D pose from UV and UZ heatmaps of current frame
        B, C, H, W = heatmap_pred.shape
        joints_3d_pred = torch.ones((B, C // 2, 3)).to(device)
        max_uv = heatmap_pred[:, :C // 2].reshape(-1, real_H * real_W).argmax(1)
        joints_3d_pred[..., :2] = unravel_indices(max_uv, (real_H, real_W)).view(B, C // 2, -1)
        joints_3d_pred[..., [0, 1]] = joints_3d_pred[..., [1, 0]]
        # add Z coordinate from UZ heatmap
        max_uz = heatmap_pred[:, C // 2:].reshape(-1, real_H * real_W).argmax(1)
        z = unravel_indices(max_uz, (real_H, real_W)).view(B, C // 2, -1)[..., 0:1]
        Z_min, _, dZ = [500, 3380, 15]
        z = ((z * dZ) + Z_min) / 1000
        # convert 2D predicted joints to 3D coordinate multiplying by inverse intrinsic matrix
        inv_intrinsic = torch.inverse(input_K).unsqueeze(1).repeat(1, joints_3d_pred.shape[1], 1, 1)
        joints_3d_pred = (inv_intrinsic @ joints_3d_pred[..., None]).squeeze(-1)
        joints_3d_pred *= z
        joints_3d_pred[..., 1] *= -1  # invert Y axis for left-handed reference frame
        joints_pred = joints_3d_pred.cpu().numpy()
        
        joints_pred_total.append(joints_pred)
        heatmap_pred_total.append(heatmap_pred.cpu().numpy())
joints_pred_total = np.concatenate(joints_pred_total)
heatmap_pred_total = np.concatenate(heatmap_pred_total)

In [None]:
depth = data['depth']
heat_uv_total = []
heat_uz_total = []
joints_total = []
heatmap_uv_pred = heatmap_pred_total[:, :16]
heatmap_uz_pred = heatmap_pred_total[:, 16:]
for i in tqdm.tqdm(range(len(heatmap_uv_pred))):
    heat_uv_total.append(to_colormap(heatmap_uv_pred[i][None, ...])[0].transpose(1, 2, 0))
    heat_uz_total.append(to_colormap(heatmap_uz_pred[i][None, ...])[0].transpose(1, 2, 0))
    joints_total.append(plot_skeleton('robot', joints_pred_total[i], '#27ae60', '', 280)[80:480-80, 120:640-120])


In [None]:
import numpy as np
from PIL import Image


a = np.array(depth)
b = np.array(heat_uv_total) * 255
c = np.array(heat_uz_total) * 255
d = np.array(joints_total)[:,:,:,::-1]

a_pad = np.pad(a, pad_width=((0, 0), (64, 64), (8,8), (0, 0)), constant_values=(255, 255))
b_pad = np.pad(b, pad_width=((0, 0), (64, 64), (8,8), (0, 0)), constant_values=(255, 255))
c_pad = np.pad(c, pad_width=((0, 0), (64, 64), (8,8), (0, 0)), constant_values=(255, 255))

# top = np.concatenate((a_pad,b_pad), axis=1)
# bottom = np.concatenate((c_pad,d), axis=1)

top = np.concatenate((a_pad,b_pad), axis=1)
bottom = np.concatenate((d,c_pad), axis=1)

montage = np.concatenate((top, bottom), axis=2)

imgs = [Image.fromarray(np.uint8(frame)) for frame in montage]
# duration is the number of milliseconds between frames; this is 40 frames per second
imgs[0].save("visualization.gif", save_all=True, append_images=imgs[1:], duration=50, loop=0)

In [None]:
gif_file = "visualization.gif"
display(widgets.HTML(f'<img src="{gif_file}" width="750" align="center">'))