## Load data

In [None]:
import numpy as np
import _pickle as pkl
import torch
import torch.nn as nn
%matplotlib inline 
import matplotlib.pyplot as plt
import os, sys, time
sys.path.append('../..')
from utils import set_seed_torch, rgb2gray, load_vh_models, frame_stack
set_seed_torch(3)
from train import encode
from argparse import Namespace
import json
import gzip

In [None]:
class ObjectView(object):
    def __init__(self, d): self.__dict__ = d
        
args = ObjectView({
 'res': 64,
#  'dataset_path': ['/home/olimoyo/visual-haptic-dynamics/experiments/data/datasets/{}'
#                    .format("vha1_2D_len16_oscxy_withGT_0B7AB071F98942578ABDA66879290F2F.pkl"),
#                   '/home/olimoyo/visual-haptic-dynamics/experiments/data/datasets/{}'
#                    .format("vha2_2D_len16_oscxy_withGT_3502DE81F7C343FB8B57FA92FDECF4DA.pkl"),
#                   '/home/olimoyo/visual-haptic-dynamics/experiments/data/datasets/{}'
#                    .format("vha3_2D_len16_oscxy_withGT_5DB32B21A6AA4E5892D2F6B8F40EF9E6.pkl")
#                  ],
#  'models_dir': '/home/olimoyo/visual-haptic-dynamics/saved_models/{}'
#                    .format("ablation"),
#  'n_pred': 13,
#  'dataset_path': ['/home/olimoyo/visual-haptic-dynamics/experiments/data/datasets/mit_push/{}'
#                    .format("rng-initial_min-tr2.5_min-rot0.5_len48.pkl")
#                  ],
#  'models_dir': '/home/olimoyo/visual-haptic-dynamics/saved_models/{}'
#                    .format("new_rec/none/priorexpert_false"),
#  'n_pred': 45,
 'device': 'cuda:1',
 'n_examples': 4,
 'n_initial': 2
})

def load_models_dir(models_dir):
    """Load hyperparameters from trained model."""
    dict_of_models = {}
    for filedir in os.listdir(models_dir):
        fullpath = os.path.join(models_dir, filedir)
        if os.path.isdir(fullpath):
            with open(os.path.join(fullpath, 'hyperparameters.txt'), 'r') as fp:
                dict_of_models[fullpath] = Namespace(**json.load(fp))
    return dict_of_models

def is_gz_file(filepath):
    with open(filepath, 'rb') as test_f:
        return test_f.read(2) == b'\x1f\x8b'
    
data = {
#     'img_rgb': [],
    'img_gray': [],
    'haptic': [],
    'arm': [],
    'actions': []
}

for dataset_path in args.dataset_path:
    if is_gz_file(dataset_path):
        with gzip.open(dataset_path, 'rb') as f:
            raw_data = pkl.load(f)
    else:
        with open(dataset_path, 'rb') as f:
            raw_data = pkl.load(f)

#     data['img_rgb'].append(torch.from_numpy(raw_data["img"].transpose(0, 1, 4, 2, 3)).int().to(device=args.device))
    data['img_gray'].append(torch.from_numpy(rgb2gray(raw_data["img"]).transpose(0, 1, 4, 2, 3)).float().to(device=args.device))
    data['haptic'].append(torch.from_numpy(raw_data['ft']).float().to(device=args.device))
    data['arm'].append(torch.from_numpy(raw_data['arm']).float().to(device=args.device))
    data['actions'].append(torch.from_numpy(raw_data["action"]).to(device=args.device).float())

data = {k:torch.cat(v, dim=0) for k,v in data.items()}

## Visualize

In [None]:
dict_of_models = load_models_dir(args.models_dir)
val_idx = None
analysis_data = {}

with torch.no_grad():
    for path, model_args in dict_of_models.items():
        nets = load_vh_models(path=path, args=model_args, mode='eval', device=args.device)
        analysis_data[path] = {}

