In [1]:
import env
import gym
from TD3 import TD3

class ActionRepeat(object):
    def __init__(self, env, amount):
        self._env = env
        self._amount = amount
        self._env._max_episode_steps = self._env._max_episode_steps // amount

    def __getattr__(self, name):
        return getattr(self._env, name)

    def step(self, action):
        total_reward = 0

        for _ in range(self._amount):
            obs, reward, _, _ = self._env.step(action)
            total_reward += reward

        return obs, total_reward, False, {}

    def reset(self, *args, **kwargs):
        return self._env.reset(*args, **kwargs)

env_name, env_str = 'MyHalfCheetah-v2', 'half_cheetah'
env = gym.make(env_name)
env = ActionRepeat(env, 4)

In [2]:
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
                   
policy = TD3(state_dim, action_dim, max_action)
policy.load(env_name, 'save/TD3')

In [3]:
import numpy as np

def print_rollout_stats(obs, acts, reward_sum):
    print("Cumulative reward ", reward_sum)
    print("Action min {}, max {}, mean {}, std {}".format(
        acts.min(), acts.max(), acts.mean(), acts.std()))
    print("Obs min {}, max {}, mean {}, std {}".format(
        obs.min(), obs.max(), obs.mean(), obs.std()))

def sample_rollout(env, policy):
    observations, actions, reward_sum = [env.reset()], [], 0

    for t in range(env._max_episode_steps):
        actions.append(policy.act(observations[t]))
        obs, reward, _, _ = env.step(actions[t])
        observations.append(obs)
        reward_sum += reward

    return np.array(observations), np.array(actions), reward_sum
    
    
O, A = [], []
for _ in range(20):
    obs, acts, reward_sum = sample_rollout(env, policy)
    O.append(obs)
    A.append(acts)
                   
O, A = np.array(O), np.array(A)
np.save(f'save/TD3/{env_str}_obs.npy', O)
np.save(f'save/TD3/{env_str}_act.npy', A)

In [4]:
import numpy as np

O = np.load(f'save/TD3/{env_str}_obs.npy')
A = np.load(f'save/TD3/{env_str}_act.npy')
print(O.shape, A.shape)
print(O.min(), O.max(), O.mean(), O.std())
print(A.min(), A.max(), A.mean(), A.std())

(20, 251, 18) (20, 250, 6)
-10.751935984528252 14.55392091816701 0.7812559760440648 2.1158196398542457
-1.0 1.0 -0.51787233 0.751424


In [10]:
O = np.load('expert_demonstrations/%s/expert_obs.npy' % env_str)
A = np.load('expert_demonstrations/%s/expert_act.npy' % env_str)
action_bound = env.action_space.high[0]
print(O.shape, A.shape)
print(O.min(), O.max(), O.mean(), O.std())
print(A.min(), A.max(), A.mean(), A.std())

(20, 251, 18) (20, 250, 6)
-17.2598140331567 466.09138757351496 11.581367552568011 51.45506850905242
-2.1904756893737196 2.157680768107007 0.0727562300912073 0.8969060018290144


In [16]:
print((A > action_bound).mean() + (A < -action_bound).mean())
A = A.clip(-action_bound, action_bound)
print(A.min(), A.max(), A.mean(), A.std())
print((A > action_bound).mean() + (A < -action_bound).mean())
np.save('expert_demonstrations/%s/expert_act_clipped.npy' % env_str, A)

0.36443333333333333
-1.0 1.0 0.04962332196652116 0.7189090617904839
0.0


In [6]:
O = np.load('TD3_obs.npy')
for obs in O[0]:
    print(obs.min(), obs.max())

-0.24045576919307665 0.10246831435064868
-2.2170024971922775 5.428058198923594
-3.216904688691702 3.2374590193142154
-3.475970241594433 4.55832939032996
-0.5964465649233486 2.1791627328297105
-2.563224409346569 5.3883672408106165
-1.4674314124858994 2.4448745912554504
-3.0431511933009046 5.744849395962814
-0.6323224360456183 2.908568615413971
-2.6517846084757144 5.343369674185851
-0.6441333978362928 3.431205649635576
-2.7323015264807093 11.023977562991574
-0.6443866452065128 3.3286229167830683
-2.876276913068025 4.928445119132387
-0.7080225937764661 4.079391933934787
-2.9427784037550166 5.086336779454876
-0.6627343397940233 5.086873087077348
-3.008141910064202 5.060714224991975
-0.7463834305499575 4.962658667902106
-2.700003842372651 4.971527900499699
-0.6286040269836732 4.908859054130623
-2.871715702782383 4.994414205881057
-2.8368584895721476 5.422262729199309
-3.337769562443696 5.279929409328857
-2.044341167922504 6.669318559534294
-2.5857807280474687 5.215893473252514
-2.3936615422

