From f5ffddd675319dfa660244d9195bf69fd3db4239 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Thu, 14 Mar 2019 12:51:17 -0700 Subject: [PATCH] Fix data parallel to uniformily distribute over device queues. --- torch_xla_py/data_parallel.py | 21 +++++++++++++-------- torch_xla_py/keyd_queue.py | 4 ++-- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/torch_xla_py/data_parallel.py b/torch_xla_py/data_parallel.py index 987d0a8d1981..1ca163e72016 100644 --- a/torch_xla_py/data_parallel.py +++ b/torch_xla_py/data_parallel.py @@ -19,10 +19,11 @@ def __init__(self): class PerDeviceQueue(object): - def __init__(self, device, maxsize): + def __init__(self, device, loader_prefetch_size, device_prefetch_size): self.device = device self.batch_number = 0 - self.queue = kq.Queue(maxsize=maxsize) + self.loader_queue = kq.Queue(maxsize=loader_prefetch_size) + self.queue = kq.Queue(maxsize=device_prefetch_size) class PerDeviceLoader(object): @@ -58,10 +59,10 @@ def __init__(self, self._batchdim = batchdim self._done = False self._lock = threading.Lock() - self._loader_queue = kq.Queue(maxsize=loader_prefetch_size) self._queues = dict() for device in self._devices: - self._queues[device] = PerDeviceQueue(device, device_prefetch_size) + self._queues[device] = PerDeviceQueue(device, loader_prefetch_size, + device_prefetch_size) thread = threading.Thread(target=self._loader_worker) thread.daemon = True thread.start() @@ -81,7 +82,7 @@ def close(self): self._done = True for dqueue in itervalues(self._queues): dqueue.queue.close() - self._loader_queue.close() + dqueue.loader_queue.close() def _expand_sample_batch(self, data, target): # TODO: Expand last sample,target to batch size @@ -92,6 +93,7 @@ def _loader_worker(self): loader_batches = max(len(self._loader) - 1, 0) num_batches = (loader_batches // len(self._devices)) * len(self._devices) batch_number = 0 + queues = list(self._queues.values()) while batch_number < num_batches and not self._done: try: data, target = self._loader.next() @@ -104,13 +106,15 @@ def _loader_worker(self): self._batch_size = data.size()[self._batchdim] if data.size()[self._batchdim] != self._batch_size: data, target = self._expand_sample_batch(data, target) - self._loader_queue.put((batch_number, (data, target))) + queues[batch_number % len(queues)].loader_queue.put((batch_number, + (data, target))) batch_number += 1 - self._loader_queue.close_write() + for dqueue in queues: + dqueue.loader_queue.close_write() def _worker(self, dqueue): while True: - item = self._loader_queue.get() + item = dqueue.loader_queue.get() if item is None: break batch_number, (data, target) = item @@ -158,6 +162,7 @@ def __call__(self): loader = self._para_loader.per_device_loader(device) thread = threading.Thread( target=self._module_runner, args=(device, module, loader, result)) + thread.daemon = True thread.start() threads.append(thread) results.append(result) diff --git a/torch_xla_py/keyd_queue.py b/torch_xla_py/keyd_queue.py index df099cf0b259..2ee367cd9d8c 100644 --- a/torch_xla_py/keyd_queue.py +++ b/torch_xla_py/keyd_queue.py @@ -54,7 +54,7 @@ def get(self, key): self._ready_cv.wait() self._waited_keys.discard(key) item = self._items.pop(key, None) - if item is not None and not self._close_write: + if item is not None: self._space_available_cv.notify() return item @@ -78,6 +78,6 @@ def get(self): while not self._items and not self._close_write: self._ready_cv.wait() item = self._items.popleft() if self._items else None - if item is not None and not self._close_write: + if item is not None: self._space_available_cv.notify() return item