Skip to content

Commit

Permalink
Functorch gradients: investigation and fix (#510)
Browse files Browse the repository at this point in the history
Summary:
*The investigation part for this PR was done by alexandresablayrolles, thanks for figuring out the reason the tests were failing*

## Background
Current implementation of functorch-based per sample gradients fails on modules which have both trainable non-recursive parameters and standard submodules, e.g. below
```
class LinearWithExtraParam(nn.Module):
    def __init__(self, in_features: int, out_features: int, hidden_dim: int = 8):
        super().__init__()
        self.fc = nn.Linear(in_features, hidden_dim)
        self.extra_param = nn.Parameter(torch.randn(hidden_dim, out_features))

    def forward(self, x):
        x = self.fc(x)
        x = x.matmul(self.extra_param)
        return x
```

The reason is - functorch hook actually computes gradients for recursive submodules too. The problem is, normal hooks are also attached to these submodules. GradSampleModule then sees two grad_sample tensors, thinks it needs to accumulate and adds them up together

## Solution(s)

There are essentially two ways we can fix this: either make functorch compute per sample gradients for non-recursive parameters only or don't attach normal hooks to submodules where the parent module is handled by functorch.

This diff implements the latter option (reasoning below), for demo purposes the former option can be seen in #531

For the pure code perspective the former option (let's call it "non-recursive functorch") is more appealing to me. It better fits the existing paradigm and matches normal hooks behaviour - all of the existing code only deals with the immediate non-recursive parameters.
However, it doesn't make much sense from the efficiency perspective. "non-recursive functorch" would do all the work to compute per-sample gradients for its submodules, only for them to be filtered out at the very last stage.
Alternative option (a.k.a. "functorch for subtrees") does involve a bit more convoluted

This has a noticeable effect on performance.
Below is the results of MNIST benchmarks with different configurations. I've tested this with different configurations, because at the end of the day, the impact on performance depends on how deep are subtrees

* Standard model- our model from MNIST example, standard layers only (2 conv + 2 linear). No overhead expected, functorch doesn't kick in
* Mid-level model - leaf nodes (two linear layers) have one extra param and are computed with functorch. Overhead: 2x Linear hook
* Extreme model - root model have one extra param and needs to be handled by functorch. Overhead: 2x linear hook + 2x conv hook

| Mode                               | non-recursive functorch | functorch for subtrees |
|:-----------------------:|:------------------------:|:-----------------------:|
| Standard model (CPU)  |  138s                                 | 136s                                |
| Standard model (GPU)  |  149s                                 | 150s                                |
| Mid-level model (CPU)  |  157s                                 | 150s                                |
| Mid-level model (GPU)  |  100s                                 | 97s                                 |
| Extreme model (CPU)    |  207s                                 | 172s                               |
| Extreme model (GPU)    |  101s                                 | 94s                                |

Pull Request resolved: #510

Reviewed By: alexandresablayrolles

Differential Revision: D39579487

Pulled By: ffuuugor

fbshipit-source-id: 1b089bd04ab110174a1f2ebb371380eb2ce76054
  • Loading branch information
Igor Shilov authored and facebook-github-bot committed Oct 28, 2022
1 parent 9ff6839 commit 7393ae4
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 66 deletions.
2 changes: 1 addition & 1 deletion opacus/grad_sample/functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def ft_compute_per_sample_gradient(layer, activations, backprops):
activations: the input to the layer
backprops: the gradient of the loss w.r.t. outputs of the layer
"""
parameters = list(layer.parameters())
parameters = list(layer.parameters(recurse=True))
if not hasattr(layer, "ft_compute_sample_grad"):
prepare_layer(layer)

Expand Down
20 changes: 18 additions & 2 deletions opacus/grad_sample/grad_sample_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
import logging
import warnings
from functools import partial
from typing import List, Tuple
from typing import Iterable, List, Tuple

import torch
import torch.nn as nn
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient, prepare_layer
from opacus.grad_sample.gsm_base import AbstractGradSampleModule
from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN, RNNLinear
from opacus.utils.module_utils import (
has_trainable_params,
requires_grad,
trainable_modules,
trainable_parameters,
Expand Down Expand Up @@ -146,6 +147,21 @@ def __init__(
def forward(self, *args, **kwargs):
return self._module(*args, **kwargs)

def iterate_submodules(self, module: nn.Module) -> Iterable[nn.Module]:
if has_trainable_params(module):
yield module

# Don't recurse if module is handled by functorch
if (
has_trainable_params(module)
and type(module) not in self.GRAD_SAMPLERS
and type(module) not in [DPRNN, DPLSTM, DPGRU]
):
return

for m in module.children():
yield from self.iterate_submodules(m)

def add_hooks(
self,
*,
Expand Down Expand Up @@ -177,7 +193,7 @@ def add_hooks(
self._module.autograd_grad_sample_hooks = []
self.autograd_grad_sample_hooks = self._module.autograd_grad_sample_hooks

for _module_name, module in trainable_modules(self._module):
for module in self.iterate_submodules(self._module):
# Do not add hooks to DPRNN, DPLSTM or DPGRU as the hooks are handled by the `RNNLinear`
if type(module) in [DPRNN, DPLSTM, DPGRU]:
continue
Expand Down
85 changes: 71 additions & 14 deletions opacus/tests/privacy_engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@
from torchvision import models, transforms
from torchvision.datasets import FakeData

from .utils import CustomLinearModule, LinearWithExtraParam


def _is_functorch_available():
try:
# flake8: noqa F401
import functorch

return True
except ImportError:
return False


def get_grad_sample_aggregated(tensor: torch.Tensor, loss_type: str = "mean"):
if tensor.grad_sample is None:
Expand Down Expand Up @@ -246,7 +258,7 @@ def _compare_to_vanilla(
# vanilla gradient is nearly zero: will match even with clipping
continue

atol = 1e-7 if max_steps == 1 else 1e-5
atol = 1e-7 if max_steps == 1 else 1e-4
self.assertEqual(
torch.allclose(vp, pp, atol=atol, rtol=1e-3),
expected_match,
Expand All @@ -265,10 +277,6 @@ def _compare_to_vanilla(
do_noise=st.booleans(),
use_closure=st.booleans(),
max_steps=st.sampled_from([1, 4]),
# do_clip=st.just(False),
# do_noise=st.just(False),
# use_closure=st.just(False),
# max_steps=st.sampled_from([4]),
)
@settings(deadline=None)
def test_compare_to_vanilla(
Expand Down Expand Up @@ -799,9 +807,7 @@ def _init_data(self):
)
return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False)

def _init_model(
self, private=False, state_dict=None, model=None, **privacy_engine_kwargs
):
def _init_model(self):
return SampleConvNet()


Expand All @@ -817,16 +823,21 @@ def _init_data(self):
)
return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False)

def _init_model(
self, private=False, state_dict=None, model=None, **privacy_engine_kwargs
):
def _init_model(self):
m = SampleConvNet()
for p in itertools.chain(m.conv1.parameters(), m.gnorm1.parameters()):
p.requires_grad = False

return m


@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version")
class PrivacyEngineConvNetFrozenTestFunctorch(PrivacyEngineConvNetFrozenTest):
def setUp(self):
super().setUp()
self.GRAD_SAMPLE_MODE = "functorch"


@unittest.skipIf(
torch.__version__ < API_CUTOFF_VERSION, "not supported in this torch version"
)
Expand All @@ -840,6 +851,13 @@ def test_sample_grad_aggregation(self):
pass


@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version")
class PrivacyEngineConvNetTestFunctorch(PrivacyEngineConvNetTest):
def setUp(self):
super().setUp()
self.GRAD_SAMPLE_MODE = "functorch"


class SampleAttnNet(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -919,6 +937,13 @@ def _init_model(
return SampleAttnNet()


@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version")
class PrivacyEngineTextTestFunctorch(PrivacyEngineTextTest):
def setUp(self):
super().setUp()
self.GRAD_SAMPLE_MODE = "functorch"


class SampleTiedWeights(nn.Module):
def __init__(self, tie=True):
super().__init__()
Expand Down Expand Up @@ -958,7 +983,39 @@ def _init_data(self):
)
return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False)

def _init_model(
self, private=False, state_dict=None, model=None, **privacy_engine_kwargs
):
def _init_model(self):
return SampleTiedWeights(tie=True)


@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version")
class PrivacyEngineTiedWeightsTestFunctorch(PrivacyEngineTiedWeightsTest):
def setUp(self):
super().setUp()
self.GRAD_SAMPLE_MODE = "functorch"


class ModelWithCustomLinear(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = CustomLinearModule(4, 8)
self.fc2 = LinearWithExtraParam(8, 4)
self.extra_param = nn.Parameter(torch.randn(4, 4))

def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = x.matmul(self.extra_param)
return x


@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version")
class PrivacyEngineCustomLayerTest(BasePrivacyEngineTest, unittest.TestCase):
def _init_data(self):
ds = TensorDataset(
torch.randn(self.DATA_SIZE, 4),
torch.randint(low=0, high=3, size=(self.DATA_SIZE,)),
)
return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False)

def _init_model(self):
return ModelWithCustomLinear()
54 changes: 6 additions & 48 deletions opacus/tests/privacy_engine_validation_test.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,16 @@
import unittest

import torch
import torch.nn as nn
import torch.nn.functional as F
from opacus import PrivacyEngine
from opacus.grad_sample.gsm_exp_weights import API_CUTOFF_VERSION
from torch.utils.data import DataLoader


class BasicSupportedModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv1d(in_channels=16, out_channels=8, kernel_size=2)
self.gn = nn.GroupNorm(num_groups=2, num_channels=8)
self.fc = nn.Linear(in_features=4, out_features=8)
self.ln = nn.LayerNorm([8, 8])

def forward(self, x):
x = self.conv(x)
x = self.gn(x)
x = self.fc(x)
x = self.ln(x)
return x


class CustomLinearModule(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self._weight = nn.Parameter(torch.randn(out_features, in_features))
self._bias = nn.Parameter(torch.randn(out_features))

def forward(self, x):
return F.linear(x, self._weight, self._bias)


class MatmulModule(nn.Module):
def __init__(self, input_features, output_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(input_features, output_features))

def forward(self, x):
return torch.matmul(x, self.weight)


class LinearWithExtraParam(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.fc = nn.Linear(in_features, out_features)
self.extra_param = nn.Parameter(torch.randn(out_features, 2))

def forward(self, x):
x = self.fc(x)
x = x.matmul(self.extra_param)
return x
from .utils import (
BasicSupportedModule,
CustomLinearModule,
LinearWithExtraParam,
MatmulModule,
)


class PrivacyEngineValidationTest(unittest.TestCase):
Expand Down
50 changes: 50 additions & 0 deletions opacus/tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicSupportedModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv1d(in_channels=16, out_channels=8, kernel_size=2)
self.gn = nn.GroupNorm(num_groups=2, num_channels=8)
self.fc = nn.Linear(in_features=4, out_features=8)
self.ln = nn.LayerNorm([8, 8])

def forward(self, x):
x = self.conv(x)
x = self.gn(x)
x = self.fc(x)
x = self.ln(x)
return x


class CustomLinearModule(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self._weight = nn.Parameter(torch.randn(out_features, in_features))
self._bias = nn.Parameter(torch.randn(out_features))

def forward(self, x):
return F.linear(x, self._weight, self._bias)


class MatmulModule(nn.Module):
def __init__(self, input_features: int, output_features: int):
super().__init__()
self.weight = nn.Parameter(torch.randn(input_features, output_features))

def forward(self, x):
return torch.matmul(x, self.weight)


class LinearWithExtraParam(nn.Module):
def __init__(self, in_features: int, out_features: int, hidden_dim: int = 8):
super().__init__()
self.fc = nn.Linear(in_features, hidden_dim)
self.extra_param = nn.Parameter(torch.randn(hidden_dim, out_features))

def forward(self, x):
x = self.fc(x)
x = x.matmul(self.extra_param)
return x
6 changes: 5 additions & 1 deletion opacus/utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
logger.setLevel(level=logging.INFO)


def parametrized_modules(module: nn.Module) -> Iterable[nn.Module]:
def has_trainable_params(module: nn.Module) -> bool:
return any(p.requires_grad for p in module.parameters(recurse=False))


def parametrized_modules(module: nn.Module) -> Iterable[Tuple[str, nn.Module]]:
"""
Recursively iterates over all submodules, returning those that
have parameters (as opposed to "wrapper modules" that just organize modules).
Expand Down

0 comments on commit 7393ae4

Please sign in to comment.