diff --git a/test/test_operations.py b/test/test_operations.py index 68ce0f477e30..9b2bb4302e84 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -274,7 +274,7 @@ def test(self): 28), torch.zeros(batch_size, dtype=torch.int64)), sample_count=sample_count * len(devices)) - def loop_fn(model, loader): + def loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) @@ -289,9 +289,8 @@ def loop_fn(model, loader): lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None) self.assertLess(loss.cpu().item(), 3.0) - model_parallel = dp.DataParallel( - XlaMNIST, train_loader, loop_fn, device_ids=devices) - model_parallel() + model_parallel = dp.DataParallel(XlaMNIST, device_ids=devices) + model_parallel(loop_fn, train_loader) if xu.getenv_as('METRICS_DEBUG', bool, defval=False): print(torch_xla._XLAC._xla_metrics_report()) @@ -307,7 +306,7 @@ def test(self): 224), torch.zeros(batch_size, dtype=torch.int64)), sample_count=sample_count * len(devices)) - def loop_fn(model, loader): + def loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) @@ -323,8 +322,8 @@ def loop_fn(model, loader): self.assertLess(loss.cpu().item(), 3.0) model_parallel = dp.DataParallel( - torchvision.models.resnet18, train_loader, loop_fn, device_ids=devices) - model_parallel() + torchvision.models.resnet18, device_ids=devices) + model_parallel(loop_fn, train_loader) if xu.getenv_as('METRICS_DEBUG', bool, defval=False): print(torch_xla._XLAC._xla_metrics_report()) diff --git a/test/test_train_cifar.py b/test/test_train_cifar.py new file mode 100644 index 000000000000..604957fa8671 --- /dev/null +++ b/test/test_train_cifar.py @@ -0,0 +1,205 @@ +import test_utils + +FLAGS = test_utils.parse_common_options( + datadir='/tmp/cifar-data', + batch_size=128, + num_epochs=15, + target_accuracy=80.0) + +from common_utils import TestCase, run_tests +import shutil +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch_xla +import torch_xla_py.data_parallel as dp +import torch_xla_py.utils as xu +import torch_xla_py.xla_model as xm +import torchvision +import torchvision.transforms as transforms +import unittest + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return F.log_softmax(out, dim=1) + + +def ResNet18(): + return ResNet(BasicBlock, [2, 2, 2, 2]) + + +def train_cifar(): + print('==> Preparing data..') + + if FLAGS.fake_data: + train_loader = xu.SampleGenerator( + data=(torch.zeros(FLAGS.batch_size, 3, 32, + 32), torch.zeros(FLAGS.batch_size, + dtype=torch.int64)), + sample_count=50000 // FLAGS.batch_size) + test_loader = xu.SampleGenerator( + data=(torch.zeros(FLAGS.batch_size, 3, 32, + 32), torch.zeros(FLAGS.batch_size, + dtype=torch.int64)), + sample_count=10000 // FLAGS.batch_size) + else: + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)), + ]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)), + ]) + + trainset = torchvision.datasets.CIFAR10( + root=FLAGS.datadir, + train=True, + download=True, + transform=transform_train) + train_loader = torch.utils.data.DataLoader( + trainset, + batch_size=FLAGS.batch_size, + shuffle=True, + num_workers=FLAGS.num_workers) + + testset = torchvision.datasets.CIFAR10( + root=FLAGS.datadir, + train=False, + download=True, + transform=transform_test) + test_loader = torch.utils.data.DataLoader( + testset, + batch_size=FLAGS.batch_size, + shuffle=False, + num_workers=FLAGS.num_workers) + + torch.manual_seed(42) + + momentum = 0.9 + lr = 0.1 + + devices = xm.get_xla_supported_devices() + # Pass [] as device_ids to run using the PyTorch/CPU engine. + model_parallel = dp.DataParallel(ResNet18, device_ids=devices) + + def train_loop_fn(model, loader, device, context): + loss_fn = nn.CrossEntropyLoss() + optimizer = optim.SGD( + model.parameters(), lr=lr, momentum=momentum, weight_decay=5e-4) + tracker = xm.RateTracker() + + for x, (data, target) in loader: + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + loss.backward() + xm.optimizer_step(optimizer) + tracker.add(FLAGS.batch_size) + print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(), + tracker.rate())) + + def test_loop_fn(model, loader, device, context): + total_samples = 0 + correct = 0 + for x, (data, target) in loader: + output = model(data) + pred = output.max(1, keepdim=True)[1] + correct += pred.eq(target.view_as(pred)).sum().item() + total_samples += data.size()[0] + + print('[{}] Accuracy={}'.format(device, correct / total_samples)) + return correct / total_samples + + accuracy = 0.0 + for epoch in range(1, FLAGS.num_epochs + 1): + model_parallel(train_loop_fn, train_loader) + accuracies = model_parallel(test_loop_fn, test_loader) + accuracy = sum(accuracies) / len(devices) + if FLAGS.metrics_debug: + print(torch_xla._XLAC._xla_metrics_report()) + + return accuracy * 100.0 + + +class TrainCIFAR10(TestCase): + + def tearDown(self): + super(TrainCIFAR10, self).tearDown() + if FLAGS.tidy and os.path.isdir(FLAGS.datadir): + shutil.rmtree(FLAGS.datadir) + + def test_accurracy(self): + self.assertGreaterEqual(train_cifar(), FLAGS.target_accuracy) + + +# Run the tests. +torch.set_default_tensor_type('torch.FloatTensor') +run_tests() diff --git a/test/test_train_imagenet.py b/test/test_train_imagenet.py new file mode 100644 index 000000000000..32e218e25e8f --- /dev/null +++ b/test/test_train_imagenet.py @@ -0,0 +1,122 @@ +import test_utils + +FLAGS = test_utils.parse_common_options( + datadir='/tmp/imagenet', batch_size=128, num_epochs=15, target_accuracy=0.0) + +from common_utils import TestCase, run_tests +import os +import shutil +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +import torch_xla +import torch_xla_py.data_parallel as dp +import torch_xla_py.utils as xu +import torch_xla_py.xla_model as xm +import unittest + + +def train_imagenet(): + print('==> Preparing data..') + if FLAGS.fake_data: + train_loader = xu.SampleGenerator( + data=(torch.zeros(FLAGS.batch_size, 3, 224, 224), + torch.zeros(FLAGS.batch_size, dtype=torch.int64)), + sample_count=1200000 // FLAGS.batch_size) + test_loader = xu.SampleGenerator( + data=(torch.zeros(FLAGS.batch_size, 3, 224, 224), + torch.zeros(FLAGS.batch_size, dtype=torch.int64)), + sample_count=50000 // FLAGS.batch_size) + else: + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + train_dataset = torchvision.datasets.ImageFolder( + os.path.join(FLAGS.datadir, 'train'), + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=FLAGS.batch_size, + shuffle=True, + num_workers=FLAGS.num_workers) + test_dataset = torchvision.datasets.ImageFolder( + os.path.join(FLAGS.datadir, 'val'), + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=FLAGS.batch_size, + shuffle=True, + num_workers=FLAGS.num_workers) + + torch.manual_seed(42) + + momentum = 0.9 + lr = 0.1 + devices = xm.get_xla_supported_devices() + # Pass [] as device_ids to run using the PyTorch/CPU engine. + model_parallel = dp.DataParallel( + torchvision.models.resnet50, device_ids=devices) + + def train_loop_fn(model, loader, device, context): + loss_fn = nn.CrossEntropyLoss() + optimizer = optim.SGD( + model.parameters(), lr=lr, momentum=momentum, weight_decay=5e-4) + tracker = xm.RateTracker() + + for x, (data, target) in loader: + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + loss.backward() + xm.optimizer_step(optimizer) + tracker.add(FLAGS.batch_size) + print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(), + tracker.rate())) + + def test_loop_fn(model, loader, device, context): + total_samples = 0 + correct = 0 + for x, (data, target) in loader: + output = model(data) + pred = output.max(1, keepdim=True)[1] + correct += pred.eq(target.view_as(pred)).sum().item() + total_samples += data.size()[0] + + print('[{}] Accuracy={}'.format(device, correct / total_samples)) + return correct / total_samples + + accuracy = 0.0 + for epoch in range(1, FLAGS.num_epochs + 1): + model_parallel(train_loop_fn, train_loader) + accuracies = model_parallel(test_loop_fn, test_loader) + accuracy = sum(accuracies) / len(devices) + if FLAGS.metrics_debug: + print(torch_xla._XLAC._xla_metrics_report()) + + return accuracy * 100.0 + + +class TrainImageNet(TestCase): + + def tearDown(self): + super(TrainImageNet, self).tearDown() + + def test_accurracy(self): + self.assertGreaterEqual(train_imagenet(), FLAGS.target_accuracy) + + +# Run the tests. +torch.set_default_tensor_type('torch.FloatTensor') +run_tests() diff --git a/test/test_train_mnist.py b/test/test_train_mnist.py index bb662e72a7b7..9cf2d902fe2b 100644 --- a/test/test_train_mnist.py +++ b/test/test_train_mnist.py @@ -13,6 +13,7 @@ import torch.optim as optim from torchvision import datasets, transforms import torch_xla +import torch_xla_py.data_parallel as dp import torch_xla_py.utils as xu import torch_xla_py.xla_model as xm import unittest @@ -41,12 +42,10 @@ def forward(self, x): def train_mnist(): - assert FLAGS.num_cores == 1 torch.manual_seed(1) # Training settings lr = 0.01 momentum = 0.5 - log_interval = 5 if FLAGS.fake_data: train_loader = xu.SampleGenerator( @@ -84,65 +83,46 @@ def train_mnist(): shuffle=True, num_workers=FLAGS.num_workers) - device = xm.xla_device() - torch_xla._XLAC._xla_set_default_device(str(device)) - model = MNIST().to(device=device) - optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) - loss_fn = nn.NLLLoss() - accuracy = None - for epoch in range(1, FLAGS.num_epochs + 1): - # Training loop for epoch. - start_time = time.time() - processed = 0 - for batch_idx, (data, target) in enumerate(train_loader): - if data.size()[0] != FLAGS.batch_size: - break - data = data.to(device=device) - target = target.to(device=device) + devices = xm.get_xla_supported_devices() + # Pass [] as device_ids to run using the PyTorch/CPU engine. + model_parallel = dp.DataParallel(MNIST, device_ids=devices) + + def train_loop_fn(model, loader, device, context): + loss_fn = nn.NLLLoss() + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) + tracker = xm.RateTracker() + for x, (data, target) in loader: optimizer.zero_grad() - y = model(data) - loss = loss_fn(y, target) + output = model(data) + loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) - - processed += FLAGS.batch_size - if batch_idx % log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\t' - 'Loss: {:.6f}\tSamples/sec: {:.1f}'.format( - epoch, processed, - len(train_loader) * FLAGS.batch_size, - 100. * batch_idx / len(train_loader), loss, - processed / (time.time() - start_time))) - - # Eval loop for epoch. - start_time = time.time() - correct_count = 0 - test_loss = 0 - count = 0 - for batch_idx, (data, target) in enumerate(test_loader): - if data.size()[0] != FLAGS.batch_size: - break - data = data.to(device=device) - target = target.to(device=device) - - y = model(data) - test_loss += loss_fn(y, target).sum().item() - pred = y.max(1, keepdim=True)[1] - correct_count += pred.eq(target.view_as(pred)).sum().item() - count += FLAGS.batch_size - - test_loss /= count - accuracy = 100.0 * correct_count / count - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%), ' - 'Samples/sec: {:.1f}\n'.format(test_loss, correct_count, count, - accuracy, - count / (time.time() - start_time))) - # Debug metric dumping. + tracker.add(FLAGS.batch_size) + print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(), + tracker.rate())) + + def test_loop_fn(model, loader, device, context): + total_samples = 0 + correct = 0 + for x, (data, target) in loader: + output = model(data) + pred = output.max(1, keepdim=True)[1] + correct += pred.eq(target.view_as(pred)).sum().item() + total_samples += data.size()[0] + + print('[{}] Accuracy={}'.format(device, correct / total_samples)) + return correct / total_samples + + accuracy = 0.0 + for epoch in range(1, FLAGS.num_epochs + 1): + model_parallel(train_loop_fn, train_loader) + accuracies = model_parallel(test_loop_fn, test_loader) + accuracy = sum(accuracies) / len(devices) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) - return accuracy + return accuracy * 100.0 class TrainMnist(TestCase): diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ce9d15a3c98e..410e33bb14aa 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -20,6 +20,7 @@ #include "torch_xla/csrc/ir_util.h" #include "torch_xla/csrc/python_util.h" #include "torch_xla/csrc/tensor_impl.h" +#include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" namespace torch_xla { @@ -100,7 +101,7 @@ void InsertCrossReplicaSum(const std::vector& tensors, double scale, } } -void SyncTensors(const std::vector& tensors) { +void SyncTensors(const std::vector& tensors, bool wait) { std::vector xtensors; for (auto& tensor : tensors) { auto xtensor = bridge::TryGetXlaTensor(ToTensor(tensor)); @@ -108,7 +109,7 @@ void SyncTensors(const std::vector& tensors) { xtensors.push_back(*xtensor); } } - XLATensor::SyncTensorsGraph(&xtensors, /*wait=*/false); + XLATensor::SyncTensorsGraph(&xtensors, wait); } void SyncLiveTensors(const std::string& device_str) { @@ -163,6 +164,32 @@ std::string GetLiveTensorsReport(size_t nodes_threshold, return ss.str(); } +std::vector GetXlaTensorsFromAten( + const std::vector& aten_tensors, + const std::vector& devices) { + std::vector xla_devices; + xla_devices.reserve(devices.size()); + for (auto& device_str : devices) { + Device device = bridge::AtenDeviceToXlaDevice(c10::Device(device_str)); + xla_devices.emplace_back(device.ToString()); + } + std::vector tensors; + tensors.reserve(aten_tensors.size()); + for (auto& aten_tensor : aten_tensors) { + tensors.push_back(ToTensor(aten_tensor)); + } + + auto data_handles = CreateTensorsData(tensors, xla_devices); + + std::vector xla_tensors; + xla_tensors.reserve(data_handles.size()); + for (auto& data_handle : data_handles) { + XLATensor xla_tensor = XLATensor::Create(std::move(data_handle)); + xla_tensors.push_back(bridge::AtenFromXlaTensor(std::move(xla_tensor))); + } + return xla_tensors; +} + void InitXlaModuleBindings(py::module m) { m.def("_initialize_aten_bindings", []() { AtenXlaType::InitializeAtenBindings(); }); @@ -189,6 +216,20 @@ void InitXlaModuleBindings(py::module m) { [](const std::vector& tensors) -> std::string { return GetTensorsHloGraph(tensors); }); + m.def("_xla_tensors_from_aten", [](const std::vector& tensors, + const std::vector& devices) { + std::vector result; + { + NoGilSection nogil; + std::vector xla_tensors = + GetXlaTensorsFromAten(tensors, devices); + result.reserve(xla_tensors.size()); + for (auto& tensor : xla_tensors) { + result.push_back(torch::autograd::make_variable(tensor)); + } + } + return result; + }); m.def("_xla_get_devices", []() { return xla::ComputationClient::Get()->GetAvailableDevices(); }); m.def("_xla_set_replication_devices", @@ -211,10 +252,13 @@ void InitXlaModuleBindings(py::module m) { m.def("_xla_set_default_device", [](const std::string& device) { return SetCurrentDevice(device); }); m.def("_xla_get_default_device", []() { return GetCurrentDevice(); }); - m.def("_xla_sync_multi", [](const std::vector& tensors) { - NoGilSection nogil; - SyncTensors(tensors); - }); + m.def( + "_xla_sync_multi", + [](const std::vector& tensors, bool wait) { + NoGilSection nogil; + SyncTensors(tensors, wait); + }, + py::arg("tensors"), py::arg("wait") = true); m.def( "_xla_sync_live_tensors", [](const std::string& device) { diff --git a/torch_xla_py/data_parallel.py b/torch_xla_py/data_parallel.py index e9517ae9639f..937ac791923b 100644 --- a/torch_xla_py/data_parallel.py +++ b/torch_xla_py/data_parallel.py @@ -1,14 +1,18 @@ from __future__ import division from __future__ import print_function +import os from six import iteritems, itervalues +import sys import threading import torch +import torch.autograd import torch_xla import torch_xla_py.utils as xu import torch_xla_py.xla_model as xm import torch_xla_py.utils as xu import torch_xla_py.keyd_queue as kq +import traceback class ThreadResult(object): @@ -60,7 +64,6 @@ def __init__(self, self._batchdim = batchdim self._drop_last = drop_last self._done = False - self._lock = threading.Lock() self._queues = dict() for device in self._devices: self._queues[device] = PerDeviceQueue(device, loader_prefetch_size, @@ -102,9 +105,8 @@ def fn(v): def _send_data_to(self, data, device): def convert_fn(tensors): - device_tensors = [x.to(device) for x in tensors] - torch_xla._XLAC._xla_sync_multi(device_tensors) - return device_tensors + devices = [str(device)] * len(tensors) + return torch_xla._XLAC._xla_tensors_from_aten(tensors, devices) def select_fn(v): return type(v) == torch.Tensor @@ -112,9 +114,7 @@ def select_fn(v): return xm.ToXlaTensorArena(convert_fn, select_fn).transform(data) def _loader_worker(self): - # TODO: When _expand_sample_batch() is implemented, remove the -1 fixup. - loader_batches = max(len(self._loader) - 1, 0) - num_batches = (loader_batches // len(self._devices)) * len(self._devices) + num_batches = (len(self._loader) // len(self._devices)) * len(self._devices) batch_number = 0 queues = list(self._queues.values()) data_iter = enumerate(self._loader) @@ -151,50 +151,67 @@ def _worker(self, dqueue): class DataParallel(object): - def __init__(self, - network, - loader, - loop_fn, - device_ids=None, - batchdim=0, - drop_last=False): - if not device_ids: + def __init__(self, network, device_ids=None, batchdim=0, drop_last=False): + if device_ids is None: device_ids = xm.get_xla_supported_devices() - self._loader = loader - self._loop_fn = loop_fn self._batchdim = batchdim self._drop_last = drop_last self._device_ids = list(device_ids) - self._modules = [] + self._models = [] for device in device_ids: module = network().to(device=torch.device(device)) - self._modules.append(module) - if not self._modules: + self._models.append(module) + if not self._models: # No XLA device, push a vanilla network in. - self._modules.append(network()) - - def _module_runner(self, device, module, loader, result): + self._models.append(network()) + + def _get_model_device(self, model): + for p in model.parameters(): + return p.device + return torch.device('cpu') + + def _handle_runner_exception(self, device, e): + print( + 'Exception in model function for device={}: {}'.format(device, str(e)), + file=sys.stderr) + traceback.print_exc(limit=16, file=sys.stderr) + # One exception in one thread is fatal, as the other ones (assuming they + # somehow did not generate the same exception) will be getting stuck in + # cross replica sum operations waiting for the defunct thread and its + # device. + os._exit(17) + + def _module_runner(self, loop_fn, device, module, loader, context, result): torch_xla._XLAC._xla_set_default_device(device) torch_xla._XLAC._xla_set_replication_devices(self._device_ids) - result.result = self._loop_fn(module, loader) - - def __call__(self): + try: + result.result = loop_fn(module, loader, torch.device(device), context) + except Exception as e: + result.result = e + self._handle_runner_exception(device, e) + + def __call__(self, loop_fn, loader): + context = dict() if not self._device_ids: ## This is called without XLA devices available. Run in normal mode. - return self._loop_fn(self._modules[0], self._loader) + return [ + loop_fn(self._models[0], enumerate(loader), + self._get_model_device(self._models[0]), context) + ] para_loader = ParallelLoader( - self._loader, + loader, self._device_ids, batchdim=self._batchdim, drop_last=self._drop_last) threads = [] results = [] - for module, device in zip(self._modules, self._device_ids): + for module, device in zip(self._models, self._device_ids): result = ThreadResult() loader = para_loader.per_device_loader(device) thread = threading.Thread( - target=self._module_runner, args=(device, module, loader, result)) + target=self._module_runner, + args=(loop_fn, device, module, loader, context, result)) thread.daemon = True thread.start() threads.append(thread)