# Omniglot Tests

In [None]:
import numpy as np
import os
import random
from sys import platform as sys_pf
import matplotlib.pyplot as plt

In [None]:
%matplotlib notebook

### Basic Testing, Visualizing & Debugging Dataset

In [None]:
def plot_motor_to_image(I,drawing,lw=2):
	drawing = [d[:,0:2] for d in drawing] # strip off the timing data (third column)
	drawing = [space_motor_to_img(d) for d in drawing] # convert to image space
	plt.imshow(I,cmap='gray')
	ns = len(drawing)
	for sid in range(ns): # for each stroke
		plot_traj(drawing[sid],get_color(sid),lw)
	plt.xticks([])
	plt.yticks([])

def plot_traj(stk,color,lw):
	n = stk.shape[0]
	if n > 1:
		plt.plot(stk[:,0],stk[:,1],color=color,linewidth=lw)
	else:
		plt.plot(stk[0,0],stk[0,1],color=color,linewidth=lw,marker='.')

# Color map for the stroke of index k
def get_color(k):	
    scol = ['r','g','b','m','c']
    ncol = len(scol)
    if k < ncol:
       out = scol[k]
    else:
       out = scol[-1]
    return out

# convert to str and add leading zero to single digit numbers
def num2str(idx):
	if idx < 10:
		return '0'+str(idx)
	return str(idx)

# Load binary image for a character
#
# fn : filename
def load_img(fn):
	I = plt.imread(fn)
	I = np.array(I,dtype=bool)
	return I

def load_motor(fn):
	motor = []
	with open(fn,'r') as fid:
		lines = fid.readlines()
	lines = [l.strip() for l in lines]
	for myline in lines:
		if myline =='START': # beginning of character
			stk = []
		elif myline =='BREAK': # break between strokes
			stk = np.array(stk)
			motor.append(stk) # add to list of strokes
			stk = [] 
		else:
			arr = np.fromstring(myline,dtype=float,sep=',')
			stk.append(arr)
	return motor

def space_motor_to_img(pt):
	pt[:,1] = -pt[:,1]
	return pt
def space_img_to_motor(pt):
	pt[:,1] = -pt[:,1]
	return

In [None]:
omniglot_path = os.path.expanduser('~/tbp/data/omniglot/python/')
img_dir = omniglot_path + 'images_background'
stroke_dir = omniglot_path + 'strokes_background'
nreps = 20 # number of renditions for each character

alphabet_names = [a for a in os.listdir(img_dir) if a[0] != '.'] # get folder names
print(alphabet_names)

In [None]:
a_id = 12
character_id = 2
alpha_name = alphabet_names[a_id]

img_char_dir = os.path.join(img_dir,alpha_name,'character'+num2str(character_id))
stroke_char_dir = os.path.join(stroke_dir,alpha_name,'character'+num2str(character_id))

# get base file name for this character
fn_example = os.listdir(img_char_dir)[0]
fn_base = fn_example[:fn_example.find('_')] 

In [None]:
plt.figure(figsize=(10,8))
for r in range(1,nreps+1): # for each rendition
    plt.subplot(4,5,r)
    fn_stk = stroke_char_dir + '/' + fn_base + '_' + num2str(r) + '.txt'
    fn_img = img_char_dir + '/' + fn_base + '_' + num2str(r) + '.png'			
    motor = load_motor(fn_stk)
    I = load_img(fn_img)
    plot_motor_to_image(I,motor)
    if r==1:
        plt.title(alpha_name[:15] + '\n character ' + str(character_id))
plt.tight_layout()
plt.show()

In [None]:
motor = load_motor(fn_stk)
motor = [d[:,0:2] for d in motor] # strip off the timing data (third column)
motor = [space_motor_to_img(d) for d in motor] # convert to image space
locations = np.zeros((2))
for stroke in motor:
    locations = np.vstack([locations, stroke])
locations = locations[1:]

In [None]:
plt.figure()
plt.imshow(I)
plt.scatter(locations[:,0],locations[:,1], s=1, c='r')
plt.show()

