From b29abad616a46dc2562079914d038f2a1339be1f Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 26 Nov 2021 14:16:14 +0800 Subject: [PATCH 1/2] support amp training for segmention models --- references/segmentation/train.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 7df5adcd73f..f3f99e3f911 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -72,19 +72,25 @@ def evaluate(model, data_loader, device, num_classes): return confmat -def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq): +def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) header = f"Epoch: [{epoch}]" for image, target in metric_logger.log_every(data_loader, print_freq, header): image, target = image.to(device), target.to(device) - output = model(image) - loss = criterion(output, target) + with torch.cuda.amp.autocast(enabled=scaler is not None): + output = model(image) + loss = criterion(output, target) optimizer.zero_grad() - loss.backward() - optimizer.step() + if scaler is not None: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() lr_scheduler.step() @@ -152,6 +158,8 @@ def main(args): params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad] params_to_optimize.append({"params": params, "lr": args.lr * 10}) optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + + scaler = torch.cuda.amp.GradScaler() if args.amp else None iters_per_epoch = len(data_loader) main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( @@ -186,6 +194,8 @@ def main(args): optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 + if args.amp: + scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) @@ -196,7 +206,7 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq) + train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler) confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) print(confmat) checkpoint = { @@ -206,6 +216,8 @@ def main(args): "epoch": epoch, "args": args, } + if args.amp: + checkpoint["scaler"] = scaler.state_dict() utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) @@ -269,6 +281,9 @@ def get_args_parser(add_help=True): # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + # Mixed precision training parameters + parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + return parser From 9a500e25ac213578fe23c4a8ed9704376893d32c Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 26 Nov 2021 14:21:02 +0800 Subject: [PATCH 2/2] fix lint --- references/segmentation/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index f3f99e3f911..2dbb962fe2f 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -158,7 +158,7 @@ def main(args): params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad] params_to_optimize.append({"params": params, "lr": args.lr * 10}) optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - + scaler = torch.cuda.amp.GradScaler() if args.amp else None iters_per_epoch = len(data_loader)