In [1]:
import numpy as np
import _pickle as pkl
import torch
from torchvision import transforms
import random
import os, sys, time
sys.path.append('..')
from utils import (set_seed_torch, Normalize)
set_seed_torch(3)

In [2]:
class ObjectView(object):
    def __init__(self, d): self.__dict__ = d

# args = ObjectView({'res': 64,
#  'dataset_path': '/media/m2-drive/datasets/pendulum-srl-sim/{}'
#                    .format("pendulum64_total_2048_traj_16_repeat_2_with_angle_train.pkl"),
#  'models_dir': '/home/olimoyo/latent-metric-control/saved_models/{}'
#                    .format("test2"),
#  'n_batches': 32,
#  'device': 'cuda:0',
#  'n_trajs': 64,
#  'n_predictions': 11
# })

args = ObjectView({'res': 64,
 'dataset_path': '/Users/oliver/Datasets/pendulum-srl-sim/{}'
                   .format("pendulum64_total_2048_traj_16_repeat_2_with_angle_train.pkl"),
 'models_dir': '/Users/oliver/latent-metric-control/saved_models/{}'
                   .format("test"),
 'n_batches': 32,
 'device': 'cpu',
 'n_trajs': 64,
 'n_predictions': 11
})


In [3]:
with open(args.dataset_path, 'rb') as f:
    data = pkl.load(f)

imgs_cached, actions, gt_state = data[0][:args.n_trajs], data[1][:args.n_trajs], data[2][:args.n_trajs]
imgs_cached = imgs_cached.reshape(imgs_cached.shape[0], imgs_cached.shape[1], args.res, args.res, 3)
imgs = torch.empty((imgs_cached.shape[0], imgs_cached.shape[1], 1, 
                    imgs_cached.shape[2], imgs_cached.shape[3]), device=args.device)
actions = torch.from_numpy(actions).to(device=args.device)

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    Normalize(mean=0.27, var=1.0 - 0.27) # 64x64
    ])
        
for ii in range(imgs_cached.shape[0]):
    for jj in range(imgs_cached.shape[1]):
        imgs[ii, jj, :, :, :] = transform(imgs_cached[ii, jj, :, :, :])

## Visualize

In [4]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from utils import load_models, frame_stack
from argparse import Namespace
import json
import torch.nn as nn

In [5]:
models = {}

# Load hyperparameters from trained model
for filedir in os.listdir(args.models_dir):
    fullpath = os.path.join(args.models_dir, filedir)
    if os.path.isdir(fullpath):
        models[fullpath] = {}
        with open(os.path.join(fullpath, 'hyperparameters.txt'), 'r') as fp:
            models[fullpath]['hyperparameters'] = Namespace(**json.load(fp))
            
