From 0fee2ee6fbda14829d8530c35ca38e444262ad29 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 6 Oct 2021 16:47:18 +0100 Subject: [PATCH 1/3] Updated classification reference script to use torch.cuda.amp --- references/classification/train.py | 40 +++++++++++------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 48ab75bc2c1..282c9d24eb6 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -13,14 +13,9 @@ import transforms import utils -try: - from apex import amp -except ImportError: - amp = None - def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, - print_freq, apex=False, model_ema=None): + print_freq, amp=False, model_ema=None, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) @@ -31,13 +26,16 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, start_time = time.time() image, target = image.to(device), target.to(device) output = model(image) - loss = criterion(output, target) optimizer.zero_grad() - if apex: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() + if amp: + with torch.cuda.amp.autocast(): + loss = criterion(output, target) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() else: + loss = criterion(output, target) loss.backward() optimizer.step() @@ -149,10 +147,6 @@ def load_data(traindir, valdir, args): def main(args): - if args.apex and amp is None: - raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " - "to enable mixed-precision training.") - if args.output_dir: utils.mkdir(args.output_dir) @@ -205,10 +199,8 @@ def main(args): else: raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) - if args.apex: - model, optimizer = amp.initialize(model, optimizer, - opt_level=args.apex_opt_level - ) + if args.amp: + scaler = torch.cuda.amp.GradScaler() args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == 'steplr': @@ -267,7 +259,8 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema) + train_one_epoch( + model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler) lr_scheduler.step() evaluate(model, criterion, data_loader_test, device=device) if model_ema: @@ -357,13 +350,8 @@ def get_args_parser(add_help=True): parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)') # Mixed precision training parameters - parser.add_argument('--apex', action='store_true', - help='Use apex for mixed precision training') - parser.add_argument('--apex-opt-level', default='O1', type=str, - help='For apex mixed precision training' - 'O0 for FP32 training, O1 for mixed precision training.' - 'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet' - ) + parser.add_argument('--amp', action='store_true', + help='Use torch.cuda.amp for mixed precision training') # distributed training parameters parser.add_argument('--world-size', default=1, type=int, From 5b2fdc77aa9300d182bb5524857154b69ea7464d Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 6 Oct 2021 16:58:57 +0100 Subject: [PATCH 2/3] Assigned scaler to None if amp is False --- references/classification/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index b754654617d..21028541429 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -221,8 +221,7 @@ def main(args): else: raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) - if args.amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.cuda.amp.GradScaler() if args.amp else None args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == "steplr": From 70a222f93b476a52ee6db63fd05029ad4a14746a Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Thu, 7 Oct 2021 11:34:56 +0100 Subject: [PATCH 3/3] Fixed linter errors --- references/classification/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 21028541429..a71d337a1b4 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -13,8 +13,9 @@ from torchvision.transforms.functional import InterpolationMode -def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, - print_freq, amp=False, model_ema=None, scaler=None): +def train_one_epoch( + model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None +): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) @@ -285,7 +286,8 @@ def main(args): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch( - model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler) + model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler + ) lr_scheduler.step() evaluate(model, criterion, data_loader_test, device=device) if model_ema: @@ -378,8 +380,7 @@ def get_args_parser(add_help=True): parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") # Mixed precision training parameters - parser.add_argument("--amp", action="store_true", - help="Use torch.cuda.amp for mixed precision training") + parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")