## Load data

In [1]:
import numpy as np
import _pickle as pkl
import torch
import torch.nn as nn
from torchvision import transforms
%matplotlib inline 
import matplotlib.pyplot as plt
import os, sys, time
sys.path.append('../..')
from utils import set_seed_torch, rgb2gray
set_seed_torch(3)
from argparse import Namespace
import json
from utils import load_models, load_vh_models, frame_stack

In [2]:
class ObjectView(object):
    def __init__(self, d): self.__dict__ = d
        
args = ObjectView({
 'res': 64,
 'dataset_path': '/home/olimoyo/visual-haptic-dynamics/experiments/data/datasets/{}'
                   .format("visual_haptic_2D_len16_osc_withGT_8C12919B740845539C0E75B5CBAF7965.pkl"),
 'models_dir': '/home/olimoyo/visual-haptic-dynamics/saved_models/{}'
                   .format("vaughan/osc/"),
 'device': 'cuda:1',
 'n_examples': 3,
 'n_pred': 8,
 'n_initial': 2
})

def load_models_dir(models_dir):
    """Load hyperparameters from trained model."""
    dict_of_models = {}
    for filedir in os.listdir(models_dir):
        fullpath = os.path.join(models_dir, filedir)
        if os.path.isdir(fullpath):
            with open(os.path.join(fullpath, 'hyperparameters.txt'), 'r') as fp:
                dict_of_models[fullpath] = Namespace(**json.load(fp))
    return dict_of_models

In [3]:
with open(args.dataset_path, 'rb') as f:
    raw_data = pkl.load(f)

data = {}
data['img_rgb'] = torch.from_numpy(raw_data["img"].transpose(0, 1, 4, 2, 3)).int().to(device=args.device)
data['img_gray'] = torch.from_numpy(rgb2gray(raw_data["img"]).transpose(0, 1, 4, 2, 3)).float().to(device=args.device)
data['haptic'] = torch.from_numpy(raw_data['ft']).float().to(device=args.device) / 100.0
data['arm'] = torch.from_numpy(raw_data['arm']).float().to(device=args.device)

data['actions'] = torch.from_numpy(raw_data["action"]).to(device=args.device).float()

FileNotFoundError: [Errno 2] No such file or directory: '/home/olimoyo/visual-haptic-dynamics/experiments/data/datasets/visual_haptic_2D_len16_osc_withGT_8C12919B740845539C0E75B5CBAF7965.pkl'

## Visualize

In [None]:
dict_of_models = load_models_dir(args.models_dir)