In [None]:
def get_image_patch(img, loc, patch_size):
    loc = np.array(loc,dtype=int)
    startx = loc[1] - patch_size//2
    stopx = loc[1] + patch_size//2
    starty = loc[0] - patch_size//2
    stopy = loc[0] + patch_size//2
    patch = img[startx:stopx, starty:stopy]
    return patch

In [None]:
i = 1
patch = get_image_patch(I, locations[i], 10)

plt.figure()
plt.subplot(1,2,1)
plt.imshow(I)
plt.scatter(locations[:,0],locations[:,1], s=1, c='r')
plt.scatter(locations[i,0],locations[i,1], s=10, c='g')
plt.axis('off')
plt.title('Location of Patch (green)')
plt.subplot(1,2,2)
plt.imshow(patch)
plt.axis('off')
plt.title('Patch Observation')
plt.show()

### Testing Data Formatting For Monty Compatibility

In [None]:
from tbp.monty.frameworks.environment_utils.transforms import DepthTo3DLocations
from tbp.monty.frameworks.environment_utils.graph_utils import get_point_normal,get_curvature_at_point
import quaternion as qt
from scipy.ndimage import gaussian_filter

In [None]:
agent_id = 'agent_01'
sensor_id = 'patch_01'
depth = 1.2 - gaussian_filter(np.array(~patch,dtype=float), sigma=0.5)
obs = {agent_id:{sensor_id:{"depth":depth, 
                            "semantic":np.array(~patch,dtype=int)}}}
rotation = qt.from_rotation_vector([np.pi / 2, 0.0, 0.0])
loc = locations[i]
sensor_position = np.array([loc[0],loc[1],0])
state = {agent_id:{"sensors":{sensor_id + ".depth":{"rotation":rotation, 
                                                    "position":sensor_position}},
                  "rotation":rotation, "position":np.array([0,0,0])}}

In [None]:
d2 = gaussian_filter(np.array(~patch,dtype=float), sigma=0.5)

In [None]:
plt.figure()
plt.subplot(1,2,1)
plt.imshow(1.2-d2)
plt.title("depth")
plt.axis('off')
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(obs[agent_id][sensor_id]['semantic'])
plt.colorbar()
plt.title("on object")
plt.axis('off')
plt.show()

#### Apply Habitat Transform

In [None]:
transform = DepthTo3DLocations(
                agent_id=agent_id,
                sensor_ids=[sensor_id],
                resolutions=[[10,10]],
                world_coord=True,
                zooms=1,
                get_all_points=True,
                use_semantic_sensor=True,
                depth_clip_sensors=(0,),
                clip_value=1.1,
            )
    

In [None]:
obs = transform(obs, state=state)

In [None]:
obs_3d = obs[agent_id][sensor_id]['semantic_3d']
locs = obs_3d[obs_3d[:,3] == 1]

obs_dim = int(np.sqrt(obs_3d.shape[0]))
half_obs_dim = obs_dim // 2
center_id = half_obs_dim + obs_dim * half_obs_dim

point_normal = get_point_normal(
                obs_3d, center_id, sensor_location=sensor_position
            )
k1, k2, dir1, dir2 = get_curvature_at_point(obs_3d, center_id, point_normal)

center_loc = obs_3d[center_id,:3]

plt.figure()
ax = plt.subplot(1,1,1,projection='3d')
ax.scatter(locs[:,0], locs[:,1], locs[:,2],c='black')
ax.scatter(center_loc[0], center_loc[1], center_loc[2],s=50,c='green')
linelen = 0.8
colors = ['green','red','orange']
for i,line in enumerate([point_normal, dir1, dir2]):
    plt.plot([center_loc[0], center_loc[0] + line[0]*linelen], 
             [center_loc[1], center_loc[1] + line[1]*linelen], 
             [center_loc[2], center_loc[2] + line[2]*linelen],
            color=colors[i])
ax.set_aspect('equal')
plt.show()

#### Test Formatting on All Points

