From 508d01652c0d438ba037bfdd679d892ed6f7c0e1 Mon Sep 17 00:00:00 2001 From: Konstantinos Bozas Date: Thu, 2 Dec 2021 12:50:32 +0000 Subject: [PATCH 1/3] support amp training for video classification models --- references/video_classification/train.py | 53 +++++++++--------------- 1 file changed, 20 insertions(+), 33 deletions(-) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 1f363f57dad..773a0bfbab2 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -12,19 +12,13 @@ from torch.utils.data.dataloader import default_collate from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler -try: - from apex import amp -except ImportError: - amp = None - - try: from torchvision.prototype import models as PM except ImportError: PM = None -def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False): +def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) @@ -34,16 +28,19 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi for video, target in metric_logger.log_every(data_loader, print_freq, header): start_time = time.time() video, target = video.to(device), target.to(device) - output = model(video) - loss = criterion(output, target) + with torch.cuda.amp.autocast(enabled=scaler is not None): + output = model(video) + loss = criterion(output, target) optimizer.zero_grad() - if apex: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() + + if scaler is not None: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() else: loss.backward() - optimizer.step() + optimizer.step() acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = video.shape[0] @@ -101,11 +98,6 @@ def collate_fn(batch): def main(args): if args.weights and PM is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if args.apex and amp is None: - raise RuntimeError( - "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " - "to enable mixed-precision training." - ) if args.output_dir: utils.mkdir(args.output_dir) @@ -224,9 +216,7 @@ def main(args): lr = args.lr * args.world_size optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) - - if args.apex: - model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) + scaler = torch.cuda.amp.GradScaler() if args.amp else None # convert scheduler to be per iteration, not per epoch, for warmup that lasts # between different epochs @@ -267,6 +257,8 @@ def main(args): optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 + if args.amp: + scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: evaluate(model, criterion, data_loader_test, device=device) @@ -278,7 +270,7 @@ def main(args): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch( - model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex + model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.amp ) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: @@ -289,6 +281,8 @@ def main(args): "epoch": epoch, "args": args, } + if args.amp: + checkpoint["scaler"] = scaler.state_dict() utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) @@ -363,17 +357,6 @@ def parse_args(): action="store_true", ) - # Mixed precision training parameters - parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training") - parser.add_argument( - "--apex-opt-level", - default="O1", - type=str, - help="For apex mixed precision training" - "O0 for FP32 training, O1 for mixed precision training." - "For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet", - ) - # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") @@ -381,6 +364,9 @@ def parse_args(): # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + # Mixed precision training parameters + parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + args = parser.parse_args() return args @@ -389,3 +375,4 @@ def parse_args(): if __name__ == "__main__": args = parse_args() main(args) + From e4d49dd3d46d5ecf3cfc75f8d48c47d5d990fbab Mon Sep 17 00:00:00 2001 From: Konstantinos Bozas Date: Thu, 2 Dec 2021 14:36:39 +0000 Subject: [PATCH 2/3] Removed extra empty line and used scaler instead of args.amp as function argument --- references/video_classification/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 773a0bfbab2..665aedebad8 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -270,7 +270,7 @@ def main(args): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch( - model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.amp + model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler ) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: @@ -375,4 +375,3 @@ def parse_args(): if __name__ == "__main__": args = parse_args() main(args) - From 4541dfad3b7803a146c36e58e1609f190cf0d8f0 Mon Sep 17 00:00:00 2001 From: Konstantinos Bozas Date: Thu, 2 Dec 2021 15:30:41 +0000 Subject: [PATCH 3/3] apply formating to pass lint tests --- references/video_classification/train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 665aedebad8..0cd88e8022f 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -269,9 +269,7 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch( - model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler - ) + train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = {