diff --git a/torch_xla_py/data_parallel.py b/torch_xla_py/data_parallel.py index fe488e72fdd2..92f3afb14e48 100644 --- a/torch_xla_py/data_parallel.py +++ b/torch_xla_py/data_parallel.py @@ -3,6 +3,7 @@ import os from six import iteritems, itervalues +from copy import deepcopy import sys import threading import torch @@ -172,9 +173,10 @@ def __init__(self, network, device_ids=None, batchdim=0, drop_last=False): xm.Replication(self._device_ids, replication_devices) if replication_devices else None) self._models = [] + module = network if isinstance(network, torch.nn.Module) else network() for device in device_ids: - module = network().to(device=torch.device(device)) - self._models.append(module) + device_module = deepcopy(module).to(device=torch.device(device)) + self._models.append(device_module) if not self._models: # No XLA device, push a vanilla network in. self._models.append(network())