Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions torch_xla_py/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ def __init__(self):

class PerDeviceQueue(object):

def __init__(self, device, maxsize):
def __init__(self, device, loader_prefetch_size, device_prefetch_size):
self.device = device
self.batch_number = 0
self.queue = kq.Queue(maxsize=maxsize)
self.loader_queue = kq.Queue(maxsize=loader_prefetch_size)
self.queue = kq.Queue(maxsize=device_prefetch_size)


class PerDeviceLoader(object):
Expand Down Expand Up @@ -58,10 +59,10 @@ def __init__(self,
self._batchdim = batchdim
self._done = False
self._lock = threading.Lock()
self._loader_queue = kq.Queue(maxsize=loader_prefetch_size)
self._queues = dict()
for device in self._devices:
self._queues[device] = PerDeviceQueue(device, device_prefetch_size)
self._queues[device] = PerDeviceQueue(device, loader_prefetch_size,
device_prefetch_size)
thread = threading.Thread(target=self._loader_worker)
thread.daemon = True
thread.start()
Expand All @@ -81,7 +82,7 @@ def close(self):
self._done = True
for dqueue in itervalues(self._queues):
dqueue.queue.close()
self._loader_queue.close()
dqueue.loader_queue.close()

def _expand_sample_batch(self, data, target):
# TODO: Expand last sample,target to batch size
Expand All @@ -92,6 +93,7 @@ def _loader_worker(self):
loader_batches = max(len(self._loader) - 1, 0)
num_batches = (loader_batches // len(self._devices)) * len(self._devices)
batch_number = 0
queues = list(self._queues.values())
while batch_number < num_batches and not self._done:
try:
data, target = self._loader.next()
Expand All @@ -104,13 +106,15 @@ def _loader_worker(self):
self._batch_size = data.size()[self._batchdim]
if data.size()[self._batchdim] != self._batch_size:
data, target = self._expand_sample_batch(data, target)
self._loader_queue.put((batch_number, (data, target)))
queues[batch_number % len(queues)].loader_queue.put((batch_number,
(data, target)))
batch_number += 1
self._loader_queue.close_write()
for dqueue in queues:
dqueue.loader_queue.close_write()

def _worker(self, dqueue):
while True:
item = self._loader_queue.get()
item = dqueue.loader_queue.get()
if item is None:
break
batch_number, (data, target) = item
Expand Down Expand Up @@ -158,6 +162,7 @@ def __call__(self):
loader = self._para_loader.per_device_loader(device)
thread = threading.Thread(
target=self._module_runner, args=(device, module, loader, result))
thread.daemon = True
thread.start()
threads.append(thread)
results.append(result)
Expand Down
4 changes: 2 additions & 2 deletions torch_xla_py/keyd_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get(self, key):
self._ready_cv.wait()
self._waited_keys.discard(key)
item = self._items.pop(key, None)
if item is not None and not self._close_write:
if item is not None:
self._space_available_cv.notify()
return item

Expand All @@ -78,6 +78,6 @@ def get(self):
while not self._items and not self._close_write:
self._ready_cv.wait()
item = self._items.popleft() if self._items else None
if item is not None and not self._close_write:
if item is not None:
self._space_available_cv.notify()
return item