with torch.no_grad():
    for path, model in models.items():
        model_args = model['hyperparameters']
        
        enc, dec, dyn = load_models(path, model_args, mode='eval', device=args.device)
        T = model_args.frame_stacks

        z_all = torch.zeros((imgs.shape[0] * (imgs.shape[1] - T), model_args.dim_z)).to(device=args.device)
        z_hat_all = torch.zeros((imgs.shape[0] * args.n_predictions, model_args.dim_z)).to(device=args.device)
        gt_state_ = gt_state[:, T:]
        gt_state_hat = gt_state[:, (T + 1):(T + 1 + args.n_predictions)]
        
        for ii in range(imgs.shape[0] // args.n_batches):
            # Direct embedding
            x = imgs[args.n_batches*ii:args.n_batches*(ii+1)]
            x_s = frame_stack(x, frames=T)
            z, mu_z, logvar_z = enc(x_s.reshape(-1, *x_s.shape[2:]))
            z_all[z.shape[0]*ii:z.shape[0]*(ii+1)] = z
            
            # Roll-out or predictions embedding
            z_hat =  torch.zeros((args.n_batches, args.n_predictions, model_args.dim_z)).to(device=args.device)
            u_f = actions[args.n_batches*ii:args.n_batches*(ii+1), (T + 1):(T + 1 + args.n_predictions)]

            x_i = imgs[args.n_batches*ii:args.n_batches*(ii+1), :(T + 1)]
            x_i_s = frame_stack(x_i, frames=T)
            z_i, mu_z_i, logvar_z_i = enc(x_i_s.reshape(-1, *x_i_s.shape[2:]))
            var_z_i = torch.diag_embed(torch.exp(logvar_z_i))
            h_i = None

            for jj in range(args.n_predictions):
                u_i = u_f[:, jj]
                z_ip1, mu_z_ip1, var_z_ip1, h_ip1 = dyn(z_t=z_i, mu_t=mu_z_i, 
                                                        var_t=var_z_i, u=u_i, h=h_i, single=True)
                z_hat[:, jj] = z_ip1[0]
                z_i, mu_z_i, var_z_i, h_i = z_ip1[0], mu_z_ip1[0], var_z_ip1[0], h_ip1
                
            z_hat = z_hat.reshape(-1, *z_hat.shape[2:])
            z_hat_all[z_hat.shape[0]*ii:z_hat.shape[0]*(ii+1)] = z_hat
        
        model['z_all'] = z_all.cpu().numpy()
        model['z_hat_all'] = z_hat_all.cpu().numpy()
        model['gt_state_hat'] = gt_state_hat.reshape(-1, *gt_state_hat.shape[2:])
        model['gt_state'] = gt_state_.reshape(-1, *gt_state.shape[2:])

Loading models in path:  /Users/oliver/latent-metric-control/saved_models/test/basez3_softplus
Loading models in path:  /Users/oliver/latent-metric-control/saved_models/test/basez2_softplus
Loading models in path:  /Users/oliver/latent-metric-control/saved_models/test/basez3
Loading models in path:  /Users/oliver/latent-metric-control/saved_models/test/basez2
Loading models in path:  /Users/oliver/latent-metric-control/saved_models/test/base


In [11]:
def plot_3d(points3d, title, colour_scales, scale_labels):
    n = len(colour_scales)
    fig = make_subplots(rows=1, cols=n, specs=[[{"type":"scene"} for _ in range(n)]])

    assert isinstance(points3d, (np.ndarray, torch.Tensor))
    if type(points3d) is torch.Tensor:
        points3d = points3d.cpu().numpy()

    for ii, (colour_scale, scale_label) in enumerate(zip(colour_scales, scale_labels)):
        assert isinstance(points3d, (np.ndarray, torch.Tensor))
        if type(colour_scale) is torch.Tensor:
            colour_scale = colour_scale.cpu().numpy()
            
        fig.add_trace(go.Scatter3d(x=points3d[:,0], y=points3d[:,1], z=points3d[:,2],
                                   mode='markers',
                                   marker=dict(size=1.5,
                                               color=colour_scale,        # set color to an array/list of desired values
#                                                colorbar=dict(title=scale_label, x=(ii + 1) * (1.0 / n)),  
                                               colorscale='Viridis',   # choose a colorscale
                                               opacity=0.75
                                               )
                                   ),
                                   row=1, col=ii + 1
                      )

    layout = go.Layout(title=title, title_x=0.5, showlegend=False,
                       xaxis=dict(zeroline=False, showgrid=True),
                       yaxis=dict(zeroline=False, showgrid=True))
    fig.update_layout(layout)

    for jj in range(n):
        fig['layout']['scene{}'.format(jj + 1)].update(dict(
            xaxis = dict(title='z₁', showticklabels=False),
            yaxis = dict(title='z₂', showticklabels=False),
            zaxis = dict(title='z₃', showticklabels=False),
            aspectmode='cube'),
        )

    return fig

def plot_2d(points2d, title, colour_scales, scale_labels):
    n = len(colour_scales)
    fig = make_subplots(rows=1, cols=n, shared_xaxes=True, shared_yaxes=True, specs=[[{"type":"xy"} for _ in range(n)]])

    assert isinstance(points2d, (np.ndarray, torch.Tensor))
    if type(points2d) is torch.Tensor:
        points2d = points2d.detach().cpu().numpy()

    for ii, (colour_scale, scale_label) in enumerate(zip(colour_scales, scale_labels)):
        assert isinstance(points2d, (np.ndarray, torch.Tensor))
        if type(colour_scale) is torch.Tensor:
            colour_scale = colour_scale.cpu().numpy()
            
        fig.add_trace(go.Scatter(x=points2d[:,0], y=points2d[:,1],
                                   mode='markers',
                                   marker=dict(size=2.5,
                                               color=colour_scale,        # set color to an array/list of desired values
#                                                colorbar=dict(title=scale_label, x=(ii + 1) * (1.00 / n)),  
                                               colorscale='Viridis',   # choose a colorscale
                                               opacity=0.75
                                               )
                                   ),
                                   row=1, col=ii + 1
                      )
    
    fig.update_xaxes(dict(title='z₁', showticklabels=False))
    fig.update_yaxes(dict(title='z₂', showticklabels=False))
    
    layout = go.Layout(title=title, title_x=0.5, showlegend=False,
                       xaxis=dict(zeroline=False, showgrid=True),
                       yaxis=dict(zeroline=False, showgrid=True, scaleanchor="x", scaleratio=1))
    fig.update_layout(layout)

    return fig

for k, model in models.items():
    z = model['z_all']
    th =  model['gt_state'][:,0]
    thdot =  model['gt_state'][:,1]
    colour_scales = [th, thdot]
    titles = ["Ang. Pos.", "Ang. Vel."]
    if z.shape[-1] == 3:
        fig = plot_3d(points3d=z, colour_scales=colour_scales, scale_labels=titles, title=k)
        fig.show()
    elif z.shape[-1] == 2:
        fig = plot_2d(points2d=z, colour_scales=colour_scales, scale_labels=titles, title=k)
        fig.show()