diff --git a/opacus/tests/grad_samples/common.py b/opacus/tests/grad_samples/common.py index 41e46321..fb6de5d0 100644 --- a/opacus/tests/grad_samples/common.py +++ b/opacus/tests/grad_samples/common.py @@ -13,18 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import io import unittest -from typing import Dict, Iterable, List, Tuple, Union +from typing import Tuple, Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from opacus.grad_sample import wrap_model -from opacus.utils.module_utils import trainable_parameters -from opacus.utils.packed_sequences import compute_seq_lengths -from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence +from opacus.utils.per_sample_gradients_utils import ( + compute_grad_samples_microbatch_and_opacus, + compute_opacus_grad_sample, + is_batch_empty, +) +from torch.nn.utils.rnn import PackedSequence from torch.testing import assert_close @@ -36,196 +36,12 @@ def shrinker(x, factor: int = 2): return max(1, x // factor) # if avoid returning 0 for x == 1 -def is_batch_empty(batch: Union[torch.Tensor, Iterable[torch.Tensor]]): - if type(batch) is torch.Tensor: - return batch.numel() == 0 - else: - return batch[0].numel() == 0 - - -class ModelWithLoss(nn.Module): - """ - To test the gradients of a module, we need to have a loss. - This module makes it easy to get a loss from any nn.Module, and automatically generates - a target y vector for it in the forward (of all zeros of the correct size). - This reduces boilerplate while testing. - """ - - supported_reductions = ["mean", "sum"] - - def __init__(self, module: nn.Module, loss_reduction: str = "mean"): - """ - Instantiates this module. - - Args: - module: The nn.Module you want to test. - loss_reduction: What reduction to apply to the loss. Defaults to "mean". - - Raises: - ValueError: If ``loss_reduction`` is not among those supported. - """ - super().__init__() - self.wrapped_module = module - - if loss_reduction not in self.supported_reductions: - raise ValueError( - f"Passed loss_reduction={loss_reduction}. Only {self.supported_reductions} supported." - ) - self.criterion = nn.L1Loss(reduction=loss_reduction) - - def forward(self, x): - if type(x) is tuple: - x = self.wrapped_module(*x) - else: - x = self.wrapped_module(x) - if type(x) is PackedSequence: - loss = _compute_loss_packedsequences(self.criterion, x) - else: - y = torch.zeros_like(x) - loss = self.criterion(x, y) - return loss - - -def clone_module(module: nn.Module) -> nn.Module: - """ - Handy utility to clone an nn.Module. PyTorch doesn't always support copy.deepcopy(), so it is - just easier to serialize the model to a BytesIO and read it from there. - - Args: - module: The module to clone - - Returns: - The clone of ``module`` - """ - with io.BytesIO() as bytesio: - torch.save(module, bytesio) - bytesio.seek(0) - module_copy = torch.load(bytesio) - return module_copy - - class GradSampleHooks_test(unittest.TestCase): """ Set of common testing utils. It is meant to be subclassed by your test. See other tests as an example of how this is done. """ - def compute_microbatch_grad_sample( - self, - x: Union[torch.Tensor, List[torch.Tensor]], - module: nn.Module, - batch_first=True, - loss_reduction="mean", - chunk_method=iter, - ) -> Dict[str, torch.tensor]: - """ - Computes per-sample gradients with the microbatch method, i.e. by computing normal gradients - with batch_size set to 1, and manually accumulating them. This is our reference for testing - as this method is obviously correct, but slow. - - Args: - x: The tensor in input to the ``module`` - module: The ``ModelWithLoss`` that wraps the nn.Module you want to test. - batch_first: Whether batch size is the first dimension (as opposed to the second). - Defaults to True. - loss_reduction: What reduction to apply to the loss. Defaults to "mean". - chunk_method: The method to use to split the batch into microbatches. Defaults to ``iter``. - - Returns: - Dictionary mapping parameter_name -> per-sample-gradient for that parameter - """ - torch.use_deterministic_algorithms(True) - torch.manual_seed(0) - np.random.seed(0) - - module = ModelWithLoss(clone_module(module), loss_reduction) - - for _, p in trainable_parameters(module): - p.microbatch_grad_sample = [] - - if not batch_first and type(x) is not list: - # This allows us to iterate with x_i - x = x.transpose(0, 1) - - # Invariant: x is [B, T, ...] - - for x_i in chunk_method(x): - # x_i is [T, ...] - module.zero_grad() - if type(x_i) is not tuple: - # EmbeddingBag provides tuples - x_i = x_i.unsqueeze( - 0 if batch_first else 1 - ) # x_i of size [1, T, ...] if batch_first, else [T, 1, ...] - loss_i = module(x_i) - loss_i.backward() - for p in module.parameters(): - p.microbatch_grad_sample.append(p.grad.detach().clone()) - - for _, p in trainable_parameters(module): - if batch_first: - p.microbatch_grad_sample = torch.stack( - p.microbatch_grad_sample, dim=0 # [B, T, ...] - ) - else: - p.microbatch_grad_sample = torch.stack( - p.microbatch_grad_sample, dim=1 # [T, B, ...] - ).transpose( - 0, 1 - ) # Opacus's semantics is that grad_samples are ALWAYS batch_first: [B, T, ...] - - microbatch_grad_samples = { - name: p.microbatch_grad_sample - for name, p in trainable_parameters(module.wrapped_module) - } - return microbatch_grad_samples - - def compute_opacus_grad_sample( - self, - x: Union[torch.Tensor, PackedSequence], - module: nn.Module, - batch_first=True, - loss_reduction="mean", - grad_sample_mode="hooks", - ) -> Dict[str, torch.tensor]: - """ - Runs Opacus to compute per-sample gradients and return them for testing purposes. - - Args: - x: The tensor in input to the ``module`` - module: The ``ModelWithLoss`` that wraps the nn.Module you want to test. - batch_first: Whether batch size is the first dimension (as opposed to the second). - Defaults to True. - loss_reduction: What reduction to apply to the loss. Defaults to "mean". - - Returns: - Dictionary mapping parameter_name -> per-sample-gradient for that parameter - """ - torch.use_deterministic_algorithms(True) - torch.manual_seed(0) - np.random.seed(0) - - gs_module = wrap_model( - model=clone_module(module), - grad_sample_mode=grad_sample_mode, - batch_first=batch_first, - loss_reduction=loss_reduction, - ) - grad_sample_module = ModelWithLoss(gs_module, loss_reduction) - - grad_sample_module.zero_grad() - loss = grad_sample_module(x) - loss.backward() - - opacus_grad_samples = { - name: p.grad_sample - for name, p in trainable_parameters( - grad_sample_module.wrapped_module._module - ) - } - - return opacus_grad_samples - def run_test( self, x: Union[torch.Tensor, PackedSequence, Tuple], @@ -237,18 +53,17 @@ def run_test( chunk_method=iter, ): grad_sample_modes = ["hooks", "functorch"] - try: - import functorch # noqa - except ImportError: - grad_sample_modes = ["hooks"] if type(module) is nn.EmbeddingBag or ( type(x) is not PackedSequence and is_batch_empty(x) ): grad_sample_modes = ["hooks"] - for grad_sample_mode in grad_sample_modes: - for loss_reduction in ["sum", "mean"]: + if ew_compatible and batch_first and torch.__version__ >= (1, 13): + grad_sample_modes += ["ew"] + + for loss_reduction in ["sum", "mean"]: + for grad_sample_mode in grad_sample_modes: with self.subTest( grad_sample_mode=grad_sample_mode, loss_reduction=loss_reduction ): @@ -262,17 +77,6 @@ def run_test( grad_sample_mode=grad_sample_mode, chunk_method=chunk_method, ) - if ew_compatible and batch_first and torch.__version__ >= (1, 13): - self.run_test_with_reduction( - x, - module, - batch_first=batch_first, - loss_reduction="sum", - atol=atol, - rtol=rtol, - grad_sample_mode="ew", - chunk_method=chunk_method, - ) def run_test_with_reduction( self, @@ -285,40 +89,27 @@ def run_test_with_reduction( grad_sample_mode="hooks", chunk_method=iter, ): - opacus_grad_samples = self.compute_opacus_grad_sample( - x, - module, - batch_first=batch_first, - loss_reduction=loss_reduction, - grad_sample_mode=grad_sample_mode, - ) - - if type(x) is PackedSequence: - x_unpacked = _unpack_packedsequences(x) - microbatch_grad_samples = self.compute_microbatch_grad_sample( - x_unpacked, - module, - batch_first=batch_first, - loss_reduction=loss_reduction, - ) - elif not is_batch_empty(x): - microbatch_grad_samples = self.compute_microbatch_grad_sample( + if not type(x) is PackedSequence and is_batch_empty(x): + _ = compute_opacus_grad_sample( x, module, batch_first=batch_first, loss_reduction=loss_reduction, - chunk_method=chunk_method, + grad_sample_mode=grad_sample_mode, ) - else: # We've checked opacus can handle 0-sized batch. Microbatch doesn't make sense return - - if microbatch_grad_samples.keys() != opacus_grad_samples.keys(): - raise ValueError( - "Keys not matching! " - f"Keys only in microbatch: {microbatch_grad_samples.keys() - opacus_grad_samples.keys()}; " - f"Keys only in Opacus: {opacus_grad_samples.keys() - microbatch_grad_samples.keys()}" - ) + ( + microbatch_grad_samples, + opacus_grad_samples, + ) = compute_grad_samples_microbatch_and_opacus( + x, + module, + batch_first=batch_first, + loss_reduction=loss_reduction, + grad_sample_mode=grad_sample_mode, + chunk_method=chunk_method, + ) self.check_shapes(microbatch_grad_samples, opacus_grad_samples, loss_reduction) self.check_values( @@ -388,59 +179,3 @@ def check_values( f"A total of {len(failed)} values do not match " f"for loss_reduction={loss_reduction}: \n\t{failed_str}" ) - - -def _unpack_packedsequences(X: PackedSequence) -> List[torch.Tensor]: - r""" - Produces a list of tensors from X (PackedSequence) such that this list was used to create X with batch_first=True - - Args: - X: A PackedSequence from which the output list of tensors will be produced. - - Returns: - unpacked_data: The list of tensors produced from X. - """ - - X_padded = pad_packed_sequence(X) - X_padded = X_padded[0].permute((1, 0, 2)) - - if X.sorted_indices is not None: - X_padded = X_padded[X.sorted_indices] - - seq_lens = compute_seq_lengths(X.batch_sizes) - unpacked_data = [0] * len(seq_lens) - for idx, length in enumerate(seq_lens): - unpacked_data[idx] = X_padded[idx][:length, :] - - return unpacked_data - - -def _compute_loss_packedsequences( - criterion: nn.L1Loss, x: PackedSequence -) -> torch.Tensor: - r""" - This function computes the loss in a different way for 'mean' reduced L1 loss while for 'sum' reduced L1 loss, - it computes the same way as with non-packed data. For 'mean' reduced L1 loss, it transforms x (PackedSequence) - into a list of tensors such that this list of tensors was used to create this PackedSequence in the first - place using batch_first=True and then takes the mean of the loss values produced from applying criterion on - each sequence sample. - - Args: - criterion: An L1 loss function with reduction either set to 'sum' or 'mean'. - x: Data in the form of a PackedSequence. - - Returns: - A loss variable, reduced either using summation or averaging from L1 errors. - """ - - if criterion.reduction == "sum": - y = torch.zeros_like(x[0]) - return criterion(x[0], y) - elif criterion.reduction == "mean": - x = _unpack_packedsequences(x) - loss_sum = 0 - for x_i in x: - y_i = torch.zeros_like(x_i) - loss_sum += criterion(x_i, y_i) - loss_mean = loss_sum / len(x) - return loss_mean diff --git a/opacus/tests/grad_samples/dp_multihead_attention_test.py b/opacus/tests/grad_samples/dp_multihead_attention_test.py index aef5abab..dc1270b6 100644 --- a/opacus/tests/grad_samples/dp_multihead_attention_test.py +++ b/opacus/tests/grad_samples/dp_multihead_attention_test.py @@ -53,6 +53,7 @@ class MultiHeadAttention_test(GradSampleHooks_test): add_bias_kv=st.booleans(), add_zero_attn=st.booleans(), kv_dim=st.booleans(), + test_or_check=st.integers(1, 2), ) @settings(deadline=10000) def test_multihead_attention( @@ -65,6 +66,7 @@ def test_multihead_attention( add_bias_kv: bool, add_zero_attn: bool, kv_dim: bool, + test_or_check: int, ): if kv_dim: kdim, vdim = D, D diff --git a/opacus/tests/grad_samples/dp_rnn_test.py b/opacus/tests/grad_samples/dp_rnn_test.py index b6f64bb0..c72f8be7 100644 --- a/opacus/tests/grad_samples/dp_rnn_test.py +++ b/opacus/tests/grad_samples/dp_rnn_test.py @@ -93,4 +93,5 @@ def test_rnn( x = torch.randn([N, T, D]) else: x = torch.randn([T, N, D]) + self.run_test(x, rnn, batch_first=batch_first, ew_compatible=False) diff --git a/opacus/tests/grad_samples/instance_norm1d_test.py b/opacus/tests/grad_samples/instance_norm1d_test.py index ec7aea80..e2e284f4 100644 --- a/opacus/tests/grad_samples/instance_norm1d_test.py +++ b/opacus/tests/grad_samples/instance_norm1d_test.py @@ -22,18 +22,11 @@ class InstanceNorm1d_test(GradSampleHooks_test): - @given( - N=st.integers(1, 4), - C=st.integers(1, 3), - W=st.integers(5, 10), - ) + @given(N=st.integers(1, 4), C=st.integers(1, 3), W=st.integers(5, 10)) @settings(deadline=10000) - def test_3d_input( - self, - N: int, - C: int, - W: int, - ): + def test_3d_input(self, N: int, C: int, W: int): + x = torch.randn([N, C, W]) norm = nn.InstanceNorm1d(num_features=C, affine=True, track_running_stats=False) + self.run_test(x, norm, batch_first=True) diff --git a/opacus/tests/grad_samples/instance_norm2d_test.py b/opacus/tests/grad_samples/instance_norm2d_test.py index 278832bc..81fae806 100644 --- a/opacus/tests/grad_samples/instance_norm2d_test.py +++ b/opacus/tests/grad_samples/instance_norm2d_test.py @@ -29,13 +29,8 @@ class InstanceNorm2d_test(GradSampleHooks_test): H=st.integers(4, 8), ) @settings(deadline=10000) - def test_4d_input( - self, - N: int, - C: int, - W: int, - H: int, - ): + def test_4d_input(self, N: int, C: int, W: int, H: int): + x = torch.randn([N, C, H, W]) norm = nn.InstanceNorm2d(num_features=C, affine=True, track_running_stats=False) self.run_test(x, norm, batch_first=True) diff --git a/opacus/tests/grad_samples/instance_norm3d_test.py b/opacus/tests/grad_samples/instance_norm3d_test.py index 1b3b3de3..d46050ec 100644 --- a/opacus/tests/grad_samples/instance_norm3d_test.py +++ b/opacus/tests/grad_samples/instance_norm3d_test.py @@ -30,14 +30,7 @@ class InstanceNorm3d_test(GradSampleHooks_test): Z=st.integers(1, 4), ) @settings(deadline=10000) - def test_5d_input( - self, - N: int, - C: int, - W: int, - H: int, - Z: int, - ): + def test_5d_input(self, N: int, C: int, W: int, H: int, Z: int): x = torch.randn([N, C, Z, H, W]) norm = nn.InstanceNorm3d(num_features=C, affine=True, track_running_stats=False) self.run_test(x, norm, batch_first=True) diff --git a/opacus/tests/grad_samples/layer_norm_test.py b/opacus/tests/grad_samples/layer_norm_test.py index cb7190de..065de61c 100644 --- a/opacus/tests/grad_samples/layer_norm_test.py +++ b/opacus/tests/grad_samples/layer_norm_test.py @@ -32,16 +32,20 @@ class LayerNorm_test(GradSampleHooks_test): ) @settings(deadline=10000) def test_input_norm( - self, - N: int, - Z: int, - W: int, - H: int, - input_dim: int, - norm_dim: int, + self, N: int, Z: int, W: int, H: int, input_dim: int, norm_dim: int ): if norm_dim >= input_dim: return + normalized_shape, x_shape = self.get_x_shape_and_norm_shape( + H, N, W, Z, input_dim, norm_dim + ) + + norm = nn.LayerNorm(normalized_shape, elementwise_affine=True) + x = torch.randn(x_shape) + self.run_test(x, norm, batch_first=True) + + @staticmethod + def get_x_shape_and_norm_shape(H, N, W, Z, input_dim, norm_dim): if norm_dim == 1: normalized_shape = W if input_dim == 2: @@ -60,7 +64,4 @@ def test_input_norm( elif norm_dim == 3: normalized_shape = [Z, H, W] x_shape = [N, Z, H, W] - - norm = nn.LayerNorm(normalized_shape, elementwise_affine=True) - x = torch.randn(x_shape) - self.run_test(x, norm, batch_first=True) + return normalized_shape, x_shape diff --git a/opacus/tests/grad_samples/sequence_bias_test.py b/opacus/tests/grad_samples/sequence_bias_test.py index de346951..9c069eac 100644 --- a/opacus/tests/grad_samples/sequence_bias_test.py +++ b/opacus/tests/grad_samples/sequence_bias_test.py @@ -29,13 +29,8 @@ class SequenceBias_test(GradSampleHooks_test): batch_first=st.booleans(), ) @settings(deadline=10000) - def test_batch_second( - self, - N: int, - T: int, - D: int, - batch_first: bool, - ): + def test_batch_second(self, N: int, T: int, D: int, batch_first: bool): + seqbias = SequenceBias(D, batch_first) if batch_first: x = torch.randn([N, T, D]) diff --git a/opacus/tests/per_sample_gradients_utils_test.py b/opacus/tests/per_sample_gradients_utils_test.py new file mode 100644 index 00000000..074bc9fb --- /dev/null +++ b/opacus/tests/per_sample_gradients_utils_test.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Callable + +import hypothesis.strategies as st +import torch +from hypothesis import given, settings + +from .grad_samples.common import expander, shrinker +from opacus.utils.per_sample_gradients_utils import ( + check_per_sample_gradients_are_correct, + get_grad_sample_modes, +) +from torch import nn + + +class PerSampleGradientsUtilsTest(unittest.TestCase): + def per_sample_grads_utils_test( + self, + x, + model, + grad_sample_mode, + is_empty=False, + atol=10e-5, + rtol=10e-4, + ): + if is_empty: + with self.assertRaises(RuntimeError): + check_per_sample_gradients_are_correct( + x, + model, + batch_first=True, + atol=atol, + rtol=rtol, + grad_sample_mode=grad_sample_mode, + ) + return + + assert check_per_sample_gradients_are_correct( + x, + model, + batch_first=True, + atol=atol, + rtol=rtol, + grad_sample_mode=grad_sample_mode, + ) + + @given( + N=st.integers(0, 4), + C=st.sampled_from([1, 3, 32]), + W=st.integers(6, 10), + out_channels_mapper=st.sampled_from([expander, shrinker]), + kernel_size=st.integers(2, 3), + stride=st.integers(1, 2), + padding=st.sampled_from([0, 1, 2, "same", "valid"]), + dilation=st.integers(1, 2), + groups=st.integers(1, 12), + grad_sample_mode=st.sampled_from(get_grad_sample_modes(use_ew=True)), + ) + @settings(deadline=10000) + def test_conv1d( + self, + N: int, + C: int, + W: int, + out_channels_mapper: Callable[[int], int], + kernel_size: int, + stride: int, + padding: int, + dilation: int, + groups: int, + grad_sample_mode: str, + ): + if padding == "same" and stride != 1: + return + out_channels = out_channels_mapper(C) + if ( + C % groups != 0 or out_channels % groups != 0 + ): # since in_channels and out_channels must be divisible by groups + return + + x = torch.randn([N, C, W]) + conv = nn.Conv1d( + in_channels=C, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + ew_compatible = N > 0 + + if not ew_compatible and grad_sample_mode == "ew": + return + + self.per_sample_grads_utils_test(x, conv, grad_sample_mode, N == 0) + + @given( + N=st.integers(0, 4), + Z=st.integers(1, 4), + H=st.integers(1, 3), + W=st.integers(10, 17), + input_dim=st.integers(2, 4), + bias=st.booleans(), + batch_first=st.booleans(), + grad_sample_mode=st.sampled_from(get_grad_sample_modes(use_ew=True)), + ) + @settings(deadline=10000) + def test_linear( + self, + N: int, + Z: int, + H: int, + W: int, + input_dim: int, + bias: bool, + batch_first: bool, + grad_sample_mode: str, + ): + + if input_dim == 2: + if not batch_first: + return # see https://github.com/pytorch/opacus/pull/265 + else: + x_shape = [N, W] + if input_dim == 3: + x_shape = [N, Z, W] + if input_dim == 4: + x_shape = [N, Z, H, W] + + linear = nn.Linear(W, W + 2, bias=bias) + x = torch.randn(x_shape) + if not batch_first: + x = x.transpose(0, 1) + ew_compatible = N > 0 + + if not ew_compatible and grad_sample_mode == "ew": + return + + self.per_sample_grads_utils_test(x, linear, grad_sample_mode, N == 0) diff --git a/opacus/utils/per_sample_gradients_utils.py b/opacus/utils/per_sample_gradients_utils.py new file mode 100644 index 00000000..754d86d6 --- /dev/null +++ b/opacus/utils/per_sample_gradients_utils.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +from typing import Callable, Dict, Iterable, List, Union + +import numpy as np +import torch +import torch.nn as nn +from opacus.grad_sample import wrap_model +from opacus.utils.module_utils import trainable_parameters +from opacus.utils.packed_sequences import compute_seq_lengths +from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence + + +def clone_module(module: nn.Module) -> nn.Module: + """ + Handy utility to clone an nn.Module. PyTorch doesn't always support copy.deepcopy(), so it is + just easier to serialize the model to a BytesIO and read it from there. + + Args: + module: The module to clone + + Returns: + The clone of ``module`` + """ + with io.BytesIO() as bytesio: + torch.save(module, bytesio) + bytesio.seek(0) + module_copy = torch.load(bytesio) + return module_copy + + +def is_batch_empty(batch: Union[torch.Tensor, Iterable[torch.Tensor]]): + if type(batch) is torch.Tensor: + return batch.numel() == 0 + else: + return batch[0].numel() == 0 + + +class ModelWithLoss(nn.Module): + """ + To test the gradients of a module, we need to have a loss. + This module makes it easy to get a loss from any nn.Module, and automatically generates + a target y vector for it in the forward (of all zeros of the correct size). + This reduces boilerplate while testing. + """ + + supported_reductions = ["mean", "sum"] + + def __init__(self, module: nn.Module, loss_reduction: str = "mean"): + """ + Instantiates this module. + + Args: + module: The nn.Module you want to test. + loss_reduction: What reduction to apply to the loss. Defaults to "mean". + + Raises: + ValueError: If ``loss_reduction`` is not among those supported. + """ + super().__init__() + self.wrapped_module = module + + if loss_reduction not in self.supported_reductions: + raise ValueError( + f"Passed loss_reduction={loss_reduction}. Only {self.supported_reductions} supported." + ) + self.criterion = nn.L1Loss(reduction=loss_reduction) + + def forward(self, x): + if type(x) is tuple: + x = self.wrapped_module(*x) + else: + x = self.wrapped_module(x) + if type(x) is PackedSequence: + loss = _compute_loss_packedsequences(self.criterion, x) + else: + y = torch.zeros_like(x) + loss = self.criterion(x, y) + return loss + + +def compute_microbatch_grad_sample( + x: Union[torch.Tensor, List[torch.Tensor]], + module: nn.Module, + batch_first: bool = True, + loss_reduction: str = "mean", + chunk_method: Callable = iter, +) -> Dict[str, torch.tensor]: + """ + Computes per-sample gradients with the microbatch method, i.e. by computing normal gradients + with batch_size set to 1, and manually accumulating them. This is our reference for testing + as this method is obviously correct, but slow. + + Args: + x: Sample input batch + module: The nn.Module you want to test. + batch_first: Whether batch size is the first dimension (as opposed to the second). + Defaults to True. + loss_reduction: What reduction to apply to the loss. Defaults to "mean". + chunk_method: The method to use to split the batch into microbatches. Defaults to ``iter``. + + Returns: + Dictionary mapping parameter_name -> per-sample-gradient for that parameter + """ + torch.use_deterministic_algorithms(True) + torch.manual_seed(0) + np.random.seed(0) + + module = ModelWithLoss(clone_module(module), loss_reduction) + + for _, p in trainable_parameters(module): + p.microbatch_grad_sample = [] + + if not batch_first and type(x) is not list: + # This allows us to iterate with x_i + x = x.transpose(0, 1) + + # Invariant: x is [B, T, ...] + + for x_i in chunk_method(x): + # x_i is [T, ...] + module.zero_grad() + if type(x_i) is not tuple: + # EmbeddingBag provides tuples + x_i = x_i.unsqueeze( + 0 if batch_first else 1 + ) # x_i of size [1, T, ...] if batch_first, else [T, 1, ...] + loss_i = module(x_i) + loss_i.backward() + for p in module.parameters(): + p.microbatch_grad_sample.append(p.grad.detach().clone()) + + for _, p in trainable_parameters(module): + if batch_first: + p.microbatch_grad_sample = torch.stack( + p.microbatch_grad_sample, dim=0 # [B, T, ...] + ) + else: + p.microbatch_grad_sample = torch.stack( + p.microbatch_grad_sample, dim=1 # [T, B, ...] + ).transpose( + 0, 1 + ) # Opacus's semantics is that grad_samples are ALWAYS batch_first: [B, T, ...] + + microbatch_grad_samples = { + name: p.microbatch_grad_sample + for name, p in trainable_parameters(module.wrapped_module) + } + return microbatch_grad_samples + + +def compute_opacus_grad_sample( + x: Union[torch.Tensor, PackedSequence], + module: nn.Module, + batch_first: bool = True, + loss_reduction: str = "mean", + grad_sample_mode: str = "hooks", +) -> Dict[str, torch.tensor]: + """ + Runs Opacus to compute per-sample gradients and return them for testing purposes. + + Args: + x: Sample input batch + module: The nn.Module you want to test. + batch_first: Whether batch size is the first dimension (as opposed to the second). + Defaults to True. + loss_reduction: What reduction to apply to the loss. Defaults to "mean". + grad_sample_mode: What sampling method to use to get gradients. + + Returns: + Dictionary mapping parameter_name -> per-sample-gradient for that parameter + """ + torch.use_deterministic_algorithms(True) + torch.manual_seed(0) + np.random.seed(0) + + gs_module = wrap_model( + model=clone_module(module), + grad_sample_mode=grad_sample_mode, + batch_first=batch_first, + loss_reduction=loss_reduction, + ) + grad_sample_module = ModelWithLoss(gs_module, loss_reduction) + + grad_sample_module.zero_grad() + loss = grad_sample_module(x) + loss.backward() + + opacus_grad_samples = { + name: p.grad_sample + for name, p in trainable_parameters(grad_sample_module.wrapped_module._module) + } + + return opacus_grad_samples + + +def check_torch_version_for_ew_sample() -> bool: + return torch.__version__ >= (1, 13) + + +def get_grad_sample_modes(use_ew: bool = False): + grad_sample_modes = ["hooks", "functorch"] + if use_ew and check_torch_version_for_ew_sample(): + grad_sample_modes.append("ew") + return grad_sample_modes + + +def check_per_sample_gradients_are_correct( + x: Union[torch.Tensor, PackedSequence], + module: nn.Module, + *, + batch_first: bool = True, + atol: float = 10e-6, + rtol: float = 10e-5, + grad_sample_mode: str = "hooks", +) -> bool: + """ + A utility to check whether per sample gradients are computed correctly with a particular model. + The check is performed by comparing the result of the slow but reliable micro-batch method `compute_microbatch_grad_sample` + with the result of optimized opacus method. + + Args: + x: Sample input batch + module: The ``ModelWithLoss`` that wraps the nn.Module you want to check. + batch_first: Whether batch size is the first dimension (as opposed to the second). + Defaults to True. + atol: The relative tolerance parameter (torch.allclose). + rtol: The absolute tolerance parameter (torch.allclose). + grad_sample_mode: What sampling method to use to get gradients. + + Returns: True if per sample gradients were computed correctly. False otherwise. + + Example: + >>> N, Z, W = 100, 10, 10 + >>> x_shape = [N, Z, W] + >>> x = torch.randn(x_shape) + >>> model = nn.Linear(W, W + 2) + >>> assert check_per_sample_gradients_are_correct( + ... x, + ... model + ... ) # This will fail only if the opacus per sample gradients do not match the micro-batch gradients. + """ + reductions = ["sum", "mean"] + if grad_sample_mode == "ew": + if not batch_first: + raise RuntimeError("Batch should be first dimension.") + if not check_torch_version_for_ew_sample(): + raise RuntimeError(f"Unsupported torch version: {torch.__version__}.") + + for loss_reduction in reductions: + if not _check_per_sample_gradients_are_correct_with_reduction( + x, + module, + batch_first=batch_first, + loss_reduction=loss_reduction, + atol=atol, + rtol=rtol, + grad_sample_mode=grad_sample_mode, + ): + return False + + return True + + +def compute_microbatch_grad_sample_tensor_or_seq( + x: Union[torch.Tensor, PackedSequence], + module: nn.Module, + batch_first: bool = True, + loss_reduction: str = "mean", +): + if type(x) is PackedSequence: + x_unpacked = unpack_packedsequences(x) + microbatch_grad_samples = compute_microbatch_grad_sample( + x_unpacked, + module, + batch_first=batch_first, + loss_reduction=loss_reduction, + ) + else: + microbatch_grad_samples = compute_microbatch_grad_sample( + x, module, batch_first=batch_first, loss_reduction=loss_reduction + ) + + return microbatch_grad_samples + + +def compute_grad_samples_microbatch_and_opacus( + x: Union[torch.Tensor, PackedSequence], + module: nn.Module, + batch_first: bool = True, + loss_reduction: str = "mean", + grad_sample_mode: str = "hooks", + chunk_method: Callable = iter, +): + if type(x) is PackedSequence: + x_unpacked = unpack_packedsequences(x) + microbatch_grad_samples = compute_microbatch_grad_sample( + x_unpacked, + module, + batch_first=batch_first, + loss_reduction=loss_reduction, + chunk_method=chunk_method, + ) + elif not is_batch_empty(x): + microbatch_grad_samples = compute_microbatch_grad_sample( + x, + module, + batch_first=batch_first, + loss_reduction=loss_reduction, + chunk_method=chunk_method, + ) + else: + raise RuntimeError("x is expected to be non-empty.") + + opacus_grad_samples = compute_opacus_grad_sample( + x, + module, + batch_first=batch_first, + loss_reduction=loss_reduction, + grad_sample_mode=grad_sample_mode, + ) + + if microbatch_grad_samples.keys() != opacus_grad_samples.keys(): + raise ValueError( + "Keys not matching! " + f"Keys only in microbatch: {microbatch_grad_samples.keys() - opacus_grad_samples.keys()}; " + f"Keys only in Opacus: {opacus_grad_samples.keys() - microbatch_grad_samples.keys()}" + ) + + return microbatch_grad_samples, opacus_grad_samples + + +def _check_per_sample_gradients_are_correct_with_reduction( + x: Union[torch.Tensor, PackedSequence], + module: nn.Module, + batch_first: bool = True, + loss_reduction: str = "mean", + atol: float = 10e-6, + rtol: float = 10e-5, + grad_sample_mode: str = "hooks", +) -> bool: + ( + microbatch_grad_samples, + opacus_grad_samples, + ) = compute_grad_samples_microbatch_and_opacus( + x, + module, + batch_first=batch_first, + loss_reduction=loss_reduction, + grad_sample_mode=grad_sample_mode, + ) + + for name, opacus_grad_sample in opacus_grad_samples.items(): + microbatch_grad_sample = microbatch_grad_samples[name] + if not opacus_grad_sample.shape == microbatch_grad_sample.shape: + return False + if not torch.allclose(microbatch_grad_sample, opacus_grad_sample, atol, rtol): + return False + return True + + +def unpack_packedsequences(X: PackedSequence) -> List[torch.Tensor]: + r""" + Produces a list of tensors from X (PackedSequence) such that this list was used to create X with batch_first=True + + Args: + X: A PackedSequence from which the output list of tensors will be produced. + + Returns: + unpacked_data: The list of tensors produced from X. + """ + + X_padded = pad_packed_sequence(X) + X_padded = X_padded[0].permute((1, 0, 2)) + + if X.sorted_indices is not None: + X_padded = X_padded[X.sorted_indices] + + seq_lens = compute_seq_lengths(X.batch_sizes) + unpacked_data = [0] * len(seq_lens) + for idx, length in enumerate(seq_lens): + unpacked_data[idx] = X_padded[idx][:length, :] + + return unpacked_data + + +def _compute_loss_packedsequences( + criterion: nn.L1Loss, x: PackedSequence +) -> torch.Tensor: + r""" + This function computes the loss in a different way for 'mean' reduced L1 loss while for 'sum' reduced L1 loss, + it computes the same way as with non-packed data. For 'mean' reduced L1 loss, it transforms x (PackedSequence) + into a list of tensors such that this list of tensors was used to create this PackedSequence in the first + place using batch_first=True and then takes the mean of the loss values produced from applying criterion on + each sequence sample. + + Args: + criterion: An L1 loss function with reduction either set to 'sum' or 'mean'. + x: Data in the form of a PackedSequence. + + Returns: + A loss variable, reduced either using summation or averaging from L1 errors. + """ + + if criterion.reduction == "sum": + y = torch.zeros_like(x[0]) + return criterion(x[0], y) + elif criterion.reduction == "mean": + x = unpack_packedsequences(x) + loss_sum = 0 + for x_i in x: + y_i = torch.zeros_like(x_i) + loss_sum += criterion(x_i, y_i) + loss_mean = loss_sum / len(x) + return loss_mean diff --git a/tutorials/guide_to_grad_sampler.ipynb b/tutorials/guide_to_grad_sampler.ipynb index 94f95f84..0076cd74 100644 --- a/tutorials/guide_to_grad_sampler.ipynb +++ b/tutorials/guide_to_grad_sampler.ipynb @@ -246,6 +246,33 @@ "\n", "If you have any questions or comments, please don't hesitate to post them on our [forum](https://discuss.pytorch.org/c/opacus/29)." ] + }, + { + "cell_type": "markdown", + "source": [ + "### Per-sample-gradients correctness utility\n", + "[Here](https://github.com/pytorch/opacus/blob/main/opacus/utils/per_sample_gradients_utils.py) you can find a simple utility function `check_per_sample_gradients_are_correct` that checks if the gradient sampler works correctly with a particular module." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "x_shape = [N, Z, W]\n", + "x = torch.randn(x_shape)\n", + "model = nn.Linear(W, W + 2)\n", + "assert check_per_sample_gradients_are_correct(\n", + " x,\n", + " model\n", + " ) # This will fail only if the opacus per sample gradients do not match the micro-batch gradients." + ], + "metadata": { + "collapsed": false + } } ], "metadata": {