Skip to content

Commit

Permalink
add deprecation warning to nn stateless functional_call (#87367)
Browse files Browse the repository at this point in the history
Same as the release version but just for master

Pull Request resolved: #87367
Approved by: https://github.com/albanD, https://github.com/atalman
  • Loading branch information
samdow authored and pytorchmergebot committed Oct 20, 2022
1 parent 9b88dcf commit bc8cf33
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
31 changes: 31 additions & 0 deletions test/test_stateless.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,37 @@ def test_reparamertize_module_fail_reset_to_original(self):
self.assertEqual(orig_sn_weight, module.l1.weight)


def test_tied_weights_warns(self):
module = MockModule()
module.tied_bias = module.l1.bias
module.register_buffer("tied_buffer", module.buffer)
weight = torch.tensor([[1.0]],)
bias = torch.tensor([0.0])
buffer = torch.tensor([0.0])

parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x))

# if tied values are the same tensors, shouldn't warn
parameters['tied_bias'] = bias
parameters['tied_buffer'] = buffer
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x))
del parameters['tied_bias']
del parameters['tied_buffer']

with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
parameters['tied_bias'] = torch.tensor([5.0])
stateless.functional_call(module, parameters, x)
del parameters['tied_bias']

with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
parameters['tied_buffer'] = torch.tensor([5.0])
stateless.functional_call(module, parameters, x)


def test_setattr(self):
class Foo(torch.nn.Module):
def __init__(self):
Expand Down
18 changes: 16 additions & 2 deletions torch/nn/utils/stateless.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
import contextlib
from typing import Any, Callable, Dict, Iterator, List, Tuple

Expand Down Expand Up @@ -35,10 +36,22 @@ def _setattr(self, name: str, value: Any) -> None:
module.__class__ = param_cls
module._orig_class = cls

def _create_swap_params(params_and_buffers):

def _check_tied_val_already_replaced(old_val, new_val, replaced_tensors_map):
if old_val not in replaced_tensors_map:
replaced_tensors_map[old_val] = new_val
elif replaced_tensors_map[old_val] is not new_val:
warnings.warn("functional_call was passed multiple values for tied weights. "
"This behavior is deprecated and will be an error in future versions")


def _create_swap_params(params_and_buffers, replaced_tensors_map):
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:
# Changes the module class to get a new __getattr__ dunder method
# that looks for the reparametrized tensor
if hasattr(module, tensor_name):
old_val = getattr(module, tensor_name)
_check_tied_val_already_replaced(old_val, tensor, replaced_tensors_map)
if hasattr(module, "_attr_to_path"):
module._attr_to_path[tensor_name] = full_path
else:
Expand All @@ -60,9 +73,10 @@ def _reparametrize_module(
module: 'torch.nn.Module',
parameters_and_buffers: Dict[str, Tensor],
) -> Iterator[None]:
orig_tensors_to_replacements: Dict[Tensor, Tensor] = {}
for name, tensor in parameters_and_buffers.items():
_apply_func_submodules(
_create_swap_params(parameters_and_buffers),
_create_swap_params(parameters_and_buffers, orig_tensors_to_replacements),
module, name.split("."), name, (tensor,))
try:
yield
Expand Down

0 comments on commit bc8cf33

Please sign in to comment.