In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pprint
from scipy.spatial.transform import Rotation

from tbp.monty.frameworks.utils.logging_utils import (load_stats,
                                                        check_rotation_accuracy,
                                                        deserialize_json_chunks)
from tbp.monty.frameworks.utils.plot_utils import (
    plot_graph, 
    plot_feature_matching_animation,
    show_one_step,
    plot_evidence_at_step,
    plot_sample_animation,
    PolicyPlot,
    plot_learned_graph,
    plot_hotspots,
    plot_sample_animation_multiobj,
    plot_evidence_transitions,
    plot_rotation_stat_animation,
    plot_detection_stat_animation,
)

from tbp.monty.frameworks.utils.transform_utils import numpy_to_scipy_quat
from tbp.monty.frameworks.utils.graph_matching_utils import (
    detect_new_object_exponential,
    detect_new_object_k_steps,
)

In [None]:
episode_num = 0

In [None]:
%matplotlib notebook

In [None]:
# Specify pre-training; determines point-cloud models that are visualized
# General paths:
pretrain_path = os.path.expanduser("~/tbp/results/monty/pretrained_models/")
log_path = os.path.expanduser("~/tbp/results/monty/projects/")

# Specific experiment paths:
#pretrained_dict = pretrain_path + "pretrained_ycb_v4/surf_agent_1lm_10similarobj/pretrained/"
pretrained_dict = pretrain_path + "pretrained_ycb_v7/surf_agent_1lm_10distinctobj/pretrained/"

In [None]:
# Specify whether curvature-informed or not
exp_path = log_path + "evidence_eval_runs/logs/base_10multi_distinctobj_dist_agent/"


In [None]:
# # Load some detailed stats
# train_stats, eval_stats, _, lm_models = load_stats(exp_path,
#                                                                 load_train=False,
#                                                                 load_eval=True,
#                                                                 load_detailed=False,
#                                                                 pretrained_dict=pretrained_dict,
#                                                                )

# # Load just a single episode from detailed stats
# det_path = os.path.join(exp_path, "detailed_run_stats.json")
# detailed_stats = deserialize_json_chunks(json_file=det_path, episodes=[episode_num])

In [None]:
# Load all detailed stats
train_stats, eval_stats, detailed_stats, lm_models = load_stats(exp_path,
                                                                load_train=False,
                                                                load_eval=True,
                                                                load_detailed=True,
                                                                pretrained_dict=pretrained_dict,
                                                               )


In [None]:
eval_stats

In [None]:
# for key in detailed_stats.keys():
#     print(key)

In [None]:
# # All keys for a particular episode
# for key in detailed_stats[str(episode_num)].keys():
#     print(key)

In [None]:
# # Motor system keys
# for key in detailed_stats[str(episode_num)]["motor_system"].keys():
#     print(key)

In [None]:
# print(detailed_stats[str(episode_num)]["motor_system"]["action_details"])

In [None]:
# print(len(detailed_stats[str(episode_num)]["motor_system"]["action_sequence"]))

In [None]:
# print((detailed_stats[str(episode_num)]["motor_system"]["action_sequence"]))

In [None]:
# # SM Keys
# for key in detailed_stats[str(episode_num)]["SM_1"].keys():  # All keys for a particular episode
#     print(key)

In [None]:
# print(np.shape(detailed_stats[str(episode_num)]["SM_0"]["processed_observations"]))

In [None]:
# print(detailed_stats[str(episode_num)]["LM_0"]["lm_processed_steps"])

In [None]:
# # SM Keys
# for item in detailed_stats[str(episode_num)]["SM_0"]["raw_observations"][100]:  # All keys for a particular episode
#     print(item)

In [None]:
# step_num_temp=100
# print(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"][step_num_temp]["rgba"])
# print(len(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"]))

In [None]:
# # Visualize available keys in LM
# for key in detailed_stats["1"]["LM_0"].keys():
#     print(key)
    
# # Locations : these are the locations of the sensor module *taking into account depth*

In [None]:
# # Visualize available keys
# for key in lm_models["pretrained"][0]["mug"]:
#     print(key)

