In [1]:
import os
import semiparametrictransfer

import glob

In [2]:
from semiparametrictransfer.data_sets.data_loader import FixLenVideoDataset
import os
import matplotlib.pyplot as plt
from semiparametrictransfer.utils.general_utils import AttrDict
import numpy as np
import pdb
from tqdm import tqdm
import h5py
from collections import OrderedDict 


from semiparametrictransfer.utils.construct_html import save_gif_list_direct
from semiparametrictransfer.utils.construct_html import fill_template, save_html_direct

  from ._conv import register_converters as _register_converters
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])

  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])

  _np_qint16 = np.dtype([("qint16", np.int16, 1)])

  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])

  _np_qint32 = np.dtype([("qint32", np.int32, 1)])

  np_resource = np.dtype([("resource", np.ubyte, 1)])



In [None]:
data_dir = os.environ['DATA'] + '/spt_trainingdata' + '/sim/tabletop-texture'
hp = AttrDict(img_sz=(48, 64),
              sel_len=-1,
              T=31)

loader = FixLenVideoDataset(data_dir, hp).get_data_loader(32)

num_batch = 100
all_images = []
all_object_qpos = []

for i_batch, sample_batched in enumerate(loader):
    images = np.asarray(sample_batched['demo_seq_images'])

    images = (images + 1) / 2
    images = np.transpose(images, [0, 1, 3, 4, 2])  # convert to channel-first
    images = (images*255).astype(np.uint8)
    actions = np.asarray(sample_batched['actions'])
    states = np.asarray(sample_batched['states'])
    
    n_objects = 3
    object_qpos = states[:, :, 9:15].reshape(states.shape[0], states.shape[1], n_objects, 2)
    
    if i_batch == num_batch - 1:
        break
        
    all_images.append(images)
    all_object_qpos.append(object_qpos)
        
#     plt.imshow(np.asarray(images[0, 0]))
#     plt.show()
    

all_images = np.concatenate(all_images, 0)
all_object_qpos = np.concatenate(all_object_qpos, 0)

In [11]:
def read_traj(path):
    with h5py.File(path, 'r') as F:
        ex_index = 0
        key = 'traj{}'.format(ex_index)

        data_dict = {}
        # Fetch data into a dict
        for name in F[key].keys():
            if name in ['states', 'actions']:
                data_dict[name] = F[key + '/' + name].value.astype(np.float32)
        states = data_dict['states']
        n_objects = 3
        data_dict['object_qpos'] = states[:, 9:15].reshape(states.shape[0], n_objects, 2)
        data_dict['images'] = F[key + '/images'].value
    return data_dict


def _get_filenames(data_dir, phase):
    assert 'hdf5' not in data_dir, "hdf5 most not be containted in the data dir!"
    filenames = sorted(glob.glob(os.path.join(data_dir, os.path.join('hdf5', phase) + '/*')))
    if not filenames:
        raise RuntimeError('No filenames found in {}'.format(data_dir))
    return filenames

def load_all_data(orig_path):
    phases = ["train", "val", "test"]
    all_data_dict = OrderedDict()
    for phase in phases:
        filenames = _get_filenames(orig_path, phase)
        print('found {} traj for {}'.format(len(filenames), phase))
        print('loading files')
        for path in tqdm(filenames):
            single_filename = str.split(path, '/')[-1]
            all_data_dict[single_filename] = read_traj(path)
    return all_data_dict

all_data_dict = load_all_data(os.environ['DATA'] + '/spt_trainingdata' + '/sim/tabletop-texture')

  1%|          | 70/8978 [00:00<00:12, 695.81it/s]

found 8978 traj for train
loading files


100%|██████████| 8978/8978 [00:11<00:00, 774.77it/s]
  2%|▏         | 8/499 [00:00<00:06, 71.73it/s]

found 499 traj for val
loading files


100%|██████████| 499/499 [00:05<00:00, 87.72it/s]
  2%|▏         | 9/523 [00:00<00:06, 83.20it/s]

found 523 traj for test
loading files


100%|██████████| 523/523 [00:06<00:00, 85.29it/s]


In [12]:

