In [20]:
import pickle as pkl
import numpy as np
from celluloid import Camera
import pdb
import scipy
import ipdb
import math
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from scipy.spatial.transform import Rotation as R
from matplotlib.patches import Rectangle
from IPython.display import HTML

In [21]:

def get_angle(rot):
    rot = R.from_quat(rot)
    euler = rot.as_euler('xzy')
    # dchange = np.sin(euler[1])*np.cos(euler[0]), np.cos(euler[1])*np.sin(euler[0])
    # dchange = np.sin(euler[1]+euler[0]), np.cos(euler[1]+euler[0])
    x = np.cos(euler[2])*np.cos(euler[1])
    y = np.sin(euler[2])*np.cos(euler[1])
    z = np.sin(euler[1])
    dchange = y, x
    return np.arctan2(x,y)*180 / math.pi
    
def plot_graph_2d(graph, ax, goal_ids, belief_ids=[]):


    #nodes_interest = [node for node in graph['nodes'] if 'GRABBABLE' in node['properties']]
    goals = [node for node in graph['nodes'] if node['class_name'] in goal_ids]
    
    belief_obj = [node for node in graph['nodes'] if node['id'] in belief_ids]

    # container_surf = dict_info['objects_inside'] + dict_info['objects_surface']
#     pdb.set_trace()
    container_surf = ['kitchentable', 'cabinet', 'kitchencabinet', 'fridge', 'bathroomcabinet', 'stove']
    container_and_surface = [node for node in graph['nodes'] if node['class_name'] in container_surf]
    container_open = [node for node in graph['nodes'] if node['class_name'] in container_surf and 'OPEN' in node['states']]

    #grabbed_obj = [node for node in graph['nodes'] if node['class_name'] in dict_info['objects_grab']]
    rooms = [node for node in graph['nodes'] if 'Rooms' == node['category']]


    # containers and surfaces
    # visible_nodes = [node for node in graph['nodes'] if node['id'] in visible_ids and node['category'] != 'Rooms']
    # action_nodes = [node for node in graph['nodes'] if node['id'] in action_ids and node['category'] != 'Rooms']

    # goal_nodes = [node for node in graph['nodes'] if node['class_name'] == 'cupcake']

    # Character
    # char_node = [node for node in graph['nodes'] if node['id'] == char_id][0]

    
    add_boxes(rooms, ax, points=None, rect={'alpha': 0.1})
    if len(container_and_surface) > 0:
        add_boxes(container_and_surface, ax, points=None, rect={'fill': False, 'edgecolor': 'blue', 'alpha': 0.3})
        
    if len(container_open) > 0:
#         print("HERE")
        add_boxes(container_open, ax, points=None, rect={'fill': False, 'edgecolor': 'orange', 'alpha': 1.0})
        
    #add_boxes([char_node], ax, points=None, rect={'facecolor': 'yellow', 'edgecolor': 'yellow', 'alpha': 0.7})
    #add_boxes(visible_nodes, ax, points={'s': 2.0, 'alpha': 1.0}, rect={'fill': False,
    #                     
    if len(goals) > 0:
        add_boxes(goals, ax, points={'s':  40.0, 'alpha': 1.0, 'edgecolors': 'orange', 'facecolors': 'none', 'linewidth': 1.0})
    if len(belief_obj) > 0:
        add_boxes(belief_obj, ax, points={'s':  30.0, 'alpha': 1.0, 'edgecolors': 'blue', 'facecolors': 'none', 'linewidth': 1.0})
    
    #add_boxes(action_nodes, ax, points={'s': 3.0, 'alpha': 1.0, 'c': 'red'})


    #bad_classes = ['character']

    ax.set_aspect('equal')
    bx, by = get_bounds([room['bounding_box'] for room in rooms])

    maxsize = max(bx[1] - bx[0], by[1] - by[0])
    gapx = (maxsize - (bx[1] - bx[0])) / 2.
    gapy = (maxsize - (by[1] - by[0])) / 2.

    ax.set_xlim(bx[0]-gapx, bx[1]+gapx)
    ax.set_ylim(by[0]-gapy, by[1]+gapy)
    ax.apply_aspect()
    