In [None]:
# we know the point normal faces the camera on the 2D image here
point_normal = np.array([0,0,-1])
colors = ['green','red','orange']
linelen = 10
plt.figure()
ax = plt.subplot(1,1,1,projection='3d')
for loc in locations:
    patch = get_image_patch(I, loc, 10)
    depth = 1.2 - gaussian_filter(np.array(~patch,dtype=float), sigma=0.5)
    obs = {agent_id:{sensor_id:{"depth":depth, 
                            "semantic":np.array(~patch,dtype=int)}}}
    rotation = qt.from_rotation_vector([np.pi / 2, 0.0, 0.0])
    sensor_position = np.array([loc[0],loc[1],0])
    state = {agent_id:{"sensors":{sensor_id + ".depth":{"rotation":rotation, 
                                                        "position":sensor_position}},
                      "rotation":rotation, "position":np.array([0,0,0])}}
    obs = transform(obs, state=state)
    obs_3d = obs[agent_id][sensor_id]['semantic_3d']
    # If the center of the patch is on the object, get curvature and plot
    if obs_3d[center_id, 3] > 0:
        locs = obs_3d[obs_3d[:,3] == 1]
        k1, k2, dir1, dir2 = get_curvature_at_point(obs_3d, center_id, point_normal)
        center_loc = obs_3d[center_id,:3]
        ax.scatter(center_loc[0], center_loc[1], center_loc[2],c=k1)
        for i,line in enumerate([point_normal, dir1, dir2]):
            plt.plot([center_loc[0], center_loc[0] + line[0]*linelen], 
                     [center_loc[1], center_loc[1] + line[1]*linelen], 
                     [center_loc[2], center_loc[2] + line[2]*linelen],
                    color=colors[i])
ax.set_aspect('equal')
plt.show()

# Load Experiment Data & Visualize

In [None]:
from tbp.monty.frameworks.utils.logging_utils import load_stats
from tbp.monty.frameworks.utils.plot_utils import (show_initial_hypotheses, 
                                                         plot_evidence_at_step,
                                                        plot_graph)

In [None]:
pretrain_path = os.path.expanduser("~/tbp/results/monty/pretrained_models/")
pretrained_dict = pretrain_path + "pretrained_ycb/supervised_pre_training_on_omniglot/pretrained/"
log_path = os.path.expanduser("~/tbp/results/monty/projects/monty_runs/")
exp_name = "evidence_on_omniglot/"
exp_path = log_path + exp_name
save_path = exp_path + '/stepwise_examples/'
train_stats, eval_stats, detailed_stats, lm_models = load_stats(exp_path,
                                                                load_train=False,
                                                                load_eval=True,
                                                                load_detailed=True,
                                                                pretrained_dict=pretrained_dict,
                                                               )

In [None]:
eval_stats

In [None]:
plot_graph(lm_models['pretrained'][0]['Korean_1'])
plt.show()

In [None]:
step = 10
episode = 0
plt.figure()
plt.subplot(1,2,1)
plt.imshow(detailed_stats[str(episode)]['SM_1']['raw_observations'][0]['depth'])
plt.title('character')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(detailed_stats[str(episode)]['SM_0']['raw_observations'][step]['depth'])
plt.title(f'observation at step {step}')
plt.axis('off')
plt.show()

In [None]:
graph = lm_models['pretrained'][0]['Korean_1']
color_id = graph.feature_mapping['principal_curvatures_log'][0]
curv_dir_ids = graph.feature_mapping['curvature_directions']
norm_len = 3

fig = plt.figure()
ax = plt.subplot(1,1,1,projection='3d')
s = ax.scatter(
    graph.pos[:, 0],
    graph.pos[:, 1],
    graph.pos[:, 2],
    s=5,
    alpha=0.5,
    c='grey'#np.arctan(np.array(graph.x[:,color_id])/100)
)

ax.set_xticks([]), ax.set_yticks([]), ax.set_zticks([])
# ax_range = 0.07
# ax.set_xlim(0-ax_range, 0+ax_range)
# ax.set_ylim(0.05-ax_range,0.05+ax_range)
# ax.set_zlim(0-ax_range, 0+ax_range)
ax.set_title("Learned Locations and Norm")

for p_id, p in enumerate(np.array(graph.pos)):
    norm = np.array(graph.norm[p_id])
#     norm = graph.x[p_id][curv_dir_ids[0]:curv_dir_ids[0]+3]
    ax.plot([p[0], p[0] + norm[0] * norm_len],
            [p[1], p[1] + norm[1] * norm_len],
            [p[2], p[2] + norm[2] * norm_len],
#             c=[np.abs(norm[0]),np.abs(norm[1]),np.abs(norm[2])]#'lightblue',
            c=[(norm[0]+1)*0.5,(norm[1]+1)*0.5,(norm[2]+1)*0.5]
           )
