# Object Similarity Analysis

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import json
from scipy.cluster.hierarchy import dendrogram, linkage

from tbp.monty.frameworks.utils.logging_utils import load_stats, deserialize_json_chunks
from tbp.monty.frameworks.utils.plot_utils import plot_graph

In [None]:
%matplotlib notebook

In [None]:
pretrain_path = os.path.expanduser("~/tbp/results/monty/pretrained_models/")
pretrained_dict = pretrain_path + "pretrained_ycb/supervised_pre_training_all_objects/pretrained/"
log_path = os.path.expanduser("~/tbp/results/monty/projects/evidence_eval_runs/logs/")
exp_name = "all_objects_1_rotation_elm/"
# exp_name = "base_config_all_objects_1rot_elm/"
# log_path = os.path.expanduser("~/tbp/results/monty/projects/evidence_eval_runs/logs/")
# exp_name = "sampling_learns5_infs6_all_objects_1rot_elm/"
exp_path = log_path + exp_name
save_path = exp_path + '/figures/'
figure_path = os.path.expanduser("~/tbp/results/monty/figures/")
# save_path = os.path.expanduser("~/tbp/results/monty/figures/evidenceLM/stepwise_examples/"+exp_name)
_, eval_stats, detailed_stats, models = load_stats(exp_path,
                                                load_train=False,
                                                load_eval=True,
                                                load_detailed=True,
                                                pretrained_dict=pretrained_dict,
                                               )

In [None]:
eval_stats

In [None]:
detailed_stats.keys()

In [None]:
print('STATS KEYS:')
for key in detailed_stats['0']['LM_0'].keys():
    print(key)

In [None]:
ycb_object_views = dict()
plt.figure(figsize=(20,15))
for episode in detailed_stats.keys():
    obj_name = eval_stats['target_object'][int(episode)]
    ycb_object_views[obj_name] = detailed_stats[episode]['SM_1']['rgba']
    plt.subplot(8,10,int(episode)+1)
    plt.imshow(detailed_stats[episode]['SM_1']['rgba'])
    plt.title(obj_name)
    plt.axis('off')
# plt.show()
plt.savefig(figure_path + 'ycb_objects.png')

In [None]:
plt.figure(figsize=(7,7))
for episode in list(detailed_stats.keys()):
    obj_name = eval_stats['target_object'][int(episode)]
    plt.imshow(detailed_stats[episode]['SM_1']['rgba'])
    plt.title(obj_name,fontsize=40)
    plt.axis('off')
    plt.savefig(figure_path + f'objects/{obj_name}.png')

In [None]:
all_objects = list(detailed_stats['0']['LM_0']['evidences_ls'].keys())
num_obj = len(all_objects)
rel_obj_evidence_matrix = np.zeros((num_obj, num_obj))
for episode in list(detailed_stats.keys()):#[:-1]:
#     target_object = eval_stats['target_object'][int(episode)]
    detected_object = detailed_stats[str(episode)]['LM_0']['current_mlh'][-1]['graph_id']
    detected_evidence = np.max(detailed_stats[str(episode)]['LM_0']['evidences_ls'][detected_object])
    for object_id, object_name in enumerate(all_objects):
        rel_obj_evidence_matrix[int(episode), object_id] = np.max(
            detailed_stats[str(episode)]['LM_0']['evidences_ls'][object_name]) - detected_evidence


In [None]:
max_evs = [np.max(detailed_stats[str(39)]['LM_0']['evidences_ls'][key]) for key in all_objects]

In [None]:
plt.figure()
plt.hist(max_evs, bins=1000)
plt.show()

In [None]:
plt.figure()
plt.hist(max_evs)
plt.show()

In [None]:
sums = rel_obj_evidence_matrix.sum(axis=1,keepdims=1)
sums[sums==0] = 1
rel_obj_evidence_matrix_normed = rel_obj_evidence_matrix/sums

In [None]:
rel_obj_evidence_matrix_normed

In [None]:
rel_obj_evidence_df = pd.DataFrame(rel_obj_evidence_matrix_normed, columns=all_objects)

In [None]:
rel_obj_evidence_df

In [None]:
f, ax = plt.subplots(figsize=(18, 15))
# sns.heatmap(rel_obj_evidence_df, linewidths=.5, ax=ax,vmin=0,vmax=1,annot=False, linecolor='black', annot_kws={"size": 15})
sns.heatmap(rel_obj_evidence_df,ax=ax)
ax.set_xticks(np.linspace(0,76,77),all_objects)
ax.set_yticks(np.linspace(0,76,77),all_objects)
# ax.set_yticklabels(all_objects)
# ax.set_xticklabels(all_objects)
#plt.xticks(rotation=90)
plt.yticks(rotation=0)

cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=15)
cbar.set_label('Evidence rel. Target (Normalized)', rotation=270,labelpad=20,fontsize=18)
#plt.title("Correlation Between Fixations on Artsyles",fontsize=20)
plt.tight_layout()
plt.show()

In [None]:
Z = linkage(rel_obj_evidence_matrix_normed, 'average')
f, ax = plt.subplots(figsize=(9, 6))
# plt.rcParams['lines.linewidth'] = 5

dn = dendrogram(Z, labels=all_objects)#,orientation='top',leaf_font_size=15)
# plt.xticks(rotation=30,fontsize=13)
plt.ylabel('Cluster Distance', fontsize=15)
sns.despine(left=False, bottom=False, right=True)
plt.tight_layout()
plt.show()

### Generalization to New Objects

In [None]:
a_file = open(figure_path+"ycb_object_views.json", "r")
ycb_object_views = json.load(a_file)
print(ycb_object_views.keys())

In [None]:
exp_path = log_path + "evidence_generalization_maxS100"
_, eval_stats, _, _ = load_stats(exp_path,
                                        load_train=False,
                                        load_eval=True,
                                        load_detailed=False,
                                       )

In [None]:
for i, target_obj in enumerate(eval_stats['target_object']):
    most_likely_obj = eval_stats['most_likely_object'][i]
    plt.figure()
    plt.subplot(1,2,1)
    plt.imshow(ycb_object_views[target_obj])
    plt.title(f'New Object:\n{target_obj}')
    plt.axis('off')
    plt.subplot(1,2,2)
    plt.imshow(ycb_object_views[most_likely_obj])
    plt.title(f'Most Likely Object:\n{most_likely_obj}')
    plt.axis('off')
#     plt.show()
    plt.savefig(figure_path + 'new_objects/' + str(target_obj) + '.png', bbox_inches='tight')

## MISC Analysis

In [None]:
detailed_stats = deserialize_json_chunks(json_file=exp_path + 'detailed_run_stats.json', episodes=[0])

In [None]:
max_evs = []
all_objects = []
plt.figure()
for key in detailed_stats['0']['LM_0']['evidences_ls'].keys():
    evs = detailed_stats['0']['LM_0']['evidences_ls'][key]
    max_evs.append(np.max(evs))
    all_objects.append(key)
    plt.hist(evs)
plt.show()

In [None]:
print(f"max: {np.max(np.array(max_evs))}")
print(f"mean: {np.mean(np.array(max_evs))}")
print(f"median: {np.median(np.array(max_evs))}")
print(f"std: {np.std(np.array(max_evs))}")

In [None]:
for stepstats in detailed_stats['0']['LM_0']['current_mlh']:
    print(stepstats['graph_id'])

In [None]:
plt.figure(figsize=(15,7))
plt.scatter(np.linspace(0,76,77),max_evs)
for i in range(77):
    plt.text(i, max_evs[i], all_objects[i])
plt.show()