In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pprint
from scipy.spatial.transform import Rotation

from tbp.monty.frameworks.utils.logging_utils import (load_stats, 
                                                        check_rotation_accuracy,
                                                        deserialize_json_chunks)
from tbp.monty.frameworks.utils.plot_utils import (
    plot_graph, 
    plot_feature_matching_animation,
    show_one_step,
    plot_evidence_at_step,
    plot_sample_animation,
    PolicyPlot,
    plot_learned_graph,
    plot_hotspots,
plot_rotation_stat_animation,
plot_detection_stat_animation,)

from tbp.monty.frameworks.utils.transform_utils import numpy_to_scipy_quat

In [None]:
episode_num = 0

In [None]:
%matplotlib notebook

In [None]:
# Specify pre-training; determines point-cloud models that are visualized
# General paths:
pretrain_path = os.path.expanduser("~/tbp/results/monty/pretrained_models/")
log_path = os.path.expanduser("~/tbp/results/monty/projects/")

# Specific experiment paths:
pretrained_dict = pretrain_path + "pretrained_ycb_v4/touch_1lm_numenta_lab_obj/pretrained/"

In [None]:
# Specify whether curvature-informed or not
exp_path = log_path + "evidence_eval_runs/logs/base_config_monty_world/"


In [None]:
# # Load some detailed stats
# train_stats, eval_stats, _, lm_models = load_stats(exp_path,
#                                                                 load_train=False,
#                                                                 load_eval=True,
#                                                                 load_detailed=False,
#                                                                 pretrained_dict=pretrained_dict,
#                                                                )

# # Load just a single episode from detailed stats
# det_path = os.path.join(exp_path, "detailed_run_stats.json")
# detailed_stats = deserialize_json_chunks(json_file=det_path, episodes=[episode_num])

In [None]:
# Load all detailed stats
train_stats, eval_stats, detailed_stats, lm_models = load_stats(exp_path,
                                                                load_train=False,
                                                                load_eval=True,
                                                                load_detailed=True,
                                                                pretrained_dict=pretrained_dict,
                                                               )


In [None]:
eval_stats

In [None]:
# for key in detailed_stats.keys():
#     print(key)

In [None]:
# # All keys for a particular episode
# for key in detailed_stats[str(episode_num)].keys():
#     print(key)

In [None]:
# # Motor system keys
# for key in detailed_stats[str(episode_num)]["motor_system"].keys():
#     print(key)

In [None]:
# print(detailed_stats[str(episode_num)]["motor_system"]["action_details"])

In [None]:
# print(len(detailed_stats[str(episode_num)]["motor_system"]["action_sequence"]))

In [None]:
# print((detailed_stats[str(episode_num)]["motor_system"]["action_sequence"]))

In [None]:
# # SM Keys
# for key in detailed_stats[str(episode_num)]["SM_1"].keys():  # All keys for a particular episode
#     print(key)

In [None]:
# pprint.pprint(detailed_stats[str(episode_num)]["SM_0"]["processed_observations"])

In [None]:
# print(np.shape(detailed_stats[str(episode_num)]["SM_0"]["processed_observations"]))

In [None]:
# print(detailed_stats[str(episode_num)]["LM_0"]["lm_processed_steps"])

In [None]:
# # SM Keys
# for item in detailed_stats[str(episode_num)]["SM_0"]["raw_observations"][100]:  # All keys for a particular episode
#     print(item)

In [None]:
# step_num_temp=100
# print(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"][step_num_temp]["rgba"])
# print(len(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"]))

In [None]:
# # Visualize available keys in LM
# for key in detailed_stats["1"]["LM_0"].keys():
#     print(key)
    
# # Locations : these are the locations of the sensor module *taking into account depth*

In [None]:
# # Visualize available keys
# for key in lm_models["pretrained"][0]["mug"]:
#     print(key)

### Visualize Object-Classification Biases

In [None]:
eval_stats = pd.read_csv(os.path.expanduser("~/tbp/results/monty/projects/monty_world/trained_on_full_ycb/fullycb_trained_config_monty_world_full_hsv.csv"))

all_objects = ["numenta_mug", "red_mug"]

results_dic = {}
for obj in all_objects:
    results_dic[obj] = []

num_episodes = eval_stats.shape[0]
for episode in range(num_episodes):
    results_dic[eval_stats.iloc[episode]['target_object']].append(eval_stats.iloc[episode]['most_likely_object'])