In [None]:
# def get_lm_model_stats(
#     lm_models,
#     episode,
#     object_id,
#     lm_index=0,
# ):
#     """
#     Get some basic stats about e.g. how many points each LM has in its object graph
#     """

#     lm = "LM_" + str(lm_index)

#     # Use point-cloud model of ground-truth object that is in the evironment
#     # This is based on the *LM's model*, but always getting the ground-truth object,
#     # i.e. regardless of whether the LM is successfully recognizing the object or not
#     # Thus we can see if e.g. there is a difference in exploration depending on how
#     # well known areas on the model are
#     learned_model_cloud = lm_models["pretrained"][lm_index][object_id].pos
    
#     return len(learned_model_cloud)

In [None]:

# for episode_iter in range(0,6):
#     num_points = get_lm_model_stats(lm_models, episode=episode_iter,
#                              object_id=detailed_stats[str(episode_iter)]["target"]["target_object"])
#     print("\nObject: " + str(detailed_stats[str(episode_iter)]["target"]["target_object"]))
#     print(num_points)

### Plot learned graphs (no policy)

In [None]:
# # Gather "hot-spot" data

# for current_episode in range(3):
#     # Mask based on successful jumps; in principle don't need to, but these will be more
#     # interesting, because these are the locations we aim for when our pose estimates tend 
#     # to (at least presumably, although not by definition) be good
#     # Need to mask based on what the MLH object was at the time (this information is available
#     # already in the LM data)
#     # --> ?need to transform based on the estimated pose at the time --> no shouldn't
#     # need to because the target location will simply be in the reference frame of the object
#     print(detailed_stats[str(current_episode)]["motor_system"]["action_details"])

In [None]:
def plot_learned_graph_with_bb(
    detailed_stats,
    lm_models,
    episode,
    object_id,
    view=None,
    noise_amount=0.0,
    lm_index=0,
    save_fig=False,
    save_path="./",
    original_corners=[],
    new_corners=[],
):
    """
    Plot the graph learned for a particular object; does not include additional
    visualizations of policy movements etc.

    It differs from plot_graph in that the focus is on plotting a graph stored in an
    LMs memory, where this is corrected to have the rotation and position in the
    environment as was experienced during an episode.

    Futhermore, there is the option to add noise, such that it is easy to visualize e.g.
    the effect of noise in the location feature.

    :param view: the elevation and azimuth to initialize the view at
    """

    # Use point-cloud model of ground-truth object that is in the evironment
    # This is based on the *LM's model*, but always getting the ground-truth object,
    learned_model_cloud = lm_models["pretrained"][lm_index][object_id].pos

    converted_quat = numpy_to_scipy_quat(
        detailed_stats[str(episode)]["target"]["primary_target_rotation_quat"]
    )
    object_rot = Rotation.from_quat(converted_quat)
    
    # Update original corners for the displacement
    original_corners = np.array(original_corners) + np.array([0, 1.5, 0])

    # Update orientation and position of the learned model to be in environmental
    # coordinates
    learned_model_cloud = (
        object_rot.apply(learned_model_cloud)
        + detailed_stats[str(episode)]["target"]["primary_target_position"]
    )


#     # Add optional noise; can be used to visualize e.g. how significant noise
#     # in the sensory information might be
#     noise_to_add = np.random.normal(0, noise_amount, size=np.shape(learned_model_cloud))
#     learned_model_cloud = learned_model_cloud + noise_to_add

    plt.figure(figsize=(5, 5))
    ax = plt.subplot(1, 1, 1, projection="3d")

    # Plot the learned graph of the object mapped on to where it actually is
    # in the environment
    ax.scatter(
        learned_model_cloud[:, 0],
        learned_model_cloud[:, 1],
        learned_model_cloud[:, 2],
        c=learned_model_cloud[:, 2],
        alpha=0.3,
    )
    
    ax.scatter(
        original_corners[0][0],
        original_corners[0][1],
        original_corners[0][2],
        c="red",
        alpha=0.8,
        label="min_original"
    )

    ax.scatter(
        original_corners[1][0],
        original_corners[1][1],
        original_corners[1][2],
        c="red",
        alpha=0.8,
        label="max_original"
    )

    ax.scatter(
        new_corners[0][0],
        new_corners[0][1],
        new_corners[0][2],
        c="green",
        alpha=0.8,
        label="min_transformed"
    )

    ax.scatter(
        new_corners[1][0],
        new_corners[1][1],
        new_corners[1][2],
        c="green",
        alpha=0.8,
        label="max_transformed"
    )
    
    ax.set_aspect("equal")
    if view is not None:
        ax.view_init(view[0], view[1])

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    print("figure saved at " + save_path)
    plt.legend()
