Skip to content

Commit

Permalink
Infer types for arguments of methods not invoked directly by monkeyty…
Browse files Browse the repository at this point in the history
…pe (#57202)

Summary:
Support adding type annotations for class methods and nn.Module methods which are not invoked under the hood of MonkeyType

** Changes **
* This PR involves a slight change in how the example inputs are passed while scripting `class` and `nn.Module` objects.
* The example inputs passed to `_script_pdt` is of the following format:
     - example_inputs= [(obj.method1, (arg_list)), (obj.method2, (arg_list)),]
* For nn.Modules, to infer types for `forward` methods, example_inputs can be passed in two ways:
    - example_inputs= [(obj.forward, (arg_list, ))]
    - example_inputs = [(obj, (arg_list, ) )]

Pull Request resolved: #57202

Reviewed By: desertfire

Differential Revision: D28382827

Pulled By: nikithamalgifb

fbshipit-source-id: 5481467f3e909493bf3f439ee312056943508534
  • Loading branch information
nikithamalgi authored and facebook-github-bot committed May 12, 2021
1 parent 1de3525 commit 9063cb0
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 19 deletions.
119 changes: 116 additions & 3 deletions test/jit/test_pdt.py
Expand Up @@ -39,7 +39,8 @@ def forward(self, x) -> Any:

make_global(TestPDTModel)
pdt_model = TestPDTModel()
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs=[(10, ), (10.80, ), (False, )])
inp: List[Tuple[Any, ...]] = [(20, ), (2.7, ), (False, ), ]
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp})
self.assertEqual(scripted_pdt_model(50), pdt_model(50))
self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8))
self.assertTrue(scripted_pdt_model(True), pdt_model(True))
Expand All @@ -65,11 +66,41 @@ def forward(self, x):
make_global(NestedPDTInner, NestedModulePDTWrapper)
inner_pdt_model = NestedPDTInner()
wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model)
scripted_pdt_model = torch.jit._script_pdt(wrapped_pdt_model, example_inputs=[(20, ), (2.7, ), (False, )])
inp: List[Tuple[Any, ...]] = [(20, ), (False, )]
scripted_pdt_model = torch.jit._script_pdt(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp})
self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30))
self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9))
self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True))

def test_nested_nn_module_class_with_args(self):
class NestedModulePDTInner(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
if isinstance(x, int):
return x * 10 + y
return x

class NestedModulePDTOuter(torch.nn.Module):
def __init__(self, inner):
super().__init__()
self.inner = inner

def forward(self, x):
return self.inner(x, 20)

make_global(NestedModulePDTInner, NestedModulePDTOuter)
inner_pdt_model = NestedModulePDTInner()
outer_pdt_model = NestedModulePDTOuter(inner_pdt_model)
inner_input: List[Tuple[Any, ...]] = [(10, 10), (1.9, 20), ]
outer_input: List[Tuple[Any, ...]] = [(20, ), (False, )]
scripted_pdt_model = torch.jit._script_pdt(outer_pdt_model, example_inputs={inner_pdt_model: inner_input,
outer_pdt_model: outer_input, })
self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30))
self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9))
self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True))

def test_nested_function_in_forward(self):
class NestedFunctionInForward(torch.nn.Module):
def __init__(self):
Expand All @@ -87,6 +118,88 @@ def fun(self, x):

make_global(NestedFunctionInForward)
pdt_model = NestedFunctionInForward()
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs=[(20, ), (False, )])
inp: List[Tuple[Any, ...]] = [(-1, ), (False, )]
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp})
self.assertEqual(scripted_pdt_model(30), pdt_model(30))
self.assertEqual(scripted_pdt_model(True), pdt_model(True))

def test_nn_module_with_export_function(self):
class TestModelWithExport(torch.nn.Module):
def __init__(self):
super().__init__()

@torch.jit.export
def fn(self, x, y) -> Any:
assert not (isinstance(x, bool) and isinstance(y, bool))
if isinstance(x, int) and isinstance(y, int):
return x + y
elif isinstance(x, float) and isinstance(y, float):
return x - y
else:
return -1


make_global(TestModelWithExport)
pdt_model = TestModelWithExport()
inp: List[Tuple[Any, ...]] = [(20, 10, ), (2.7, 8.9, ), ]
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model.fn: inp})
self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90))
self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2))
self.assertTrue(scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2))

def test_class_methods(self):
class PDTModel:
def test_sum(self, a):
return sum(a)

make_global(PDTModel)
pdt_model = PDTModel()
inp: List[Tuple[Any, ...]] = [([10, 20, ], ), ]
scripted_pdt_model = torch.jit._script_pdt(PDTModel, example_inputs={pdt_model.test_sum: inp})
script_model = scripted_pdt_model()
self.assertEqual(script_model.test_sum([10, 20, 30, ], ), pdt_model.test_sum([10, 20, 30, ], ))

