In [2]:
import rlil
import time
import pickle
import os
import torch
import numpy as np
from gym import wrappers
from copy import deepcopy
from rlil.environments import GymEnvironment, Action
from rlil.presets.continuous import td3, noisy_td3
from rlil.initializer import set_device
from utils.pendulum_render import PendulumRender
set_device("cpu")

In [4]:
# set env
env = GymEnvironment("Pendulum-v0")
renderer = PendulumRender()
renderer.add_pendulum("greedy", color=(.8, .3, .3, 0.5))
num_noise = 5
for i in range(num_noise):
    renderer.add_pendulum("noise{}".format(i), color=(.2, 1.0 - i*0.1, .2, 0.5))

# load agent
agent_fn = td3()# noisy_td3()
agent = agent_fn(env)
# agent_dir = "../runs/noisy_net/Pendulum-v0/noisy-td3_ca7c79d_2020-05-13_12:25:22.382949"
# agent.load(agent_dir)

-----ON_POLICY_MODE: False-----
-----N step: 1-----
-----Discount factor: 0.99-----


## Observation

Type: Box(3)

Num | Observation  | Min | Max  
----|--------------|-----|----   
0   | cos(theta)   | -1.0| 1.0
1   | sin(theta)   | -1.0| 1.0
2   | theta dot    | -8.0| 8.0


## Actions

Type: Box(1)

Num | Action  | Min | Max  
----|--------------|-----|----   
0   | Joint effort | -2.0| 2.0


In [5]:
def get_theta_thetadots(state):
    theta = torch.atan2(state.features[:, 1], state.features[:, 0]).item()
    theta_dot = state.features[:, -1].item()
    return theta, theta_dot

In [7]:
# render predicted pendulum
env.reset()
greedy_env = deepcopy(env)
envs = [deepcopy(env) for _ in range(num_noise)]

greedy_agent = agent.make_lazy_agent(evaluation=True)
greedy_agent.set_replay_buffer(env)
agents = [agent.make_lazy_agent(evaluation=False) for _ in range(num_noise)]
for a in agents:
    a.set_replay_buffer(env)

for i in range(100):
    greedy_action = greedy_agent.act(greedy_env.state, greedy_env.reward)
    actions = [a.act(e.state, e.reward) for a, e in zip(agents, envs)]
    
    # render 
    renderer.render("greedy", greedy_env._env.state, greedy_action.features.item())
    for i in range(num_noise):
        renderer.render("noise{}".format(i), envs[i]._env.state, actions[i].features.item())
        
    time.sleep(1/60)
    
    # step oracle
    greedy_env.step(greedy_action)
    for i in range(num_noise):
        envs[i].step(actions[i])