From ec46af71ffa4d90a423d2dc3af89e9eb64146bca Mon Sep 17 00:00:00 2001 From: Vivswan Shah <58091053+Vivswan@users.noreply.github.com> Date: Fri, 6 Jan 2023 14:23:47 -0500 Subject: [PATCH 1/9] Added super init to Module Added super init to Module for user modules derived from multiple python classes --- torch/nn/modules/module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 5275c9abd29c1..34a74c578ee8b 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -461,6 +461,7 @@ def __init__(self) -> None: super().__setattr__('_load_state_dict_pre_hooks', OrderedDict()) super().__setattr__('_load_state_dict_post_hooks', OrderedDict()) super().__setattr__('_modules', OrderedDict()) + super().__init__() forward: Callable[..., Any] = _forward_unimplemented From e60b22dede072ef04999ca9bd199898976f161ed Mon Sep 17 00:00:00 2001 From: Vivswan Shah <58091053+Vivswan@users.noreply.github.com> Date: Wed, 11 Jan 2023 15:43:52 -0500 Subject: [PATCH 2/9] added args, kwargs to __init__ --- torch/nn/modules/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 34a74c578ee8b..e64d299a988b7 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -433,7 +433,7 @@ def forward(self, x): _load_state_dict_post_hooks: Dict[int, Callable] _modules: Dict[str, Optional['Module']] - def __init__(self) -> None: + def __init__(self, *args, **kwargs) -> None: """ Initializes internal Module state, shared by both nn.Module and ScriptModule. """ @@ -461,7 +461,7 @@ def __init__(self) -> None: super().__setattr__('_load_state_dict_pre_hooks', OrderedDict()) super().__setattr__('_load_state_dict_post_hooks', OrderedDict()) super().__setattr__('_modules', OrderedDict()) - super().__init__() + super().__init__(*args, **kwargs) forward: Callable[..., Any] = _forward_unimplemented From 5e44d5633a4e7f20f7b341cd4a72074cddad356a Mon Sep 17 00:00:00 2001 From: Vivswan Shah <58091053+Vivswan@users.noreply.github.com> Date: Wed, 11 Jan 2023 16:40:46 -0500 Subject: [PATCH 3/9] added call_super_init default value for call_super_init is False --- torch/nn/modules/module.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index e64d299a988b7..589fb46586d09 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -432,7 +432,8 @@ def forward(self, x): _state_dict_pre_hooks: Dict[int, Callable] _load_state_dict_post_hooks: Dict[int, Callable] _modules: Dict[str, Optional['Module']] - + call_super_init: bool = False + def __init__(self, *args, **kwargs) -> None: """ Initializes internal Module state, shared by both nn.Module and ScriptModule. @@ -461,7 +462,9 @@ def __init__(self, *args, **kwargs) -> None: super().__setattr__('_load_state_dict_pre_hooks', OrderedDict()) super().__setattr__('_load_state_dict_post_hooks', OrderedDict()) super().__setattr__('_modules', OrderedDict()) - super().__init__(*args, **kwargs) + + if Module.call_super_init: + super().__init__(*args, **kwargs) forward: Callable[..., Any] = _forward_unimplemented From 03b24eb3c90bf51af38eade867f2d1d8c531200e Mon Sep 17 00:00:00 2001 From: Vivswan Shah <58091053+Vivswan@users.noreply.github.com> Date: Wed, 11 Jan 2023 21:10:32 -0500 Subject: [PATCH 4/9] Module to self.call_super_init --- torch/nn/modules/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 589fb46586d09..778f4efe9a6ac 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -463,7 +463,7 @@ def __init__(self, *args, **kwargs) -> None: super().__setattr__('_load_state_dict_post_hooks', OrderedDict()) super().__setattr__('_modules', OrderedDict()) - if Module.call_super_init: + if self.call_super_init: super().__init__(*args, **kwargs) forward: Callable[..., Any] = _forward_unimplemented From 3f5049a93e9369752639858b4c0357b48e33ed67 Mon Sep 17 00:00:00 2001 From: Vivswan Shah <58091053+Vivswan@users.noreply.github.com> Date: Wed, 11 Jan 2023 21:27:10 -0500 Subject: [PATCH 5/9] to follow the convention --- torch/nn/modules/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 778f4efe9a6ac..55ce2506b5174 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -464,7 +464,7 @@ def __init__(self, *args, **kwargs) -> None: super().__setattr__('_modules', OrderedDict()) if self.call_super_init: - super().__init__(*args, **kwargs) + super(Module, self).__init__(*args, **kwargs) forward: Callable[..., Any] = _forward_unimplemented From 8376d9efba6b06c8ce5bc97cb5f38472442619cd Mon Sep 17 00:00:00 2001 From: Vivswan Shah <58091053+Vivswan@users.noreply.github.com> Date: Thu, 12 Jan 2023 14:26:38 -0500 Subject: [PATCH 6/9] lint: removed trailing spaces --- torch/nn/modules/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 55ce2506b5174..411be092aad71 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -433,7 +433,7 @@ def forward(self, x): _load_state_dict_post_hooks: Dict[int, Callable] _modules: Dict[str, Optional['Module']] call_super_init: bool = False - + def __init__(self, *args, **kwargs) -> None: """ Initializes internal Module state, shared by both nn.Module and ScriptModule. From 0258b3b30a73b2883b2b2dd1074d52a358251b22 Mon Sep 17 00:00:00 2001 From: Vivswan Shah <58091053+Vivswan@users.noreply.github.com> Date: Fri, 20 Jan 2023 13:14:02 -0500 Subject: [PATCH 7/9] added test_module_super_init --- test/test_nn.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/test_nn.py b/test/test_nn.py index 4118a5cac71a9..31a967296de0a 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -164,6 +164,35 @@ def test_module_backcompat(self): input = torch.randn(2, 3, dtype=torch.float) self.assertEqual(m(input).size(), (2, 5)) + def test_module_super_init(self): + class MyMixin: + def __init__(self, *a, **kw): + super().__init__(*a, **kw) + self.mixin_init = True + + class MyModuleWithMixinBefore(MyMixin, nn.Module): + def __init__(self): + super().__init__() + + class MyModuleWithMixinAfter(nn.Module, MyMixin): + def __init__(self): + super().__init__() + + self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init')) + self.assertFalse(hasattr(MyModuleWithMixinAfter(), 'mixin_init')) + + nn.Module.call_super_init = True + self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init')) + self.assertTrue(hasattr(MyModuleWithMixinAfter(), 'mixin_init')) + nn.Module.call_super_init = False + + MyModuleWithMixinBefore.call_super_init = True + MyModuleWithMixinAfter.call_super_init = True + self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init')) + self.assertTrue(hasattr(MyModuleWithMixinAfter(), 'mixin_init')) + MyModuleWithMixinBefore.call_super_init = False + MyModuleWithMixinAfter.call_super_init = False + def test_share_memory(self): class Net(nn.Module): def __init__(self): From 4e475fd2c9433f44e255eb85fe5a54f5f2f0716a Mon Sep 17 00:00:00 2001 From: Vivswan Shah <58091053+Vivswan@users.noreply.github.com> Date: Tue, 24 Jan 2023 12:33:13 -0500 Subject: [PATCH 8/9] added TypeError for BC --- torch/nn/modules/module.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 411be092aad71..ff1548deb264c 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -440,6 +440,13 @@ def __init__(self, *args, **kwargs) -> None: """ torch._C._log_api_usage_once("python.nn_module") + # Backward compatibility: no args used to be allowed when call_super_init=False + if self.call_super_init is False and bool(kwargs): + raise TypeError("{}.__init__() got an unexpected keyword argument '{}'".format(type(self).__name__, next(iter(kwargs)))) + + if self.call_super_init is False and bool(args): + raise TypeError("{}.__init__() takes 1 positional argument but {} were given".format(type(self).__name__, len(args) + 1)) + """ Calls super().__setattr__('a', a) instead of the typical self.a = a to avoid Module.__setattr__ overhead. Module's __setattr__ has special From 7fe1b15d14e52a47421bf06cc3827b06d1ef36ca Mon Sep 17 00:00:00 2001 From: Vivswan Shah <58091053+Vivswan@users.noreply.github.com> Date: Tue, 24 Jan 2023 12:38:19 -0500 Subject: [PATCH 9/9] corrected line width --- torch/nn/modules/module.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index ff1548deb264c..59e0336234372 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -442,10 +442,12 @@ def __init__(self, *args, **kwargs) -> None: # Backward compatibility: no args used to be allowed when call_super_init=False if self.call_super_init is False and bool(kwargs): - raise TypeError("{}.__init__() got an unexpected keyword argument '{}'".format(type(self).__name__, next(iter(kwargs)))) + raise TypeError("{}.__init__() got an unexpected keyword argument '{}'" + "".format(type(self).__name__, next(iter(kwargs)))) if self.call_super_init is False and bool(args): - raise TypeError("{}.__init__() takes 1 positional argument but {} were given".format(type(self).__name__, len(args) + 1)) + raise TypeError("{}.__init__() takes 1 positional argument but {} were" + " given".format(type(self).__name__, len(args) + 1)) """ Calls super().__setattr__('a', a) instead of the typical self.a = a