In [1]:
import env
import gym
from TD3 import TD3

class ActionRepeat(object):
    def __init__(self, env, amount):
        self._env = env
        self._amount = amount
        self._env._max_episode_steps = self._env._max_episode_steps // amount

    def __getattr__(self, name):
        return getattr(self._env, name)

    def step(self, action):
        total_reward = 0

        for _ in range(self._amount):
            obs, reward, _, _ = self._env.step(action)
            total_reward += reward

        return obs, total_reward, False, {}

    def reset(self, *args, **kwargs):
        return self._env.reset(*args, **kwargs)

env_name, env_str = 'MyHalfCheetah-v2', 'half_cheetah'
env = gym.make(env_name)
env = ActionRepeat(env, 4)

In [2]:
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
                   
policy = TD3(state_dim, action_dim, max_action)
policy.load(env_name, 'save/TD3')

In [3]:
import numpy as np

def print_rollout_stats(obs, acts, reward_sum):
    print("Cumulative reward ", reward_sum)
    print("Action min {}, max {}, mean {}, std {}".format(
        acts.min(), acts.max(), acts.mean(), acts.std()))
    print("Obs min {}, max {}, mean {}, std {}".format(
        obs.min(), obs.max(), obs.mean(), obs.std()))

def sample_rollout(env, policy):
    observations, actions, reward_sum = [env.reset()], [], 0

    for t in range(env._max_episode_steps):
        actions.append(policy.act(observations[t]))
        obs, reward, _, _ = env.step(actions[t])
        observations.append(obs)
        reward_sum += reward

    return np.array(observations), np.array(actions), reward_sum
    
    
O, A = [], []
for _ in range(20):
    obs, acts, reward_sum = sample_rollout(env, policy)
    O.append(obs)
    A.append(acts)
                   
O, A = np.array(O), np.array(A)
np.save(f'save/TD3/{env_str}_obs.npy', O)
np.save(f'save/TD3/{env_str}_act.npy', A)

In [4]:
import numpy as np

O = np.load(f'save/TD3/{env_str}_obs.npy')
A = np.load(f'save/TD3/{env_str}_act.npy')
print(O.shape, A.shape)
print(O.min(), O.max(), O.mean(), O.std())
print(A.min(), A.max(), A.mean(), A.std())

(20, 251, 18) (20, 250, 6)
-11.519016382693518 15.365782248512922 0.7856723205389691 2.126600898625022
-1.0 1.0 -0.51907396 0.7513538


In [5]:
O = np.load('expert_demonstrations/%s/expert_obs.npy' % env_str)
A = np.load('expert_demonstrations/%s/expert_act.npy' % env_str)
action_bound = env.action_space.high[0]
print(O.shape, A.shape)
print(O.min(), O.max(), O.mean(), O.std())
print(A.min(), A.max(), A.mean(), A.std())

(20, 251, 18) (20, 250, 6)
-17.2598140331567 466.09138757351496 11.581367552568011 51.45506850905242
-2.1904756893737196 2.157680768107007 0.0727562300912073 0.8969060018290144


In [6]:
O = np.load(f'save/TD3/{env_str}_obs.npy')
for obs in O[0]:
    print(obs.min(), obs.max())

-0.09804260334570808 0.1559661790163903
-2.6340514909348984 4.6405004505544944
-1.0575151070359623 5.066807272272785
-1.382633799240185 4.898223931977299
-2.5348665453880024 3.915392920930322
-2.647612501865823 5.656928204332139
-0.8777368518847419 1.3594089235225022
-2.085849455193783 2.740507086202558
-2.2620593275360266 2.0983582607531703
-4.572560348039407 7.428424377198532
-0.6746623306931585 2.9078062461945553
-1.7835740064940742 4.695351795484368
-0.6250201383660005 3.75637585882617
-3.0448046625410328 12.595524151475392
-1.1394309579209667 3.1526581349782212
-2.9568302105865034 4.929334900004054
-0.66221056241228 4.8972869925883336
-2.5568969031869866 4.757351509077795
-0.7688104814047204 4.8683755661258115
-2.6922112998467425 4.724784999663313
-0.6078059553458048 4.8086374425229295
-3.3669890609953725 4.943730124048877
-1.8926318288680335 5.569460454957294
-3.287783088382407 5.322906897885755
-2.076487847644674 5.681576761010888
-3.0332677332291618 5.245708414846281
-1.2347926

In [7]:
O = np.load('expert_demonstrations/%s/expert_obs.npy' % env_str)
for obs in O[1]:
    print(obs.min(), obs.max(), obs[2])

