In [1]:
import env
import gym
from model_free import TD3

class ActionRepeat(object):
    def __init__(self, env, amount):
        self._env = env
        self._amount = amount
        self._env._max_episode_steps = self._env._max_episode_steps // amount

    def __getattr__(self, name):
        return getattr(self._env, name)

    def step(self, action):
        total_reward = 0

        for _ in range(self._amount):
            obs, reward, _, _ = self._env.step(action)
            total_reward += reward

        return obs, total_reward, False, {}

    def reset(self, *args, **kwargs):
        return self._env.reset(*args, **kwargs)

env_name, env_str = 'MyHalfCheetah-v2', 'half_cheetah'
env = gym.make(env_name)
env = ActionRepeat(env, 4)

In [2]:
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
                   
policy = TD3(state_dim, action_dim, max_action)
policy.load(env_name, 'save/TD3')

In [3]:
import numpy as np

def print_rollout_stats(obs, acts, reward_sum):
    print("Cumulative reward ", reward_sum)
    print("Action min {}, max {}, mean {}, std {}".format(
        acts.min(), acts.max(), acts.mean(), acts.std()))
    print("Obs min {}, max {}, mean {}, std {}".format(
        obs.min(), obs.max(), obs.mean(), obs.std()))

def sample_rollout(env, policy):
    observations, actions, reward_sum = [env.reset()], [], 0

    for t in range(env._max_episode_steps):
        actions.append(policy.act(observations[t]))
        obs, reward, _, _ = env.step(actions[t])
        observations.append(obs)
        reward_sum += reward

    return np.array(observations), np.array(actions), reward_sum
    
    
O, A = [], []
for _ in range(20):
    obs, acts, reward_sum = sample_rollout(env, policy)
    O.append(obs)
    A.append(acts)
                   
O, A = np.array(O), np.array(A)
np.save('TD3_obs.npy', O)
np.save('TD3_act.npy', A)

In [4]:
O, A = np.load('TD3_obs.npy'), np.load('TD3_act.npy')
print(O.shape, A.shape)
print(O.min(), O.max(), O.mean(), O.std())
print(A.min(), A.max(), A.mean(), A.std())

(20, 251, 18) (20, 250, 6)
-12.98123377817222 14.52720278768601 0.7168978624829655 2.168804458354458
-1.0 1.0 -0.4431953 0.803009


In [5]:
O = np.load('expert_demonstrations/%s/expert_obs.npy' % env_str)
A = np.load('expert_demonstrations/%s/expert_act.npy' % env_str)
action_bound = env.action_space.high[0]
print(O.shape, A.shape)
print(O.min(), O.max(), O.mean(), O.std())
print(A.min(), A.max(), A.mean(), A.std())
print((A > action_bound).mean() + (A < -action_bound).mean())
A = A.clip(-action_bound, action_bound)
print(A.min(), A.max(), A.mean(), A.std())
print((A > action_bound).mean() + (A < -action_bound).mean())
np.save('expert_demonstrations/%s/expert_act_clipped.npy' % env_str, A)

(20, 251, 18) (20, 250, 6)
-17.2598140331567 466.09138757351496 11.581367552568011 51.45506850905242
-2.1904756893737196 2.157680768107007 0.0727562300912073 0.8969060018290144
0.36443333333333333
-1.0 1.0 0.04962332196652116 0.7189090617904839
0.0


In [41]:
O = np.load('TD3_obs.npy')
for obs in O[0]:
    print(obs.min(), obs.max())

