-
Notifications
You must be signed in to change notification settings - Fork 308
/
test_trpo_with_model.py
73 lines (57 loc) · 2.31 KB
/
test_trpo_with_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
This script creates a test that fails when garage.tf.algos.TRPO performance is
too low.
"""
import gym
from garage.envs import normalize
from garage.experiment import LocalRunner
from garage.experiment import snapshotter
from garage.np.baselines import LinearFeatureBaseline
from garage.tf.algos import TRPO
from garage.tf.envs import TfEnv
from garage.tf.optimizers import FiniteDifferenceHvp
from garage.tf.policies import CategoricalGRUPolicyWithModel
from garage.tf.policies import CategoricalLSTMPolicyWithModel
from tests.fixtures import TfGraphTestCase
class TestTRPO(TfGraphTestCase):
def test_trpo_lstm_cartpole(self):
with LocalRunner(self.sess) as runner:
env = TfEnv(normalize(gym.make('CartPole-v1')))
policy = CategoricalLSTMPolicyWithModel(
name='policy', env_spec=env.spec)
baseline = LinearFeatureBaseline(env_spec=env.spec)
algo = TRPO(
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
discount=0.99,
max_kl_step=0.01,
optimizer_args=dict(
hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)))
snapshotter.snapshot_dir = './'
runner.setup(algo, env)
last_avg_ret = runner.train(n_epochs=10, batch_size=2048)
assert last_avg_ret > 80
env.close()
test_trpo_lstm_cartpole.large = True
def test_trpo_gru_cartpole(self):
with LocalRunner(self.sess) as runner:
env = TfEnv(normalize(gym.make('CartPole-v1')))
policy = CategoricalGRUPolicyWithModel(
name='policy', env_spec=env.spec)
baseline = LinearFeatureBaseline(env_spec=env.spec)
algo = TRPO(
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
discount=0.99,
max_kl_step=0.01,
optimizer_args=dict(
hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)))
runner.setup(algo, env)
last_avg_ret = runner.train(n_epochs=10, batch_size=2048)
assert last_avg_ret > 80
env.close()
test_trpo_gru_cartpole.large = True