From a7fa1e65f0b88da5c07981f8eabd95d23a04255e Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 6 Nov 2021 00:23:20 +0800 Subject: [PATCH 1/3] fix bug in amp --- references/classification/train.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 220cf001d60..0d531646fb2 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -30,17 +30,16 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): start_time = time.time() image, target = image.to(device), target.to(device) - output = model(image) + with torch.cuda.amp.autocast(enabled=args.amp): + output = model(image) + loss = criterion(output, target) optimizer.zero_grad() if args.amp: - with torch.cuda.amp.autocast(): - loss = criterion(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: - loss = criterion(output, target) loss.backward() if args.clip_grad_norm is not None: From 4e5f2b4a4bbddcf3208c2afbae0d446db4fb802a Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sat, 6 Nov 2021 10:35:49 +0800 Subject: [PATCH 2/3] fix bug in training by amp --- references/classification/train.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 0d531646fb2..a39a5a50f9e 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -41,11 +41,9 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg scaler.update() else: loss.backward() - - if args.clip_grad_norm is not None: - nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) - - optimizer.step() + if args.clip_grad_norm is not None: + nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) + optimizer.step() if model_ema and i % args.model_ema_steps == 0: model_ema.update_parameters(model) From 88cc0b405ac6c4458ced33f3a0e93a90247ed6ab Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 9 Nov 2021 10:28:51 +0800 Subject: [PATCH 3/3] support use gradient clipping when amp is enabled --- references/classification/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/references/classification/train.py b/references/classification/train.py index a39a5a50f9e..0b855d105c9 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -37,6 +37,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg optimizer.zero_grad() if args.amp: scaler.scale(loss).backward() + if args.clip_grad_norm is not None: + # we should unscale the gradients of optimizer's assigned params if do gradient clipping + scaler.unscale_(optimizer) + nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) scaler.step(optimizer) scaler.update() else: