-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Foreach gradient clipping #91846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Foreach gradient clipping #91846
Changes from all commits
78e98f9
1f4d702
29ce4e4
88703fc
2e6b473
bcf4315
30ef374
de6f9d0
f23c708
07c6e9b
9b89a86
7e948fe
c3e6f97
c6f0e94
a00a89b
dc37718
71fddce
9477b07
d00dafa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| 3ab8494305810d3c943f670bc6b028514942c7a0 | ||
| eac4e547138ab22a9b41c6f96208613fd7dd19d5 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| from functools import partial | ||
| from collections import OrderedDict | ||
| from tempfile import NamedTemporaryFile | ||
| from unittest import SkipTest | ||
|
|
||
| import torch | ||
|
|
||
|
|
@@ -1654,85 +1655,6 @@ def assign_weight(): | |
| # This should work though | ||
| l2.weight = Parameter(torch.randn(10, 10)) | ||
|
|
||
| def test_clip_grad_norm(self): | ||
| l = nn.Linear(10, 10) | ||
| max_norm = 2 | ||
|
|
||
| def compute_norm(norm_type): | ||
| norm_type = float(norm_type) | ||
| if norm_type != inf: | ||
| total_norm = 0 | ||
| for p in l.parameters(): | ||
| total_norm += p.grad.data.abs().pow(norm_type).sum() | ||
| return pow(total_norm, 1. / norm_type) | ||
| else: | ||
| return max(p.grad.data.abs().max() for p in l.parameters()) | ||
|
|
||
| def compare_scaling(grads): | ||
| p_scale = [p.grad.data.div(g).view(-1) for p, g in zip(l.parameters(), grads)] | ||
| scale = torch.cat(p_scale) | ||
| self.assertEqual(scale.std(), 0) | ||
| return scale[0] | ||
|
|
||
| grads = torch.arange(1., 101).view(10, 10), torch.ones(10).div(1000) | ||
| for norm_type in [0.5, 1.5, 2, 4, 'inf']: | ||
| for p, g in zip(l.parameters(), grads): | ||
| p._grad = g.clone().view_as(p.data) | ||
| norm_before = compute_norm(norm_type) | ||
| norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type) | ||
| norm_after = compute_norm(norm_type) | ||
| self.assertEqual(norm, norm_before) | ||
| self.assertEqual(norm_after, max_norm) | ||
| self.assertLessEqual(norm_after, norm_before) | ||
| compare_scaling(grads) | ||
|
|
||
| # Small gradients should be left unchanged | ||
| grads = torch.rand(10, 10).div(10000), torch.ones(10).div(500) | ||
| for norm_type in [0.5, 1.5, 2, 4, 'inf']: | ||
| for p, g in zip(l.parameters(), grads): | ||
| p.grad.data.copy_(g) | ||
| norm_before = compute_norm(norm_type) | ||
| norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type) | ||
| norm_after = compute_norm(norm_type) | ||
| self.assertEqual(norm, norm_before) | ||
| self.assertEqual(norm_before, norm_after) | ||
| self.assertLessEqual(norm_after, max_norm) | ||
| scale = compare_scaling(grads) | ||
| self.assertEqual(scale, 1) | ||
|
|
||
| # Should accept a single Tensor as input | ||
| p1, p2 = torch.randn(10, 10), torch.randn(10, 10) | ||
| g = torch.arange(1., 101).view(10, 10) | ||
| p1._grad = g.clone() | ||
| p2._grad = g.clone() | ||
| for norm_type in [0.5, 1.5, 2, 4, 'inf']: | ||
| clip_grad_norm_(p1, max_norm, norm_type=norm_type) | ||
| clip_grad_norm_([p2], max_norm, norm_type=norm_type) | ||
| self.assertEqual(p1.grad, p2.grad) | ||
|
|
||
| def test_clip_grad_value(self): | ||
| l = nn.Linear(10, 10) | ||
| clip_value = 2.5 | ||
|
|
||
| grad_w, grad_b = torch.arange(-50., 50).view(10, 10).div_(5), torch.ones(10).mul_(2) | ||
| for grad_list in [[grad_w, grad_b], [grad_w, None]]: | ||
| for p, g in zip(l.parameters(), grad_list): | ||
| p._grad = g.clone().view_as(p.data) if g is not None else g | ||
|
|
||
| clip_grad_value_(l.parameters(), clip_value) | ||
| for p in filter(lambda p: p.grad is not None, l.parameters()): | ||
| self.assertLessEqual(p.grad.data.max(), clip_value) | ||
| self.assertGreaterEqual(p.grad.data.min(), -clip_value) | ||
|
|
||
| # Should accept a single Tensor as input | ||
| p1, p2 = torch.randn(10, 10), torch.randn(10, 10) | ||
| g = torch.arange(-50., 50).view(10, 10).div_(5) | ||
| p1._grad = g.clone() | ||
| p2._grad = g.clone() | ||
| clip_grad_value_(p1, clip_value) | ||
| clip_grad_value_([p2], clip_value) | ||
| self.assertEqual(p1.grad, p2.grad) | ||
|
|
||
| def test_parameters_to_vector(self): | ||
| conv1 = nn.Conv2d(3, 10, 5) | ||
| fc1 = nn.Linear(10, 20) | ||
|
|
@@ -11473,7 +11395,8 @@ def run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, pre | |
|
|
||
| @onlyCUDA | ||
| @deviceCountAtLeast(2) | ||
| def test_clip_grad_norm_multi_device(self, devices): | ||
| @parametrize_test('foreach', (False, True)) | ||
| def test_clip_grad_norm_multi_device(self, devices, foreach): | ||
|
||
| class TestModel(nn.Module): | ||
| def __init__(self): | ||
| super(TestModel, self).__init__() | ||
|
|
@@ -11489,8 +11412,8 @@ def __init__(self): | |
| p.grad = torch.ones_like(p) | ||
| for p in ref_model.parameters(): | ||
| p.grad = torch.ones_like(p) | ||
| norm = clip_grad_norm_(test_model.parameters(), 0.5, norm_type=norm_type) | ||
| expected = clip_grad_norm_(ref_model.parameters(), 0.5, norm_type=norm_type) | ||
| norm = clip_grad_norm_(test_model.parameters(), 0.5, norm_type=norm_type, foreach=foreach) | ||
| expected = clip_grad_norm_(ref_model.parameters(), 0.5, norm_type=norm_type, foreach=foreach) | ||
| self.assertEqual(norm, expected) | ||
| for p, pe in zip(test_model.parameters(), ref_model.parameters()): | ||
| self.assertEqual(p.grad.to(devices[0]), pe.grad) | ||
|
|
@@ -12042,6 +11965,91 @@ def perm_fn(x): | |
| with cm: | ||
| _test(activation=activation, batch_first=batch_first, training=training) | ||
|
|
||
| @parametrize_test('foreach', (False, True)) | ||
| def test_clip_grad_value(self, foreach, device): | ||
| if torch.device(device).type == 'xla' and foreach: | ||
| raise SkipTest('foreach not supported on XLA') | ||
|
|
||
| l = nn.Linear(10, 10).to(device) | ||
| clip_value = 2.5 | ||
|
|
||
| grad_w, grad_b = torch.arange(-50., 50, device=device).view(10, 10).div_(5), torch.ones(10, device=device).mul_(2) | ||
| for grad_list in [[grad_w, grad_b], [grad_w, None]]: | ||
| for p, g in zip(l.parameters(), grad_list): | ||
| p._grad = g.clone().view_as(p.data) if g is not None else g | ||
|
|
||
| clip_grad_value_(l.parameters(), clip_value, foreach=foreach) | ||
| for p in filter(lambda p: p.grad is not None, l.parameters()): | ||
| self.assertLessEqual(p.grad.data.max(), clip_value) | ||
| self.assertGreaterEqual(p.grad.data.min(), -clip_value) | ||
|
|
||
| # Should accept a single Tensor as input | ||
| p1, p2 = torch.randn(10, 10, device=device), torch.randn(10, 10, device=device) | ||
| g = torch.arange(-50., 50, device=device).view(10, 10).div_(5) | ||
| p1._grad = g.clone() | ||
| p2._grad = g.clone() | ||
| clip_grad_value_(p1, clip_value, foreach=foreach) | ||
| clip_grad_value_([p2], clip_value, foreach=foreach) | ||
| self.assertEqual(p1.grad, p2.grad) | ||
|
|
||
| @parametrize_test('foreach', (False, True)) | ||
| @parametrize_test('norm_type', (0.5, 1.5, 2, 4, 'inf')) | ||
| def test_clip_grad_norm(self, norm_type, foreach, device): | ||
| if torch.device(device).type == 'xla' and foreach: | ||
| raise SkipTest('foreach not supported on XLA') | ||
|
|
||
| l = nn.Linear(10, 10).to(device) | ||
| max_norm = 2 | ||
|
|
||
| def compute_norm(norm_type): | ||
| norm_type = float(norm_type) | ||
| if norm_type != inf: | ||
| total_norm = 0 | ||
| for p in l.parameters(): | ||
| total_norm += p.grad.data.abs().pow(norm_type).sum() | ||
| return pow(total_norm, 1. / norm_type) | ||
| else: | ||
| return max(p.grad.data.abs().max() for p in l.parameters()) | ||
|
|
||
| def compare_scaling(grads): | ||
| p_scale = [p.grad.data.div(g).view(-1) for p, g in zip(l.parameters(), grads)] | ||
| scale = torch.cat(p_scale) | ||
| self.assertEqual(scale.std(), 0) | ||
| return scale[0] | ||
|
|
||
| grads = torch.arange(1., 101, device=device).view(10, 10), torch.ones(10, device=device).div(1000) | ||
| for p, g in zip(l.parameters(), grads): | ||
| p._grad = g.clone().view_as(p.data) | ||
| norm_before = compute_norm(norm_type) | ||
| norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type, foreach=foreach) | ||
| norm_after = compute_norm(norm_type) | ||
| self.assertEqual(norm, norm_before) | ||
| self.assertEqual(norm_after, max_norm) | ||
| self.assertLessEqual(norm_after, norm_before) | ||
| compare_scaling(grads) | ||
|
|
||
| # Small gradients should be left unchanged | ||
| grads = torch.rand(10, 10, device=device).div(10000), torch.ones(10, device=device).div(500) | ||
| for p, g in zip(l.parameters(), grads): | ||
| p.grad.data.copy_(g) | ||
| norm_before = compute_norm(norm_type) | ||
| norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type, foreach=foreach) | ||
| norm_after = compute_norm(norm_type) | ||
| self.assertEqual(norm, norm_before) | ||
| self.assertEqual(norm_before, norm_after) | ||
| self.assertLessEqual(norm_after, max_norm) | ||
| scale = compare_scaling(grads) | ||
| self.assertEqual(scale, 1) | ||
|
|
||
| # Should accept a single Tensor as input | ||
| p1, p2 = torch.randn(10, 10, device=device), torch.randn(10, 10, device=device) | ||
| g = torch.arange(1., 101, device=device).view(10, 10) | ||
| p1._grad = g.clone() | ||
| p2._grad = g.clone() | ||
| clip_grad_norm_(p1, max_norm, norm_type=norm_type, foreach=foreach) | ||
| clip_grad_norm_([p2], max_norm, norm_type=norm_type, foreach=foreach) | ||
| self.assertEqual(p1.grad, p2.grad) | ||
|
|
||
|
|
||
| class TestFunctionalPickle(TestCase): | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.