-
Notifications
You must be signed in to change notification settings - Fork 324
/
functorch.py
61 lines (48 loc) · 2.14 KB
/
functorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from opacus.layers.dp_rnn import RNNLinear
def prepare_layer(layer, batch_first=True):
"""
Prepare a layer to compute grad samples using functorch.
The grad samples are computed by redoing the forward and
backward passes on the functional version of the module.
Args:
layer: the layer to prepare
batch_first: whether the input is batch_first or not
"""
from functorch import grad, make_functional, vmap
if len(list(layer.buffers())) > 0:
raise NotImplementedError(
"This layer has buffers and is not supported by Opacus"
)
flayer, _ = make_functional(layer)
def compute_loss_stateless_model(params, activations, backprops):
if batch_first or type(layer) is RNNLinear:
batched_activations = activations.unsqueeze(0)
batched_backprops = backprops.unsqueeze(0)
else:
# If batch_first is False, the batch dimension is the second dimension
batched_activations = activations.unsqueeze(1)
batched_backprops = backprops.unsqueeze(1)
output = flayer(params, batched_activations)
loss = (output * batched_backprops).sum()
return loss
ft_compute_grad = grad(compute_loss_stateless_model)
# Note that the vmap is done on the first dimension, regardless of batch_first
# This is because the activations and backprops given by the GradSampleModule
# are always batch_first=True
layer.ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0))
def ft_compute_per_sample_gradient(layer, activations, backprops):
"""
Compute the per-sample gradient of the layer.
Args:
layer: the layer on which to compute the gradient
activations: the input to the layer
backprops: the gradient of the loss w.r.t. outputs of the layer
"""
parameters = list(layer.parameters())
if not hasattr(layer, "ft_compute_sample_grad"):
prepare_layer(layer)
per_sample_grads = layer.ft_compute_sample_grad(parameters, activations, backprops)
ret = {}
for i_p, p in enumerate(parameters):
ret[p] = per_sample_grads[i_p]
return ret