diff --git a/.circleci/test.sh b/.circleci/test.sh index 44ce0f4510b0..8a47a1ddfce5 100755 --- a/.circleci/test.sh +++ b/.circleci/test.sh @@ -19,7 +19,7 @@ echo "Running Python Tests" ./test/run_tests.sh # echo "Running MNIST Test" -# python test/test_train_mnist.py --tidy +# python test/test_train_mp_mnist.py --tidy # if [ -x "$(command -v nvidia-smi)" ]; then # python test/test_train_mp_mnist_amp.py --fake_data # fi diff --git a/README.md b/README.md index 5b049a5db5d3..46515080a5a7 100644 --- a/README.md +++ b/README.md @@ -200,7 +200,7 @@ Training on pods can be broken down to largely 3 different steps: If you prefer to not use an [instance group](#create-your-instance-group), you can decide to use a list of VM instances that you may have already created (or can create individually). Make sure that you create all the VM instances in the same zone as the TPU node, and also make sure that the VMs have the same configuration (datasets, VM size, disk size, etc.). Then you can [start distributed training](#start-distributed-training) after creating your TPU pod. The difference is in the `python -m torch_xla.distributed.xla_dist` command. For example, to use a list of VMs run the following command (ex. conda with v3-32): ``` (torch-xla-1.7)$ cd /usr/share/torch-xla-1.7/pytorch/xla -(torch-xla-1.7)$ python -m torch_xla.distributed.xla_dist --tpu=$TPU_POD_NAME --vm $VM1 --vm $VM2 --vm $VM3 --vm $VM4 --conda-env=torch-xla-1.7 --env=XLA_USE_BF16=1 -- python test/test_train_imagenet.py --fake_data +(torch-xla-1.7)$ python -m torch_xla.distributed.xla_dist --tpu=$TPU_POD_NAME --vm $VM1 --vm $VM2 --vm $VM3 --vm $VM4 --conda-env=torch-xla-1.7 --env=XLA_USE_BF16=1 -- python test/test_train_mp_imagenet.py --fake_data ``` ### Datasets for distributed training diff --git a/docker/common.sh b/docker/common.sh index af4cab6ccaf1..4c3e503138cc 100755 --- a/docker/common.sh +++ b/docker/common.sh @@ -5,7 +5,7 @@ function run_deployment_tests() { export XRT_WORKERS="localservice:0;grpc://localhost:40934" export CC=clang-8 CXX=clang++-8 - time python /pytorch/xla/test/test_train_mnist.py + time python /pytorch/xla/test/test_train_mp_mnist.py --fake_data time bash /pytorch/xla/test/run_tests.sh time bash /pytorch/xla/test/cpp/run_tests.sh } diff --git a/test/test_train_cifar.py b/test/test_train_cifar.py deleted file mode 100644 index c1fe097c7699..000000000000 --- a/test/test_train_cifar.py +++ /dev/null @@ -1,266 +0,0 @@ -import args_parse - -MODEL_OPTS = { - '--use_torchvision': { - 'default': False, - 'type': bool, - }, -} -FLAGS = args_parse.parse_common_options( - datadir='/tmp/cifar-data', - batch_size=128, - num_epochs=25, - momentum=0.9, - lr=0.1, - target_accuracy=80.0, - opts=MODEL_OPTS.items()) - -import os -from statistics import mean -import shutil -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import torch_xla -import torch_xla.debug.metrics as met -import torch_xla.distributed.data_parallel as dp -import torch_xla.utils.utils as xu -import torch_xla.core.xla_model as xm -import torch_xla.test.test_utils as test_utils -import torchvision -import torchvision.transforms as transforms -import unittest - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, in_planes, planes, stride=1): - super(BasicBlock, self).__init__() - self.conv1 = nn.Conv2d( - in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d( - planes, planes, kernel_size=3, stride=1, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion * planes: - self.shortcut = nn.Sequential( - nn.Conv2d( - in_planes, - self.expansion * planes, - kernel_size=1, - stride=stride, - bias=False), nn.BatchNorm2d(self.expansion * planes)) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += self.shortcut(x) - out = F.relu(out) - return out - - -class ResNet(nn.Module): - - def __init__(self, block, num_blocks, num_classes=10): - super(ResNet, self).__init__() - self.in_planes = 64 - - self.conv1 = nn.Conv2d( - 3, 64, kernel_size=3, stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(64) - self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) - self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) - self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) - self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) - self.linear = nn.Linear(512 * block.expansion, num_classes) - - def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1] * (num_blocks - 1) - layers = [] - for stride in strides: - layers.append(block(self.in_planes, planes, stride)) - self.in_planes = planes * block.expansion - return nn.Sequential(*layers) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.layer1(out) - out = self.layer2(out) - out = self.layer3(out) - out = self.layer4(out) - out = F.avg_pool2d(out, 4) - out = torch.flatten(out, 1) - out = self.linear(out) - return F.log_softmax(out, dim=1) - - -def ResNet18(): - return ResNet(BasicBlock, [2, 2, 2, 2]) - - -def train_cifar(): - print('==> Preparing data..') - - if FLAGS.fake_data: - train_dataset_len = 50000 # Number of example in CIFAR train set. - train_loader = xu.SampleGenerator( - data=(torch.zeros(FLAGS.batch_size, 3, 32, - 32), torch.zeros(FLAGS.batch_size, - dtype=torch.int64)), - sample_count=train_dataset_len // FLAGS.batch_size // - xm.xrt_world_size()) - test_loader = xu.SampleGenerator( - data=(torch.zeros(FLAGS.batch_size, 3, 32, - 32), torch.zeros(FLAGS.batch_size, - dtype=torch.int64)), - sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) - else: - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.2010)), - ]) - - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.2010)), - ]) - - train_dataset = torchvision.datasets.CIFAR10( - root=FLAGS.datadir, - train=True, - download=True, - transform=transform_train) - train_dataset_len = len(train_dataset) - test_dataset = torchvision.datasets.CIFAR10( - root=FLAGS.datadir, - train=False, - download=True, - transform=transform_test) - train_sampler = None - test_sampler = None - if xm.xrt_world_size() > 1: - train_sampler = torch.utils.data.distributed.DistributedSampler( - train_dataset, - num_replicas=xm.xrt_world_size(), - rank=xm.get_ordinal(), - shuffle=True) - test_sampler = torch.utils.data.distributed.DistributedSampler( - test_dataset, - num_replicas=xm.xrt_world_size(), - rank=xm.get_ordinal(), - shuffle=False) - train_loader = torch.utils.data.DataLoader( - train_dataset, - batch_size=FLAGS.batch_size, - sampler=train_sampler, - drop_last=FLAGS.drop_last, - shuffle=False if train_sampler else True, - num_workers=FLAGS.num_workers) - test_loader = torch.utils.data.DataLoader( - test_dataset, - batch_size=FLAGS.batch_size, - sampler=test_sampler, - drop_last=FLAGS.drop_last, - shuffle=False, - num_workers=FLAGS.num_workers) - - torch.manual_seed(42) - - devices = ( - xm.get_xla_supported_devices( - max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) - # Pass [] as device_ids to run using the PyTorch/CPU engine. - model = torchvision.models.resnet18 if FLAGS.use_torchvision else ResNet18 - model_parallel = dp.DataParallel(model, device_ids=devices) - writer = test_utils.get_summary_writer(FLAGS.logdir) - - def train_loop_fn(model, loader, device, context): - loss_fn = nn.CrossEntropyLoss() - optimizer = context.getattr_or( - 'optimizer', lambda: optim.SGD( - model.parameters(), - lr=FLAGS.lr, - momentum=FLAGS.momentum, - weight_decay=5e-4)) - tracker = xm.RateTracker() - - model.train() - for x, (data, target) in enumerate(loader): - optimizer.zero_grad() - output = model(data) - loss = loss_fn(output, target) - loss.backward() - xm.optimizer_step(optimizer) - tracker.add(FLAGS.batch_size) - if x % FLAGS.log_steps == 0: - test_utils.print_training_update( - device, - x, - loss.item(), - tracker.rate(), - tracker.global_rate(), - summary_writer=writer) - - def test_loop_fn(model, loader, device, context): - total_samples = 0 - correct = 0 - model.eval() - for data, target in loader: - output = model(data) - pred = output.max(1, keepdim=True)[1] - correct += pred.eq(target.view_as(pred)).sum().item() - total_samples += data.size()[0] - - accuracy = 100.0 * correct / total_samples - test_utils.print_test_update(device, accuracy) - return accuracy - - accuracy = 0.0 - num_devices = len( - xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 - num_training_steps_per_epoch = train_dataset_len // ( - FLAGS.batch_size * num_devices) - max_accuracy = 0.0 - for epoch in range(1, FLAGS.num_epochs + 1): - model_parallel(train_loop_fn, train_loader) - accuracies = model_parallel(test_loop_fn, test_loader) - accuracy = mean(accuracies) - max_accuracy = max(accuracy, max_accuracy) - print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) - global_step = (epoch - 1) * num_training_steps_per_epoch - test_utils.write_to_summary( - writer, - global_step, - dict_to_write={'Accuracy/test': accuracy}, - write_xla_metrics=True) - if FLAGS.metrics_debug: - xm.master_print(met.metrics_report()) - - test_utils.close_summary_writer(writer) - print('Max Accuracy: {:.2f}%'.format(max_accuracy)) - return max_accuracy - - -class TrainCIFAR10(unittest.TestCase): - - def tearDown(self): - super(TrainCIFAR10, self).tearDown() - if FLAGS.tidy and os.path.isdir(FLAGS.datadir): - shutil.rmtree(FLAGS.datadir) - - def test_accurracy(self): - self.assertGreaterEqual(train_cifar(), FLAGS.target_accuracy) - - -# Run the tests. -if __name__ == '__main__': - torch.set_default_tensor_type('torch.FloatTensor') - unittest.main() diff --git a/test/test_train_imagenet.py b/test/test_train_imagenet.py deleted file mode 100644 index 6e79234af484..000000000000 --- a/test/test_train_imagenet.py +++ /dev/null @@ -1,265 +0,0 @@ -import args_parse - -SUPPORTED_MODELS = [ - 'alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', - 'inception_v3', 'resnet101', 'resnet152', 'resnet18', 'resnet34', - 'resnet50', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg11_bn', 'vgg13', - 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn' -] - -MODEL_OPTS = { - '--model': { - 'choices': SUPPORTED_MODELS, - 'default': 'resnet50', - }, - '--test_set_batch_size': { - 'type': int, - }, - '--lr_scheduler_type': { - 'type': str, - }, - '--lr_scheduler_divide_every_n_epochs': { - 'type': int, - }, - '--lr_scheduler_divisor': { - 'type': int, - }, -} - -FLAGS = args_parse.parse_common_options( - datadir='/tmp/imagenet', - batch_size=None, - num_epochs=None, - momentum=None, - lr=None, - target_accuracy=None, - opts=MODEL_OPTS.items(), -) - -import os -import schedulers -from statistics import mean -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import torchvision -import torchvision.transforms as transforms -import torch_xla -import torch_xla.debug.metrics as met -import torch_xla.distributed.data_parallel as dp -import torch_xla.utils.utils as xu -import torch_xla.core.xla_model as xm -import torch_xla.test.test_utils as test_utils -import unittest - -DEFAULT_KWARGS = dict( - batch_size=128, - test_set_batch_size=64, - num_epochs=18, - momentum=0.9, - lr=0.1, - target_accuracy=0.0, -) -MODEL_SPECIFIC_DEFAULTS = { - # Override some of the args in DEFAULT_KWARGS, or add them to the dict - # if they don't exist. - 'resnet50': - dict( - DEFAULT_KWARGS, **{ - 'lr': 0.8, - 'lr_scheduler_divide_every_n_epochs': 20, - 'lr_scheduler_divisor': 5, - 'lr_scheduler_type': 'WarmupAndExponentialDecayScheduler', - }) -} - -# Set any args that were not explicitly given by the user. -default_value_dict = MODEL_SPECIFIC_DEFAULTS.get(FLAGS.model, DEFAULT_KWARGS) -for arg, value in default_value_dict.items(): - if getattr(FLAGS, arg) is None: - setattr(FLAGS, arg, value) - -MODEL_PROPERTIES = { - 'inception_v3': { - 'img_dim': 299, - 'model_fn': lambda: torchvision.models.inception_v3(aux_logits=False) - }, - 'DEFAULT': { - 'img_dim': 224, - 'model_fn': getattr(torchvision.models, FLAGS.model) - } -} - - -def get_model_property(key): - return MODEL_PROPERTIES.get(FLAGS.model, MODEL_PROPERTIES['DEFAULT'])[key] - - -def train_imagenet(): - print('==> Preparing data..') - img_dim = get_model_property('img_dim') - if FLAGS.fake_data: - train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. - train_loader = xu.SampleGenerator( - data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), - torch.zeros(FLAGS.batch_size, dtype=torch.int64)), - sample_count=train_dataset_len // FLAGS.batch_size // - xm.xrt_world_size()) - test_loader = xu.SampleGenerator( - data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), - torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), - sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) - else: - normalize = transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - train_dataset = torchvision.datasets.ImageFolder( - os.path.join(FLAGS.datadir, 'train'), - transforms.Compose([ - transforms.RandomResizedCrop(img_dim), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) - train_dataset_len = len(train_dataset.imgs) - resize_dim = max(img_dim, 256) - test_dataset = torchvision.datasets.ImageFolder( - os.path.join(FLAGS.datadir, 'val'), - # Matches Torchvision's eval transforms except Torchvision uses size - # 256 resize for all models both here and in the train loader. Their - # version crashes during training on 299x299 images, e.g. inception. - transforms.Compose([ - transforms.Resize(resize_dim), - transforms.CenterCrop(img_dim), - transforms.ToTensor(), - normalize, - ])) - - train_sampler = None - test_sampler = None - if xm.xrt_world_size() > 1: - train_sampler = torch.utils.data.distributed.DistributedSampler( - train_dataset, - num_replicas=xm.xrt_world_size(), - rank=xm.get_ordinal(), - shuffle=True) - test_sampler = torch.utils.data.distributed.DistributedSampler( - test_dataset, - num_replicas=xm.xrt_world_size(), - rank=xm.get_ordinal(), - shuffle=False) - train_loader = torch.utils.data.DataLoader( - train_dataset, - batch_size=FLAGS.batch_size, - sampler=train_sampler, - drop_last=FLAGS.drop_last, - shuffle=False if train_sampler else True, - num_workers=FLAGS.num_workers) - test_loader = torch.utils.data.DataLoader( - test_dataset, - batch_size=FLAGS.test_set_batch_size, - sampler=test_sampler, - drop_last=FLAGS.drop_last, - shuffle=False, - num_workers=FLAGS.num_workers) - - torch.manual_seed(42) - - devices = ( - xm.get_xla_supported_devices( - max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) - # Pass [] as device_ids to run using the PyTorch/CPU engine. - torchvision_model = get_model_property('model_fn') - model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) - writer = test_utils.get_summary_writer(FLAGS.logdir) - - def train_loop_fn(model, loader, device, context): - loss_fn = nn.CrossEntropyLoss() - optimizer = context.getattr_or( - 'optimizer', lambda: optim.SGD( - model.parameters(), - lr=FLAGS.lr, - momentum=FLAGS.momentum, - weight_decay=1e-4)) - lr_scheduler = context.getattr_or( - 'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler( - optimizer, - scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), - scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), - scheduler_divide_every_n_epochs=getattr( - FLAGS, 'lr_scheduler_divide_every_n_epochs', None), - num_steps_per_epoch=num_training_steps_per_epoch, - summary_writer=writer if xm.is_master_ordinal() else None)) - tracker = xm.RateTracker() - model.train() - for x, (data, target) in enumerate(loader): - optimizer.zero_grad() - output = model(data) - loss = loss_fn(output, target) - loss.backward() - xm.optimizer_step(optimizer) - tracker.add(FLAGS.batch_size) - if x % FLAGS.log_steps == 0: - test_utils.print_training_update( - device, - x, - loss.item(), - tracker.rate(), - tracker.global_rate(), - summary_writer=writer) - if lr_scheduler: - lr_scheduler.step() - - def test_loop_fn(model, loader, device, context): - total_samples = 0 - correct = 0 - model.eval() - for data, target in loader: - output = model(data) - pred = output.max(1, keepdim=True)[1] - correct += pred.eq(target.view_as(pred)).sum().item() - total_samples += data.size()[0] - - accuracy = 100.0 * correct / total_samples - test_utils.print_test_update(device, accuracy) - return accuracy - - accuracy = 0.0 - num_devices = len( - xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 - num_training_steps_per_epoch = train_dataset_len // ( - FLAGS.batch_size * num_devices) - max_accuracy = 0.0 - for epoch in range(1, FLAGS.num_epochs + 1): - model_parallel(train_loop_fn, train_loader) - accuracies = model_parallel(test_loop_fn, test_loader) - accuracy = mean(accuracies) - max_accuracy = max(accuracy, max_accuracy) - print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) - global_step = (epoch - 1) * num_training_steps_per_epoch - test_utils.write_to_summary( - writer, - global_step, - dict_to_write={'Accuracy/test': accuracy}, - write_xla_metrics=True) - if FLAGS.metrics_debug: - xm.master_print(met.metrics_report()) - - test_utils.close_summary_writer(writer) - print('Max Accuracy: {:.2f}%'.format(max_accuracy)) - return max_accuracy - - -class TrainImageNet(unittest.TestCase): - - def tearDown(self): - super(TrainImageNet, self).tearDown() - - def test_accurracy(self): - self.assertGreaterEqual(train_imagenet(), FLAGS.target_accuracy) - - -# Run the tests. -if __name__ == '__main__': - torch.set_default_tensor_type('torch.FloatTensor') - unittest.main() diff --git a/test/test_train_mnist.py b/test/test_train_mnist.py deleted file mode 100644 index 949e6cc77cd3..000000000000 --- a/test/test_train_mnist.py +++ /dev/null @@ -1,192 +0,0 @@ -import args_parse - -FLAGS = args_parse.parse_common_options( - datadir='/tmp/mnist-data', - batch_size=128, - momentum=0.5, - lr=0.01, - target_accuracy=98.0, - num_epochs=18) - -import os -from statistics import mean -import shutil -import time -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torchvision import datasets, transforms -import torch_xla -import torch_xla.distributed.data_parallel as dp -import torch_xla.debug.metrics as met -import torch_xla.utils.utils as xu -import torch_xla.core.xla_model as xm -import torch_xla.test.test_utils as test_utils -import unittest - - -class MNIST(nn.Module): - - def __init__(self): - super(MNIST, self).__init__() - self.conv1 = nn.Conv2d(1, 10, kernel_size=5) - self.bn1 = nn.BatchNorm2d(10) - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) - self.bn2 = nn.BatchNorm2d(20) - self.fc1 = nn.Linear(320, 50) - self.fc2 = nn.Linear(50, 10) - - def forward(self, x): - x = F.relu(F.max_pool2d(self.conv1(x), 2)) - x = self.bn1(x) - x = F.relu(F.max_pool2d(self.conv2(x), 2)) - x = self.bn2(x) - x = torch.flatten(x, 1) - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return F.log_softmax(x, dim=1) - - -def train_mnist(): - torch.manual_seed(1) - - if FLAGS.fake_data: - train_dataset_len = 60000 # Number of images in MNIST dataset. - train_loader = xu.SampleGenerator( - data=(torch.zeros(FLAGS.batch_size, 1, 28, - 28), torch.zeros(FLAGS.batch_size, - dtype=torch.int64)), - sample_count=train_dataset_len // FLAGS.batch_size // - xm.xrt_world_size()) - test_loader = xu.SampleGenerator( - data=(torch.zeros(FLAGS.batch_size, 1, 28, - 28), torch.zeros(FLAGS.batch_size, - dtype=torch.int64)), - sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) - else: - train_dataset = datasets.MNIST( - FLAGS.datadir, - train=True, - download=True, - transform=transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))])) - train_dataset_len = len(train_dataset) - test_dataset = datasets.MNIST( - FLAGS.datadir, - train=False, - transform=transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))])) - train_sampler = None - test_sampler = None - if xm.xrt_world_size() > 1: - train_sampler = torch.utils.data.distributed.DistributedSampler( - train_dataset, - num_replicas=xm.xrt_world_size(), - rank=xm.get_ordinal(), - shuffle=True) - test_sampler = torch.utils.data.distributed.DistributedSampler( - test_dataset, - num_replicas=xm.xrt_world_size(), - rank=xm.get_ordinal(), - shuffle=False) - train_loader = torch.utils.data.DataLoader( - train_dataset, - batch_size=FLAGS.batch_size, - sampler=train_sampler, - drop_last=FLAGS.drop_last, - shuffle=False if train_sampler else True, - num_workers=FLAGS.num_workers) - test_loader = torch.utils.data.DataLoader( - test_dataset, - batch_size=FLAGS.batch_size, - sampler=test_sampler, - drop_last=FLAGS.drop_last, - shuffle=False, - num_workers=FLAGS.num_workers) - - devices = ( - xm.get_xla_supported_devices( - max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) - # Scale learning rate to num cores - lr = FLAGS.lr * max(len(devices), 1) - # Pass [] as device_ids to run using the PyTorch/CPU engine. - model_parallel = dp.DataParallel(MNIST, device_ids=devices) - - def train_loop_fn(model, loader, device, context): - loss_fn = nn.NLLLoss() - optimizer = context.getattr_or( - 'optimizer', - lambda: optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum)) - tracker = xm.RateTracker() - - model.train() - for x, (data, target) in enumerate(loader): - optimizer.zero_grad() - output = model(data) - loss = loss_fn(output, target) - loss.backward() - xm.optimizer_step(optimizer) - tracker.add(FLAGS.batch_size) - if x % FLAGS.log_steps == 0: - test_utils.print_training_update(device, x, loss.item(), tracker.rate(), - tracker.global_rate()) - - def test_loop_fn(model, loader, device, context): - total_samples = 0 - correct = 0 - model.eval() - for data, target in loader: - output = model(data) - pred = output.max(1, keepdim=True)[1] - correct += pred.eq(target.view_as(pred)).sum().item() - total_samples += data.size()[0] - - accuracy = 100.0 * correct / total_samples - test_utils.print_test_update(device, accuracy) - return accuracy - - accuracy = 0.0 - writer = test_utils.get_summary_writer(FLAGS.logdir) - num_devices = len( - xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 - num_training_steps_per_epoch = train_dataset_len // ( - FLAGS.batch_size * num_devices) - max_accuracy = 0.0 - for epoch in range(1, FLAGS.num_epochs + 1): - model_parallel(train_loop_fn, train_loader) - accuracies = model_parallel(test_loop_fn, test_loader) - accuracy = mean(accuracies) - max_accuracy = max(accuracy, max_accuracy) - print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) - global_step = (epoch - 1) * num_training_steps_per_epoch - test_utils.write_to_summary( - writer, - global_step, - dict_to_write={'Accuracy/test': accuracy}, - write_xla_metrics=True) - if FLAGS.metrics_debug: - xm.master_print(met.metrics_report()) - - test_utils.close_summary_writer(writer) - print('Max Accuracy: {:.2f}%'.format(max_accuracy)) - return max_accuracy - - -class TrainMnist(unittest.TestCase): - - def tearDown(self): - super(TrainMnist, self).tearDown() - if FLAGS.tidy and os.path.isdir(FLAGS.datadir): - shutil.rmtree(FLAGS.datadir) - - def test_accurracy(self): - self.assertGreaterEqual(train_mnist(), FLAGS.target_accuracy) - - -# Run the tests. -if __name__ == '__main__': - torch.set_default_tensor_type('torch.FloatTensor') - unittest.main()