diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 7df5adcd73f..2dbb962fe2f 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -72,19 +72,25 @@ def evaluate(model, data_loader, device, num_classes): return confmat -def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq): +def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) header = f"Epoch: [{epoch}]" for image, target in metric_logger.log_every(data_loader, print_freq, header): image, target = image.to(device), target.to(device) - output = model(image) - loss = criterion(output, target) + with torch.cuda.amp.autocast(enabled=scaler is not None): + output = model(image) + loss = criterion(output, target) optimizer.zero_grad() - loss.backward() - optimizer.step() + if scaler is not None: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() lr_scheduler.step() @@ -153,6 +159,8 @@ def main(args): params_to_optimize.append({"params": params, "lr": args.lr * 10}) optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + scaler = torch.cuda.amp.GradScaler() if args.amp else None + iters_per_epoch = len(data_loader) main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9 @@ -186,6 +194,8 @@ def main(args): optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 + if args.amp: + scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) @@ -196,7 +206,7 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq) + train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler) confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) print(confmat) checkpoint = { @@ -206,6 +216,8 @@ def main(args): "epoch": epoch, "args": args, } + if args.amp: + checkpoint["scaler"] = scaler.state_dict() utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) @@ -269,6 +281,9 @@ def get_args_parser(add_help=True): # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + # Mixed precision training parameters + parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + return parser