In [7]:
O = np.load('expert_demonstrations/%s/expert_obs.npy' % env_str)
for obs in O[1]:
    print(obs.min(), obs.max(), obs[2])

-0.2053105792766814 0.3152890439893641 -0.09742072774706148
-1.9424439216570777 4.547382704062731 0.05840000017390809
-5.218915311677941 2.3025003609833514 -0.2297794598333387
-1.182374213750419 6.541040138892428 0.0640540345311489
-1.9841267375849325 6.422739254409489 0.12993931611043574
-5.083465846841685 3.167157230179666 0.12449832061000521
-3.5909294725805676 3.7693890500541576 0.038199420846741695
-4.529144910707986 3.020943285533475 -0.635247914766367
-1.1560244067644538 3.738013726131424 -0.8290828870796162
-1.3635759255041453 3.3230021591852617 -0.9116556947412949
-1.7149855000678391 3.3218166347000984 -0.9181705738198307
-7.470108883966254 5.782503546022629 -1.081734121025653
-1.765018771978553 8.403980272802704 -0.8548397579852016
-1.8991789822394356 4.316145732702985 -0.7846905058346078
-1.0923209140361687 4.358449060287892 -0.6584023411812682
-4.9883031842640815 5.290330618324729 -0.5376304025276679
-0.7481184775467515 5.882225459690424 -0.19389073239515367
-4.026145426116

-3.546310199886344 262.4709107206666 262.4709107206666
-8.25057544721575 264.2136323493937 264.2136323493937
-6.062628411792211 266.7773014754938 266.7773014754938
-2.1221426077579206 268.4901328644235 268.4901328644235
-2.163276983934238 270.2827594064811 270.2827594064811
-3.2926259958955626 272.8430495304196 272.8430495304196
-6.301691200434821 274.51055809842296 274.51055809842296
-0.39684483905278695 276.18565320296153 276.18565320296153
-4.571798781516168 278.82261059670554 278.82261059670554
-4.236378674726077 280.24517093751035 280.24517093751035
-2.8067638487513626 281.64664450610513 281.64664450610513
-3.780827549747264 283.24282372670297 283.24282372670297
-7.169468416574424 285.6189881913536 285.6189881913536
-9.47243244924545 287.58057369273837 287.58057369273837
-5.022067201640501 289.3383323626718 289.3383323626718
-3.257970986147442 291.9175623646338 291.9175623646338
-3.0265183497420116 293.39403293738644 293.39403293738644
-1.7049425712918282 294.8366233321074 294.836

In [8]:
O = np.load('TD3_obs.npy')
print(O[:, :, 2].min(), O[:, :, 2].max())
print(O[:, :, 0].min(), O[:, :, 0].max())

O = np.load('expert_demonstrations/%s/expert_obs.npy' % env_str)
print(O[:, :, 2].min(), O[:, :, 2].max())
print(O[:, :, 0].min(), O[:, :, 0].max())

-0.7705469651596416 1.4307908481762608
-1.3358717898444183 5.714052976240964
-1.3738031153485324 466.09138757351496
0.0 20.19333942500907


In [9]:
env.reset()
state = env.sim.get_state()
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape)))
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape) * 2))
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape) * -1))
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape) * -2))
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape) * 0.5))

(array([-0.10973165, -0.26292127, -0.20911861,  0.48408427,  0.56114922,
        0.50386735,  0.71298876,  0.63649483,  0.51290477, -0.09941798,
       -1.82167312,  1.21463277, -2.83597424, -3.25488955, -2.42139786,
       -0.1890662 , -5.88516302, -0.09741851]), 0.14336164837117615, False, {})
(array([-0.10973165, -0.26292127, -0.20911861,  0.48408427,  0.56114922,
        0.50386735,  0.71298876,  0.63649483,  0.51290477, -0.09941798,
       -1.82167312,  1.21463277, -2.83597424, -3.25488955, -2.42139786,
       -0.1890662 , -5.88516302, -0.09741851]), -7.056638351628825, False, {})
(array([ 0.39203376, -0.23628184,  0.12096416, -0.51895156, -0.50984129,
       -0.42012853, -0.68183407, -0.59155681, -0.50653466,  0.38330465,
       -1.98931535, -0.65213766,  0.66825206,  3.14894048,  0.00518223,
        3.33151433,  3.07422075,  0.10383124]), -3.2611133663823653, False, {})
(array([ 0.39203376, -0.23628184,  0.12096416, -0.51895156, -0.50984129,
       -0.42012853, -0.68183407, -0.5