#     plt.savefig(
#         save_path + f"{episode}.png",
#         bbox_inches="tight",
#         dpi=300,
#     )
#     else:
#     plt.show()

In [None]:
view_dic = dict(
    mug=[-45,45],
    spoon=[-45,90],
)

original_corners = [[-0.093004, -0.010495, -0.041427], [0.093004, 0.010495, 0.041427]]
new_corners = [[0.04471411, 1.5414902,  0.0821906], [-0.04471411,  1.4585098,  -0.0821906]]


episode_num=0 # 0 is mug, 6 is spoon
target_id=detailed_stats[str(episode_num)]["target"]["primary_target_object"]

# # Plot just the learned graph
sns.set(font_scale = 1.0)
plot_learned_graph_with_bb(detailed_stats, 
                    lm_models,  
                    episode=episode_num,
                    view=None,
                    object_id=target_id,
                  save_fig=True,
                  original_corners=original_corners,
                  new_corners=new_corners)

### Plot hot-spot areas visited by top-down policy

In [None]:
view_dic = dict(
    mug=[25,-65],
    spoon=[-45,90],
    fork=[-45,110],
)

episode_num=0
target_id=detailed_stats[str(episode_num)]["target"]["target_object"]

# Plot just the learned graph
sns.set(font_scale = 1.0)
plot_hotspots(detailed_stats, 
                lm_models,  
                episode_range=2,
                view=view_dic[target_id],
                object_id=target_id,)

### Plot policy *animation*

In [None]:
view_dic = dict(
    mug = [35,25],  # [45,-55]
    spoon = [45,-75],
    golf_ball = [45, 45],
)
zoom_dic = dict(
    mug = 1.35,
    spoon = 1.35,
    golf_ball = 0.75,
)

episode_num = 0

policy_plotter = PolicyPlot(
        detailed_stats,
        lm_models,
        episode = episode_num,
        object_id=detailed_stats[str(episode_num)]["target"]["target_object"],
        agent_type="distant",
        jumps_used=True,
        extra_vis="sensor_pose",
        lm_index=0,)

policy_plotter.plot_animation()
#policy_plotter.visualize_plot()

#     view=view_dic[detailed_stats[str(episode_num)]["target"]["target_object"]],
#     zoom=zoom_dic[detailed_stats[str(episode_num)]["target"]["target_object"]]


### Plot graph with policy, and *agent movements*

In [None]:
"""
Good views for mug basic policy:
[25,-150]
[0,0]
"""

sns.set(font_scale = 1.0)
episode_num=0
temp_step=-1
plot_policy_across_model(detailed_stats, 
                                lm_models,  
                                episode = episode_num,
                                step=30,
                                object_id=detailed_stats[str(episode_num)]["target"]["target_object"],
                                #view=[45,-80],
                                #zoom=1.5,
                                #extra_vis="sensor_pose",
                                #agent_step=40,
                        )

### Camera observations

In [None]:
def plot_sample_animation(patch_obs, viz_obs, semantic_obs, primary_target="", save_bool=False):
    """Plot video of sampled oservations."""
    from IPython import display
    from matplotlib import animation

    fig = plt.figure(figsize=(8, 4))
    ax1 = fig.add_subplot(1, 2, 1)
    marked_obs = viz_obs[0].copy()
    marked_obs[29:35, 29:35] = [0, 0, 255, 255]
    im1 = plt.imshow(marked_obs)
    ax1.set_xticks([]), ax1.set_yticks([])
    plt.title("Overview (Zoomed out)")
    ax2 = fig.add_subplot(1, 2, 2)
    im2 = plt.imshow(patch_obs[0])
    ax2.set_xticks([]), ax2.set_yticks([])

