In [None]:
%load_ext autoreload
%autoreload 2
import torch
from PIL import Image
from pathlib import Path
import numpy as np

from robopose.config import LOCAL_DATA_DIR
from robopose.datasets.datasets_cfg import make_scene_dataset
from robopose.rendering.bullet_scene_renderer import BulletSceneRenderer
from robopose.visualization.singleview_articulated import make_singleview_prediction_plots
from robopose.visualization.bokeh_utils import save_image_figure

import os
from tqdm import tqdm_notebook as tqdm
from bokeh.plotting import gridplot
from bokeh.io import show, output_notebook; output_notebook()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
# Pick Robot
renderer = BulletSceneRenderer('owi535')
# renderer = BulletSceneRenderer('panda')
# renderer = BulletSceneRenderer('baxter')
# renderer = BulletSceneRenderer('iiwa7')

In [None]:
# Pick inference result_id and dataset

# result_id = 'craves-lab-unknownq--1804'
# dataset = 'craves.lab.real.test'

# result_id = 'dream-panda-unknownq--1804'
# dataset = 'dream.panda.real.orb'

result_id = 'craves-youtube-unknownq-focal=1000--1804'
dataset = 'craves.youtube'

scene_ds = make_scene_dataset(dataset)

results = torch.load(LOCAL_DATA_DIR / 'results' / result_id / f'dataset={dataset}' / 'results.pth.tar')

In [None]:
# Generate visualization for a few images
type_joints = 'unknown'
dets = 'full_image_detections'
pred_keys = [
    f'{dets}/{type_joints}_joints/init',
]

pred_keys += [f'{dets}/{type_joints}_joints/iteration={K}' for K in (1, 10)]
print(pred_keys)

n_images = len(results['predictions'][f'{dets}/{type_joints}_joints/init'])
pred_ids = np.random.choice(np.arange(n_images), 5)
fig_array = []
save_dir = Path(f'unknown_joints_dataset={dataset}')

fig_idx = 0
all_figures = []
def add_figure(fig):
    fig.title.text = f'fig={len(all_figures)}'
    all_figures.append(fig)
    
for pred_idx in tqdm(pred_ids):
    row = []
    for n, pred_key in enumerate(pred_keys):
        if pred_key in results['predictions']:
            pred = results['predictions'][pred_key][[pred_idx]]
            figures = make_singleview_prediction_plots(scene_ds, renderer, pred)
            if n == 0:
                fig = figures['input_im']
                add_figure(fig)
                row.append(fig)
            fig = figures['pred_overlay']
            add_figure(fig)
            row.append(fig)
        rgb_input = figures['rgb_input']
        rgb_overlay = figures['rgb_overlay']
    fig_array.append(row)
show(gridplot(fig_array, sizing_mode='scale_width'))

In [None]:
# Save selected figures

save_ids = [0, 3]
for idx in save_ids:
    save_image_figure(all_figures[idx], f'images/examples_{idx}.jpg')