Skip to content

Commit

Permalink
Support RaySampler + GPU
Browse files Browse the repository at this point in the history
This assumes that all workers will fit on one GPU. If the intent is to
distribute sampling across multiple GPUs, we'll need to take the number
of GPUs as an argument to RaySampler.
  • Loading branch information
krzentner committed Sep 16, 2020
1 parent a2e8f9e commit f2516a8
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/garage/sampler/ray_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""
from collections import defaultdict
import itertools
import math

import click
import cloudpickle
Expand All @@ -33,7 +34,13 @@ def __init__(self, worker_factory, agents, envs):
ray.init(log_to_driver=False,
ignore_reinit_error=True,
include_dashboard=False)
self._sampler_worker = ray.remote(SamplerWorker)
# Assume the user has a big enough GPU to fit all workers, if they're
# using GPU.
# Avoid floating point rounding issues by rounding number of workers up
# to a power of 2:
n_workers_pow_2 = 2**math.ceil(math.log2(worker_factory.n_workers))
remote_wrapper = ray.remote(num_gpus=1 / n_workers_pow_2)
self._sampler_worker = remote_wrapper(SamplerWorker)
self._worker_factory = worker_factory
self._agents = agents
self._envs = self._worker_factory.prepare_worker_messages(envs)
Expand Down

0 comments on commit f2516a8

Please sign in to comment.