-0.2053105792766814 0.3152890439893641 -0.09742072774706148
-1.9424439216570777 4.547382704062731 0.05840000017390809
-5.218915311677941 2.3025003609833514 -0.2297794598333387
-1.182374213750419 6.541040138892428 0.0640540345311489
-1.9841267375849325 6.422739254409489 0.12993931611043574
-5.083465846841685 3.167157230179666 0.12449832061000521
-3.5909294725805676 3.7693890500541576 0.038199420846741695
-4.529144910707986 3.020943285533475 -0.635247914766367
-1.1560244067644538 3.738013726131424 -0.8290828870796162
-1.3635759255041453 3.3230021591852617 -0.9116556947412949
-1.7149855000678391 3.3218166347000984 -0.9181705738198307
-7.470108883966254 5.782503546022629 -1.081734121025653
-1.765018771978553 8.403980272802704 -0.8548397579852016
-1.8991789822394356 4.316145732702985 -0.7846905058346078
-1.0923209140361687 4.358449060287892 -0.6584023411812682
-4.9883031842640815 5.290330618324729 -0.5376304025276679
-0.7481184775467515 5.882225459690424 -0.19389073239515367
-4.026145426116

-2.871689618290466 251.4770583449806 251.4770583449806
-3.9159665319035666 254.41846684508283 254.41846684508283
-4.97535260670594 256.06463619963415 256.06463619963415
-2.4802410991533073 257.94395358633835 257.94395358633835
-3.4151925487647476 260.1039406761965 260.1039406761965
-3.546310199886344 262.4709107206666 262.4709107206666
-8.25057544721575 264.2136323493937 264.2136323493937
-6.062628411792211 266.7773014754938 266.7773014754938
-2.1221426077579206 268.4901328644235 268.4901328644235
-2.163276983934238 270.2827594064811 270.2827594064811
-3.2926259958955626 272.8430495304196 272.8430495304196
-6.301691200434821 274.51055809842296 274.51055809842296
-0.39684483905278695 276.18565320296153 276.18565320296153
-4.571798781516168 278.82261059670554 278.82261059670554
-4.236378674726077 280.24517093751035 280.24517093751035
-2.8067638487513626 281.64664450610513 281.64664450610513
-3.780827549747264 283.24282372670297 283.24282372670297
-7.169468416574424 285.6189881913536 285.

In [8]:
O = np.load(f'save/TD3/{env_str}_obs.npy')
print(O[:, :, 2].min(), O[:, :, 2].max())
print(O[:, :, 0].min(), O[:, :, 0].max())

O = np.load('expert_demonstrations/%s/expert_obs.npy' % env_str)
print(O[:, :, 2].min(), O[:, :, 2].max())
print(O[:, :, 0].min(), O[:, :, 0].max())

-0.8813040193569746 0.7321214789342055
-0.3536687767903407 5.870399348983142
-1.3738031153485324 466.09138757351496
0.0 20.19333942500907


In [9]:
env.reset()
state = env.sim.get_state()
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape)))
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape) * 2))
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape) * -1))
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape) * -2))
env.sim.set_state(state)
print(env.step(np.ones(env.action_space.shape) * 0.5))

(array([-0.29000253, -0.17488076, -0.08080397,  0.47824205,  0.539011  ,
        0.4931918 ,  0.71287697,  0.64726535,  0.51150485, -0.28327528,
       -1.71073711,  0.76098108, -2.09995542, -2.8207799 , -2.06848264,
       -0.12873733, -5.22789669, -0.05933   ]), -1.3761670869869853, False, {})
(array([-0.29000253, -0.17488076, -0.08080397,  0.47824205,  0.539011  ,
        0.4931918 ,  0.71287697,  0.64726535,  0.51150485, -0.28327528,
       -1.71073711,  0.76098108, -2.09995542, -2.8207799 , -2.06848264,
       -0.12873733, -5.22789669, -0.05933   ]), -8.576167086986986, False, {})
(array([ 0.19908008, -0.12150486,  0.24387173, -0.51716174, -0.51473253,
       -0.42001129, -0.67832286, -0.5972334 , -0.51072271,  0.18984958,
       -1.90494951, -1.04566123,  0.8080802 ,  4.44054711,  0.00982572,
        3.53654683,  3.58546877,  0.10498915]), -4.66900713483287, False, {})
(array([ 0.19908008, -0.12150486,  0.24387173, -0.51716174, -0.51473253,
       -0.42001129, -0.67832286, -0.597