From ba2f991020480d49ad34369c98f2018af59a73d8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 10 Nov 2021 09:01:53 +0000 Subject: [PATCH] Simplify the gradient clipping code. --- references/classification/train.py | 4 ++-- references/classification/utils.py | 8 -------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 0b855d105c9..8fcc9d132dd 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -40,13 +40,13 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg if args.clip_grad_norm is not None: # we should unscale the gradients of optimizer's assigned params if do gradient clipping scaler.unscale_(optimizer) - nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) + nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) scaler.step(optimizer) scaler.update() else: loss.backward() if args.clip_grad_norm is not None: - nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm) + nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) optimizer.step() if model_ema and i % args.model_ema_steps == 0: diff --git a/references/classification/utils.py b/references/classification/utils.py index ac09bd69d86..473684fe162 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -409,11 +409,3 @@ def reduce_across_processes(val): dist.barrier() dist.all_reduce(t) return t - - -def get_optimizer_params(optimizer): - """Generator to iterate over all parameters in the optimizer param_groups.""" - - for group in optimizer.param_groups: - for p in group["params"]: - yield p