diff --git a/references/classification/train.py b/references/classification/train.py index 38ac592237a..a71d337a1b4 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -12,13 +12,10 @@ from torch.utils.data.dataloader import default_collate from torchvision.transforms.functional import InterpolationMode -try: - from apex import amp -except ImportError: - amp = None - -def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False, model_ema=None): +def train_one_epoch( + model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None +): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) @@ -29,13 +26,16 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri start_time = time.time() image, target = image.to(device), target.to(device) output = model(image) - loss = criterion(output, target) optimizer.zero_grad() - if apex: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() + if amp: + with torch.cuda.amp.autocast(): + loss = criterion(output, target) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() else: + loss = criterion(output, target) loss.backward() optimizer.step() @@ -156,12 +156,6 @@ def load_data(traindir, valdir, args): def main(args): - if args.apex and amp is None: - raise RuntimeError( - "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " - "to enable mixed-precision training." - ) - if args.output_dir: utils.mkdir(args.output_dir) @@ -228,8 +222,7 @@ def main(args): else: raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) - if args.apex: - model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) + scaler = torch.cuda.amp.GradScaler() if args.amp else None args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == "steplr": @@ -292,7 +285,9 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema) + train_one_epoch( + model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler + ) lr_scheduler.step() evaluate(model, criterion, data_loader_test, device=device) if model_ema: @@ -385,15 +380,7 @@ def get_args_parser(add_help=True): parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") # Mixed precision training parameters - parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training") - parser.add_argument( - "--apex-opt-level", - default="O1", - type=str, - help="For apex mixed precision training" - "O0 for FP32 training, O1 for mixed precision training." - "For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet", - ) + parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")