def test_class_with_multiple_methods(self):
class PDTModelWithManyMethods:
def test_list_to_dict(self, a):
new_dictionary: Dict[float, bool] = {}
for element in a:
new_dictionary[element] = True
return new_dictionary

def test_substring(self, a, b):
return b in a

make_global(PDTModelWithManyMethods)
pdt_model = PDTModelWithManyMethods()
list_inp: List[Tuple[Any, ...]] = [([1.2, 2.3, ], ), ]
str_inp: List[Tuple[Any, ...]] = [("abc", "b", ), ]
scripted_pdt_model = torch.jit._script_pdt(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp,
pdt_model.test_substring: str_inp})
script_model = scripted_pdt_model()
self.assertEqual(script_model.test_list_to_dict([1.1, 2.2, 3.3, ], ), pdt_model.test_list_to_dict([1.1, 2.2, 3.3, ], ))
self.assertEqual(script_model.test_substring("helloworld", "world", ), pdt_model.test_substring("helloworld", "world", ))
self.assertEqual(script_model.test_substring("helloworld", "def", ), pdt_model.test_substring("helloworld", "def", ))

def test_multiple_class_with_same_method(self):
class PDTModelOne:
def test_find(self, a, b):
return b in a.keys()

class PDTModelTwo:
def test_find(self, a, b):
return b in a

make_global(PDTModelOne, PDTModelTwo)
pdt_model_one = PDTModelOne()
pdt_model_two = PDTModelTwo()
dict_inp: List[Tuple[Any, ...]] = [({1.2: True, 2.3: False, }, 1.2), ]
list_inp: List[Tuple[Any, ...]] = [(["abc", "b", ], "c"), ]
scripted_pdt_model_one = torch.jit._script_pdt(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp})
scripted_pdt_model_two = torch.jit._script_pdt(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp})

script_model_one, script_model_two = scripted_pdt_model_one(), scripted_pdt_model_two()
self.assertEqual(script_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4),
pdt_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4))
self.assertEqual(script_model_two.test_find(["hello", "world", ], "world"),
pdt_model_two.test_find(["hello", "world", ], "world"))
47 changes: 31 additions & 16 deletions torch/jit/_script.py
Expand Up @@ -13,7 +13,7 @@
import copy
import pickle
import warnings
from typing import Any, Dict, List, Tuple, Optional
from typing import Any, Dict, List, Tuple, Union, Callable


import torch
Expand Down Expand Up @@ -851,7 +851,8 @@ def call_prepare_scriptable_func(obj):
memo: Dict[int, torch.nn.Module] = {}
return call_prepare_scriptable_func_impl(obj, memo)

def _script_pdt(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs: Optional[List[Tuple]] = None):
def _script_pdt(obj, optimize=None, _frames_up=0, _rcb=None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None):
# This is a private API, intended for internal use only. Usage of this API is only for experimental
# purposes only and is highly discouraged.
global type_trace_db
Expand All @@ -869,20 +870,34 @@ def _script_pdt(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs: Opt
if isinstance(obj, ScriptFunction):
return obj

# If MonkeyType is installed, enable profile directed type annotation
# Check if example_inputs are defined and generate call traces
# for the method by running eager mode version of the method with
# the provide example inputs. This logs all the traces in type_trace_db
type_trace_db = JitTypeTraceStore()
if monkeytype_trace:
monkeytype_config = JitTypeTraceConfig(type_trace_db)
with monkeytype_trace(monkeytype_config):
for example_input in example_inputs: # type: ignore[union-attr]
obj(*example_input)
else:
warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
"to enable Profile-Directed Typing in TorchScript. Refer to "
"https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ")
if example_inputs:
# If MonkeyType is installed, enable profile directed type annotation
# Check if example_inputs are defined and generate call traces
# for the method by running eager mode version of the method with
# the provide example inputs. This logs all the traces in type_trace_db
type_trace_db = JitTypeTraceStore()
if monkeytype_trace:
monkeytype_config = JitTypeTraceConfig(type_trace_db)
with monkeytype_trace(monkeytype_config):
if isinstance(example_inputs, Dict):
# If the obj is an nn.Module or a class, then each method is
# executed with the arguments provided in the example inputs.
# example inputs here will be of type Dict(class.method, (arguments))
# This is used to infer type annotations for those methods
# which are not called directly under the hood of monkeytype.
for module, example_input in example_inputs.items():
for example in example_input:
module(*example)
elif isinstance(example_inputs, List):
for examples in example_inputs:
obj(*examples)
else:
warnings.warn("Error: Unable to infer types. Please format the inputs to type `List[Tuple]`"
" or `Dict[Callable, List[Tuple]]` to be run with MonkeyType.")
else:
warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
"to enable Profile-Directed Typing in TorchScript. Refer to "
"https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ")
return script(obj, optimize, _frames_up, _rcb)

def script(obj, optimize=None, _frames_up=0, _rcb=None):
Expand Down

0 comments on commit 9063cb0

Please sign in to comment.