In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
%matplotlib notebook

from tbp.monty.frameworks.utils.logging_utils import (load_stats,
                                                         print_overall_stats,
                                                         print_unsupervised_stats)
from tbp.monty.frameworks.utils.plot_utils import (plot_graph,
                                                         show_initial_hypotheses, 
                                                         plot_evidence_at_step)

In [None]:
pretrain_path = os.path.expanduser("~/tbp/results/monty/pretrained_models/")
pretrained_dict = pretrain_path + "pretrained_ycb_v3/supervised_pre_training_base/pretrained/"

log_path = os.path.expanduser("~/tbp/results/monty/projects/evidence_eval_runs/logs/")

exp_name = "randomrot_rawnoise_10distobj_touch/"
exp_path = log_path + exp_name

save_path = exp_path + '/stepwise_examples/'

train_stats, eval_stats, detailed_stats, lm_models = load_stats(exp_path,
                                                                load_train=False,
                                                                load_eval=False,
                                                                load_detailed=True,
                                                                load_models=False,
                                                                pretrained_dict=pretrained_dict,
                                                               )

In [None]:
print(len(detailed_stats['0']['SM_0']['raw_observations'][10]['depth'][0]))
#['SM_1']['raw_observations'][0]['depth'])

In [None]:
def plot_image_grid(images, ncols=None, cmap='gray'):
    '''Plot a grid of images'''
    if not ncols:
        factors = [i for i in range(1, len(images)+1) if len(images) % i == 0]
        ncols = factors[len(factors) // 2] if len(factors) else len(images) // 4 + 1
    nrows = int(len(images) / ncols) + int(len(images) % ncols)
    imgs = [images[i] if len(images) > i else None for i in range(nrows * ncols)]
    f, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 2*nrows))
    axes = axes.flatten()[:len(imgs)]
    for img, ax in zip(imgs, axes.flatten()): 
        if np.any(img):
            if len(img.shape) > 2 and img.shape[2] == 1:
                img = img.squeeze()
            ax.imshow(img, cmap=cmap)

In [None]:
idx_array = [10, 23, 56, 73, 68, 95, 123, 300]

img_list = []
for idx in idx_array:
    img_list.append(np.array(detailed_stats['0']['SM_0']['raw_observations'][idx]['rgba']))

plot_image_grid(img_list, ncols=None, cmap='gray')

In [None]:
depth_image = np.array(detailed_stats['1']['SM_0']['raw_observations'][1]['depth'])
print(depth_image)
X = np.arange(64)
Y = np.arange(64)
X, Y = np.meshgrid(X, Y)

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})                      
surf = ax.plot_surface(X, Y, depth_image, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)
ax.set_zlim(0.0245, 0.0270)