Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into benchmarks-ci
- Loading branch information
Showing
26 changed files
with
798 additions
and
672 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
torch==1.8.1 | ||
torch | ||
torchvision>=0.9.1 | ||
tqdm>=4.40 | ||
requests>=2.25.1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from opacus.layers.dp_rnn import RNNLinear | ||
|
||
|
||
def prepare_layer(layer, batch_first=True): | ||
""" | ||
Prepare a layer to compute grad samples using functorch. | ||
The grad samples are computed by redoing the forward and | ||
backward passes on the functional version of the module. | ||
Args: | ||
layer: the layer to prepare | ||
batch_first: whether the input is batch_first or not | ||
""" | ||
from functorch import grad, make_functional, vmap | ||
|
||
if len(list(layer.buffers())) > 0: | ||
raise NotImplementedError( | ||
"This layer has buffers and is not supported by Opacus" | ||
) | ||
flayer, _ = make_functional(layer) | ||
|
||
def compute_loss_stateless_model(params, activations, backprops): | ||
if batch_first or type(layer) is RNNLinear: | ||
batched_activations = activations.unsqueeze(0) | ||
batched_backprops = backprops.unsqueeze(0) | ||
else: | ||
# If batch_first is False, the batch dimension is the second dimension | ||
batched_activations = activations.unsqueeze(1) | ||
batched_backprops = backprops.unsqueeze(1) | ||
|
||
output = flayer(params, batched_activations) | ||
loss = (output * batched_backprops).sum() | ||
|
||
return loss | ||
|
||
ft_compute_grad = grad(compute_loss_stateless_model) | ||
# Note that the vmap is done on the first dimension, regardless of batch_first | ||
# This is because the activations and backprops given by the GradSampleModule | ||
# are always batch_first=True | ||
layer.ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0)) | ||
|
||
|
||
def ft_compute_per_sample_gradient(layer, activations, backprops): | ||
""" | ||
Compute the per-sample gradient of the layer. | ||
Args: | ||
layer: the layer on which to compute the gradient | ||
activations: the input to the layer | ||
backprops: the gradient of the loss w.r.t. outputs of the layer | ||
""" | ||
parameters = list(layer.parameters()) | ||
if not hasattr(layer, "ft_compute_sample_grad"): | ||
prepare_layer(layer) | ||
|
||
per_sample_grads = layer.ft_compute_sample_grad(parameters, activations, backprops) | ||
|
||
ret = {} | ||
for i_p, p in enumerate(parameters): | ||
ret[p] = per_sample_grads[i_p] | ||
|
||
return ret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.