# Visualize Predictions of SkeletonDiffusion

Select a trained model and a dataset, and visualize a selected example.

Models trained on AMASS can run on `amass` and `3dpw`, models trained on `amass-mano`, `h36m`, or `freeman` only on their resprective datasets. 

In [1]:
# Pick a checkpoint, and a dataset
checkpoint_path = './trained_models/hmp/amass-mano/diffusion/checkpoints/cvpr_release.pt'
dataset_name = 'amass-mano'
num_samples = 50 # number of generated samples


In [2]:
%env CUDA_VISIBLE_DEVICES=0
import torch
from omegaconf import OmegaConf
OmegaConf.register_new_resolver("eval", eval)


from src.eval_prepare_model import  prepare_model, get_prediction, process_evaluation_pair
from src.eval_utils import prepare_eval_dataset
from src.metrics.ranking import get_closest_and_nfurthest_maxapd
from src.inference_utils import quick_cfg_for_inference_no_hydra

env: CUDA_VISIBLE_DEVICES=0


# Choose a pretrained model and a dataset

In [3]:
cfg = quick_cfg_for_inference_no_hydra(checkpoint_path, dataset_name=dataset_name, num_samples=num_samples)

In [4]:
data_loader, dataset, skeleton = prepare_eval_dataset(cfg, split="valid", drop_last=False, num_workers=0, batch_size=cfg['batch_size'], dataset=None, stats_mode=cfg['stats_mode'])

Constructing AMASSDataset for split  valid


Loading datasets:  ['HumanEva', 'HDM05', 'SFU', 'MoSh'] all
Constructed AMASSDataset for split valid with a total of 605848 samples


In [5]:
model, device, *_ = prepare_model(cfg, skeleton, **cfg)

> GPU 0 ready: Quadro RTX 5000
Loading Autoencoder checkpoint: ./trained_models/hmp/amass-mano/autoencoder/checkpoints/cvpr_release.pt ...


Diffusion is_ddim_sampling:  False
Loading Diffusion checkpoint: ./trained_models/hmp/amass-mano/diffusion/checkpoints/cvpr_release.pt ...


# Generate Predictions for a specific Dataset sample

In [6]:
# Select one segment from the dataset
IDX = 7275
data, target, extra = data_loader.dataset[IDX]
data, target = data.unsqueeze(0).to(device), target.unsqueeze(0).to(device)
pred = get_prediction(data, model, extra=extra, **cfg) # [batch_size, n_samples, seq_length, num_joints, features]
target, pred, _, data = process_evaluation_pair(skeleton, target=target, pred_dict={'pred': pred, 'obs': data})
target, pred, data = target[0], pred[0], data[0]

In [7]:
# get closest and n furthest predictions
# First we pick the prediction that it is closest to the GT, and then 5 predictions that are farstest away from the closest one.
# Determining these most diverse predictions is doing by choosing predictions that maximize the diversity of the ensemble
pred_closest, sorted_preds, SAMPLE_IDXS = get_closest_and_nfurthest_maxapd(pred, target, nsamples=6-1)
sorted_preds =  torch.cat([pred_closest.unsqueeze(0), sorted_preds], dim=0) 
len(sorted_preds)

6

## Create plots

In [8]:
%matplotlib widget
import matplotlib.pyplot as plt
from matplotlib import animation
from src.utils.plot_parallel import create_plot_canvas, get_drawing_funct

plt.close('all')

kpts3d_all = sorted_preds.cpu().numpy()
kpts3d_gt = target.cpu().numpy()
kpts3d_obs = data.cpu().numpy()

#create skeleton plots and draw first frame
fig, axes, plots = create_plot_canvas(kpts3d_obs, plot_titles=SAMPLE_IDXS, figsize=6, skeleton=skeleton)

ModuleNotFoundError: No module named 'ipympl'

In [None]:
# Define animation function
# blit=True re-draws only the parts that have changed.
anim = animation.FuncAnimation(fig, get_drawing_funct(kpts3d_obs, kpts3d_gt, kpts3d_all, plots, axes, skeleton), frames=len(kpts3d_obs) + len(kpts3d_all[0]), interval=30, blit=True)

In [None]:
# Draw all frames
from IPython.display import HTML
HTML(anim.to_html5_video())