all_object_qpos = np.stack([all_data_dict[key]['object_qpos'] for key in all_data_dict.keys()])
all_images = np.stack([all_data_dict[key]['images'] for key in all_data_dict.keys()]).squeeze()
len(all_data_dict.keys())

10000

In [13]:
all_data_dict['traj_0to1.h5'].keys()

dict_keys(['actions', 'states', 'object_qpos', 'images'])

In [14]:
def compute_nearest_neighbors():
    nearest_ind = {} 
    
    obj_displacements = all_object_qpos[:, -1] - all_object_qpos[:, 0]
    obj_displacements_mag = np.linalg.norm(obj_displacements, axis=-1)
    largest_displacement_index = np.argmax(obj_displacements_mag, axis=1)
    
    num_vis_traj = 20
    num_traj = all_object_qpos.shape[0]
    vis_indices = [i for i in range(num_traj) if largest_displacement_index[i] > 0.01][:num_vis_traj]
    
    # get largest obj displacement per trajectory
    largest_displacement = np.stack([obj_displacements[i, ind] for i, ind in enumerate(largest_displacement_index)])  
    
    for i, k in tqdm(enumerate(all_data_dict.keys())):
        # compute the magnitude of differences between i-th displacement vector and all other displacements
        diff_mag = np.linalg.norm(obj_displacements - largest_displacement[i][None, None], axis=-1)
        
        # take the minimum difference among all 3 objects:
        diff_mag = np.min(diff_mag, axis=-1)
        # get the batch indices of the lowest dist:
        
        numbest_k = 128
        best_ind = np.argsort(diff_mag)[:numbest_k]
        
#         print('i {}: bestind {} largest disp {}'.format(i, best_ind[:10], largest_displacement[i]))
                
        nearest_ind[i] = best_ind
        all_data_dict[k]['nearest_ind'] = best_ind
    return nearest_ind, vis_indices
nearest_ind, vis_indices = compute_nearest_neighbors()

10000it [00:13, 766.73it/s]


In [15]:
# show only the nearest neighbors for the top 10
show_nn_gifs = 5
def save_gifs():
    all_inds = []
    
    for k in vis_indices:
        all_inds.append(k)
        all_inds.extend(nearest_ind[k][:show_nn_gifs])
            
    all_inds = set(all_inds)
    print('saving {} traj'.format(len(all_inds)))
    gif_list = [(ind, all_images[ind]) for  ind in all_inds]
    
    folder = os.environ['DATA'] + '/spt_trainingdata' + '/sim/tabletop-texture/visuals'
    name = 'sawyer'
    html_paths = save_gif_list_direct(folder, name, gif_list)
    return html_paths

html_paths = save_gifs()
print('done')

  1%|          | 1/99 [00:00<00:14,  6.54it/s]

saving 99 traj


100%|██████████| 99/99 [00:12<00:00,  7.90it/s]

done





In [16]:
def save_nearest_neighbor_gifs():
    itemdict = {}
    # show only the nearest neighbors for the top 10
    for i in vis_indices:
        nearest_ind_i = nearest_ind[i][:show_nn_gifs]
        nearest_paths = [html_paths[ind] for ind in nearest_ind_i]
#         import pdb; pdb.set_trace()
        itemdict['img{}'.format(i)] = [html_paths[i]] + nearest_paths
            
    html_page = fill_template(itemdict)
    save_html_direct(os.environ['DATA'] + '/spt_trainingdata' + '/sim/tabletop-texture/visuals/index.html', html_page)
    
save_nearest_neighbor_gifs()

In [18]:
def save_nearest_neighbors():
    from copy import deepcopy
    
    all_data_dict_ = deepcopy(all_data_dict)
    all_data_dict_noimages = {}
    for key in all_data_dict.keys():
        all_data_dict_[key].pop('images')  
        all_data_dict_noimages[key] = all_data_dict_[key]
    all_data_dict_noimages['traj_0to1.h5'].keys()
        
    import _pickle as pkl
    pkl.dump(all_data_dict_noimages,
             open(os.environ['DATA'] + '/spt_trainingdata' + '/sim/tabletop-texture/all_data_dict_noimages_trainvaltest.pkl', 'wb'))
save_nearest_neighbors()