From e8bc932c1e6a6a1f8a4b1b00f0f7ef134b5edf4d Mon Sep 17 00:00:00 2001 From: Igor Shilov Date: Tue, 18 Apr 2023 09:12:18 -0700 Subject: [PATCH] Clean-up unnecessary warnings (including update to PyTorch 2.0) (#581) Summary: This PR is a collection of smaller fixes that will save us some deprecation issues in the future ## 1. Updating to PyTorch 2.0 **Key files: grad_sample/functorch.py, requirements.txt** `functorch` has been a part of core PyTorch since 1.13. Now they're going a step further and changing the API, while deprecating the old one. There's a [guide](https://pytorch.org/docs/master/func.migrating.html) on how to migrate. TL;DR - `make_functional` will no longer be part of the API, with `torch.func.functional_call()` being (non drop-in) replacement. They key difference for us is `make_functional()` creates a fresh copy of the module, while `functional_call()` uses existing module. As a matter of fact, we need the fresh copy (otherwise all the hooks start firing and you enter nested madness), so I've copy-pasted a [gist](https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf) from the official guide on how to get a full replacement for `make_functional`. ## 2. New mechanism for gradient accumulation detection **Key file: privacy_engine.py, grad_sample_module.py** As [reported](https://discuss.pytorch.org/t/gan-raises-userwarning-using-a-non-full-backward-hook-when-the-forward-contains-multiple/175638/2) on the forum, clients are still getting "non-full backward hook" warning even when using `grad_sample_mode="ew"`. Naturally, `functorch` and `hooks` modes rely on backward hooks and can't be migrated to full hooks because [reasons](https://github.com/pytorch/opacus/issues/328#issuecomment-1027361968). However, `ew` doesn't rely on hooks and it's unclear why the message should appear. The reason, however, is simple. If the client is using poisson sampling we add an extra check to prohibit gradient accumulation (two poisson batches combined is not a poisson batch), and we do that by the means of backward hooks. ~In this case, backward hook serves a simple purpose and there shouldn't be any problems with migrating to the new method, however that involved changing the checking method. That's because `register_backward_hook` is called *after* hooks on submodule, but `register_full_backward_hook` is called before.~ Strikethrough solution didn't work, because hook order execution is weird for complex graphs, e.g. for GANs. For example, if your forward call looks like this: ``` Discriminator(Generator(x)) ``` then top-level module hook will precede submodule's hooks for `Generator`, but not for `Discriminator` As such, I've realised that gradient accumulation is not even supported in `ExpandedWeights`, so we don't have to worry about that. And the other two modes are both hooks-based, so we can just check the accumulation in the existing backward hook, no need for an extra hook. Deleted some code, profit. ## 3. Refactoring `wrap_collate_with_empty` to please pickle Now here're two facts I didn't know before 1) You can't pickle a nested function, e.g. you can't do the following ```python def foo(): def bar(): <...> return bar pickle.dump(foo(), ...) ``` 2) Whether or not `multiprocessing` uses pickle is python- and platform- dependant. This affects our tests when we test `DataLoader` with multiple workers. As such, our data loaders tests: * Pass on CircleCI with python3.9 * Fail on my local machine with python3.9 * Pass on my local machine with python3.7 I'm not sure how cow common the issue is, but it's safer to just refactor `wrap_collate_with_empty` to avoid nested functions. ## 4. Fix benchmark tests We don't really run `benchmarks/tests` on a regular basis, and some of them were broken since we've upgraded to PyTorch 1.13 (`API_CUTOFF_VERSION` doesn't exist anymore) ## 4. Fix flake8 config Flake8 config no [longer support](https://flake8.pycqa.org/en/latest/user/configuration.html) inline comments, fix is due Pull Request resolved: https://github.com/pytorch/opacus/pull/581 Reviewed By: alexandresablayrolles Differential Revision: D44749760 Pulled By: ffuuugor fbshipit-source-id: cf225f4134c049da4ee2eef53e1af3ef54d090bf --- .circleci/config.yml | 60 ++++---- .circleci/flake8_config.ini | 157 ++++++++++++++------- benchmarks/tests/test_layers.py | 29 +--- examples/cifar10.py | 4 +- opacus/data_loader.py | 53 +++++-- opacus/grad_sample/__init__.py | 2 + opacus/grad_sample/functorch.py | 42 +++++- opacus/grad_sample/grad_sample_module.py | 15 ++ opacus/grad_sample/gsm_base.py | 23 +++ opacus/privacy_engine.py | 49 +------ opacus/tests/grad_samples/common.py | 2 +- opacus/tests/gradient_accumulation_test.py | 16 ++- requirements.txt | 2 +- 13 files changed, 277 insertions(+), 177 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 46019409..d1544d00 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -6,16 +6,16 @@ version: 2.1 commands: - py_3_7_setup: - description: "Install and switch to Python 3.7.5; also install pip and pytest." + py_3_9_setup: + description: "Install and switch to Python 3.9; also install pip and pytest." steps: - run: - name: "Setup Python v3.7.5 environment" + name: "Setup Python v3.9 environment" command: | cd /opt/circleci/.pyenv && git pull && cd - - pyenv install -s 3.7.5 - pyenv global 3.7.5 - pyenv local 3.7.5 + pyenv install -s 3.9.4 + pyenv global 3.9.4 + pyenv local 3.9.4 pyenv versions echo "In venv: $(pyenv local) - $(python -V), $(pip -V)" sudo "$(which python)" -m pip install --upgrade pip @@ -283,9 +283,9 @@ commands: jobs: - lint_py37_torch_release: + lint_py39_torch_release: docker: - - image: cimg/python:3.7.5 + - image: cimg/python:3.9 steps: - checkout - pip_dev_install @@ -293,14 +293,6 @@ jobs: - lint_black - isort - unittest_py37_torch_release: - docker: - - image: cimg/python:3.7.5 - steps: - - checkout - - pip_dev_install - - unit_tests - unittest_py38_torch_release: docker: - image: cimg/python:3.8 @@ -328,10 +320,10 @@ jobs: prv_accountant_values: docker: - - image: cimg/python:3.7.5 + - image: cimg/python:3.9 steps: - checkout - - py_3_7_setup + - py_3_9_setup - pip_dev_install - run: name: "Unit test prv accountant" @@ -339,23 +331,23 @@ jobs: command: | python -m unittest opacus.tests.prv_accountant - integrationtest_py37_torch_release_cpu: + integrationtest_py39_torch_release_cpu: docker: - - image: cimg/python:3.7.5 + - image: cimg/python:3.9 steps: - checkout - - py_3_7_setup + - py_3_9_setup - pip_dev_install - mnist_integration_test: device: "cpu" - integrationtest_py37_torch_release_cuda: + integrationtest_py39_torch_release_cuda: machine: resource_class: gpu.nvidia.small.multi image: ubuntu-2004-cuda-11.4:202110-01 steps: - checkout - - py_3_7_setup + - py_3_9_setup - pip_dev_install - run_nvidia_smi - mnist_integration_test: @@ -369,13 +361,13 @@ jobs: - dcgan_integration_test: device: "cuda" - micro_benchmarks_py37_torch_release_cuda: + micro_benchmarks_py39_torch_release_cuda: machine: resource_class: gpu.nvidia.small.multi image: ubuntu-2004-cuda-11.4:202110-01 steps: - checkout - - py_3_7_setup + - py_3_9_setup - pip_dev_install - run_nvidia_smi - benchmark_layers_integration_test: @@ -459,7 +451,7 @@ jobs: image: ubuntu-2004-cuda-11.4:202110-01 steps: - checkout - - py_3_7_setup + - py_3_9_setup - pip_dev_install - run_nvidia_smi - command_unit_tests_multi_gpu @@ -493,9 +485,7 @@ workflows: not: equal: [ scheduled_pipeline, << pipeline.trigger_source >> ] jobs: - - lint_py37_torch_release: - filters: *exclude_ghpages - - unittest_py37_torch_release: + - lint_py39_torch_release: filters: *exclude_ghpages - unittest_py38_torch_release: filters: *exclude_ghpages @@ -505,9 +495,9 @@ workflows: filters: *exclude_ghpages - unittest_multi_gpu: filters: *exclude_ghpages - - integrationtest_py37_torch_release_cpu: + - integrationtest_py39_torch_release_cpu: filters: *exclude_ghpages - - integrationtest_py37_torch_release_cuda: + - integrationtest_py39_torch_release_cuda: filters: *exclude_ghpages - prv_accountant_values: filters: *exclude_ghpages @@ -518,12 +508,12 @@ workflows: jobs: - unittest_py39_torch_nightly: filters: *exclude_ghpages - - integrationtest_py37_torch_release_cpu: + - integrationtest_py39_torch_release_cpu: filters: *exclude_ghpages - - integrationtest_py37_torch_release_cuda: + - integrationtest_py39_torch_release_cuda: filters: *exclude_ghpages - - lint_py37_torch_release: + - lint_py39_torch_release: filters: *exclude_ghpages - - micro_benchmarks_py37_torch_release_cuda: + - micro_benchmarks_py39_torch_release_cuda: filters: *exclude_ghpages diff --git a/.circleci/flake8_config.ini b/.circleci/flake8_config.ini index 896a3ef3..988ed7f3 100644 --- a/.circleci/flake8_config.ini +++ b/.circleci/flake8_config.ini @@ -5,60 +5,113 @@ max-line-length = 80 ignore = # Black conflicts and overlaps. # Found in https://github.com/psf/black/issues/429 - B950, # Line too long. (Use `arc lint`'s LINEWRAP instead) - E111, # Indentation is not a multiple of four. - E115, # Expected an indented block (comment). - E117, # Over-indented. - E121, # Continuation line under-indented for hanging indent. - E122, # Continuation line missing indentation or outdented. - E123, # Closing bracket does not match indentation of opening bracket's line. - E124, # Closing bracket does not match visual indentation. - E125, # Continuation line with same indent as next logical line. - E126, # Continuation line over-indented for hanging indent. - E127, # Continuation line over-indented for visual indent. - E128, # Continuation line under-indented for visual indent. - E129, # Visually indented line with same indent as next logical line. - E201, # Whitespace after '('. - E202, # Whitespace before ')'. - E203, # Whitespace before ':'. - E221, # Multiple spaces before operator. - E222, # Multiple spaces after operator. - E225, # Missing whitespace around operator. - E226, # Missing whitespace around arithmetic operator. - E227, # Missing whitespace around bitwise or shift operator. - E231, # Missing whitespace after ',', ';', or ':'. - E241, # Multiple spaces after ','. - E251, # Unexpected spaces around keyword / parameter equals. - E261, # At least two spaces before inline comment. - E262, # Inline comment should start with '# '. - E265, # Block comment should start with '# '. - E271, # Multiple spaces after keyword. - E272, # Multiple spaces before keyword. - E301, # Expected 1 blank line, found 0. - E302, # Expected 2 blank lines, found 0. - E303, # Too many blank lines (3). - E305, # Expected 2 blank lines after end of function or class. - E306, # Expected 1 blank line before a nested definition. - E501, # Line too long (82 > 79 characters). - E502, # The backslash is redundant between brackets. - E701, # Multiple statements on one line (colon). - E702, # Multiple statements on one line (semicolon). - E703, # Statement ends with a semicolon. - E704, # Multiple statements on one line (def). - W291, # Trailing whitespace. - W292, # No newline at end of file. - W293, # Blank line contains whitespace. - W391, # Blank line at end of file. + # B950: Line too long. (Use `arc lint`'s LINEWRAP instead) + B950, + # E111: Indentation is not a multiple of four. + E111, + # E115: Expected an indented block (comment). + E115, + # E117: Over-indented. + E117, + # E121: Continuation line under-indented for hanging indent. + E121, + # E122: Continuation line missing indentation or outdented. + E122, + # E123: Closing bracket does not match indentation of opening bracket's line. + E123, + # E124: Closing bracket does not match visual indentation. + E124, + # E125: Continuation line with same indent as next logical line. + E125, + # E126: Continuation line over-indented for hanging indent. + E126, + # E127: Continuation line over-indented for visual indent. + E127, + # E128: Continuation line under-indented for visual indent. + E128, + # E129: Visually indented line with same indent as next logical line. + E129, + # E201: Whitespace after '('. + E201, + # E202: Whitespace before ')'. + E202, + # E203: Whitespace before ':'. + E203, + # E221: Multiple spaces before operator. + E221, + # E222: Multiple spaces after operator. + E222, + # E225: Missing whitespace around operator. + E225, + # E226: Missing whitespace around arithmetic operator. + E226, + # E227: Missing whitespace around bitwise or shift operator. + E227, + # E231: Missing whitespace after ',', ';', or ':'. + E231, + # E241: Multiple spaces after ','. + E241, + # E251: Unexpected spaces around keyword / parameter equals. + E251, + # E261: At least two spaces before inline comment. + E261, + # E262: Inline comment should start with '# '. + E262, + # E265: Block comment should start with '# '. + E265, + # E271: Multiple spaces after keyword. + E271, + # E272: Multiple spaces before keyword. + E272, + # E301: Expected 1 blank line, found 0. + E301, + # E302: Expected 2 blank lines, found 0. + E302, + # E303: Too many blank lines (3). + E303, + # E305: Expected 2 blank lines after end of function or class. + E305, + # E306: Expected 1 blank line before a nested definition. + E306, + # E501: Line too long (82 > 79 characters). + E501, + # E502: The backslash is redundant between brackets. + E502, + # E701: Multiple statements on one line (colon). + E701, + # E702: Multiple statements on one line (semicolon). + E702, + # E703: Statement ends with a semicolon. + E703, + # E704: Multiple statements on one line (def). + E704, + # W291: Trailing whitespace. + W291, + # W292: No newline at end of file. + W292, + # W293: Blank line contains whitespace. + W293, + # W391: Blank line at end of file. + W391, # Too opinionated. - E265, # Block comment should start with '# '. - E266, # Too many leading '#' for block comment. - E402, # Module level import not at top of file. - E722, # Do not use bare except, specify exception instead. (Duplicate of B001) - F811, # Redefinition of unused name from line n. - P207, # (Duplicate of B003) - P208, # (Duplicate of C403) - W503 # Line break occurred before a binary operator. + # E265: Block comment should start with '# '. + E265, + # E266: Too many leading '#' for block comment. + E266, + # E402: Module level import not at top of file. + E402, + # E722: Do not use bare except, specify exception instead. (Duplicate of B001) + E722, + # F811: Redefinition of unused name from line n. + F811, + # P207: (Duplicate of B003) + P207, + # P208: (Duplicate of C403) + P208, + # W503: Line break occurred before a binary operator. + W503 + exclude = .hg, __pycache__, diff --git a/benchmarks/tests/test_layers.py b/benchmarks/tests/test_layers.py index 8bf7fade..593f515e 100644 --- a/benchmarks/tests/test_layers.py +++ b/benchmarks/tests/test_layers.py @@ -15,36 +15,21 @@ import copy import logging import random -from typing import Any, Dict, List, Set, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import pytest import torch import torch.nn as nn from helpers import skipifnocuda from opacus.grad_sample import GradSampleModule -from opacus.grad_sample.gsm_exp_weights import ( - API_CUTOFF_VERSION, - GradSampleModuleExpandedWeights, -) +from opacus.grad_sample.gsm_exp_weights import GradSampleModuleExpandedWeights from opacus.layers import DPGRU, DPLSTM, DPRNN, DPMultiheadAttention from benchmarks.layers import LayerFactory from benchmarks.utils import reset_peak_memory_stats -def _gsm_modes() -> Set[str]: - ret = ["baseline", "hooks"] - try: - import functorch # noqa: F401, Checking for import errors - - ret += ["functorch"] - except ImportError: - pass - - if torch.__version__ >= API_CUTOFF_VERSION: - ret += ["ew"] - return set(ret) - +GSM_MODES = {"baseline", "hooks", "ew", "functorch"} PARAMETERS = [ ( @@ -121,7 +106,7 @@ def test_layer_modules( """ for layer_name, module, gsm_mode_blocklist in layer_list: - for gsm_mode in _gsm_modes() - set(gsm_mode_blocklist): + for gsm_mode in GSM_MODES - set(gsm_mode_blocklist): if gsm_mode in gsm_mode_blocklist: continue @@ -161,7 +146,7 @@ def test_to_device( assert reset_peak_memory_stats(cuda).cur_mem == 0 for layer_name, module, gsm_mode_blocklist in layer_list: - for gsm_mode in _gsm_modes() - set(gsm_mode_blocklist): + for gsm_mode in GSM_MODES - set(gsm_mode_blocklist): layer = LayerFactory.create( layer_name=layer_name, batch_size=64, @@ -207,7 +192,7 @@ def test_layer_outputs( } for layer_name, module, gsm_mode_blocklist in layer_list: - for gsm_mode in _gsm_modes() - set(gsm_mode_blocklist): + for gsm_mode in GSM_MODES - set(gsm_mode_blocklist): for random_seed in (random_seed_a, random_seed_b): logging.error(f"{gsm_mode}, {layer_name}") layer = LayerFactory.create( @@ -246,7 +231,7 @@ def test_forward_backward( layer_config: config for instantiating the layers in layer_list """ for layer_name, module, gsm_mode_blocklist in layer_list: - for gsm_mode in _gsm_modes() - set(gsm_mode_blocklist): + for gsm_mode in GSM_MODES - set(gsm_mode_blocklist): layer = LayerFactory.create( layer_name=layer_name, batch_size=64, diff --git a/examples/cifar10.py b/examples/cifar10.py index 970eb3e7..61eafc19 100644 --- a/examples/cifar10.py +++ b/examples/cifar10.py @@ -33,6 +33,8 @@ import torchvision.transforms as transforms from opacus import PrivacyEngine from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP +from opacus.grad_sample.functorch import make_functional +from torch.func import grad, grad_and_value, vmap from torch.nn.parallel import DistributedDataParallel as DDP from torchvision.datasets import CIFAR10 from tqdm import tqdm @@ -139,8 +141,6 @@ def train(args, model, train_loader, optimizer, privacy_engine, epoch, device): top1_acc = [] if args.grad_sample_mode == "no_op": - from functorch import grad_and_value, make_functional, vmap - # Functorch prepare fmodel, _fparams = make_functional(model) diff --git a/opacus/data_loader.py b/opacus/data_loader.py index 0dac2a75..f3b18233 100644 --- a/opacus/data_loader.py +++ b/opacus/data_loader.py @@ -13,7 +13,8 @@ # limitations under the License. import logging -from typing import Any, Optional, Sequence, Tuple, Type, Union +from functools import partial +from typing import Any, List, Optional, Sequence, Tuple, Type, Union import torch from opacus.utils.uniform_sampler import ( @@ -28,6 +29,40 @@ logger = logging.getLogger(__name__) +def collate( + batch: List[torch.Tensor], + *, + collate_fn: Optional[_collate_fn_t], + sample_empty_shapes: Sequence[Tuple], + dtypes: Sequence[Union[torch.dtype, Type]], +): + """ + Wraps `collate_fn` to handle empty batches. + + Default `collate_fn` implementations typically can't handle batches of length zero. + Since this is a possible case for poisson sampling, we need to wrap the collate + method, producing tensors with the correct shape and size (albeit the batch + dimension being zero-size) + + Args: + batch: List of tensort to be passed to collate_fn implementation + collate_fn: Collame method to be wrapped + sample_empty_shapes: Sample tensors with the expected shape + dtypes: Expected dtypes + + Returns: + Batch tensor(s) + """ + + if len(batch) > 0: + return collate_fn(batch) + else: + return [ + torch.zeros(shape, dtype=dtype) + for shape, dtype in zip(sample_empty_shapes, dtypes) + ] + + def wrap_collate_with_empty( *, collate_fn: Optional[_collate_fn_t], @@ -48,16 +83,12 @@ def wrap_collate_with_empty( the input batch is of size 0 """ - def collate(batch): - if len(batch) > 0: - return collate_fn(batch) - else: - return [ - torch.zeros(shape, dtype=dtype) - for shape, dtype in zip(sample_empty_shapes, dtypes) - ] - - return collate + return partial( + collate, + collate_fn=collate_fn, + sample_empty_shapes=sample_empty_shapes, + dtypes=dtypes, + ) def shape_safe(x: Any) -> Tuple: diff --git a/opacus/grad_sample/__init__.py b/opacus/grad_sample/__init__.py index a9b995f6..60b0403b 100644 --- a/opacus/grad_sample/__init__.py +++ b/opacus/grad_sample/__init__.py @@ -21,6 +21,7 @@ from .group_norm import compute_group_norm_grad_sample # noqa from .gsm_base import AbstractGradSampleModule from .gsm_exp_weights import GradSampleModuleExpandedWeights +from .gsm_no_op import GradSampleModuleNoOp from .instance_norm import compute_instance_norm_grad_sample # noqa from .layer_norm import compute_layer_norm_grad_sample # noqa from .linear import compute_linear_grad_sample # noqa @@ -30,6 +31,7 @@ __all__ = [ "GradSampleModule", "GradSampleModuleExpandedWeights", + "GradSampleModuleNoOp", "AbstractGradSampleModule", "register_grad_sampler", "create_or_accumulate_grad_sample", diff --git a/opacus/grad_sample/functorch.py b/opacus/grad_sample/functorch.py index 4a228e40..956141ff 100644 --- a/opacus/grad_sample/functorch.py +++ b/opacus/grad_sample/functorch.py @@ -1,5 +1,43 @@ +import copy + +import torch import torch.nn as nn from opacus.layers.dp_rnn import RNNLinear +from torch.func import grad, vmap + + +# https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf +def make_functional(mod: nn.Module, disable_autograd_tracking: bool = False): + """ + Helper method to mimic deprecated `functorch.make_functional()` behaviour. See + https://pytorch.org/docs/master/func.migrating.html + + Args: + mod: module to be converted to functional + disable_autograd_tracking: + + Returns: + Tuple with cloned model and new params + """ + params_dict = dict(mod.named_parameters()) + params_names = params_dict.keys() + params_values = tuple(params_dict.values()) + + stateless_mod = copy.deepcopy(mod) + stateless_mod.to("meta") + + if hasattr(stateless_mod, "allow_grad_accumulation"): + stateless_mod.allow_grad_accumulation() + + def fmodel(new_params_values, *args, **kwargs): + new_params_dict = { + name: value for name, value in zip(params_names, new_params_values) + } + return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs) + + if disable_autograd_tracking: + params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values) + return fmodel, params_values def prepare_layer(layer, batch_first=True): @@ -12,14 +50,13 @@ def prepare_layer(layer, batch_first=True): layer: the layer to prepare batch_first: whether the input is batch_first or not """ - from functorch import grad, make_functional, vmap - if len(list(layer.buffers())) > 0: raise NotImplementedError( "This layer has buffers and is not supported by Opacus" ) if type(layer) is nn.EmbeddingBag: raise NotImplementedError("Functorch does not support EmbeddingBag yet") + flayer, _ = make_functional(layer) def compute_loss_stateless_model(params, activations, backprops): @@ -33,7 +70,6 @@ def compute_loss_stateless_model(params, activations, backprops): output = flayer(params, batched_activations) loss = (output * batched_backprops).sum() - return loss ft_compute_grad = grad(compute_loss_stateless_model) diff --git a/opacus/grad_sample/grad_sample_module.py b/opacus/grad_sample/grad_sample_module.py index 309ce27b..b7e07491 100644 --- a/opacus/grad_sample/grad_sample_module.py +++ b/opacus/grad_sample/grad_sample_module.py @@ -135,6 +135,7 @@ def __init__( ) self.hooks_enabled = False + self.grad_accumulation_allowed = True self.batch_first = batch_first self.loss_reduction = loss_reduction self.force_functorch = force_functorch @@ -348,6 +349,14 @@ def capture_backprops_hook( if p._forward_counter == 0: promote_current_grad_sample(p) + if not self.grad_accumulation_allowed: + if isinstance(p.grad_sample, list) and len(p.grad_sample) > 1: + raise ValueError( + "Poisson sampling is not compatible with grad accumulation. " + "You need to call optimizer.step() after every forward/backward pass " + "or consider using BatchMemoryManager" + ) + if len(module.activations) == 0: if hasattr(module, "max_batch_len"): del module.max_batch_len @@ -474,6 +483,12 @@ def validate( else: return errors + def forbid_grad_accumulation(self): + self.grad_accumulation_allowed = False + + def allow_grad_accumulation(self): + self.grad_accumulation_allowed = True + def _get_batch_size(*, module: nn.Module, batch_dim: int) -> int: """ diff --git a/opacus/grad_sample/gsm_base.py b/opacus/grad_sample/gsm_base.py index d3ea5af0..f6947fae 100644 --- a/opacus/grad_sample/gsm_base.py +++ b/opacus/grad_sample/gsm_base.py @@ -15,9 +15,11 @@ import logging from abc import ABC, abstractmethod +from typing import Optional import torch.nn as nn from opacus.utils.module_utils import trainable_parameters +from torch.utils.hooks import RemovableHandle logger = logging.getLogger(__name__) @@ -58,6 +60,7 @@ def __init__( self._module = m self.batch_first = batch_first self.loss_reduction = loss_reduction + self.grad_accumulation_hook: Optional[RemovableHandle] = None for _, p in trainable_parameters(self): p.grad_sample = None @@ -139,3 +142,23 @@ def _clean_up_attributes(self): for p in self.parameters(): if hasattr(p, attr): delattr(p, attr) + + def forbid_grad_accumulation(self): + """ + Sets a flag to detect gradient accumulation (multiple forward/backward passes + without an optimizer step or clearing out gradients). + + When set, GradSampleModule will throw a ValueError on the second backward pass. + :return: + """ + pass + + def allow_grad_accumulation(self): + """ + Unsets a flag to detect gradient accumulation (multiple forward/backward passes + without an optimizer step or clearing out gradients). + + When set, GradSampleModule will throw a ValueError on the second backward pass. + :return: + """ + pass diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index 13886f0a..e9f59eb3 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -30,58 +30,12 @@ ) from opacus.optimizers import DPOptimizer, get_optimizer_class from opacus.schedulers import _GradClipScheduler, _NoiseScheduler -from opacus.utils.module_utils import trainable_parameters from opacus.validators.module_validator import ModuleValidator from torch import nn, optim from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader -def forbid_accumulation_hook( - module: AbstractGradSampleModule, - _grad_input: torch.Tensor, - _grad_output: torch.Tensor, -): - """ - Model hook that detects repetitive forward/backward passes between optimizer steps. - - This is a backward hook that will be wrapped around the whole model using - `register_backward_hook`. We wish to detect a case where: - - `optimizer.zero_grad()` is not called before the backward pass; and - - `p.grad_sample` was updated in a *previous* iteration. - - To do so, we attach a backward hook to the model that runs *after* the computation - of `grad_sample` for the current step. We compute the number of accumulated iterations - like on `optimizers/optimizer.py` and check whether it's strictly larger than one. - - Args: - module: input module - _grad_input: module input gradient (not used here) - _grad_output: module output gradient (not used here) - - Raises: - ValueError - If the hook detected multiple forward/backward passes between optimizer steps - - """ - if not module.training: - return - - for _, p in trainable_parameters(module): - if p.grad_sample is not None: - if isinstance(p.grad_sample, torch.Tensor): - accumulated_iterations = 1 - elif isinstance(p.grad_sample, list): - accumulated_iterations = len(p.grad_sample) - - if accumulated_iterations > 1: - raise ValueError( - "Poisson sampling is not compatible with grad accumulation. " - "You need to call optimizer.step() after every forward/backward pass " - "or consider using BatchMemoryManager" - ) - - class PrivacyEngine: """ Main entry point to the Opacus API - use ``PrivacyEngine`` to enable differential @@ -126,7 +80,6 @@ def __init__(self, *, accountant: str = "prv", secure_mode: bool = False): self.secure_mode = secure_mode self.secure_rng = None self.dataset = None # only used to detect switching to a different dataset - if self.secure_mode: try: import torchcsprng as csprng @@ -403,7 +356,7 @@ def make_private( grad_sample_mode=grad_sample_mode, ) if poisson_sampling: - module.register_backward_hook(forbid_accumulation_hook) + module.forbid_grad_accumulation() data_loader = self._prepare_data_loader( data_loader, distributed=distributed, poisson_sampling=poisson_sampling diff --git a/opacus/tests/grad_samples/common.py b/opacus/tests/grad_samples/common.py index fb6de5d0..f16cf3ec 100644 --- a/opacus/tests/grad_samples/common.py +++ b/opacus/tests/grad_samples/common.py @@ -59,7 +59,7 @@ def run_test( ): grad_sample_modes = ["hooks"] - if ew_compatible and batch_first and torch.__version__ >= (1, 13): + if ew_compatible and batch_first: grad_sample_modes += ["ew"] for loss_reduction in ["sum", "mean"]: diff --git a/opacus/tests/gradient_accumulation_test.py b/opacus/tests/gradient_accumulation_test.py index 297450fa..f603b16b 100644 --- a/opacus/tests/gradient_accumulation_test.py +++ b/opacus/tests/gradient_accumulation_test.py @@ -20,7 +20,7 @@ import torch.nn as nn import torch.nn.functional as F from opacus import PrivacyEngine -from opacus.grad_sample import GradSampleModule +from opacus.grad_sample.utils import get_gsm_class from opacus.utils.batch_memory_manager import BatchMemoryManager from torch.utils.data import DataLoader from torchvision import transforms @@ -68,6 +68,7 @@ def setUp(self): self.LR = 0 # we want to call optimizer.step() without modifying the model self.ALPHAS = [1 + x / 10.0 for x in range(1, 100, 10)] self.criterion = nn.CrossEntropyLoss() + self.GRAD_SAMPLE_MODE = "hooks" self.setUp_data() self.setUp_model_and_optimizer() @@ -130,7 +131,8 @@ def test_grad_sample_accumulation(self): Calling loss.backward() multiple times should sum up the gradients in .grad and accumulate all the individual gradients in .grad-sample """ - grad_sample_module = GradSampleModule(self.model) + gsm_class = get_gsm_class(self.GRAD_SAMPLE_MODE) + grad_sample_module = gsm_class(self.model) data_iter = iter(self.dl) # 4 batches of size 4 each self.model_forward_backward(grad_sample_module, data_iter, num_steps=8) # should accumulate grads in .grad and .grad_sample @@ -181,6 +183,7 @@ def test_privacy_engine_poisson_accumulation(self): data_loader=self.dl, noise_multiplier=0.0, max_grad_norm=999, + grad_sample_mode=self.GRAD_SAMPLE_MODE, ) self.model_forward_backward(model, dl, num_steps=1) @@ -197,6 +200,7 @@ def test_privacy_engine_no_poisson_accumulation(self): noise_multiplier=0.0, max_grad_norm=999, poisson_sampling=False, + grad_sample_mode=self.GRAD_SAMPLE_MODE, ) self.model_forward_backward(model, dl, num_steps=8) @@ -230,6 +234,7 @@ def test_privacy_engine_zero_grad(self): noise_multiplier=1.0, max_grad_norm=1.0, poisson_sampling=False, + grad_sample_mode=self.GRAD_SAMPLE_MODE, ) # should work fine with zero_grad @@ -252,6 +257,7 @@ def test_batch_splitter_zero_grad(self): noise_multiplier=1.0, max_grad_norm=1.0, poisson_sampling=False, + grad_sample_mode=self.GRAD_SAMPLE_MODE, ) with BatchMemoryManager( @@ -265,3 +271,9 @@ def test_batch_splitter_zero_grad(self): self.model_forward_backward( model, new_data_loader, optimizer, num_steps=3, do_zero_grad=False ) + + +class GradientAccumulationTestFunctorch(GradientAccumulationTest): + def setUp(self): + super().setUp() + self.GRAD_SAMPLE_MODE = "functorch" diff --git a/requirements.txt b/requirements.txt index 94d687da..7a591a56 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ numpy>=1.15 -torch>=1.13 +torch>=2.0 scipy>=1.2 opt-einsum>=3.3.0