In [20]:
import torch
import matplotlib.pyplot as plt
from common import *
import time
import copy
import gym

import xarray as xr
import numpy as np

import os
import matplotlib

# Rollout Fns ==================================================
def do_rollout_mesh(env, policy, mesh, render=False):
    torch.autograd.set_grad_enabled(False)

    act_list = []
    obs_list = []
    rew_list = []

    dtype = torch.float32
    obs = env.reset()
    done = False
    cur_step = 0

    while not done:
        obs = torch.as_tensor(obs, dtype=dtype).detach()
        
        obs_list.append(obs.clone())

        act = policy(obs)
        obs, rew, done, _ = env.step(act.numpy())
        if render:
            env.render()
            time.sleep(.01)

        act_list.append(torch.as_tensor(act.clone()))
        rew_list.append(rew)

        if obs in mesh:
            break
        else:
            mesh[obs] = 1

        cur_step += 1

    ep_length = len(rew_list)
    ep_obs = torch.stack(obs_list)
    ep_act = torch.stack(act_list)
    ep_rew = torch.tensor(rew_list, dtype=dtype)
    ep_rew = ep_rew.reshape(-1, 1)

    torch.autograd.set_grad_enabled(True)
    return ep_obs, ep_act, ep_rew, mesh

def do_rollout_pert(env, policy, initial_pos, num_steps=10):
    torch.autograd.set_grad_enabled(False)
    act_list = []
    obs_list = []
    rew_list = []

    dtype = torch.float32
    obs = my_reset(env, initial_pos)
    done = False

    for cur_step in range(num_steps):
        obs = torch.as_tensor(obs, dtype=dtype).detach()
        obs_list.append(obs.clone())

        act = policy(obs)
        obs, rew, done, _ = env.step(act.numpy())
        
        act_list.append(torch.as_tensor(act.clone()))
        rew_list.append(rew)


    ep_length = len(rew_list)
    ep_obs = torch.stack(obs_list)
    ep_act = torch.stack(act_list)
    ep_rew = torch.tensor(rew_list, dtype=dtype)
    ep_rew = ep_rew.reshape(-1, 1)

    torch.autograd.set_grad_enabled(True)
    return ep_obs, ep_act, ep_rew


In [None]:
env_names = ["HalfCheetah-v2", "Hopper-v2", "Walker2d-v2"]

matplotlib.style.use('default')

font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 16}

matplotlib.rc('font', **font)
init_names = ["identity", "madodiv", "identity"]

for env_name, init_name in zip(env_names, init_names):
    init_data = torch.load(f"./data17/{env_name}.xr")
    init_policy_dict = init_data.policy_dict
    
    data = torch.load(f"./data_mcshdim4/{env_name}.xr")
    policy_dict = data.policy_dict
    rews = data.rews#/data.post_rews
    exp_names = [fn.__name__ for fn in data.attrs['post_fns']]
    num_seeds = len(policy_dict[exp_names[0]])

    means = rews.mean(dim="trial")
    stds = rews.std(dim="trial")

    #plt.subplots(1,1, figsize=(10,7))
    plt.plot(means.T)
    plt.legend(['Conservative box dim', 'Box dim'], loc='lower right')
    ci = stds

    for mean, c in zip(means, ci):
        plt.fill_between([t for t in range(len(mean))], (mean-c), (mean+c), alpha=.5)
        
    plt.title(f"Preprocessed Reward")
    plt.ylabel(r"Avg. return $\pm$ std")
    plt.xlabel(r"Epoch")
    plt.grid()
    plt.figure()

    
#     for exp_name in exp_names:  
#         plt.plot(rews.loc[exp_name].T[:,0:10])
#         plt.legend([i for i in range(10)])
#         plt.title(exp_name)
#         plt.figure()
    

In [5]:
env_name = "HalfCheetah-v2"
env = gym.make(env_name)
data = torch.load(f"./data_mcshdim4/{env_name}.xr")
policy_dict = data.policy_dict

init_data = torch.load(f"./data17/{env_name}.xr")
init_pol_dict = init_data.policy_dict
policy_dict['identity'] = init_pol_dict['identity']


exp_names = [fn.__name__ for fn in data.attrs['post_fns']]
num_seeds = len(policy_dict[exp_names[0]])
mesh_sizes_dict = {}

In [94]:
# HalfCheetah-v2
post = 'mdim_div'; seed = 2
policy = policy_dict[post][seed]

def my_reset(env, point):
    env.reset()
    qpos = np.concatenate((np.array([0.0]), point[:8]))
    qvel = point[8:]
    env.unwrapped.set_state(qpos, qvel)
    return env.unwrapped._get_obs()


nom_obs, nom_acts, nom_rews, _ = do_long_rollout(env,policy,10000)

nominal_state = nom_obs[500]

obs, acts, rews = do_rollout_pert(env, policy, nominal_state, num_steps=11)
cmp_point = obs[-1]

delta = .1

ics = []
for i,s in enumerate(nominal_state):
    p_state = np.copy(nominal_state)
    m_state = np.copy(nominal_state)
    p_state[i] += delta
    m_state[i] -= delta
    
    ics.append((p_state, m_state))

    
eig_mat = np.zeros((17,17))
for i, ic in enumerate(ics):
    p_state , m_state = ic
    pobs, pacts, prews = do_rollout_pert(env,policy,p_state,num_steps=11)
    mobs, macts, mrews = do_rollout_pert(env,policy,m_state,num_steps=11)
    eig_mat[:,i] = np.abs(pobs[-1].numpy()-mobs[-1].numpy()) - np.abs(p_state - m_state)

In [21]:
post = 'mdim_div'; seed = 1

if post not in mesh_sizes_dict:
    mesh_sizes_dict[post] = {}

policy = policy_dict[post][seed]
mesh = BoxMesh(.2)
mesh_sizes = []

#init_states = np.linspace(-.3, .3, num=2)
#init_conditions = np.meshgrid(*[init_states]*17)

In [None]:
for i in range(1000000):
    obs, acts, rews, mesh = do_rollout_mesh(env, policy, mesh)
    mesh_sizes.append(len(mesh))
plt.plot(mesh_sizes)
mesh_sizes_dict[post][seed] = mesh_sizes

In [None]:
plt.plot(mesh_sizes_dict['identity'][1])
plt.plot(mesh_sizes_dict['mdim_div'][1])