Permalink
Cannot retrieve contributors at this time
Join GitHub today
GitHub is home to over 40 million developers working together to host and review code, manage projects, and build software together.
Sign up
Fetching contributors…

from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import logging | |
import ray | |
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer | |
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ | |
MultiAgentBatch | |
from ray.rllib.utils.annotations import override | |
from ray.rllib.utils.filter import RunningStat | |
from ray.rllib.utils.timer import TimerStat | |
from ray.rllib.utils.memory import ray_get_and_free | |
logger = logging.getLogger(__name__) | |
class MicrobatchOptimizer(PolicyOptimizer): | |
"""A microbatching synchronous RL optimizer. | |
This optimizer pulls sample batches from workers until the target | |
microbatch size is reached. Then, it computes and accumulates the policy | |
gradient in a local buffer. This process is repeated until the number of | |
samples collected equals the train batch size. Then, an accumulated | |
gradient update is made. | |
This allows for training with effective batch sizes much larger than can | |
fit in GPU or host memory. | |
""" | |
def __init__(self, workers, train_batch_size=10000, microbatch_size=1000): | |
PolicyOptimizer.__init__(self, workers) | |
if train_batch_size <= microbatch_size: | |
raise ValueError( | |
"The microbatch size must be smaller than the train batch " | |
"size, got {} vs {}".format(microbatch_size, train_batch_size)) | |
self.update_weights_timer = TimerStat() | |
self.sample_timer = TimerStat() | |
self.grad_timer = TimerStat() | |
self.throughput = RunningStat() | |
self.train_batch_size = train_batch_size | |
self.microbatch_size = microbatch_size | |
self.learner_stats = {} | |
self.policies = dict(self.workers.local_worker() | |
.foreach_trainable_policy(lambda p, i: (i, p))) | |
logger.debug("Policies to train: {}".format(self.policies)) | |
@override(PolicyOptimizer) | |
def step(self): | |
with self.update_weights_timer: | |
if self.workers.remote_workers(): | |
weights = ray.put(self.workers.local_worker().get_weights()) | |
for e in self.workers.remote_workers(): | |
e.set_weights.remote(weights) | |
fetches = {} | |
accumulated_gradients = {} | |
samples_so_far = 0 | |
# Accumulate minibatches. | |
i = 0 | |
while samples_so_far < self.train_batch_size: | |
i += 1 | |
with self.sample_timer: | |
samples = [] | |
while sum(s.count for s in samples) < self.microbatch_size: | |
if self.workers.remote_workers(): | |
samples.extend( | |
ray_get_and_free([ | |
e.sample.remote() | |
for e in self.workers.remote_workers() | |
])) | |
else: | |
samples.append(self.workers.local_worker().sample()) | |
samples = SampleBatch.concat_samples(samples) | |
self.sample_timer.push_units_processed(samples.count) | |
samples_so_far += samples.count | |
logger.info( | |
"Computing gradients for microbatch {} ({}/{} samples)".format( | |
i, samples_so_far, self.train_batch_size)) | |
# Handle everything as if multiagent | |
if isinstance(samples, SampleBatch): | |
samples = MultiAgentBatch({ | |
DEFAULT_POLICY_ID: samples | |
}, samples.count) | |
with self.grad_timer: | |
for policy_id, policy in self.policies.items(): | |
if policy_id not in samples.policy_batches: | |
continue | |
batch = samples.policy_batches[policy_id] | |
grad_out, info_out = ( | |
self.workers.local_worker().compute_gradients( | |
MultiAgentBatch({ | |
policy_id: batch | |
}, batch.count))) | |
grad = grad_out[policy_id] | |
fetches.update(info_out) | |
if policy_id not in accumulated_gradients: | |
accumulated_gradients[policy_id] = grad | |
else: | |
grad_size = len(accumulated_gradients[policy_id]) | |
assert grad_size == len(grad), (grad_size, len(grad)) | |
c = [] | |
for a, b in zip(accumulated_gradients[policy_id], | |
grad): | |
c.append(a + b) | |
accumulated_gradients[policy_id] = c | |
self.grad_timer.push_units_processed(samples.count) | |
# Apply the accumulated gradient | |
logger.info("Applying accumulated gradients ({} samples)".format( | |
samples_so_far)) | |
self.workers.local_worker().apply_gradients(accumulated_gradients) | |
if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches: | |
self.learner_stats = fetches[DEFAULT_POLICY_ID] | |
else: | |
self.learner_stats = fetches | |
self.num_steps_sampled += samples_so_far | |
self.num_steps_trained += samples_so_far | |
return self.learner_stats | |
@override(PolicyOptimizer) | |
def stats(self): | |
return dict( | |
PolicyOptimizer.stats(self), **{ | |
"sample_time_ms": round(1000 * self.sample_timer.mean, 3), | |
"grad_time_ms": round(1000 * self.grad_timer.mean, 3), | |
"update_time_ms": round(1000 * self.update_weights_timer.mean, | |
3), | |
"opt_peak_throughput": round(self.grad_timer.mean_throughput, | |
3), | |
"sample_peak_throughput": round( | |
self.sample_timer.mean_throughput, 3), | |
"opt_samples": round(self.grad_timer.mean_units_processed, 3), | |
"learner": self.learner_stats, | |
}) |