#pprint.pprint(results_dic)


In [None]:
# Need to mask by target-object
category_results = eval_stats['most_likely_object'].value_counts()
print(category_results)


### Analayze Learned Models

In [None]:
def get_lm_model_stats(
    lm_models,
    episode,
    object_id,
    lm_index=0,
):
    """
    Get some basic stats about e.g. how many points each LM has in its object graph
    """

    lm = "LM_" + str(lm_index)

    # Use point-cloud model of ground-truth object that is in the evironment
    # This is based on the *LM's model*, but always getting the ground-truth object,
    # i.e. regardless of whether the LM is successfully recognizing the object or not
    # Thus we can see if e.g. there is a difference in exploration depending on how
    # well known areas on the model are
    learned_model_cloud = lm_models["pretrained"][lm_index][object_id].pos
    
    return len(learned_model_cloud)

In [None]:

# for episode_iter in range(0,2):
#     num_points = get_lm_model_stats(lm_models, episode=episode_iter,
#                              object_id=detailed_stats[str(episode_iter)]["target"]["target_object"])
#     print("\nObject: " + str(detailed_stats[str(episode_iter)]["target"]["target_object"]))
#     print(num_points)

### Plot learned graphs (no policy)

In [None]:
# # Gather "hot-spot" data

# for current_episode in range(3):
#     # Mask based on successful jumps; in principle don't need to, but these will be more
#     # interesting, because these are the locations we aim for when our pose estimates tend 
#     # to (at least presumably, although not by definition) be good
#     # Need to mask based on what the MLH object was at the time (this information is available
#     # already in the LM data)
#     # --> ?need to transform based on the estimated pose at the time --> no shouldn't
#     # need to because the target location will simply be in the reference frame of the object
#     print(detailed_stats[str(current_episode)]["motor_system"]["action_details"])

In [None]:
view_dic = dict(
    mug=[-45,45],
    spoon=[-45,90],
    numenta_mug=[-45,25],
    red_mug=[180,-45],
)

episode_num=1 # 0 is mug, 6 is spoon
target_id=detailed_stats[str(episode_num)]["target"]["target_object"]

# Plot just the learned graph
sns.set(font_scale = 1.0)
plot_learned_graph(detailed_stats, 
                    lm_models,  
                    episode=episode_num,
                    view=view_dic[target_id],
                    object_id=target_id,
                    save_fig=True)

### Plot hot-spot areas visited by top-down policy

In [None]:
view_dic = dict(
    mug=[25,-65],
    spoon=[-45,90],
    fork=[-45,110],
)

episode_num=0
target_id=detailed_stats[str(episode_num)]["target"]["target_object"]

# Plot just the learned graph
sns.set(font_scale = 1.0)
plot_hotspots(detailed_stats, 
                lm_models,  
                episode_range=2,
                view=view_dic[target_id],
                object_id=target_id,)

### Plot policy *animation*

In [None]:
view_dic = dict(
    mug = [35,25],  # [45,-55]
    spoon = [45,-75],
    golf_ball = [45, 45],
)
zoom_dic = dict(
    mug = 1.35,
    spoon = 1.35,
    golf_ball = 0.75,
)

episode_num = 0

plot_policy_animation(detailed_stats, 
                                lm_models,  
                                episode = episode_num,
                                step=-1,
                                agent_type="vision",
                                jumps_used=True,
                                object_id=detailed_stats[str(episode_num)]["target"]["target_object"],
                                view=view_dic[detailed_stats[str(episode_num)]["target"]["target_object"]],
                                zoom=zoom_dic[detailed_stats[str(episode_num)]["target"]["target_object"]],
                                extra_vis="sensor_pose",  # sensor_pose or lm_processed
                                #agent_step=40,
                        )

### Plot graph with policy, and *agent movements*

In [None]:
"""
Good views for mug basic policy:
[25,-150]
[0,0]
"""

sns.set(font_scale = 1.0)
episode_num=0
temp_step=-1
plot_policy_across_model(detailed_stats, 
                                lm_models,  
                                episode = episode_num,
                                step=30,
                                object_id=detailed_stats[str(episode_num)]["target"]["target_object"],
                                #view=[45,-80],
                                #zoom=1.5,
                                #extra_vis="sensor_pose",
                                #agent_step=40,
                        )

### Camera observations

In [None]:
episode_num = 8