def add_box(nodes, args_rect):
    rectangles = []
    centers = [[], []]
    for node in nodes:
        cx, cy = node['bounding_box']['center'][0], node['bounding_box']['center'][2]
        w, h = node['bounding_box']['size'][0], node['bounding_box']['size'][2]
        minx, miny = cx - w / 2., cy - h / 2.
        centers[0].append(cx)
        centers[1].append(cy)
        if args_rect is not None:
            rectangles.append(
                Rectangle((minx, miny), w, h, **args_rect)
            )
    return rectangles, centers


def add_boxes(nodes, ax, points=None, rect=None):
    rectangles = []
    rectangles_class, center = add_box(nodes, rect)
    rectangles += rectangles_class
    if points is not None:
        ax.scatter(center[0], center[1], **points)
    if rect is not None:
        ax.add_patch(rectangles[0])
        collection = PatchCollection(rectangles, match_original=True)
        ax.add_collection(collection)
        
def get_bounds(bounds):
    minx, maxx = None, None
    miny, maxy = None, None
    for bound in bounds:
        bgx, sx = bound['center'][0] + bound['size'][0] / 2., bound['center'][0] - bound['size'][0] / 2.
        bgy, sy = bound['center'][2] + bound['size'][2] / 2., bound['center'][2] - bound['size'][2] / 2.
        minx = sx if minx is None else min(minx, sx)
        miny = sy if miny is None else min(miny, sy)
        maxx = bgx if maxx is None else max(maxx, bgx)
        maxy = bgy if maxy is None else max(maxy, bgy)
    return (minx, maxx), (miny, maxy)

In [22]:
def visualize_trajectory(file_path=None, gen_vid=False, plot_belief=False, belief_id=None):
    char_id = 1
    if belief_id is None:
        plot_belief = False
    with open(file_path, 'rb') as f:
        content = pkl.load(f)
    
    belief_agent = content['belief'][0]
    goal_objs = [x.split('_')[1] for x,y in content['goals'][0].items() if y > 0]
    observations = content['obs']
    print(len(observations))
    
    observations = content['graph']
    if False: #'obj_transform' in observations[0]['nodes'][0]:
        coords = [[node['obj_transform'] for node in obs if node['id'] == char_id][0] for obs in observations]
        rots = [get_angle(coord['rotation']) for coord in coords]
        xy = np.array([[coord['position'][0],coord['position'][2]] for coord in coords])
    else:
        coords = [[node['bounding_box'] for node in obs['nodes'] if node['id'] == char_id][0] for obs in observations]
        rots = None
        xy = np.array([[coord['center'][0],coord['center'][2]] for coord in coords])
        
    n = 250
    colors = plt.cm.jet(np.linspace(0,1,n))

    if not plot_belief:
        fig = plt.figure(figsize=(6,6))
        ax = plt.axes()
        plt.axis('off')
    else:
        fig = plt.figure(figsize=(12,6))
        grid = plt.GridSpec(2, 3, wspace=0.1, hspace=0.1)
        id_object = belief_id
        if False:
            try:
                #pass
                id_object = int(content['subgoals'][0][-1][0].split('_')[1])
            except:
                #pass
                # ipdb.set_trace()
                id_object = int(content['subgoals'][0][0][-1].split('_')[1])
        ax = fig.add_subplot(grid[:, :2])
        ax.axis('off')
        ax_belief = fig.add_subplot(grid[1, 2])
        id2name = {node['id']: node['class_name'] for node in content['init_unity_graph']['nodes']}
    
    if gen_vid:
        camera = Camera(fig)
    steps = len(xy)
    steps_total = list(range(len(xy)))
    if not gen_vid:
        steps_total = [len(xy)-1]
        
        
    for steps_t in tqdm(steps_total):
