Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions torch_xla/distributed/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,14 @@ class DataParallel(object):
will be run on PyTorch CPU device.
"""

def __init__(self, network, device_ids=None):
def __init__(self, network, device_ids=None, **kwargs):
if device_ids is None:
device_ids = xm.get_xla_supported_devices()
self._device_ids = [str(x) for x in device_ids]
self._native_run = False
self._models = []
self._contexts = []
self._kwargs = kwargs
module = network if isinstance(network, torch.nn.Module) else network()
for device in device_ids:
device_module = deepcopy(module).to(device=torch.device(device))
Expand Down Expand Up @@ -95,15 +96,22 @@ def _handle_runner_exception(self, device, e):
# device.
os._exit(17)

def _module_runner(self, loop_fn, device, module, loader, context, result):
def _module_runner(self, loop_fn, device, module, loader, context, result,
**kwargs):
xm.set_replication(device, self._device_ids)
try:
result.result = loop_fn(module, loader, torch.device(device), context)
result.result = loop_fn(module, loader, torch.device(device), context,
**kwargs)
except Exception as e:
result.result = e
self._handle_runner_exception(device, e)

def __call__(self, loop_fn, loader, fixed_batch_size=False, batchdim=0):
def __call__(self,
loop_fn,
loader,
fixed_batch_size=False,
batchdim=0,
**kwargs):
"""Runs one EPOCH of training/test.

Args:
Expand All @@ -130,15 +138,17 @@ def __call__(self, loop_fn, loader, fixed_batch_size=False, batchdim=0):
## This is called without XLA devices available. Run in normal mode.
return [
loop_fn(self._models[0], enumerate(loader),
torch.device(self._device_ids[0]), self._contexts[0])
torch.device(self._device_ids[0]), self._contexts[0],
self._kwargs)
]

xm.wait_device_ops()
para_loader = pl.ParallelLoader(
loader,
self._device_ids,
batchdim=batchdim,
fixed_batch_size=fixed_batch_size)
fixed_batch_size=fixed_batch_size,
**self._kwargs)
threads = []
results = []
for module, device, context in zip(self._models, self._device_ids,
Expand All @@ -147,7 +157,8 @@ def __call__(self, loop_fn, loader, fixed_batch_size=False, batchdim=0):
loader = para_loader.per_device_loader(device)
thread = threading.Thread(
target=self._module_runner,
args=(loop_fn, device, module, loader, context, result))
args=(loop_fn, device, module, loader, context, result),
kwargs=kwargs)
thread.daemon = True
thread.start()
threads.append(thread)
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/distributed/xla_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _parse_workers_config(config):
# XRT_WORKERS='worker:0;ismz9:25822'
workers = collections.OrderedDict()
for worker in config.split('|'):
m = re.match(r'(\w+):(\d+);((grpc://)?[\w.]+:\d+)', worker)
m = re.match(r'(\w+):(\d+);((grpc://)?[a-zA-Z0-9_\-\.]+:\d+)', worker)
if not m:
raise ValueError('Bad worker syntax: {}'.format(worker))
workers['{}:{}'.format(m.group(1), m.group(2))] = WorkerConfigEntry(
Expand All @@ -67,7 +67,7 @@ def _parse_tpu_config(config):
# XRT_TPU_CONFIG='tpu_worker;0;ismz9:25822'
workers = collections.OrderedDict()
for worker in config.split('|'):
m = re.match(r'(\w+);(\d+);([\w.]+:\d+)', worker)
m = re.match(r'(\w+);(\d+);([a-zA-Z0-9_\-\.]+:\d+)', worker)
if not m:
raise ValueError('Bad worker syntax: {}'.format(worker))
workers['{}:{}'.format(m.group(1), m.group(2))] = WorkerConfigEntry(
Expand Down