In [1]:
import rlil
import time
import pickle
import os
import torch
import numpy as np
from copy import deepcopy
from rlil.environments import GymEnvironment, Action
from rlil.presets.continuous import rs_mpc
from rlil.initializer import set_device
from utils.pendulum_render import PendulumRender
set_device("cpu")

In [2]:
# set env
env = GymEnvironment("Pendulum-v0")
renderer = PendulumRender()
renderer.add_pendulum("ground truth", color=(.8, .3, .3, 0.5))
renderer.add_pendulum("nn", color=(.2, .9, .2, 0.5))

# load agent
agent_fn = rs_mpc()
agent = agent_fn(env)
agent_dir = "../runs/rs_mpc/Pendulum-v0/rs-mpc_5746978_2020-05-09_10:41:06.054844"
agent.load(agent_dir)

-----ON_POLICY_MODE: False-----




In [9]:
from gym.wrappers import TimeLimit
isinstance(env._env, TimeLimit)

True

In [26]:
env.reset()
env._env._elapsed_steps
env._env._max_episode_steps

200

In [23]:
env._env.observation_space.low
obs_space = type(env._env.observation_space)

In [25]:
env._env.observation_space.low + 

array([-1., -1., -8.], dtype=float32)

## Observation

Type: Box(3)

Num | Observation  | Min | Max  
----|--------------|-----|----   
0   | cos(theta)   | -1.0| 1.0
1   | sin(theta)   | -1.0| 1.0
2   | theta dot    | -8.0| 8.0


## Actions

Type: Box(1)

Num | Action  | Min | Max  
----|--------------|-----|----   
0   | Joint effort | -2.0| 2.0


In [3]:
def get_theta_thetadots(state):
    theta = torch.atan2(state.features[:, 1], state.features[:, 0]).item()
    theta_dot = state.features[:, -1].item()
    return theta, theta_dot

In [8]:
# render predicted pendulum
state = env.reset()
pred_state = state

for i in range(20):
    action = Action(torch.FloatTensor(Action.action_space().sample()).unsqueeze(0))
    # render 
    theta, theta_dot = get_theta_thetadots(pred_state)
    renderer.render("ground truth", env._env.state, action.features.item())
    renderer.render("nn", [theta, theta_dot], action.features.item())
    time.sleep(1/20)
    
    # predict next state
    pred_next_state = agent.dynamics(pred_state, action)
    pred_state = pred_next_state
    
    # step oracle
    env.step(action)