ax.set_aspect("equal")
plt.show()

In [None]:
obj = 'Korean_1'
graph = lm_models['pretrained'][0][obj]
pc1ispc2 = graph.feature_mapping['pc1_is_pc2']
curv_dir_ids = graph.feature_mapping['curvature_directions']
norm_len = 10

fig = plt.figure()
ax = plt.subplot(1,1,1,projection='3d')
s = ax.scatter(
    graph.pos[:, 0],
    graph.pos[:, 1],
    graph.pos[:, 2],
    s=5,
)

ax.set_xticks([]), ax.set_yticks([]), ax.set_zticks([])
# ax_range = 0.07
# ax.set_xlim(0-ax_range, 0+ax_range)
# ax.set_ylim(0.05-ax_range,0.05+ax_range)
# ax.set_zlim(0-ax_range, 0+ax_range)
ax.set_title("Principal Curvature Directions")

for p_id, p in enumerate(np.array(graph.pos)):
    norm = graph.norm[p_id]
#     ax.plot([p[0], p[0] + norm[0] * norm_len],
#                 [p[1], p[1] + norm[1] * norm_len],
#                 [p[2], p[2] + norm[2] * norm_len],
#                 c='blue',
#                )
#     if not graph.x[p_id, pc1ispc2[0]]:
    cd1 = graph.x[p_id][curv_dir_ids[0]:curv_dir_ids[0]+3]
    cd2 = graph.x[p_id][curv_dir_ids[0]+3:curv_dir_ids[0]+6]
    ax.plot([p[0], p[0] + cd1[0] * norm_len],
            [p[1], p[1] + cd1[1] * norm_len],
            [p[2], p[2] + cd1[2] * norm_len],
            c='red',
           )
    ax.plot([p[0], p[0] + cd2[0] * norm_len * 0.5],
            [p[1], p[1] + cd2[1] * norm_len * 0.5],
            [p[2], p[2] + cd2[2] * norm_len * 0.5],
            c='orange',
           )
ax.set_aspect("equal")
plt.show()

In [None]:
show_initial_hypotheses(detailed_stats, episode=4, obj='Korean_1')

In [None]:
plt.figure()
plt.hist(detailed_stats['4']['LM_0']['evidences'][3]['Korean_2'])
plt.show()

In [None]:
objects = ['Korean_1','Korean_2','Korean_3','Gujarati_2']

episode = 4
lm = 'LM_0'
current_evidence_update_th = -1

In [None]:
for step in range(eval_stats['num_steps'][episode]):
    plot_evidence_at_step(detailed_stats,
                          lm_models,
                              episode, 
                              step,
                              objects,
                              save_fig=True, 
                              save_path=save_path)

## Plot Longer Run Stats

In [None]:
pretrain_path = os.path.expanduser("~/tbp/results/monty/pretrained_models/")
pretrained_dict = pretrain_path + "pretrained_ycb/supervised_pre_training_on_omniglot_large/pretrained/"
log_path = os.path.expanduser("~/tbp/results/monty/projects/monty_runs/")
exp_name = "evidence_on_omniglot_large/"
exp_path = log_path + exp_name
save_path = exp_path + '/stepwise_examples/'
train_stats, eval_stats, detailed_stats, lm_models = load_stats(exp_path,
                                                                load_train=False,
                                                                load_eval=True,
                                                                load_detailed=False,
                                                                pretrained_dict=pretrained_dict,
                                                               )

In [None]:
same_v_perf = []
new_v_perf = []
for i, perf in enumerate(eval_stats['performance']):
    if i%2 == 0:
        same_v_perf.append(perf)
    else:
        new_v_perf.append(perf)

In [None]:
fig, ax = plt.subplots()
plt.hist([same_v_perf, new_v_perf], rwidth=0.6, color=['green','orange'])
plt.legend(['same version', 'new version'], fontsize=15)
plt.ylabel('count', fontsize=15)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.title('Performance on Omniglot (1 Alphabet)', fontsize=15)
plt.show()