#     num_steps = len(all_obs)
#     plot_obs = all_obs[0]
#     for obs in all_obs[1:]:
#         # obj_obs = obs[np.where(obs[:, 3] > 0)]
#         plot_obs = np.append(plot_obs, obs, axis=0)
#     res = plot_obs.shape[0] // num_steps
#     obj_obs = plot_obs[
#         np.where((plot_obs[:res, 3] > 0))  # & (plot_obs[:res, 2] < 0))
#     ]

#     scale_obs = plot_obs[np.where((plot_obs[:, 3] > 0))]
#     p1 = ax3.scatter(
#         -obj_obs[:, 1],
#         obj_obs[:, 0],
#         obj_obs[:, 2],
#         c=obj_obs[:, 2],
#         vmin=min(scale_obs[:, 2]),
#         vmax=max(scale_obs[:, 2]),
#     )

#     ax3.set_xticks([]), ax3.set_yticks([]), ax3.set_zticks([])
#     ax3.set_xlabel("x", labelpad=-10)
#     ax3.set_ylabel("y", labelpad=-10)
#     ax3.set_zlabel("z", labelpad=-10)

#     plot_zoom = 0.07
#     means = np.mean(plot_obs, axis=0)
#     ax3.set_xlim([-means[1] - plot_zoom, -means[1] + plot_zoom])
#     ax3.set_ylim([means[0] - plot_zoom, means[0] + plot_zoom])
#     ax3.set_zlim([means[2] - plot_zoom, means[2] + plot_zoom])
#     ax3.view_init(110, 0)

    def init():
        # avoid calling 0 twice
        pass

    def animate(i):
        marked_obs = viz_obs[i].copy()
        marked_obs[29:35, 29:35] = [0, 0, 255, 255]
        im1.set_array(marked_obs)
        im2.set_array(patch_obs[i])
        plt.title("primary_target: " + primary_target + "\nstepwise_target: " + semantic_obs[i])

#         point_idx = int((i + 1) * res)
#         obj_obs = plot_obs[
#             np.where(
#                 (plot_obs[:point_idx, 3] > 0)  # & (plot_obs[:point_idx, 2] < 0)
#             )
#         ]
#         p1._offsets3d = (-obj_obs[:, 1], obj_obs[:, 0], obj_obs[:, 2])
#         p1.set_array(obj_obs[:, 2])

        return (ax1,)

    anim = animation.FuncAnimation(
        fig, animate, frames=len(viz_obs), init_func=init
    )
    video = anim.to_html5_video()
    html = display.HTML(video)
    display.display(html)
    
    if save_bool:
        anim.save(
            "viewfinder_gif.gif",
            writer="imagemagick",
            dpi=300,
        )
    plt.close()

In [None]:
episode_num=0

In [None]:
# print(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"][step_num_temp]["rgba"])
# print(len(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"]))

viz_obs = []
patch_obs = []

# assert (len(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"]) == len(detailed_stats[str(episode_num)]["SM_1"]["raw_observations"]), "Different number of obs")

