In [1]:
import numpy as np
import _pickle as pkl
import torch
%matplotlib inline 
import matplotlib.pyplot as plt
import os, sys, time
sys.path.append('../..')
from utils import set_seed_torch, rgb2gray
set_seed_torch(3)

In [2]:
class ObjectView(object):
    def __init__(self, d): self.__dict__ = d
        
args = ObjectView({'res': 64,
 'dataset_path': '/home/olimoyo/visual-haptic-dynamics/experiments/data/datasets/{}'
                   .format("visual_haptic_2D_len16_withGT_3D9E4376CF4746EEA20DCD520218038D.pkl"),
 'models_dir': '/home/olimoyo/visual-haptic-dynamics/saved_models/{}'
                   .format("monolith"),
 'device': 'cuda:0',
 'up_to_n_pred': 8,
})

def load_models_dir(models_dir):
    """Load hyperparameters from trained model."""
    dict_of_models = {}
    for filedir in os.listdir(models_dir):
        fullpath = os.path.join(models_dir, filedir)
        if os.path.isdir(fullpath):
            with open(os.path.join(fullpath, 'hyperparameters.txt'), 'r') as fp:
                dict_of_models[fullpath] = Namespace(**json.load(fp))
    return dict_of_models

In [3]:
with open(args.dataset_path, 'rb') as f:
    data = pkl.load(f)

x = {}
x['img_rgb'] = torch.from_numpy(data["img"].transpose(0, 1, 4, 2, 3)).int().to(device=args.device)
x['img_gray'] = torch.from_numpy(rgb2gray(data["img"]).transpose(0, 1, 4, 2, 3)).float().to(device=args.device)
x['haptic'] = torch.from_numpy(data['ft']).float().to(device=args.device)
x['arm'] = torch.from_numpy(data['arm']).float().to(device=args.device)

actions = torch.from_numpy(data["action"]).to(device=args.device).float()

In [4]:
from utils import load_vh_models, frame_stack
from argparse import Namespace
import json
import torch.nn as nn
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

In [5]:
dict_of_models = load_models_dir(args.models_dir)
analysis_data = {}

with torch.no_grad():
    for path, hyperparams in dict_of_models.items():
        model_args = hyperparams
        model_name = path.split("/")[-1]
        models = load_vh_models(path, model_args, mode='eval', device=args.device)

        if model_args.dim_x[0] == 1:
            imgs = x['img_gray']
        elif model_args.dim_x[0] == 3:
            imgs = x['img_rgb']
        
        total_len = imgs.shape[1]
        T = model_args.frame_stacks
        
        analysis_data[model_name] = {}
        for n_pred in range(1, args.up_to_n_pred + 1):
            analysis_data[model_name][f"{n_pred}_pred"] = {}
            analysis_data[model_name][f"{n_pred}_pred"]["MSE_x"] = 0
            analysis_data[model_name][f"{n_pred}_pred"]["MSE_z"] = 0
            analysis_data[model_name][f"{n_pred}_pred"]["SSIM"] = 0
            analysis_data[model_name][f"{n_pred}_pred"]["PSNR"] = 0

            for ii in range(T, total_len - n_pred):
                x_i = imgs[:, (ii - 1):(ii + 1)] 
                x_i = frame_stack(x_i, frames=T)[:, 0]

                x_ft_i = x['haptic'][:, ii:(ii + 1)] 
                x_arm_i = x['arm'][:, ii:(ii + 1)] 
                
                x_gt = imgs[:, (ii + n_pred - 1):(ii + n_pred + 1)]
                x_gt = frame_stack(x_gt, frames=T)[:, 0]

                z_gt = models["img_enc"](x_gt)
                u = actions[:, (ii + 1):(ii + n_pred + 1)]
                joint_inp = torch.cat((
                        x_ft_i, 
                        x_arm_i), 
                        dim=-1
                    )
                
                # Encode
                z_all = []
                if model_args.use_img_enc:
                    z_all.append(models["img_enc"](x_i))
                if model_args.use_joint_enc:
                    joint_inp = torch.cat((
                        x_ft_i, 
                        x_arm_i), 
                        dim=-1
                    )
                    joint_inp = joint_inp.transpose(-1, -2)
                    z_all.append(models["joint_enc"](joint_inp.reshape(-1, *joint_inp.shape[2:])))
                else:
                    if model_args.use_haptic_enc:
                        z_all.append(models["haptic_enc"](x_ft_i.reshape(-1, *x_ft_i.shape[2:])))
                    if model_args.use_arm_enc:
                        z_all.append(models["arm_enc"](x_arm_i.reshape(-1, *x_arm_i.shape[2:])))
                
                z_cat_i = torch.cat(z_all, dim=1)
                z_i, mu_z_i, logvar_z_i = models["mix"](z_cat_i)
                var_z_i = torch.diag_embed(torch.exp(logvar_z_i))
                
                # Predict
                h_i = None
                for jj in range(n_pred):
                    z_ip1, mu_z_ip1, var_z_ip1, h_ip1 = models["dyn"](
                        z_t=z_i, 
                        mu_t=mu_z_i, 
                        var_t=var_z_i, 
                        u=u[:, jj], 
                        h_0=h_i, 
                        single=True
                    )
                    z_i, mu_z_i, var_z_i, h_i = z_ip1, mu_z_ip1, var_z_ip1, h_ip1    

                # Decode 
                z_hat = mu_z_ip1
                x_hat = models["img_dec"](z_hat)

                x_hat = x_hat[:, 0:1].view(x_hat.shape[0], -1).cpu().numpy()
                x_gt = x_gt[:, 0:1].view(x_gt.shape[0], -1).cpu().numpy()
                z_hat = z_hat.cpu().numpy()
                z_gt = z_gt.cpu().numpy()

                n = x_gt.shape[0]
                
                image_error = np.sum(((x_gt - x_hat)**2), axis=-1)
                print("prediction length", n_pred, "starting position", ii, np.average(image_error))
                
                analysis_data[model_name][f"{n_pred}_pred"]["MSE_x"] += (np.sum((x_gt - x_hat)**2) / n)
                analysis_data[model_name][f"{n_pred}_pred"]["MSE_z"] += (np.sum((z_gt - z_hat)**2) / n)

                batch_ssim = 0
                batch_psnr = 0
                for jj in range(n):
                    batch_ssim += ssim(
                        x_gt[jj].reshape(64,64), 
                        x_hat[jj].reshape(64,64), 
                        data_range=1.0
                    )
                    analysis_data[model_name][f"{n_pred}_pred"]["PSNR"] += psnr(
                        x_gt[jj].reshape(64,64), 
                        x_hat[jj].reshape(64,64), 
                        data_range=1.0
                    )
                    
                analysis_data[model_name][f"{n_pred}_pred"]["SSIM"] += (batch_ssim / n)
                analysis_data[model_name][f"{n_pred}_pred"]["PSNR"] += (batch_psnr / n)
    
            # Average MSE/SSIM/PSNR per image 
            for k in analysis_data[model_name][f"{n_pred}_pred"]:
                analysis_data[model_name][f"{n_pred}_pred"][k] /= (total_len - n_pred - 1)

    print("DONE!")

