In [1]:
import numpy as np
import _pickle as pkl
import torch
from torchvision import transforms
%matplotlib inline 
import matplotlib.pyplot as plt
import os, sys, time
sys.path.append('../..')
from utils import set_seed_torch, rgb2gray
set_seed_torch(3)

In [2]:
class ObjectView(object):
    def __init__(self, d): self.__dict__ = d
        
args = ObjectView({'res': 64,
 'dataset_path': '/home/olimoyo/visual-haptic-dynamics/experiments/data/datasets/{}'
                   .format("visual_haptic_2D_len16_569B46785E3F45BCA172AE53EA070D5E.pkl"),
 'models_dir': '/home/olimoyo/visual-haptic-dynamics/saved_models/{}'
                   .format("test"),
 'device': 'cuda:0',
 'up_to_n_pred': 6,
})

In [3]:
with open(args.dataset_path, 'rb') as f:
    data = pkl.load(f)

x = {}
x['img_rgb'] = torch.from_numpy(data["img"].transpose(0, 1, 4, 2, 3)).int().to(device=args.device)
x['img_gray'] = torch.from_numpy(rgb2gray(data["img"]).transpose(0, 1, 4, 2, 3)).float().to(device=args.device)
x['haptic'] = torch.from_numpy(data['ft']).float().to(device=args.device)
x['arm'] = torch.from_numpy(data['arm']).float().to(device=args.device)

actions = torch.from_numpy(data["action"]).to(device=args.device).float()

In [4]:
print(x['img_rgb'].shape)
print(x['img_gray'].shape)
print(x['haptic'].shape)
print(x['arm'].shape)
print(actions.shape)

torch.Size([1408, 16, 3, 64, 64])
torch.Size([1408, 16, 1, 64, 64])
torch.Size([1408, 16, 32, 6])
torch.Size([1408, 16, 32, 6])
torch.Size([1408, 16, 2])


In [5]:
from utils import load_vh_models, frame_stack
from argparse import Namespace
import json
import torch.nn as nn

In [13]:
models = {}

# Load hyperparameters from trained model
for filedir in os.listdir(args.models_dir):
    fullpath = os.path.join(args.models_dir, filedir)
    if os.path.isdir(fullpath):
        models[fullpath] = {}
        with open(os.path.join(fullpath, 'hyperparameters.txt'), 'r') as fp:
            models[fullpath]['hyperparameters'] = Namespace(**json.load(fp))

with torch.no_grad():
    for path, model in models.items():
        model_args = model['hyperparameters']
        models = load_vh_models(path, model_args, mode='eval', device=args.device)

        if model_args.dim_x[0] == 1:
            imgs = x['img_gray']
        elif model_args.dim_x[0] == 3:
            imgs = x['img_rgb']
        
        total_len = imgs.shape[1]
        T = model_args.frame_stacks
                    
        data = {}
        for n_pred in range(1, args.up_to_n_pred):
            data[f"{n_pred}_pred"] = {}
            data[f"{n_pred}_pred"]["MSE"] = 0
            data[f"{n_pred}_pred"]["SSIM"] = 0
            data[f"{n_pred}_pred"]["PSNR"] = 0
            
            for ii in range(T, total_len - n_pred):
                x_i = imgs[:, (ii - 1):(ii + 1)] 
                x_i = frame_stack(x_i, frames=T)  
                
                x_ft_i = x['haptic'][:, ii:(ii + 1)] 
                x_arm_i = x['arm'][:, ii:(ii + 1)] 
                
                x_gt = imgs[:, (ii + n_pred)]
                u = actions[:, (ii + 1):(ii + n_pred + 1)]
                
                # Encode
                z_all = []
                if model_args.use_img_enc:
                    z_all.append(models["img_enc"](x_i.reshape(-1, *x_i.shape[2:])))

                if model_args.use_joint_enc:
                    joint_inp = torch.cat((
                        x_ft_i.reshape(-1, *x_ft_i.shape[2:]), 
                        x_arm_i.reshape(-1, *x_arm_i.shape[2:])), 
                        dim=-1
                    )
                    z_all.append(models["joint_enc"](joint_inp)[:, -1])
                else:
                    if model_args.use_haptic_enc:
                        z_all.append(models["haptic_enc"](x_ft_i.reshape(-1, *x_ft_i.shape[2:]))[:, -1])
                    if model_args.use_arm_enc:
                        z_all.append(models["arm_enc"](x_arm_i.reshape(-1, *x_arm_i.shape[2:]))[:, -1])
                
                z_cat_i = torch.cat(z_all, dim=1)
                z_i, mu_z_i, logvar_z_i = models["mix"](z_cat_i)
                var_z_i = torch.diag_embed(torch.exp(logvar_z_i))
                
                # Predict
                h_i = None
                for jj in range(n_pred):
                    z_ip1, mu_z_ip1, var_z_ip1, h_ip1 = models["dyn"](
                        z_t=z_i, 
                        mu_t=mu_z_i, 
                        var_t=var_z_i, 
                        u=u[:, jj], 
                        h=h_i, 
                        single=True
                    )
                    z_i, mu_z_i, var_z_i, h_i = z_ip1, mu_z_ip1, var_z_ip1, h_ip1    

                # Decode 
                x_hat = models["img_dec"](z_ip1)

                # TODO: Calculate MSE/SSIM/PSNR for single prediction
                x_hat = x_hat[:, 0:1]

                data[f"{n_pred}_pred"]["MSE"] += 0
                data[f"{n_pred}_pred"]["SSIM"] += 0
                data[f"{n_pred}_pred"]["PSNR"] += 0
                
            
            # Average MSE/SSIM/PSNR
            for k in data[f"{n_pred}_pred"]:
                data[f"{n_pred}_pred"][k] /= total_len - n_pred - 1
    print("DONE")
            

Loading models in path:  /home/olimoyo/visual-haptic-dynamics/saved_models/test/2D_lightTCN_base_gru_lm_vha
DONE


In [14]:
print(data)

{'1_pred': {'MSE': 0.0, 'SSIM': 0.0, 'PSNR': 0.0}, '2_pred': {'MSE': 0.0, 'SSIM': 0.0, 'PSNR': 0.0}, '3_pred': {'MSE': 0.0, 'SSIM': 0.0, 'PSNR': 0.0}, '4_pred': {'MSE': 0.0, 'SSIM': 0.0, 'PSNR': 0.0}, '5_pred': {'MSE': 0.0, 'SSIM': 0.0, 'PSNR': 0.0}}
