Skip to content

Commit 8128275

Browse files
HuanyuZhangfacebook-github-bot
authored andcommitted
Validate/Replace nn.GRU by DPGRU
Summary: To solve Github Issue[#783]. Add support of nn.GRU of the moduel validator Reviewed By: iden-kalemaj Differential Revision: D82027020 fbshipit-source-id: eb84efadd1a4e7048e4abd2c255f0fe219e74f74
1 parent a077079 commit 8128275

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import torch.nn as nn
19+
from opacus.layers import DPGRU
20+
from opacus.utils.module_utils import are_state_dict_equal
21+
from opacus.validators.errors import ShouldReplaceModuleError
22+
from opacus.validators.module_validator import ModuleValidator
23+
24+
25+
class GRUValidator_test(unittest.TestCase):
26+
def setUp(self) -> None:
27+
self.gru = nn.GRU(8, 4)
28+
self.mv = ModuleValidator.VALIDATORS
29+
self.mf = ModuleValidator.FIXERS
30+
31+
def test_validate(self) -> None:
32+
val_gru = self.mv[type(self.gru)](self.gru)
33+
self.assertEqual(len(val_gru), 1)
34+
self.assertTrue(isinstance(val_gru[0], ShouldReplaceModuleError))
35+
36+
def test_fix(self) -> None:
37+
fix_gru = self.mf[type(self.gru)](self.gru)
38+
self.assertTrue(isinstance(fix_gru, DPGRU))
39+
self.assertTrue(
40+
are_state_dict_equal(self.gru.state_dict(), fix_gru.state_dict())
41+
)

opacus/validators/gru.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import List
17+
18+
import torch.nn as nn
19+
from opacus.layers import DPGRU
20+
21+
from .errors import ShouldReplaceModuleError, UnsupportedModuleError
22+
from .utils import register_module_fixer, register_module_validator
23+
24+
25+
@register_module_validator(nn.GRU)
26+
def validate(module: nn.GRU) -> List[UnsupportedModuleError]:
27+
return [
28+
ShouldReplaceModuleError(
29+
"We do not support nn.GRU because its implementation uses special "
30+
"modules. We have written a GRU class that is a drop-in replacement "
31+
"which is compatible with our Grad Sample hooks. Please run the recommended "
32+
"replacement!"
33+
)
34+
]
35+
36+
37+
@register_module_fixer(nn.GRU)
38+
def fix(module: nn.GRU) -> DPGRU:
39+
dpgru = DPGRU(
40+
input_size=module.input_size,
41+
hidden_size=module.hidden_size,
42+
num_layers=module.num_layers,
43+
bias=module.bias,
44+
batch_first=module.batch_first,
45+
dropout=module.dropout,
46+
bidirectional=module.bidirectional,
47+
proj_size=module.proj_size,
48+
)
49+
dpgru.load_state_dict(module.state_dict())
50+
return dpgru

0 commit comments

Comments
 (0)