Skip to content

Commit

Permalink
Per sample grad correctness util (#532)
Browse files Browse the repository at this point in the history
Summary:
## Types of changes

- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [x] Docs change / refactoring / dependency upgrade

## Motivation and Context / Related issue

Implementation of the utility described in #484.

Refactored the code to avoid code duplicates.

## How Has This Been Tested (if it applies)

Added the new utility as a test case for existing tests stored in `tests.grad_samples`.

## Checklist

- [x] The documentation is up-to-date with the changes I made.
- [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**).
- [x] All tests passed, and additional code has been covered with new tests.

Pull Request resolved: #532

Reviewed By: karthikprasad

Differential Revision: D40797432

Pulled By: ffuuugor

fbshipit-source-id: 923009d6f7f6d4c34bce9f4af39945fdf9ff9d57
  • Loading branch information
psolikov authored and facebook-github-bot committed Feb 10, 2023
1 parent 1e661e8 commit a0a31ba
Show file tree
Hide file tree
Showing 11 changed files with 661 additions and 335 deletions.
317 changes: 26 additions & 291 deletions opacus/tests/grad_samples/common.py
Expand Up @@ -13,18 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import unittest
from typing import Dict, Iterable, List, Tuple, Union
from typing import Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from opacus.grad_sample import wrap_model
from opacus.utils.module_utils import trainable_parameters
from opacus.utils.packed_sequences import compute_seq_lengths
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence
from opacus.utils.per_sample_gradients_utils import (
compute_grad_samples_microbatch_and_opacus,
compute_opacus_grad_sample,
is_batch_empty,
)
from torch.nn.utils.rnn import PackedSequence
from torch.testing import assert_close


Expand All @@ -36,196 +36,12 @@ def shrinker(x, factor: int = 2):
return max(1, x // factor) # if avoid returning 0 for x == 1


def is_batch_empty(batch: Union[torch.Tensor, Iterable[torch.Tensor]]):
if type(batch) is torch.Tensor:
return batch.numel() == 0
else:
return batch[0].numel() == 0


class ModelWithLoss(nn.Module):
"""
To test the gradients of a module, we need to have a loss.
This module makes it easy to get a loss from any nn.Module, and automatically generates
a target y vector for it in the forward (of all zeros of the correct size).
This reduces boilerplate while testing.
"""

supported_reductions = ["mean", "sum"]

def __init__(self, module: nn.Module, loss_reduction: str = "mean"):
"""
Instantiates this module.
Args:
module: The nn.Module you want to test.
loss_reduction: What reduction to apply to the loss. Defaults to "mean".
Raises:
ValueError: If ``loss_reduction`` is not among those supported.
"""
super().__init__()
self.wrapped_module = module

if loss_reduction not in self.supported_reductions:
raise ValueError(
f"Passed loss_reduction={loss_reduction}. Only {self.supported_reductions} supported."
)
self.criterion = nn.L1Loss(reduction=loss_reduction)

def forward(self, x):
if type(x) is tuple:
x = self.wrapped_module(*x)
else:
x = self.wrapped_module(x)
if type(x) is PackedSequence:
loss = _compute_loss_packedsequences(self.criterion, x)
else:
y = torch.zeros_like(x)
loss = self.criterion(x, y)
return loss


def clone_module(module: nn.Module) -> nn.Module:
"""
Handy utility to clone an nn.Module. PyTorch doesn't always support copy.deepcopy(), so it is
just easier to serialize the model to a BytesIO and read it from there.
Args:
module: The module to clone
Returns:
The clone of ``module``
"""
with io.BytesIO() as bytesio:
torch.save(module, bytesio)
bytesio.seek(0)
module_copy = torch.load(bytesio)
return module_copy


class GradSampleHooks_test(unittest.TestCase):
"""
Set of common testing utils. It is meant to be subclassed by your test.
See other tests as an example of how this is done.
"""

def compute_microbatch_grad_sample(
self,
x: Union[torch.Tensor, List[torch.Tensor]],
module: nn.Module,
batch_first=True,
loss_reduction="mean",
chunk_method=iter,
) -> Dict[str, torch.tensor]:
"""
Computes per-sample gradients with the microbatch method, i.e. by computing normal gradients
with batch_size set to 1, and manually accumulating them. This is our reference for testing
as this method is obviously correct, but slow.
Args:
x: The tensor in input to the ``module``
module: The ``ModelWithLoss`` that wraps the nn.Module you want to test.
batch_first: Whether batch size is the first dimension (as opposed to the second).
Defaults to True.
loss_reduction: What reduction to apply to the loss. Defaults to "mean".
chunk_method: The method to use to split the batch into microbatches. Defaults to ``iter``.
Returns:
Dictionary mapping parameter_name -> per-sample-gradient for that parameter
"""
torch.use_deterministic_algorithms(True)
torch.manual_seed(0)
np.random.seed(0)

module = ModelWithLoss(clone_module(module), loss_reduction)

for _, p in trainable_parameters(module):
p.microbatch_grad_sample = []

if not batch_first and type(x) is not list:
# This allows us to iterate with x_i
x = x.transpose(0, 1)

# Invariant: x is [B, T, ...]

for x_i in chunk_method(x):
# x_i is [T, ...]
module.zero_grad()
if type(x_i) is not tuple:
# EmbeddingBag provides tuples
x_i = x_i.unsqueeze(
0 if batch_first else 1
) # x_i of size [1, T, ...] if batch_first, else [T, 1, ...]
loss_i = module(x_i)
loss_i.backward()
for p in module.parameters():
p.microbatch_grad_sample.append(p.grad.detach().clone())

for _, p in trainable_parameters(module):
if batch_first:
p.microbatch_grad_sample = torch.stack(
p.microbatch_grad_sample, dim=0 # [B, T, ...]
)
else:
p.microbatch_grad_sample = torch.stack(
p.microbatch_grad_sample, dim=1 # [T, B, ...]
).transpose(
0, 1
) # Opacus's semantics is that grad_samples are ALWAYS batch_first: [B, T, ...]

microbatch_grad_samples = {
name: p.microbatch_grad_sample
for name, p in trainable_parameters(module.wrapped_module)
}
return microbatch_grad_samples

def compute_opacus_grad_sample(
self,
x: Union[torch.Tensor, PackedSequence],
module: nn.Module,
batch_first=True,
loss_reduction="mean",
grad_sample_mode="hooks",
) -> Dict[str, torch.tensor]:
"""
Runs Opacus to compute per-sample gradients and return them for testing purposes.
Args:
x: The tensor in input to the ``module``
module: The ``ModelWithLoss`` that wraps the nn.Module you want to test.
batch_first: Whether batch size is the first dimension (as opposed to the second).
Defaults to True.
loss_reduction: What reduction to apply to the loss. Defaults to "mean".
Returns:
Dictionary mapping parameter_name -> per-sample-gradient for that parameter
"""
torch.use_deterministic_algorithms(True)
torch.manual_seed(0)
np.random.seed(0)

gs_module = wrap_model(
model=clone_module(module),
grad_sample_mode=grad_sample_mode,
batch_first=batch_first,
loss_reduction=loss_reduction,
)
grad_sample_module = ModelWithLoss(gs_module, loss_reduction)

grad_sample_module.zero_grad()
loss = grad_sample_module(x)
loss.backward()

opacus_grad_samples = {
name: p.grad_sample
for name, p in trainable_parameters(
grad_sample_module.wrapped_module._module
)
}

return opacus_grad_samples

def run_test(
self,
x: Union[torch.Tensor, PackedSequence, Tuple],
Expand All @@ -237,18 +53,17 @@ def run_test(
chunk_method=iter,
):
grad_sample_modes = ["hooks", "functorch"]
try:
import functorch # noqa
except ImportError:
grad_sample_modes = ["hooks"]

if type(module) is nn.EmbeddingBag or (
type(x) is not PackedSequence and is_batch_empty(x)
):
grad_sample_modes = ["hooks"]

for grad_sample_mode in grad_sample_modes:
for loss_reduction in ["sum", "mean"]:
if ew_compatible and batch_first and torch.__version__ >= (1, 13):
grad_sample_modes += ["ew"]

for loss_reduction in ["sum", "mean"]:
for grad_sample_mode in grad_sample_modes:
with self.subTest(
grad_sample_mode=grad_sample_mode, loss_reduction=loss_reduction
):
Expand All @@ -262,17 +77,6 @@ def run_test(
grad_sample_mode=grad_sample_mode,
chunk_method=chunk_method,
)
if ew_compatible and batch_first and torch.__version__ >= (1, 13):
self.run_test_with_reduction(
x,
module,
batch_first=batch_first,
loss_reduction="sum",
atol=atol,
rtol=rtol,
grad_sample_mode="ew",
chunk_method=chunk_method,
)

def run_test_with_reduction(
self,
Expand All @@ -285,40 +89,27 @@ def run_test_with_reduction(
grad_sample_mode="hooks",
chunk_method=iter,
):
opacus_grad_samples = self.compute_opacus_grad_sample(
x,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
grad_sample_mode=grad_sample_mode,
)

if type(x) is PackedSequence:
x_unpacked = _unpack_packedsequences(x)
microbatch_grad_samples = self.compute_microbatch_grad_sample(
x_unpacked,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
)
elif not is_batch_empty(x):
microbatch_grad_samples = self.compute_microbatch_grad_sample(
if not type(x) is PackedSequence and is_batch_empty(x):
_ = compute_opacus_grad_sample(
x,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
chunk_method=chunk_method,
grad_sample_mode=grad_sample_mode,
)
else:
# We've checked opacus can handle 0-sized batch. Microbatch doesn't make sense
return

if microbatch_grad_samples.keys() != opacus_grad_samples.keys():
raise ValueError(
"Keys not matching! "
f"Keys only in microbatch: {microbatch_grad_samples.keys() - opacus_grad_samples.keys()}; "
f"Keys only in Opacus: {opacus_grad_samples.keys() - microbatch_grad_samples.keys()}"
)
(
microbatch_grad_samples,
opacus_grad_samples,
) = compute_grad_samples_microbatch_and_opacus(
x,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
grad_sample_mode=grad_sample_mode,
chunk_method=chunk_method,
)

self.check_shapes(microbatch_grad_samples, opacus_grad_samples, loss_reduction)
self.check_values(
Expand Down Expand Up @@ -388,59 +179,3 @@ def check_values(
f"A total of {len(failed)} values do not match "
f"for loss_reduction={loss_reduction}: \n\t{failed_str}"
)


def _unpack_packedsequences(X: PackedSequence) -> List[torch.Tensor]:
r"""
Produces a list of tensors from X (PackedSequence) such that this list was used to create X with batch_first=True
Args:
X: A PackedSequence from which the output list of tensors will be produced.
Returns:
unpacked_data: The list of tensors produced from X.
"""

X_padded = pad_packed_sequence(X)
X_padded = X_padded[0].permute((1, 0, 2))

if X.sorted_indices is not None:
X_padded = X_padded[X.sorted_indices]

seq_lens = compute_seq_lengths(X.batch_sizes)
unpacked_data = [0] * len(seq_lens)
for idx, length in enumerate(seq_lens):
unpacked_data[idx] = X_padded[idx][:length, :]

return unpacked_data


def _compute_loss_packedsequences(
criterion: nn.L1Loss, x: PackedSequence
) -> torch.Tensor:
r"""
This function computes the loss in a different way for 'mean' reduced L1 loss while for 'sum' reduced L1 loss,
it computes the same way as with non-packed data. For 'mean' reduced L1 loss, it transforms x (PackedSequence)
into a list of tensors such that this list of tensors was used to create this PackedSequence in the first
place using batch_first=True and then takes the mean of the loss values produced from applying criterion on
each sequence sample.
Args:
criterion: An L1 loss function with reduction either set to 'sum' or 'mean'.
x: Data in the form of a PackedSequence.
Returns:
A loss variable, reduced either using summation or averaging from L1 errors.
"""

if criterion.reduction == "sum":
y = torch.zeros_like(x[0])
return criterion(x[0], y)
elif criterion.reduction == "mean":
x = _unpack_packedsequences(x)
loss_sum = 0
for x_i in x:
y_i = torch.zeros_like(x_i)
loss_sum += criterion(x_i, y_i)
loss_mean = loss_sum / len(x)
return loss_mean
2 changes: 2 additions & 0 deletions opacus/tests/grad_samples/dp_multihead_attention_test.py
Expand Up @@ -53,6 +53,7 @@ class MultiHeadAttention_test(GradSampleHooks_test):
add_bias_kv=st.booleans(),
add_zero_attn=st.booleans(),
kv_dim=st.booleans(),
test_or_check=st.integers(1, 2),
)
@settings(deadline=10000)
def test_multihead_attention(
Expand All @@ -65,6 +66,7 @@ def test_multihead_attention(
add_bias_kv: bool,
add_zero_attn: bool,
kv_dim: bool,
test_or_check: int,
):
if kv_dim:
kdim, vdim = D, D
Expand Down

0 comments on commit a0a31ba

Please sign in to comment.