-0.13627010685927474 0.13594109953289377
-1.5200681459822558 3.0472569865740464
-1.8954799102550497 6.362054806135884
-1.3698743114653857 3.221319736395427
-2.116689976699 3.433234498732357
-1.7590560916592217 4.079603860795589
-2.3288711655148195 6.050030379018915
-0.6324927716817786 4.1212290390501245
-2.5908700043160797 5.56639165727859
-0.6541840556457906 3.533670777454736
-2.278135142379368 4.9053343266968845
-1.5988037805874817 3.7119425054245614
-1.8942338087096333 7.714145838480796
-1.958161577431589 3.9091588092767005
-2.2602766200601927 5.071505600701684
-0.6923203240907059 4.43323654092147
-3.51856661330583 7.687839850775591
-2.9176344316151672 4.2661055520490265
-3.1964736432430816 5.174205091193197
-2.1320126531873393 5.5214455750519695
-3.2022385588594844 5.786404499512896
-2.3294074866751746 5.636213750881199
-2.4966421154567597 5.1356961778484695
-2.237595206786545 5.083999130731591
-4.499425414650021 13.65864028557779
-2.886068151607702 5.511570129871021
-2.60771544308

In [45]:
O = np.load('expert_demonstrations/%s/expert_obs.npy' % env_str)
for obs in O[1]:
    print(obs.min(), obs.max(), obs[2])

-0.2053105792766814 0.3152890439893641 -0.09742072774706148
-1.9424439216570777 4.547382704062731 0.05840000017390809
-5.218915311677941 2.3025003609833514 -0.2297794598333387
-1.182374213750419 6.541040138892428 0.0640540345311489
-1.9841267375849325 6.422739254409489 0.12993931611043574
-5.083465846841685 3.167157230179666 0.12449832061000521
-3.5909294725805676 3.7693890500541576 0.038199420846741695
-4.529144910707986 3.020943285533475 -0.635247914766367
-1.1560244067644538 3.738013726131424 -0.8290828870796162
-1.3635759255041453 3.3230021591852617 -0.9116556947412949
-1.7149855000678391 3.3218166347000984 -0.9181705738198307
-7.470108883966254 5.782503546022629 -1.081734121025653
-1.765018771978553 8.403980272802704 -0.8548397579852016
-1.8991789822394356 4.316145732702985 -0.7846905058346078
-1.0923209140361687 4.358449060287892 -0.6584023411812682
-4.9883031842640815 5.290330618324729 -0.5376304025276679
-0.7481184775467515 5.882225459690424 -0.19389073239515367
-4.026145426116

In [48]:
O = np.load('TD3_obs.npy')
print(O[:, :, 2].min(), O[:, :, 2].max())
print(O[:, :, 0].min(), O[:, :, 0].max())

O = np.load('expert_demonstrations/%s/expert_obs.npy' % env_str)
print(O[:, :, 2].min(), O[:, :, 2].max())
print(O[:, :, 0].min(), O[:, :, 0].max())

-3.5650650119443785 1.544295410865241
-1.6676734511271984 6.136299896777473
-1.3738031153485324 466.09138757351496
0.0 20.19333942500907


In [47]:
env.reset()
state = env.sim.get_state()
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape)))
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape) * 2))
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape) * -1))
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape) * -2))
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape) * 0.5))

(array([-0.18469202, -0.20371873, -0.22429946,  0.48778801,  0.5522581 ,
        0.49347044,  0.71161846,  0.6262945 ,  0.51181153, -0.17734774,
       -1.73018543,  1.01645313, -2.55163852, -3.13103014, -2.37537727,
       -0.15312798, -4.85421291, -0.05402869]), -0.6825296588119623, False, {})
(array([-0.18469202, -0.20371873, -0.22429946,  0.48778801,  0.5522581 ,
        0.49347044,  0.71161846,  0.6262945 ,  0.51181153, -0.17734774,
       -1.73018543,  1.01645313, -2.55163852, -3.13103014, -2.37537727,
       -0.15312798, -4.85421291, -0.05402869]), -7.882529658811963, False, {})
(array([ 0.3406295 , -0.17450392,  0.10522791, -0.51278894, -0.50646177,
       -0.42003137, -0.69212984, -0.6142823 , -0.51219773,  0.32310625,
       -1.8421801 , -0.96284358,  0.8745152 ,  3.90620256,  0.00669594,
        4.1552773 ,  3.28736303,  0.09990396]), -4.099728112692602, False, {})
(array([ 0.3406295 , -0.17450392,  0.10522791, -0.51278894, -0.50646177,
       -0.42003137, -0.69212984, -0.61