From 3b916e874326cbc36719680dd96317e61a70dfc9 Mon Sep 17 00:00:00 2001 From: Kaiyu Shi Date: Tue, 10 Jul 2018 19:48:40 +0800 Subject: [PATCH] Grad clip for parameters on different devices --- torch/nn/utils/clip_grad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index a81415d75d3f9..db808adcf70b2 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -29,12 +29,12 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2): total_norm = 0 for p in parameters: param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm ** norm_type + total_norm += param_norm.item() ** norm_type total_norm = total_norm ** (1. / norm_type) clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for p in parameters: - p.grad.data.mul_(clip_coef.item()) + p.grad.data.mul_(clip_coef) return total_norm