Skip to content

Commit 79d3bba

Browse files
EnayatUllahfacebook-github-bot
authored andcommitted
Add per-sample gradient norm computation as a functionality
Summary: Per-sample gradient norm is computed for Ghost Clipping, but it can be useful generally. Exposed it as a functionality. ``` ... loss.backward() per_sample_norms = model.per_sample_gradient_norms ``` Differential Revision: D68634969
1 parent c7d6144 commit 79d3bba

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import logging
1919
from typing import List
2020

21-
import torch
22-
import torch.nn as nn
2321
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient
2422
from opacus.grad_sample.grad_sample_module import (
2523
GradSampleModule,
@@ -28,6 +26,9 @@
2826
)
2927
from opacus.utils.module_utils import requires_grad, trainable_parameters
3028

29+
import torch
30+
import torch.nn as nn
31+
3132

3233
logger = logging.getLogger(__name__)
3334
logger.disabled = True
@@ -120,6 +121,7 @@ def __init__(
120121
self.trainable_parameters = [p for _, p in trainable_parameters(self._module)]
121122
self.max_grad_norm = max_grad_norm
122123
self.use_ghost_clipping = use_ghost_clipping
124+
self._per_sample_gradient_norms = None
123125

124126
def get_clipping_coef(self) -> torch.Tensor:
125127
"""Get per-example gradient scaling factor for clipping."""
@@ -131,6 +133,7 @@ def get_norm_sample(self) -> torch.Tensor:
131133
norm_sample = torch.stack(
132134
[param._norm_sample for param in self.trainable_parameters], dim=0
133135
).norm(2, dim=0)
136+
self.per_sample_gradient_norms = norm_sample
134137
return norm_sample
135138

136139
def capture_activations_hook(
@@ -231,3 +234,16 @@ def capture_backprops_hook(
231234
if len(module.activations) == 0:
232235
if hasattr(module, "max_batch_len"):
233236
del module.max_batch_len
237+
238+
@property
239+
def per_sample_gradient_norms(self) -> torch.Tensor:
240+
if self._per_sample_gradient_norms is not None:
241+
return self._per_sample_gradient_norms
242+
else:
243+
raise AttributeError(
244+
"per_sample_gradient_norms is not set. Please call forward and backward on the model before accessing this property."
245+
)
246+
247+
@per_sample_gradient_norms.setter
248+
def per_sample_gradient_norms(self, value):
249+
self._per_sample_gradient_norms = value

opacus/optimizers/optimizer_fast_gradient_clipping.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def zero_grad(self, set_to_none: bool = False):
146146

147147
for p in self.params:
148148
p.grad_sample = None
149+
p._per_sample_gradient_norms = None
149150

150151
if not self._is_last_step_skipped:
151152
p.summed_grad = None

0 commit comments

Comments
 (0)