From 860a8470cccdc2977f0af6e206f9fac6b37c9f8d Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 20 Jul 2017 11:35:58 -0700 Subject: [PATCH] Add an option to perform distributed ImageNet training --- imagenet/main.py | 48 ++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/imagenet/main.py b/imagenet/main.py index 2a1540ba13..310dc21e99 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -7,8 +7,10 @@ import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn +import torch.distributed as dist import torch.optim import torch.utils.data +import torch.utils.data.distributed import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models @@ -49,6 +51,12 @@ help='evaluate model on validation set') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') +parser.add_argument('--world-size', default=1, type=int, + help='number of distributed processes') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='gloo', type=str, + help='distributed backend') best_prec1 = 0 @@ -57,6 +65,12 @@ def main(): global args, best_prec1 args = parser.parse_args() + args.distributed = args.world_size > 1 + + if args.distributed: + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size) + # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) @@ -65,11 +79,15 @@ def main(): print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch]() - if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): - model.features = torch.nn.DataParallel(model.features) - model.cuda() + if not args.distributed: + if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): + model.features = torch.nn.DataParallel(model.features) + model.cuda() + else: + model = torch.nn.DataParallel(model).cuda() else: - model = torch.nn.DataParallel(model).cuda() + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() @@ -78,7 +96,7 @@ def main(): momentum=args.momentum, weight_decay=args.weight_decay) - # optionally resume from a checkpoint + # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) @@ -100,15 +118,23 @@ def main(): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - train_loader = torch.utils.data.DataLoader( - datasets.ImageFolder(traindir, transforms.Compose([ + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ transforms.RandomSizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, - ])), - batch_size=args.batch_size, shuffle=True, - num_workers=args.workers, pin_memory=True) + ])) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(valdir, transforms.Compose([ @@ -125,6 +151,8 @@ def main(): return for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch) # train for one epoch