Skip to content

Commit

Permalink
[JIT] Add __prepare_scriptable__ duck typing to allow replacing nn.…
Browse files Browse the repository at this point in the history
…modules with scriptable preparations (#45645)

Summary:
Fixes #45072

As discussed with zdevito gchanan cpuhrsch and suo, this change allows developers to create custom preparations for their modules before scripting. This is done by adding a `__prepare_scriptable__` method to a module which returns the prepared scriptable module out-of-place. It does not expand the API surface for end users.

Prior art by jamesr66a: #42244

cc: zhangguanheng66

Reviewed By: dongreenberg, ngimel

Differential Revision: D24039990

Pulled By: zhangguanheng66

fbshipit-source-id: 4ddff2d353124af9c2ef22db037df7e3d26efe65

ghstack-source-id: 8cf5a7723a1fbde418800d55389ac2f588bb05bd
Pull Request resolved: #49242
  • Loading branch information
Meghan Lele committed Dec 11, 2020
1 parent f4226b5 commit 7f4f975
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
44 changes: 44 additions & 0 deletions test/jit/test_recursive_script.py
Expand Up @@ -495,6 +495,50 @@ def forward(self, x):

self.checkModule(M(), (torch.randn(5, 5),))

def test_prepare_scriptable_basic(self):
class SeluButReluWhenScripted(torch.nn.SELU):
def __prepare_scriptable__(self):
return nn.ReLU()

t = torch.randn(5, 5)
m = SeluButReluWhenScripted()
sm = torch.jit.script(m)
eager_out = m(t)
script_out = sm(t)
self.assertNotEqual(eager_out, script_out)

def test_prepare_scriptable_iterable_modules(self):
class SeluButReluWhenScripted(torch.nn.SELU):
def __prepare_scriptable__(self):
return nn.ReLU()

class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
shared = SeluButReluWhenScripted()
self.sequential = nn.Sequential(
SeluButReluWhenScripted(),
SeluButReluWhenScripted(),
nn.Sequential(SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()),
shared,
)
self.module_list = nn.ModuleList([SeluButReluWhenScripted(),
shared,
SeluButReluWhenScripted()])

def forward(self, x):
for mod in self.module_list:
x += mod(x)
x += self.sequential(x)
return x

t = torch.randn(5, 5)
m = M()
eager_out = m(t.clone())
sm = torch.jit.script(m)
script_out = sm(t.clone())
self.assertNotEqual(eager_out, script_out)

def test_attributes(self):
@torch.jit.script
class Inner2(object):
Expand Down
26 changes: 26 additions & 0 deletions test/jit/test_torchbind.py
Expand Up @@ -62,6 +62,32 @@ def f():
return ss1.pop() + ss2.pop()
test_equality(f, lambda x: x)

# test nn module with prepare_scriptable function
class NonJitableClass(object):
def __init__(self, int1, int2):
self.int1 = int1
self.int2 = int2

def return_vals(self):
return self.int1, self.int2

class CustomWrapper(torch.nn.Module):
def __init__(self, foo):
super(CustomWrapper, self).__init__()
self.foo = foo

def forward(self) -> None:
self.foo.increment(1)
return

def __prepare_scriptable__(self):
int1, int2 = self.foo.return_vals()
foo = torch.classes._TorchScriptTesting._Foo(int1, int2)
return CustomWrapper(foo)

foo = CustomWrapper(NonJitableClass(1, 2))
jit_foo = torch.jit.script(foo)

def test_torchbind_take_as_arg(self):
global StackString # see [local resolution in python]
StackString = torch.classes._TorchScriptTesting._StackString
Expand Down
14 changes: 14 additions & 0 deletions torch/jit/_script.py
Expand Up @@ -741,6 +741,19 @@ class RecursiveScriptModule(ScriptModule): # type: ignore
def __init__(self, arg=None):
super().__init__()

def call_prepare_scriptable_func(obj):
if not isinstance(obj, torch.nn.Module):
return obj
obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj # type: ignore
for name in obj.__dict__:
sub_module = obj.__dict__.get(name)
if name == '_modules':
for k, v in sub_module.items():
sub_module[k] = call_prepare_scriptable_func(v)
obj.__setattr__(name, sub_module)
elif isinstance(sub_module, torch.nn.Module) and not isinstance(sub_module, ScriptModule):
obj.__setattr__(name, call_prepare_scriptable_func(sub_module))
return obj

def script(obj, optimize=None, _frames_up=0, _rcb=None):
r"""
Expand Down Expand Up @@ -894,6 +907,7 @@ def forward(self, input):
return obj

if isinstance(obj, torch.nn.Module):
obj = call_prepare_scriptable_func(obj)
return torch.jit._recursive.create_script_module(
obj, torch.jit._recursive.infer_methods_to_compile
)
Expand Down

0 comments on commit 7f4f975

Please sign in to comment.