Loading models in path:  /home/olimoyo/visual-haptic-dynamics/saved_models/test/smallz_gru_lm_vha_4step
prediction length 1 starting position 1 278.93762
prediction length 1 starting position 2 118.995605
prediction length 1 starting position 3 73.40684
prediction length 1 starting position 4 27.665127
prediction length 1 starting position 5 13.319531
prediction length 1 starting position 6 8.862243
prediction length 1 starting position 7 4.6967487
prediction length 1 starting position 8 2.7539203
prediction length 1 starting position 9 3.7615943
prediction length 1 starting position 10 5.027911


KeyboardInterrupt: 

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(16,12))

def add_data_to_plot(name, data):
    properties = name.split("_")
    
    # Map model properties to line properties here
    if "vha" in properties:
        color = "r"
    else:
        color = "b"
        
    if "4step" in properties:
        linestyle = "-."
    else:
        linestyle = "-"
        
    if "nl" in properties:
        marker = "o"
    else:
        marker = "v"
    
    plot_data = {'MSE_x': [], 'SSIM': [], 'PSNR': [], "MSE_z": []}
    for k, v in data.items():
        plot_data['MSE_x'].append(v['MSE_x'])
        plot_data['SSIM'].append(v['SSIM'])
        plot_data['PSNR'].append(v['PSNR'])
        plot_data['MSE_z'].append(v['MSE_z'])

    axs[0,0].plot(
        list(range(1, len(plot_data['MSE_x']) + 1)), 
        plot_data['MSE_x'],
        color=color, 
        linestyle=linestyle,
        marker=marker
    )
    axs[0,1].plot(
        list(range(1, len(plot_data['SSIM']) + 1)), 
        plot_data['SSIM'],
        color=color, 
        linestyle=linestyle,
        marker=marker
    )
    axs[1,0].plot(
        list(range(1, len(plot_data['PSNR']) + 1)), 
        plot_data['PSNR'],
        color=color,
        linestyle=linestyle,
        marker=marker
    )
    axs[1,1].plot(
        list(range(1, len(plot_data['MSE_z']) + 1)), 
        plot_data['MSE_z'],
        color=color,
        linestyle=linestyle,
        marker=marker
    )

axs[0,0].set_title("MSE (Image)")
axs[0,1].set_title("SSIM")
axs[1,0].set_title("PSNR")
axs[1,1].set_title("MSE (Latent)")

for k, v in analysis_data.items():
    add_data_to_plot(k, v)