-
Notifications
You must be signed in to change notification settings - Fork 324
/
grad_sample_module.py
488 lines (415 loc) · 17.8 KB
/
grad_sample_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import warnings
from functools import partial
from typing import List, Tuple
import torch
import torch.nn as nn
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient, prepare_layer
from opacus.grad_sample.gsm_base import AbstractGradSampleModule
from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN, RNNLinear
from opacus.utils.module_utils import (
requires_grad,
trainable_modules,
trainable_parameters,
)
logger = logging.getLogger(__name__)
def create_or_accumulate_grad_sample(
*, param: torch.Tensor, grad_sample: torch.Tensor, max_batch_len: int
) -> None:
"""
Creates a ``_current_grad_sample`` attribute in the given parameter, or adds to it
if the ``_current_grad_sample`` attribute already exists.
Args:
param: Parameter to which ``grad_sample`` will be added
grad_sample: Per-sample gradients tensor. Must be of the same
shape as ``param`` with extra batch dimension
layer: nn.Module parameter belongs to
"""
if param.requires_grad:
if hasattr(param, "_current_grad_sample"):
param._current_grad_sample[: grad_sample.shape[0]] += grad_sample
else:
param._current_grad_sample = torch.zeros(
torch.Size([max_batch_len]) + grad_sample.shape[1:],
device=grad_sample.device,
dtype=grad_sample.dtype,
)
param._current_grad_sample[: grad_sample.shape[0]] = grad_sample
def promote_current_grad_sample(p: nn.Parameter) -> None:
if p.requires_grad:
if p.grad_sample is not None:
if isinstance(p.grad_sample, list):
p.grad_sample.append(p._current_grad_sample)
else:
p.grad_sample = [p.grad_sample, p._current_grad_sample]
else:
p.grad_sample = p._current_grad_sample
del p._current_grad_sample
class GradSampleModule(AbstractGradSampleModule):
"""
Hooks-based implementation of AbstractGradSampleModule
Computes per-sample gradients using custom-written methods for each layer.
See README.md for more details
"""
GRAD_SAMPLERS = {}
def __init__(
self,
m: nn.Module,
*,
batch_first=True,
loss_reduction="mean",
strict: bool = True,
force_functorch=False,
):
"""
Args:
m: nn.Module to be wrapped
batch_first: Flag to indicate if the input tensor to the corresponding module
has the first dimension representing the batch. If set to True, dimensions on
input tensor are expected be ``[batch_size, ...]``, otherwise
``[K, batch_size, ...]``
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
is a sum or a mean operation. Can take values "sum" or "mean"
strict: If set to ``True``, the input module will be validated to check that
``GradSampleModule`` has grad sampler functions for all submodules of
the input module (i.e. if it knows how to calculate per sample gradients)
for all model parameters. If set to ``False``, per sample gradients will
be computed on "best effort" basis - they will be available where
possible and set to None otherwise. This is not recommended, because
some unsupported modules (e.g. BatchNorm) affect other parameters and
invalidate the concept of per sample gradients for the entire model.
force_functorch: If set to ``True``, will use functorch to compute
all per sample gradients. Otherwise, functorch will be used only
for layers without registered grad sampler methods.
Raises:
NotImplementedError
If ``strict`` is set to ``True`` and module ``m`` (or any of its
submodules) doesn't have a registered grad sampler function.
"""
super().__init__(
m,
batch_first=batch_first,
loss_reduction=loss_reduction,
)
errors = self.validate(module=m, strict=strict)
if errors and not strict:
logger.info(
f"GradSampleModule found the following errors: {errors}."
"Using non-strict mode, continuing"
)
self.hooks_enabled = False
self.batch_first = batch_first
self.loss_reduction = loss_reduction
self.force_functorch = force_functorch
self.add_hooks(
loss_reduction=loss_reduction,
batch_first=batch_first,
force_functorch=force_functorch,
)
def forward(self, *args, **kwargs):
return self._module(*args, **kwargs)
def add_hooks(
self,
*,
loss_reduction: str = "mean",
batch_first: bool = True,
force_functorch: bool = False,
) -> None:
"""
Adds hooks to model to save activations and backprop values.
The hooks will
1. save activations into param.activations during forward pass
2. compute per-sample gradients in params.grad_sample during backward pass.
Call ``remove_hooks(model)`` to disable this.
Args:
model: the model to which hooks are added
batch_first: Flag to indicate if the input tensor to the corresponding module
has the first dimension representing the batch. If set to True, dimensions on
input tensor are expected be ``[batch_size, ...]``, otherwise
``[K, batch_size, ...]``
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
is a sum or a mean operation. Can take values "sum" or "mean"
force_functorch: If set to ``True``, will use functorch to compute all per sample gradients.
Otherwise, functorch will be used only for layers without registered grad sampler methods.
"""
if hasattr(self._module, "autograd_grad_sample_hooks"):
raise ValueError("Trying to add hooks twice to the same model")
else:
self._module.autograd_grad_sample_hooks = []
self.autograd_grad_sample_hooks = self._module.autograd_grad_sample_hooks
for _module_name, module in trainable_modules(self._module):
# Do not add hooks to DPRNN, DPLSTM or DPGRU as the hooks are handled by the `RNNLinear`
if type(module) in [DPRNN, DPLSTM, DPGRU]:
continue
if force_functorch or not type(module) in self.GRAD_SAMPLERS:
prepare_layer(module, batch_first=batch_first)
self.autograd_grad_sample_hooks.append(
module.register_forward_hook(self.capture_activations_hook)
)
self.autograd_grad_sample_hooks.append(
module.register_backward_hook(
partial(
self.capture_backprops_hook,
loss_reduction=loss_reduction,
batch_first=batch_first,
)
)
)
self.enable_hooks()
def remove_hooks(self) -> None:
"""
Removes hooks added by ``add_hooks()``
"""
self.disable_hooks()
for p in self.parameters():
if hasattr(p, "ddp_hooks"):
while p.ddp_hooks:
handle = p.ddp_hooks.pop()
handle.remove()
delattr(p, "ddp_hooks")
if not hasattr(self, "autograd_grad_sample_hooks"):
raise ValueError("Asked to remove hooks, but no hooks found")
else:
while self.autograd_grad_sample_hooks:
handle = self.autograd_grad_sample_hooks.pop()
handle.remove()
delattr(self, "autograd_grad_sample_hooks")
delattr(self._module, "autograd_grad_sample_hooks")
# Remove functorch hooks
for _module_name, module in trainable_modules(self._module):
if hasattr(module, "ft_compute_sample_grad"):
delattr(module, "ft_compute_sample_grad")
def disable_hooks(self) -> None:
r"""
Globally disable all hooks installed by this library.
Why is this needed? As per https://github.com/pytorch/pytorch/issues/25723, there is
a bug in Autograd that makes removing hooks do nothing if the graph was already
constructed. For this reason, we have this method to at least turn them off.
"""
self.hooks_enabled = False
def enable_hooks(self) -> None:
r"""
The opposite of ``disable_hooks()``. Hooks are always enabled unless you explicitly
disable them so you don't need to call this unless you want to re-enable them.
"""
self.hooks_enabled = True
def _close(self):
super()._close()
self.remove_hooks()
def capture_activations_hook(
self,
module: nn.Module,
forward_input: List[torch.Tensor],
_forward_output: torch.Tensor,
):
if (
not requires_grad(module)
or not module.training
or not torch.is_grad_enabled()
):
return
if not self.hooks_enabled:
return
if not hasattr(module, "activations"):
module.activations = []
module.activations.append(forward_input[0].detach()) # pyre-ignore
for _, p in trainable_parameters(module):
p._forward_counter += 1
def capture_backprops_hook(
self,
module: nn.Module,
_forward_input: torch.Tensor,
forward_output: torch.Tensor,
loss_reduction: str,
batch_first: bool,
):
"""
Computes per sample gradients given the current backprops and activations
stored by the associated forward hook. Computed per sample gradients are
stored in ``grad_sample`` field in each parameter.
For non-recurrent layers the process is straightforward: for each
``loss.backward()`` call this hook will be called exactly one. For recurrent
layers, however, this is more complicated and the hook will be called multiple
times, while still processing the same batch of data.
For this reason we first accumulate the gradients from *the same batch* in
``p._current_grad_sample`` and then, when we detect the end of a full backward
pass - we store accumulated result on ``p.grad_sample``.
From there, ``p.grad_sample`` could be either a Tensor or a list of Tensors,
if accumulated over multiple batches
Args:
module: nn.Module,
_forward_input: torch.Tensor,
forward_output: torch.Tensor,
loss_reduction: str,
batch_first: bool,
"""
if not self.hooks_enabled:
return
backprops = forward_output[0].detach()
activations, backprops = self.rearrange_grad_samples(
module=module,
backprops=backprops,
loss_reduction=loss_reduction,
batch_first=batch_first,
)
if not self.force_functorch and type(module) in self.GRAD_SAMPLERS:
grad_sampler_fn = self.GRAD_SAMPLERS[type(module)]
else:
grad_sampler_fn = ft_compute_per_sample_gradient
grad_samples = grad_sampler_fn(module, activations, backprops)
for param, gs in grad_samples.items():
create_or_accumulate_grad_sample(
param=param, grad_sample=gs, max_batch_len=module.max_batch_len
)
# Detect end of current batch processing and switch accumulation
# mode from sum to stacking. Used for RNNs and tied parameters
# (See #417 for details)
for _, p in trainable_parameters(module):
p._forward_counter -= 1
if p._forward_counter == 0:
promote_current_grad_sample(p)
if len(module.activations) == 0:
if hasattr(module, "max_batch_len"):
del module.max_batch_len
def rearrange_grad_samples(
self,
*,
module: nn.Module,
backprops: torch.Tensor,
loss_reduction: str,
batch_first: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Rearrange activations and grad_samples based on loss reduction and batch dim
Args:
module: the module for which per-sample gradients are computed
backprops: the captured backprops
loss_reduction: either "mean" or "sum" depending on whether backpropped
loss was averaged or summed over batch
batch_first: True is batch dimension is first
"""
if not hasattr(module, "activations"):
raise ValueError(
f"No activations detected for {type(module)},"
" run forward after add_hooks(model)"
)
batch_dim = 0 if batch_first or type(module) is RNNLinear else 1
activations = module.activations.pop()
if not hasattr(module, "max_batch_len"):
# For packed sequences, max_batch_len is set in the forward of the model (e.g. the LSTM)
# Otherwise we infer it here
module.max_batch_len = _get_batch_size(
module=module,
grad_sample=activations,
batch_dim=batch_dim,
)
n = module.max_batch_len
if loss_reduction == "mean":
backprops = backprops * n
elif loss_reduction == "sum":
backprops = backprops
else:
raise ValueError(
f"loss_reduction = {loss_reduction}. Only 'sum' and 'mean' losses are supported"
)
# No matter where the batch dimension was, .grad_samples will *always* put it in the first dim
if batch_dim != 0:
activations = activations.permute(
[batch_dim] + [x for x in range(activations.dim()) if x != batch_dim]
)
backprops = backprops.permute(
[batch_dim] + [x for x in range(backprops.dim()) if x != batch_dim]
)
return activations, backprops
@classmethod
def is_supported(cls, module: nn.Module) -> bool:
"""
Checks if this individual model is supported (i.e. has a registered
grad sampler function)
Notes:
Note that this method does not check submodules
Args:
module: nn.Module to be checked
Returns:
``True`` if grad sampler is found, ``False`` otherwise
"""
warnings.warn(
"GradSampleModule.is_supported is deprecated, as all layers can now be used with functorch.",
DeprecationWarning,
)
return True
@classmethod
def validate(
cls, module: nn.Module, *, strict: bool = False
) -> List[NotImplementedError]:
"""
Check if per sample gradients can be fully computed for a given model
Args:
module: nn.Module to be checked
raise_if_error: Behaviour in case of a negative check result. Will
return the list of exceptions if set to ``False``, and throw otherwise
Returns:
Empty list of validation is successful.
List of validation errors if ``raise_if_error=False`` and
unsupported modules are found
Raises:
NotImplementedError
If ``raise_if_error=True`` and unsupported modules are found
"""
errors = []
errors.extend(
[
NotImplementedError(
f"Model contains a trainable layer "
f"that Opacus doesn't currently support({m_name}:{m}). "
f"Please implement and register grad sampler for this layer. "
f"(See opacus.grad_sample.utils.register_grad_sampler)"
)
for m_name, m in trainable_modules(module)
# With functorch, all modules are trainable
# We still want to avoid module that have buffers (e.g. BatchNorm)
# as the buffers are not private
if len(list(m.buffers())) > 0
]
)
# raise or return errors as needed
if strict and len(errors) > 0:
raise NotImplementedError(errors)
else:
return errors
def _get_batch_size(
*, module: nn.Module, grad_sample: torch.Tensor, batch_dim: int
) -> int:
"""
Computes and returns the maximum batch size which is the maximum of the dimension values
along 'batch_dim' axis over module.activations + [grad_sample], where module.activations is
a list.
If module.activations is a not a list, then return grad_sample.shape[batch_dim].
Args:
module: input module
grad_sample: per sample gradient tensor
batch_dim: batch dimension
Returns:
Maximum sequence length in a batch
"""
max_batch_len = 0
for out in module.activations:
if out.shape[batch_dim] > max_batch_len:
max_batch_len = out.shape[batch_dim]
max_batch_len = max(max_batch_len, grad_sample.shape[batch_dim])
return max_batch_len