forked from rll/rllab
-
Notifications
You must be signed in to change notification settings - Fork 21
/
batch_polopt.py
160 lines (146 loc) · 5.92 KB
/
batch_polopt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import time
from rllab.algos import RLAlgorithm
import rllab.misc.logger as logger
from sandbox.rocky.tf.policies.base import Policy
import tensorflow as tf
from sandbox.rocky.tf.samplers import BatchSampler
from sandbox.rocky.tf.samplers import VectorizedSampler
from rllab.sampler.utils import rollout
class BatchPolopt(RLAlgorithm):
"""
Base class for batch sampling-based policy optimization methods.
This includes various policy gradient methods like vpg, npg, ppo, trpo, etc.
"""
def __init__(
self,
env,
policy,
baseline,
scope=None,
n_itr=500,
start_itr=0,
batch_size=5000,
max_path_length=500,
discount=0.99,
gae_lambda=1,
plot=False,
pause_for_plot=False,
center_adv=True,
positive_adv=False,
store_paths=False,
whole_paths=True,
fixed_horizon=False,
sampler_cls=None,
sampler_args=None,
force_batch_sampler=False,
**kwargs
):
"""
:param env: Environment
:param policy: Policy
:type policy: Policy
:param baseline: Baseline
:param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms
simultaneously, each using different environments and policies
:param n_itr: Number of iterations.
:param start_itr: Starting iteration.
:param batch_size: Number of samples per iteration.
:param max_path_length: Maximum length of a single rollout.
:param discount: Discount.
:param gae_lambda: Lambda used for generalized advantage estimation.
:param plot: Plot evaluation run after each iteration.
:param pause_for_plot: Whether to pause before contiuing when plotting.
:param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1.
:param positive_adv: Whether to shift the advantages so that they are always positive. When used in
conjunction with center_adv the advantages will be standardized before shifting.
:param store_paths: Whether to save all paths data to the snapshot.
:return:
"""
self.env = env
self.policy = policy
self.baseline = baseline
self.scope = scope
self.n_itr = n_itr
self.start_itr = start_itr
self.batch_size = batch_size
self.max_path_length = max_path_length
self.discount = discount
self.gae_lambda = gae_lambda
self.plot = plot
self.pause_for_plot = pause_for_plot
self.center_adv = center_adv
self.positive_adv = positive_adv
self.store_paths = store_paths
self.whole_paths = whole_paths
self.fixed_horizon = fixed_horizon
if sampler_cls is None:
if self.policy.vectorized and not force_batch_sampler:
sampler_cls = VectorizedSampler
else:
sampler_cls = BatchSampler
if sampler_args is None:
sampler_args = dict()
self.sampler = sampler_cls(self, **sampler_args)
self.init_opt()
def start_worker(self):
self.sampler.start_worker()
def shutdown_worker(self):
self.sampler.shutdown_worker()
def obtain_samples(self, itr):
return self.sampler.obtain_samples(itr)
def process_samples(self, itr, paths):
return self.sampler.process_samples(itr, paths)
def train(self, sess=None):
created_session = True if (sess is None) else False
if sess is None:
sess = tf.Session()
sess.__enter__()
sess.run(tf.global_variables_initializer())
self.start_worker()
start_time = time.time()
for itr in range(self.start_itr, self.n_itr):
itr_start_time = time.time()
with logger.prefix('itr #%d | ' % itr):
logger.log("Obtaining samples...")
paths = self.obtain_samples(itr)
logger.log("Processing samples...")
samples_data = self.process_samples(itr, paths)
logger.log("Logging diagnostics...")
self.log_diagnostics(paths)
logger.log("Optimizing policy...")
self.optimize_policy(itr, samples_data)
logger.log("Saving snapshot...")
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
if self.store_paths:
params["paths"] = samples_data["paths"]
logger.save_itr_params(itr, params)
logger.log("Saved")
logger.record_tabular('Time', time.time() - start_time)
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length)
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")
self.shutdown_worker()
if created_session:
sess.close()
def log_diagnostics(self, paths):
self.env.log_diagnostics(paths)
self.policy.log_diagnostics(paths)
self.baseline.log_diagnostics(paths)
def init_opt(self):
"""
Initialize the optimization procedure. If using tensorflow, this may
include declaring all the variables and compiling functions
"""
raise NotImplementedError
def get_itr_snapshot(self, itr, samples_data):
"""
Returns all the data that should be saved in the snapshot for this
iteration.
"""
raise NotImplementedError
def optimize_policy(self, itr, samples_data):
raise NotImplementedError