#         ipdb.set_trace()
        if belief_id is None:
            belief_ids = []
        else:
            belief_ids = [belief_id]
        plot_graph_2d({'nodes': observations[steps_t]['nodes']}, ax, goal_objs, belief_ids)
        if plot_belief:
            belief_object = content['belief'][0][steps_t]
            do_plot_belief(belief_object, id2name, id_object, ax_belief)
        
        its = steps_t
        cxy = xy[its:its+1,:]
        if rots is not None:
            angle = rots[its]
        else:
            angle = 0
        ax.scatter(cxy[:,0], cxy[:, 1], color=colors[its], s=50, marker= (3, 0, 270+angle))
        for steps in range(steps_t+1):
            it = steps

            if it > 0:
                cxy = xy[it-1:it+1,:]
                ax.plot(cxy[:,0], cxy[:, 1], '--', color=colors[it], )


        if gen_vid:
            camera.snap()
    if gen_vid:
        dir_name, fname = file_path.split('/')[-2:]
        fname = fname.replace('.pik', '')
        
        fn = '{}_{}.mp4'.format(dir_name, fname)
        final = camera.animate()
        final.save(fn)
        return final
    return None

def do_plot_belief(belief_object, id2name, id_object, currax):
#     currax.clear()
    belief_curr_object = belief_object[id_object]['INSIDE']
    names = belief_curr_object[0]
    probs = belief_curr_object[1]
    names = [id2name[name] if name is not None else 'None' for name in names]
    names = [name.replace('bathroom', 'b.').replace('kitchen', 'k.') for name in names]
    probs = scipy.special.softmax(probs)
    x = np.arange(len(names))
    currax.bar(x, probs, color='blue')
    currax.set_ylabel("Prob")
    currax.set_xticks(x)
    currax.grid(axis='y')
    currax.set_ylim([0,1])
    currax.set_xticklabels(names, rotation=40)
    

In [23]:
ls ../../data_scratch/train_env_task_set_20_full_reduced_tasks/

[0m[01;34m1_full_opencost0_closecostFalse_walkcost0.05_forgetrate0v9_particles_v2[0m/
[01;34m2_full_opencost500_closecostFalse_walkcost0.05_forgetrate0v9_particles_v2[0m/
[01;34m3_full_opencost-500_closecostFalse_walkcost0.05_forgetrate0v9_particles_v2[0m/
[01;34m4_partial_opencost0_closecostFalse_walkcost0.05_forgetrate0v9_particles_v2[0m/
[01;34m6_partial_opencost0_closecostFalse_walkcost0.05_forgetrate0.2v9_particles_v2[0m/
[01;34m8_partial_opencost-500_closecostFalse_walkcost0.05_forgetrate0.2v9_particles_v2[0m/


In [None]:
init_path = '../../data_scratch/train_env_task_set_20_full_reduced_tasks/'
file_path = [
#     f'{init_path}/3_full_opencost-500_closecostFalse_walkcost0.05_forgetrate0v9_particles_v2/logs_episode.70_iter.0.pik',
#     f'{init_path}/4_partial_opencost0_closecostFalse_walkcost0.05_forgetrate0v9_particles_v2/logs_episode.87_iter.0.pik',
    f'{init_path}/6_partial_opencost0_closecostFalse_walkcost0.05_forgetrate0.2v9_particles_v2/logs_episode.140_iter.0.pik',
    f'{init_path}/8_partial_opencost-500_closecostFalse_walkcost0.05_forgetrate0.2v9_particles_v2/logs_episode.140_iter.0.pik'
]
# for filep in file_path:
# 414, 415, 416
vid = visualize_trajectory(file_path[0], gen_vid=True, plot_belief=True, belief_id=414)

  1%|          | 3/250 [00:00<00:09, 27.05it/s]

250


 58%|█████▊    | 146/250 [00:19<00:13,  7.82it/s]

In [None]:
ls