Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18364,6 +18364,32 @@ def _check_lazy_conv_state(self, gen_module, gen_lazy_module,
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
module.load_state_dict(lazy_module.state_dict())

def test_lazy_forward_keyword_arguments(self):
class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
def __init__(self):
super().__init__()
self.c = 100

def initialize_parameters(self, a, b=1):
if b == 1:
self.c = 100
else:
self.c = 200

def forward(self, a, b=1):
return a + b + self.c

m = TestModule()
self.assertEqual(102, m(a=1))
self.assertEqual(103, m(a=1, b=2))

m = TestModule()
self.assertEqual(203, m(1, 2))
self.assertEqual(202, m(a=1, b=1))

m = TestModule()
self.assertEqual(203, m(a=1, b=2))
self.assertEqual(202, m(1, 1))

def test_lazy_pre_forward_hook(self):
"""
Expand Down
6 changes: 3 additions & 3 deletions torch/nn/modules/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(self: _LazyProtocol, *args, **kwargs):
# Mypy doesnt like this super call in a mixin
super().__init__(*args, **kwargs) # type: ignore[misc]
self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
self._initialize_hook = self.register_forward_pre_hook(self._infer_parameters)
self._initialize_hook = self.register_forward_pre_hook(self._infer_parameters, with_kwargs=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be BC-breaking no?
Even for simple things like the LazyLinear, this will make it fail if any argument is passed as a kwargs.

warnings.warn('Lazy modules are a new feature under heavy development '
'so changes to the API or functionality can happen at any moment.')

Expand Down Expand Up @@ -235,7 +235,7 @@ def has_uninitialized_params(self: _LazyProtocol):
return True
return False

def _infer_parameters(self: _LazyProtocol, module, input):
def _infer_parameters(self: _LazyProtocol, module, input, kwargs):
r"""Infers the size and initializes the parameters according to the
provided input batch.
Given a module that contains parameters that were declared inferrable
Expand All @@ -245,7 +245,7 @@ def _infer_parameters(self: _LazyProtocol, module, input):
The module is set into evaluation mode before running the forward pass in order
to avoid saving statistics or calculating gradients
"""
module.initialize_parameters(*input)
module.initialize_parameters(*input, **kwargs)
if module.has_uninitialized_params():
raise RuntimeError('module {} has not been fully initialized'.format(self._get_name()))
module._initialize_hook.remove()
Expand Down
59 changes: 47 additions & 12 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,41 @@ def __repr__(self):
__str__ = __repr__


class _PreForwardHookWrapper:
def __init__(self, hook: Callable[..., None], with_kwargs=False) -> None:
self.hook = hook
self.with_kwargs = with_kwargs

def _call_impl(self, module, input, kwargs):
if self.with_kwargs:
return self.hook(module, input, kwargs)
else:
return self.hook(module, input)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need to fix the return type here to make sure both the updated input and kwargs are returned.


__call__ : Callable[..., Any] = _call_impl


class _ForwardHookWrapper:
def __init__(self, hook: Callable[..., None], with_kwargs=False) -> None:
self.hook = hook
self.with_kwargs = with_kwargs

def _call_impl(self, module, input, kwargs, result):
print("self module input kwargs result")
print(self)
print(module)
print(input)
print(kwargs)
print(result)
if self.with_kwargs:
return self.hook(module, input, kwargs, result)
else:
return self.hook(module, input, result)

__call__ : Callable[..., Any] = _call_impl



def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
Expand All @@ -49,7 +84,7 @@ def _addindent(s_, numSpaces):
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'


def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
def register_module_forward_pre_hook(hook: Callable[..., None], with_kwargs=False) -> RemovableHandle:
r"""Registers a forward pre-hook common to all modules.

.. warning ::
Expand Down Expand Up @@ -77,11 +112,11 @@ def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHand
``handle.remove()``
"""
handle = hooks.RemovableHandle(_global_forward_pre_hooks)
_global_forward_pre_hooks[handle.id] = hook
_global_forward_pre_hooks[handle.id] = _PreForwardHookWrapper(hook, with_kwargs)
return handle


def register_module_forward_hook(hook: Callable[..., None]) -> RemovableHandle:
def register_module_forward_hook(hook: Callable[..., None], with_kwargs=False) -> RemovableHandle:
r"""Registers a global forward hook for all the modules

.. warning ::
Expand Down Expand Up @@ -109,7 +144,7 @@ def register_module_forward_hook(hook: Callable[..., None]) -> RemovableHandle:
``register_forward_hook``.
"""
handle = hooks.RemovableHandle(_global_forward_hooks)
_global_forward_hooks[handle.id] = hook
_global_forward_hooks[handle.id] = _ForwardHookWrapper(hook, with_kwargs)
return handle

def register_module_backward_hook(
Expand Down Expand Up @@ -1027,7 +1062,7 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn):
"some grad_input. Please use register_full_backward_hook to get the documented "
"behavior.")

def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
def register_forward_pre_hook(self, hook: Callable[..., None], with_kwargs=False) -> RemovableHandle:
r"""Registers a forward pre-hook on the module.

The hook will be called every time before :func:`forward` is invoked.
Expand All @@ -1047,10 +1082,10 @@ def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandl
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._forward_pre_hooks)
self._forward_pre_hooks[handle.id] = hook
self._forward_pre_hooks[handle.id] = _PreForwardHookWrapper(hook, with_kwargs)
return handle

def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
def register_forward_hook(self, hook: Callable[..., None], with_kwargs=False) -> RemovableHandle:
r"""Registers a forward hook on the module.

The hook will be called every time after :func:`forward` has computed an output.
Expand All @@ -1070,7 +1105,7 @@ def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._forward_hooks)
self._forward_hooks[handle.id] = hook
self._forward_hooks[handle.id] = _ForwardHookWrapper(hook, with_kwargs)
return handle

def _slow_forward(self, *input, **kwargs):
Expand Down Expand Up @@ -1105,8 +1140,8 @@ def _call_impl(self, *input, **kwargs):
if self._backward_hooks or _global_backward_hooks:
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
if _global_forward_pre_hooks or self._forward_pre_hooks:
for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
result = hook(self, input)
for hook_wrapper in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
result = hook_wrapper(self, input, kwargs)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
Expand All @@ -1119,8 +1154,8 @@ def _call_impl(self, *input, **kwargs):

result = forward_call(*input, **kwargs)
if _global_forward_hooks or self._forward_hooks:
for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
hook_result = hook(self, input, result)
for hook_wrapper in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
hook_result = hook_wrapper(self, input, kwargs, result)
if hook_result is not None:
result = hook_result

Expand Down