#         if model_args.dim_x[0] == 1:
#             img_key = 'img_gray'
#         elif model_args.dim_x[0] == 3:
#             img_key = 'img_rgb'
        img_key = 'img_gray'
        
        # XXX: Assume same validation indices for all models, then we can compare on the same examples
        if val_idx is None:
            # Use validation indices only
            with open(os.path.join(path, "val_idx.pkl"), 'rb') as f:
                val_idx = pkl.load(f)
            data_val = {k:v[val_idx] for k,v in data.items()}

            # Use a random batch to test
            ii = np.random.randint(data_val[img_key].shape[0] // args.n_examples)
            batch_range = range(args.n_examples*ii, args.n_examples*(ii+1))
            test_batch = {k:v[batch_range] for k,v in data_val.items()}

        T = model_args.frame_stacks
        
        assert args.n_initial + args.n_pred <= data_val[img_key].shape[1]
        assert args.n_initial > T
        
        # Set up data for batch
        x_img = test_batch[img_key][:, :(args.n_initial + args.n_pred)]
        x_ft = test_batch['haptic'][:, :(args.n_initial + args.n_pred)]
        x_arm = test_batch['arm'][:, :(args.n_initial + args.n_pred)]
        u = test_batch['actions']
        x_i = {}
        
        # Sequence of initial images
        x_img_i = x_img[:, :args.n_initial]
        x_img_i = frame_stack(x_img_i, frames=T)
        n, l = x_img_i.shape[0], x_img_i.shape[1] 
        x_i["img"] = x_img_i
        
        # Sequence of extra modalities
        x_ft_i = x_ft[:, :args.n_initial] / model_args.ft_normalization
        x_arm_i = x_arm[:, :args.n_initial]
        u_i = u[:, T:args.n_initial]

        if model_args.context_modality != "none":
            if model_args.context_modality == "joint":
                x_i["context"] = torch.cat((x_ft_i, x_arm_i), dim=-1)
            elif model_args.context_modality == "ft":
                x_i["context"] = x_ft_i
            elif model_args.context_modality == "arm":
                x_i["context"] = x_arm_i
                
            if model_args.use_context_frame_stack:
                x_i['context'] = frame_stack(x_i['context'], frames=T)
            else:
                x_i["context"] = x_i["context"][:, T:]
            x_i["context"] = x_i["context"].transpose(-1, -2)
            
        x_i = {k:v.reshape(-1, *v.shape[2:]) for k, v in x_i.items()}

        # Encode
        if model_args.use_prior_expert:
            q_z_i, _, _ = encode(nets, model_args, x_i, u_i, device=args.device)
        else:
            q_z_i = encode(nets, model_args, x_i, u_i, device=args.device)

        # Group and prepare for prediction
        q_z_i = {k:v.reshape(n, l, *v.shape[1:]).transpose(1,0) for k, v in q_z_i.items()}
        u = u.transpose(1,0)

        z_hat = torch.zeros(((l + args.n_pred), n, model_args.dim_z)).to(device=args.device)
        z_hat[0:l] = q_z_i["mu"]
        
        # First run
        z_i, mu_z_i, var_z_i = q_z_i["z"], q_z_i["mu"], q_z_i["cov"]
        u_pred = u[(T + 1):(1 + args.n_initial)]
        h_i = None

        # Predict
        for jj in range(0, args.n_pred):
            z_ip1, mu_z_ip1, var_z_ip1, h_ip1 = nets["dyn"](
                z_t=z_i, 
                mu_t=mu_z_i, 
                var_t=var_z_i, 
                u=u_pred, 
                h_0=h_i, 
                single=False
            )
            z_hat[jj + l] = mu_z_ip1[-1]
            z_i, mu_z_i, var_z_i, h_i = z_ip1[-1:], mu_z_ip1[-1:], var_z_ip1[-1:], h_ip1
            u_pred = u[1 + args.n_initial + jj][None]
                        
        # Decode
        z_hat = z_hat.transpose(1, 0)
        x_hat = nets["img_dec"](z_hat.reshape(-1, *z_hat.shape[2:]))
        x_hat = x_hat.reshape(n, (l + args.n_pred), *x_hat.shape[1:])

        # Move to cpu, np
        x_hat = x_hat.cpu().numpy()
        x_img = x_img.cpu().numpy()

        mse = np.sum((x_hat[:, l:, 0].reshape(n, args.n_pred, -1) - 
                      x_img[:, args.n_initial:, 0].reshape(n, args.n_pred, -1))**2, axis=2)
        
        analysis_data[path]["x_img"] = x_img
        analysis_data[path]["x_hat"] = x_hat
        analysis_data[path]["mse"] = mse

In [None]:
font = {'family':'serif', 'serif': ['computer modern roman']}
plt.rc('font',**font)

def plot(x_img, x_hat, mse):
    # Plotting
    for bb in range(n):
        columns = args.n_initial + args.n_pred
        rows = 2
        fig=plt.figure(figsize=(16, 2))
        for ii in range(columns*rows):
            empty=False
            if ii<((columns*rows)/2):
                img = x_img[bb,ii,0,:,:]
            else:
                idx = int(ii-((columns*rows)/2))
                if idx < (T + 1):
                    empty = True
                else:
                    img = x_hat[bb,idx-1,0,:,:]
            if not empty:
                ax = fig.add_subplot(rows, columns, ii+1)
                plt.imshow(img, cmap="gray")
            plt.axis('off')
        print("MSE of predictions: ", mse[bb])
    #     fig.tight_layout()
        plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.0, hspace=0.1725)
        plt.show()

for k,v in analysis_data.items():
    print(k)
    plot(v["x_img"], v["x_hat"], v["mse"])