with torch.no_grad():
    for path, model_args in dict_of_models.items():
            
        nets = load_vh_models(path=path, args=model_args, mode='eval', device=args.device)
        
        def encode(nets, x_img, x_ft, x_arm, ctx_img):
            if model_args.context_modality != "none":
                if model_args.context_modality == "joint": 
                    ctx = torch.cat((x_ft, x_arm), dim=-1) # (n, l, f, 12)
                elif model_args.context_modality == "ft": 
                    ctx = x_ft
                elif model_args.context_modality == "arm":
                    ctx = x_arm
                ctx = ctx.float().to(device=args.device) # (n, l, f, 6)
                ctx = ctx.transpose(-1, -2)
                ctx = ctx.reshape(-1, *ctx.shape[2:])
                
            n, l = x_img.shape[0], x_img.shape[1]

            if model_args.context in ["initial_image", "goal_image"]:
                x_img = torch.cat((x_img, ctx_img.repeat(1, l, 1, 1, 1)), dim=2)
            
            z_all_enc = []
            z_img = nets["img_enc"](x_img.reshape(-1, *x_img.shape[2:]))
            z_all_enc.append(z_img)  

            if model_args.context_modality != "none":
                z_context = nets["context_enc"](ctx)
                z_all_enc.append(z_context)   
            
            if model_args.context in ["initial_latent_state", "goal_latent_state"]:
                z_img_context = nets["context_img_enc"](ctx_img)
                z_img_context_rep = z_img_context.unsqueeze(1).repeat(1, l, 1)
                z_all_enc.append(z_img_context_rep.reshape(-1, *z_img_context_rep.shape[2:]))        
                ret_context = z_img_context
            elif model_args.context in ["all_past_states"]:
                if l > 1:
                    z_img_context, h_img_context = nets["context_img_rnn_enc"](
                        z_img.reshape(n, l, *z_img.shape[1:])[:, :-1].transpose(1,0)
                    )
                    pad = torch.zeros((1, *z_img_context.shape[1:])).float().to(device=args.device)
                    z_img_context_latest, h_img_context_latest = nets["context_img_rnn_enc"](
                        z_img.reshape(n, l, *z_img.shape[1:]).transpose(1,0)
                    )
                    z_img_context = torch.cat((pad, z_img_context), dim=0)
                    z_img_context = z_img_context.transpose(1, 0)
                    ret_context = (z_img_context_latest.transpose(1, 0), h_img_context_latest)
                else:
                    z_img_context = torch.zeros((n, l, 16)).float().to(device=args.device)
                    ret_context = (z_img_context, None)
                z_all_enc.append(z_img_context.reshape(-1, *z_img_context.shape[2:]))
            else:
                ret_context = None
                        
            z_cat_enc = torch.cat(z_all_enc, dim=-1)
            z, mu_z, logvar_z = nets["mix"](z_cat_enc)
            var_z = torch.diag_embed(torch.exp(logvar_z))
            
            return z, mu_z, var_z, ret_context
        
        if model_args.dim_x[0] == 1:
            img_key = 'img_gray'
        elif model_args.dim_x[0] == 3:
            img_key = 'img_rgb'
                    
        T = model_args.frame_stacks
        
        assert args.n_initial + args.n_pred <= data[img_key].shape[1]
        assert args.n_initial > T
        
        # Use a random batch to test
        ii = np.random.randint(data[img_key].shape[0] // args.n_examples)
        batch_range = range(args.n_examples*ii, args.n_examples*(ii+1))
        test_batch = {k:v[batch_range] for k,v in data.items()}
        
        # Ground truth images and controls
        x_img = test_batch[img_key][:, :(args.n_initial + args.n_pred)]
        u = test_batch['actions']
        
        # Sequence of initial images
        x_img_i = x_img[:, :args.n_initial]
        x_img_i = frame_stack(x_img_i, frames=T)
        
        # Sequence of ground truth images
        x_img_gt = x_img[:, -(T + 1):]
        x_img_gt = frame_stack(x_img_gt, frames=T)
        
        # Sequence of extra modalities
        x_ft_i = test_batch['haptic'][:, T:args.n_initial]
        x_arm_i = test_batch['arm'][:, T:args.n_initial]

        n, l = x_img_i.shape[0], x_img_i.shape[1] 

        if model_args.context in ["initial_latent_state", "initial_image"]:
            ctx_img = x_img_i[:, 0]
        elif model_args.context in ["goal_latent_state", "goal_image"]:
            ctx_img = x_img_gt[:, -1]
        else:
            ctx_img = None

        # Encode
        z_ini, mu_z_ini, var_z_ini, ret_context = encode(nets, x_img_i, x_ft_i, x_arm_i, ctx_img)
        
        if model_args.context in ["all_past_states"]:
            z_img_context, h_img_context = ret_context
        else:
            z_img_context = ret_context
        h_i = None
                
        # Group and prepare for prediction
        q_z_i = {"z": z_ini, "mu": mu_z_ini, "cov": var_z_ini}
        q_z_i = {k:v.reshape(n, l, *v.shape[1:]).transpose(1,0) for k, v in q_z_i.items()}
        u = u.transpose(1,0)

        z_hat = torch.zeros(((l + args.n_pred), n, model_args.dim_z)).to(device=args.device)
        z_hat[0:l] = q_z_i["mu"]
        
        # First run
        z_i, mu_z_i, var_z_i = q_z_i["z"], q_z_i["mu"], q_z_i["cov"]
        u_i = u[(T + 1):(1 + args.n_initial)]
        
        # Predict
        for jj in range(0, args.n_pred):
            z_ip1, mu_z_ip1, var_z_ip1, h_ip1 = nets["dyn"](
                z_t=z_i, 
                mu_t=mu_z_i, 
                var_t=var_z_i, 
                u=u_i, 
                h_0=h_i, 
                single=False
            )
            z_hat[jj + l] = mu_z_ip1[-1]
            z_i, mu_z_i, var_z_i, h_i = z_ip1[-1:], mu_z_ip1[-1:], var_z_ip1[-1:], h_ip1
            u_i = u[1 + args.n_initial + jj][None]
            if model_args.context in ["all_past_states"]:
                z_img_context_dec = z_img_context[:, -1]
                z_cat_single_dec = torch.cat((mu_z_ip1[-1], z_img_context_dec), dim=-1)
                x_hat_ip1 = nets["img_dec"](z_cat_single_dec)
                z_img_ip1 = nets["img_enc"](x_hat_ip1)
                z_img_context_ip1, h_img_context = nets["context_img_rnn_enc"](z_img_ip1.unsqueeze(0), h=h_img_context)
                z_img_context = torch.cat((z_img_context, z_img_context_ip1.transpose(1,0)), dim=1)
                
        z_hat = z_hat.transpose(1, 0)
        
        # Decode
        z_all_dec = []
        z_all_dec.append(z_hat)

        if model_args.context in ["initial_latent_state", "goal_latent_state"]:
            z_img_context_rep = z_img_context.unsqueeze(1).repeat(1, (l + args.n_pred), 1)
            z_all_dec.append(z_img_context_rep)
        elif model_args.context in ["all_past_states"]:
            z_img_context = z_img_context.transpose(1, 0)
            pad = torch.zeros((1, *z_img_context.shape[1:])).float().to(device=args.device)
            z_img_context = torch.cat((pad, z_img_context[:-1]), dim=0)
            z_img_context = z_img_context.transpose(1, 0)
            z_all_dec.append(z_img_context)

        z_cat_dec = torch.cat(z_all_dec, dim=-1)
        x_hat = nets["img_dec"](z_cat_dec.reshape(-1, *z_cat_dec.shape[2:]))
        x_hat = x_hat.reshape(n, (l + args.n_pred), *x_hat.shape[1:])
                
        # Move to cpu, np
        x_hat = x_hat.cpu().numpy()
        x_img = x_img.cpu().numpy()

        mse = np.sum((x_hat[:, l:, 0].reshape(n, args.n_pred, -1) - 
                      x_img[:, args.n_initial:, 0].reshape(n, args.n_pred, -1))**2, axis=2)

        # Plotting
        for bb in range(n):
            columns = args.n_initial + args.n_pred
            rows = 2
            fig=plt.figure(figsize=(16, 2))
            fig.suptitle('n_initial = {}, frame_stacks = {}, predictions = {}'.format(args.n_initial, T, args.n_pred))
            for ii in range(columns*rows):
                if ii<((columns*rows)/2):
                    img = x_img[bb,ii,0,:,:]
                else:
                    idx = int(ii-((columns*rows)/2))
                    if idx < T:
                        img = np.zeros((model_args.dim_x[1], model_args.dim_x[2]))
                    else:
                        img = x_hat[bb,idx-1,0,:,:]
                fig.add_subplot(rows, columns, ii+1)
                plt.imshow(img, cmap="gray")
                
                plt.axis('off')
            print("MSE of predictions: ", mse[bb])
            plt.subplots_adjust(wspace=0.1, hspace=0.1)
            plt.show()