Skip to content

Commit 60b7522

Browse files
EnayatUllahfacebook-github-bot
authored andcommittedFeb 11, 2025
Fix freezing modules in Ghost Clipping (#729)
Summary: Freezing modules with ghost clipping throws an error as corresponding per-sample norms are (not) calculated. Fix: keep in memory the list of all parameters and checking if corresponding requires_grad is True when calculating norms. Further, unfreezing modules (with and without ghost clipping) wasn't supported because the hooks aren't present for the corresponding modules. Fix: rewrite `requires_grad_' to add the hook. Facebook We initially used a `trainable_parameters(module)` to traverse the list of trainable modules upon norm computation. It was slow because `trainable_parameters(module)` is a generator and it traverses the neural network graph overtime. We replaced it with a list of trainable parameters fixed during model creation time. This is what lead to issues with freezing modules as this list is not updated. Fix: Use **all parameters** **list** -- not a generator, so no traversal happens. Further, we check `requires_grad` when calculating per-sample norm to ascertain whether to compute it or not. This is how this check is done in (non-private) [optimizer](https://github.com/pytorch/pytorch/blob/5725462cd8679dd1dea8a469b1bf2e71f226b664/torch/optim/optimizer.py#L963) to determine which parameters are frozen or not. Differential Revision: D68656459
1 parent 0d186a4 commit 60b7522

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed
 

‎opacus/grad_sample/grad_sample_module.py

+14
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,20 @@ def __init__(
145145
force_functorch=force_functorch,
146146
)
147147

148+
def requires_grad_(self, requires_grad: bool = True) -> nn.Module:
149+
" Rewrite requires_grad_ to add/remove hooks based on requires_grad value "
150+
if requires_grad:
151+
# Attack hook to the module
152+
self.add_hooks(
153+
loss_reduction=self.loss_reduction,
154+
batch_first=self.batch_first,
155+
force_functorch=self.force_functorch,
156+
)
157+
else:
158+
# Remove hooks
159+
self.remove_hooks()
160+
return super().requires_grad_(requires_grad)
161+
148162
def forward(self, *args, **kwargs):
149163
return self._module(*args, **kwargs)
150164

‎opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(
117117
strict=strict,
118118
force_functorch=force_functorch,
119119
)
120-
self.trainable_parameters = [p for _, p in trainable_parameters(self._module)]
120+
self.all_parameters = [p for p in self.parameters()]
121121
self.max_grad_norm = max_grad_norm
122122
self.use_ghost_clipping = use_ghost_clipping
123123
self._per_sample_gradient_norms = None
@@ -130,7 +130,12 @@ def get_clipping_coef(self) -> torch.Tensor:
130130
def get_norm_sample(self) -> torch.Tensor:
131131
"""Get per-example gradient norms."""
132132
norm_sample = torch.stack(
133-
[param._norm_sample for param in self.trainable_parameters], dim=0
133+
[
134+
param._norm_sample
135+
for param in self.all_parameters
136+
if param.requires_grad
137+
],
138+
dim=0,
134139
).norm(2, dim=0)
135140
self.per_sample_gradient_norms = norm_sample
136141
return norm_sample

0 commit comments

Comments
 (0)
Failed to load comments.