for ii in range(len(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"])):
    
    viz_obs.append(detailed_stats[str(episode_num)]["SM_1"]["raw_observations"][ii]["rgba"])
    patch_obs.append(detailed_stats[str(episode_num)]["SM_0"]["raw_observations"][ii]["rgba"])

viz_obs = np.array(viz_obs) #[:100]
patch_obs = np.array(patch_obs) #[:100]
semantic_obs = detailed_stats[str(episode_num)]["LM_0"]["stepwise_targets_list"]

In [None]:
plot_sample_animation_multiobj(patch_obs, viz_obs, semantic_obs, primary_target=detailed_stats[str(episode_num)]["target"]["primary_target_object"], save_bool=True)

### Visualize Evidence Across Object Transitions

In [None]:
detailed_stats[str(episode_num)]["LM_0"].keys()

In [None]:
detailed_stats[str(episode_num)]["LM_0"]["stepwise_targets_list"]
len(detailed_stats[str(episode_num)]["LM_0"]["stepwise_targets_list"])

In [None]:
step_num=4

In [None]:
np.shape(detailed_stats[str(episode_num)]["LM_0"]["evidences"])

In [None]:
color_mapping = {
'golf_ball':"grey", 'dice':"black", 'spoon':"steelblue", 'strawberry':"red", 'banana':"yellow", 'bowl':"darkviolet", 'potted_meat_can':"lightblue", 'mug':"sienna", 'c_lego_duplo':"hotpink", 'mustard_bottle':"gold"
}
detection_cmapping = {
    "true_positive" : "blue",
    "false_positive" : "red",
    "false_negative" : "grey"
}

In [None]:
%matplotlib inline

In [None]:
# K-absolute method
detection_params_dict = dict(
    detection_threshold=-1.0,
    k=2,
    reset_at_positive_jump=False,
)

plot_evidence_transitions(episode_num, detailed_stats[str(episode_num)]["LM_0"], 
                          detection_fun=detect_new_object_k_steps,
                          detection_params_dict=detection_params_dict,
                          primary_target=detailed_stats[str(episode_num)]["target"]["primary_target_object"],
                          color_mapping=color_mapping)

In [None]:
# Exponential_method
detection_params_dict = dict(
    detection_threshold=-1.0,
    decay_rate=2,
    reset_at_positive_jump=False,
)
plot_evidence_transitions(episode_num, detailed_stats[str(episode_num)]["LM_0"], 
                          detection_fun=detect_new_object_exponential,
                          detection_params_dict=detection_params_dict,
                          primary_target=detailed_stats[str(episode_num)]["target"]["primary_target_object"],
                          color_mapping=color_mapping)

In [None]:
# K-absolute method with reset when we get positive evidence
detection_params_dict = dict(
    detection_threshold=-1.0,
    k=2,
    reset_at_positive_jump=True,
)
plot_evidence_transitions(episode_num, detailed_stats[str(episode_num)]["LM_0"], 
                          detection_fun=detect_new_object_k_steps,
                          detection_params_dict=detection_params_dict,
                          primary_target=detailed_stats[str(episode_num)]["target"]["primary_target_object"],
                          color_mapping=color_mapping)

In [None]:
def evaluate_at_thresholds(detection_fun, detection_param):
    thresholds_to_try = np.linspace(-5.0, 0, 50)
    tpr_results = []
    fpr_results = []

    for threshold in thresholds_to_try:
    #     print("Trying threshold :" + str(threshold))
        temp_tpr=[]
        temp_fpr=[]
        for episode_num in range(10):
            tpr, fpr = plot_evidence_transitions(episode_num, detailed_stats[str(episode_num)]["LM_0"], 
                                  detection_fun=detection_fun,
                                 detection_threshold=threshold,
                                  detection_param=detection_param,
                                  primary_target=detailed_stats[str(episode_num)]["target"]["primary_target_object"],
                                  color_mapping=color_mapping)
            if tpr is not None:
                temp_tpr.append(tpr)
            if fpr is not None:
                temp_fpr.append(fpr)
        # Get performance of this threshold averaged across multiple episodes,
        # but only if there is some meaningful performance to log
    #     print("Threshold results:")
    #     print(temp_tpr)
    #     print(temp_fpr)
        if len(temp_tpr) > 0 and len(temp_fpr) > 0:
            tpr_results.append(np.mean(temp_tpr))
            fpr_results.append(np.mean(temp_fpr))
    return tpr_results, fpr_results
# print(tpr_results)
# print(fpr_results)

In [None]:
# Try some different decay values
# 0 corresponds to no decay, 30 corresponds to only looking at the most recent step
decay_rate_vals = [0, 0.1, 0.25, 0.5, 2.0, 30]
k_vals = [1, 2, 4, 6, 8, 10]
k_reset_vals = [2, 4, 6, 8, 10, 20]  # No point evaluating at k==1, as equivalent to without reset
k_norm_vals = [2, 4, 6, 8, 10, 20]  # Again, k==1 is equivalent to other versions

param_results = {}
for decay_rate in decay_rate_vals:
    
    tpr_results, fpr_results = evaluate_at_thresholds(detection_fun=detect_new_object_exponential,
                                                      detection_param=decay_rate)
    param_results["decay_" + str(decay_rate) + "_tpr"] = tpr_results
    param_results["decay_" + str(decay_rate) + "_fpr"] = fpr_results

# for k in k_vals:
#     tpr_results, fpr_results = evaluate_at_thresholds(detection_fun=detect_new_object_k_steps,
#                                                       detection_param=k)
#     param_results["k_" + str(k) + "_tpr"] = tpr_results
#     param_results["k_" + str(k) + "_fpr"] = fpr_results

# for k_reset in k_reset_vals:
#     tpr_results, fpr_results = evaluate_at_thresholds(detection_fun=k_abs_with_reset_detect_new_object,
#                                                       detection_param=k_reset)
#     param_results["k_reset_" + str(k_reset) + "_tpr"] = tpr_results
#     param_results["k_reset_" + str(k_reset) + "_fpr"] = fpr_results

# for k_norm in k_norm_vals:
#     tpr_results, fpr_results = evaluate_at_thresholds(detection_fun=k_norm_with_reset_detect_new_object,
#                                                       detection_param=k_norm)
#     param_results["k_norm_" + str(k_norm) + "_tpr"] = tpr_results
#     param_results["k_norm_" + str(k_norm) + "_fpr"] = fpr_results


In [None]:
decay_rate_vals = [0.5]
for decay_rate in decay_rate_vals:
    plt.plot(param_results["decay_" + str(decay_rate) + "_fpr"], param_results["decay_" + str(decay_rate) + "_tpr"], label="Decay " + str(decay_rate), alpha=0.5)

# for k in k_vals:
#     plt.plot(param_results["k_" + str(k) + "_fpr"], param_results["k_" + str(k) + "_tpr"], label="k-val " + str(k), linestyle="dotted", alpha=0.5)

# for k_reset in k_reset_vals:
#     plt.plot(param_results["k_reset_" + str(k_reset) + "_fpr"], param_results["k_reset_" + str(k_reset) + "_tpr"], label="k_reset-val " + str(k_reset), linestyle="dashed", alpha=0.5)
    
# for k_norm in k_norm_vals:
#     plt.plot(param_results["k_norm_" + str(k_norm) + "_fpr"], param_results["k_norm_" + str(k_norm) + "_tpr"], label="k_norm_-val " + str(k_norm), linestyle="solid", alpha=0.5)
    
plt.plot([0, 1.0], [0, 1.0], color="grey", linestyle="--", alpha=0.5, label="random classifier")
plt.title("New-Object Detection")
plt.ylabel("TPR")
plt.xlabel("FPR")
plt.xlim(0,1.0)
plt.ylim(0,1.0)
plt.legend(loc="lower right")
#plt.plot(thresholds_to_try, thresholds_to_try, color="grey")

In [None]:
# # Checking the equivalence of decay rate and k at appropriate parameters
# decay_rate=30
# plt.plot(param_results[str(decay_rate) + "_fpr"], param_results[str(decay_rate) + "_tpr"], label="Decay " + str(decay_rate), alpha=0.5)

# k=1
# plt.plot(param_results[str(k) + "_fpr"], param_results[str(k) + "_tpr"], label="k-val " + str(k), linestyle="dotted", alpha=0.5)
    
# plt.plot([0, 1.0], [0, 1.0], color="grey", linestyle="--", alpha=0.5, label="random classifier")
# plt.title("New-Object Detection")
# plt.ylabel("TPR")
# plt.xlabel("FPR")
# plt.xlim(0,1.0)
# plt.ylim(0,1.0)
# plt.legend(loc="lower right")
# #plt.plot(thresholds_to_try, thresholds_to_try, color="grey")

In [None]:
# Save results for all objects
for episode_num in range(10):
    plot_evidence_transitions(episode_num, detailed_stats[str(episode_num)]["LM_0"], 
                              detection_fun=detect_new_object_exponential,
                              detection_threshold=-1.0,
                              detection_param=30,
                              primary_target=detailed_stats[str(episode_num)]["target"]["primary_target_object"],
                              color_mapping=color_mapping,
                              save_fig_path="./")