In [None]:
# print(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"][step_num_temp]["rgba"])
# print(len(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"]))

viz_obs = []
patch_obs = []

# assert (len(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"]) == len(detailed_stats[str(episode_num)]["SM_1"]["raw_observations"]), "Different number of obs")

for ii in range(len(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"])):
    
    viz_obs.append(detailed_stats[str(episode_num)]["SM_1"]["raw_observations"][ii]["rgba"])
    patch_obs.append(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"][ii]["rgba"])

viz_obs = np.array(viz_obs) #[:100]
patch_obs = np.array(patch_obs) #[:100]

In [None]:
def local_plot_sample_animation(patch_obs, viz_obs, object_id="", resolution=64):
    """Plot video of sampled oservations."""
    from IPython import display
    from matplotlib import animation

    fig = plt.figure(figsize=(8, 4))
    ax1 = fig.add_subplot(1, 2, 1)
    marked_obs = viz_obs[0].copy()
    
    pixel_window = int(resolution * 0.05)
    lower_end = int(resolution/2)-pixel_window
    upper_end = int(resolution/2)+pixel_window

    marked_obs[lower_end:upper_end, lower_end:upper_end] = [0, 0, 255, 255]
    im1 = plt.imshow(marked_obs)
    ax1.set_xticks([]), ax1.set_yticks([])
    plt.title("Overview (Zoomed out)")
    ax2 = fig.add_subplot(1, 2, 2)
    im2 = plt.imshow(patch_obs[0])
    plt.title("Sensor View")
    ax2.set_xticks([]), ax2.set_yticks([])

#     num_steps = len(all_obs)
#     plot_obs = all_obs[0]
#     for obs in all_obs[1:]:
#         # obj_obs = obs[np.where(obs[:, 3] > 0)]
#         plot_obs = np.append(plot_obs, obs, axis=0)
#     res = plot_obs.shape[0] // num_steps
#     obj_obs = plot_obs[
#         np.where((plot_obs[:res, 3] > 0))  # & (plot_obs[:res, 2] < 0))
#     ]

#     scale_obs = plot_obs[np.where((plot_obs[:, 3] > 0))]
#     p1 = ax3.scatter(
#         -obj_obs[:, 1],
#         obj_obs[:, 0],
#         obj_obs[:, 2],
#         c=obj_obs[:, 2],
#         vmin=min(scale_obs[:, 2]),
#         vmax=max(scale_obs[:, 2]),
#     )

#     ax3.set_xticks([]), ax3.set_yticks([]), ax3.set_zticks([])
#     ax3.set_xlabel("x", labelpad=-10)
#     ax3.set_ylabel("y", labelpad=-10)
#     ax3.set_zlabel("z", labelpad=-10)

#     plot_zoom = 0.07
#     means = np.mean(plot_obs, axis=0)
#     ax3.set_xlim([-means[1] - plot_zoom, -means[1] + plot_zoom])
#     ax3.set_ylim([means[0] - plot_zoom, means[0] + plot_zoom])
#     ax3.set_zlim([means[2] - plot_zoom, means[2] + plot_zoom])
#     ax3.view_init(110, 0)

    def init():
        # avoid calling 0 twice
        pass

    def animate(i):
        marked_obs = viz_obs[i].copy()
        marked_obs[lower_end:upper_end, lower_end:upper_end] = [0, 0, 255, 255]
        im1.set_array(marked_obs)
        im2.set_array(patch_obs[i])

#         point_idx = int((i + 1) * res)
#         obj_obs = plot_obs[
#             np.where(
#                 (plot_obs[:point_idx, 3] > 0)  # & (plot_obs[:point_idx, 2] < 0)
#             )
#         ]
#         p1._offsets3d = (-obj_obs[:, 1], obj_obs[:, 0], obj_obs[:, 2])
#         p1.set_array(obj_obs[:, 2])

        return (ax1,)

    anim = animation.FuncAnimation(
        fig, animate, frames=len(viz_obs), init_func=init
    )
    anim.save(
            "viewfinder_" + object_id + ".gif",
            writer="imagemagick",
            dpi=300,
        )
    video = anim.to_html5_video()
    html = display.HTML(video)
    display.display(html)
    plt.close()
    

In [None]:
local_plot_sample_animation(patch_obs, viz_obs, 
                            object_id=detailed_stats[str(episode_num)]["target"]["target_object"],
                           resolution=256)