Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 40 additions & 40 deletions test/jit/test_pdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def forward(self, x) -> Any:
make_global(TestPDTModel)
pdt_model = TestPDTModel()
inp: List[Tuple[Any, ...]] = [(20, ), (2.7, ), (False, ), ]
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp})
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp})
self.assertEqual(scripted_pdt_model(50), pdt_model(50))
self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8))
self.assertTrue(scripted_pdt_model(True), pdt_model(True))
Expand All @@ -67,7 +67,7 @@ def forward(self, x):
inner_pdt_model = NestedPDTInner()
wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model)
inp: List[Tuple[Any, ...]] = [(20, ), (False, )]
scripted_pdt_model = torch.jit._script_pdt(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp})
scripted_pdt_model = torch.jit.script(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp})
self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30))
self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9))
self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True))
Expand Down Expand Up @@ -95,8 +95,8 @@ def forward(self, x):
outer_pdt_model = NestedModulePDTOuter(inner_pdt_model)
inner_input: List[Tuple[Any, ...]] = [(10, 10), (1.9, 20), ]
outer_input: List[Tuple[Any, ...]] = [(20, ), (False, )]
scripted_pdt_model = torch.jit._script_pdt(outer_pdt_model, example_inputs={inner_pdt_model: inner_input,
outer_pdt_model: outer_input, })
scripted_pdt_model = torch.jit.script(outer_pdt_model, example_inputs={inner_pdt_model: inner_input,
outer_pdt_model: outer_input, })
self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30))
self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9))
self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True))
Expand All @@ -119,7 +119,7 @@ def fun(self, x):
make_global(NestedFunctionInForward)
pdt_model = NestedFunctionInForward()
inp: List[Tuple[Any, ...]] = [(-1, ), (False, )]
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp})
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp})
self.assertEqual(scripted_pdt_model(30), pdt_model(30))
self.assertEqual(scripted_pdt_model(True), pdt_model(True))

Expand All @@ -142,7 +142,7 @@ def fn(self, x, y) -> Any:
make_global(TestModelWithExport)
pdt_model = TestModelWithExport()
inp: List[Tuple[Any, ...]] = [(20, 10, ), (2.7, 8.9, ), ]
scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model.fn: inp})
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model.fn: inp})
self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90))
self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2))
self.assertTrue(scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2))
Expand All @@ -155,7 +155,7 @@ def test_sum(self, a):
make_global(PDTModel)
pdt_model = PDTModel()
inp: List[Tuple[Any, ...]] = [([10, 20, ], ), ]
scripted_pdt_model = torch.jit._script_pdt(PDTModel, example_inputs={pdt_model.test_sum: inp})
scripted_pdt_model = torch.jit.script(PDTModel, example_inputs={pdt_model.test_sum: inp})
script_model = scripted_pdt_model()
self.assertEqual(script_model.test_sum([10, 20, 30, ], ), pdt_model.test_sum([10, 20, 30, ], ))

Expand All @@ -174,8 +174,8 @@ def test_substring(self, a, b):
pdt_model = PDTModelWithManyMethods()
list_inp: List[Tuple[Any, ...]] = [([1.2, 2.3, ], ), ]
str_inp: List[Tuple[Any, ...]] = [("abc", "b", ), ]
scripted_pdt_model = torch.jit._script_pdt(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp,
pdt_model.test_substring: str_inp})
scripted_pdt_model = torch.jit.script(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp,
pdt_model.test_substring: str_inp})
script_model = scripted_pdt_model()
self.assertEqual(script_model.test_list_to_dict([1.1, 2.2, 3.3, ], ), pdt_model.test_list_to_dict([1.1, 2.2, 3.3, ], ))
self.assertEqual(script_model.test_substring("helloworld", "world", ), pdt_model.test_substring("helloworld", "world", ))
Expand All @@ -195,8 +195,8 @@ def test_find(self, a, b):
pdt_model_two = PDTModelTwo()
dict_inp: List[Tuple[Any, ...]] = [({1.2: True, 2.3: False, }, 1.2), ]
list_inp: List[Tuple[Any, ...]] = [(["abc", "b", ], "c"), ]
scripted_pdt_model_one = torch.jit._script_pdt(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp})
scripted_pdt_model_two = torch.jit._script_pdt(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp})
scripted_pdt_model_one = torch.jit.script(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp})
scripted_pdt_model_two = torch.jit.script(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp})

script_model_one, script_model_two = scripted_pdt_model_one(), scripted_pdt_model_two()
self.assertEqual(script_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4),
Expand All @@ -209,28 +209,28 @@ def test_sum(a, b):
return a + b

make_global(test_sum)
scripted_fn_add = torch.jit._script_pdt(test_sum, example_inputs=[(3, 4)])
scripted_fn_add = torch.jit.script(test_sum, example_inputs=[(3, 4)])
self.assertEqual(scripted_fn_add(10, 2), test_sum(10, 2))

def test_sub(a, b):
return a - b

make_global(test_sub)
scripted_fn_sub = torch.jit._script_pdt(test_sub, example_inputs=[(3.9, 4.10)])
scripted_fn_sub = torch.jit.script(test_sub, example_inputs=[(3.9, 4.10)])
self.assertEqual(scripted_fn_sub(6.5, 2.9), test_sub(6.5, 2.9))

def test_mul(a, b):
return a * b

make_global(test_mul)
scripted_fn_mul = torch.jit._script_pdt(test_mul, example_inputs=[(-10, 9)])
scripted_fn_mul = torch.jit.script(test_mul, example_inputs=[(-10, 9)])
self.assertEqual(scripted_fn_mul(-1, 3), test_mul(-1, 3))

def test_args_complex(real, img):
return torch.complex(real, img)

make_global(test_args_complex)
scripted_fn_complex = torch.jit._script_pdt(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))])
scripted_fn_complex = torch.jit.script(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))])
arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4)
self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2))

Expand All @@ -241,7 +241,7 @@ def test_bool(a):
return 0

make_global(test_bool)
scripted_fn_bool = torch.jit._script_pdt(test_bool, example_inputs=[(True,)])
scripted_fn_bool = torch.jit.script(test_bool, example_inputs=[(True,)])
self.assertEqual(scripted_fn_bool(True), test_bool(True))

def test_str(a):
Expand All @@ -251,7 +251,7 @@ def test_str(a):
return True

make_global(test_str)
scripted_fn_str = torch.jit._script_pdt(test_str, example_inputs=[("",)])
scripted_fn_str = torch.jit.script(test_str, example_inputs=[("",)])
self.assertEqual(scripted_fn_str("abc"), test_str("abc"))

def test_pdt_list_and_tuple(self):
Expand All @@ -260,24 +260,24 @@ def test_list_and_tuple(a):

make_global(test_list_and_tuple)

scripted_fn_float_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([4.9, 8.9],)])
scripted_fn_float_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([4.9, 8.9],)])
self.assertEqual(scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6]))

scripted_fn_bool_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([True, False, True],)])
scripted_fn_bool_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([True, False, True],)])
self.assertEqual(scripted_fn_bool_list_input([True, True, True]), test_list_and_tuple([True, True, True]))

scripted_fn_int_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([3, 4, 5], )])
scripted_fn_int_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([3, 4, 5], )])
self.assertEqual(scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3]))

scripted_fn_float_tuple_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[((4.9, 8.9),)])
scripted_fn_float_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((4.9, 8.9),)])
self.assertEqual(scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6)))

scripted_fn_bool_tuple_input = torch.jit._script_pdt(test_list_and_tuple,
example_inputs=[((True, False, True),)])
scripted_fn_bool_tuple_input = torch.jit.script(test_list_and_tuple,
example_inputs=[((True, False, True),)])
self.assertEqual(scripted_fn_bool_tuple_input((True, True, True)),
test_list_and_tuple((True, True, True)))

scripted_fn_int_tuple_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[((3, 4, 5), )])
scripted_fn_int_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((3, 4, 5), )])
self.assertEqual(scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3)))

def test_nested_list_and_tuple(self):
Expand All @@ -295,22 +295,22 @@ def test_nested_tuple(inp):
make_global(test_nested_list, test_nested_tuple)

list_inp = [[1, 2, 3, ], [5, 6, 7, ]]
scripted_fn = torch.jit._script_pdt(test_nested_list, example_inputs=[(list_inp, ), ])
scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ])
inp = [[0, 4, 7, ], [8, 11, ], [6, -1, -20, ]]
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))

list_inp = ([1, 2, 3, ], [5, 6, 7, ])
scripted_fn = torch.jit._script_pdt(test_nested_list, example_inputs=[(list_inp, ), ])
scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ])
inp = ([0, 4, 7, ], [8, 11, ], [6, -1, -20, ])
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))

tup_inp = [(1.0, 2.6, 3.7, ), (5.7, 6.1, 1.7, )]
scripted_fn = torch.jit._script_pdt(test_nested_tuple, example_inputs=[(tup_inp, ), ])
scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ])
inp = [(1.0, 4.1, 7.4, ), (4.8, 1.1, -1.2, ), (6.3, -1.3, -2.0, )]
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))

tup_inp = ((True, False, True, ), (False, False, False, ))
scripted_fn = torch.jit._script_pdt(test_nested_tuple, example_inputs=[(tup_inp, ), ])
scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ])
inp = ((True, True, True, ), (False, False, True, ))
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))

Expand All @@ -324,11 +324,11 @@ def test_dict_int_list(a):
make_global(test_dict, test_dict_int_list)

str_bool_inp = {'foo' : True, 'bar': False}
scripted_fn = torch.jit._script_pdt(test_dict, example_inputs=[(str_bool_inp,)])
scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)])
self.assertEqual(scripted_fn({'foo' : False, 'bar': True}, ), test_dict({'foo' : False, 'bar': True}, ))

str_list_inp = {0 : [True, False], 1: [False, True]}
scripted_fn = torch.jit._script_pdt(test_dict_int_list, example_inputs=[(str_list_inp,)])
scripted_fn = torch.jit.script(test_dict_int_list, example_inputs=[(str_list_inp,)])
self.assertEqual(scripted_fn({0 : [False, False], 1: [True, True]}, ),
test_dict_int_list({0 : [False, False], 1: [True, True]}, ))

Expand All @@ -349,14 +349,14 @@ def test_multiple_type_refinement(a):

make_global(test_multiple_types, test_multiple_type_refinement)

scripted_fn = torch.jit._script_pdt(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )])
scripted_fn = torch.jit.script(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )])
self.assertEqual(scripted_fn(10), test_multiple_types(10))
self.assertEqual(scripted_fn("def"), test_multiple_types("def"))
self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999))
self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14]))

scripted_fn = torch.jit._script_pdt(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,),
([3, 4, 5],), (True, ), ({"a": True}, ), ])
scripted_fn = torch.jit.script(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,),
([3, 4, 5],), (True, ), ({"a": True}, ), ])
self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10))
self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def"))
self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999))
Expand All @@ -381,7 +381,7 @@ def test_model(a, m):
make_global(UserDefinedClass, test_model)

user_class = UserDefinedClass()
scripted_fn = torch.jit._script_pdt(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
scripted_fn = torch.jit.script(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
self.assertEqual(scripted_fn(100, user_class, ), test_model(100, user_class))
self.assertEqual(scripted_fn(1.9, user_class, ), test_model(1.9, user_class))

Expand All @@ -403,7 +403,7 @@ def test_model_with_args(a, m):
make_global(ClassWithArgs, test_model_with_args)

user_class = ClassWithArgs(False)
scripted_fn = torch.jit._script_pdt(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
scripted_fn = torch.jit.script(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
self.assertEqual(scripted_fn(100, ClassWithArgs(True), ), test_model_with_args(100, ClassWithArgs(True)))

def test_nn_parameter_as_arg(self):
Expand All @@ -420,7 +420,7 @@ def forward(self, y):

make_global(TestNNParameter)
pdt_model = TestNNParameter()
scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [(10, ), ], })
scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [(10, ), ], })
self.assertEqual(scripted_fn(20), pdt_model(20))

def test_fx_tracing_with_typing(self):
Expand All @@ -434,7 +434,7 @@ def forward(self, a) -> FXModelOutput:

make_global(FXModel, FXModelOutput)
pdt_model = FXModel()
scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
self.assertEqual(scripted_fn([20]), pdt_model([20]))

def test_nonetype_as_optional_of_type(self):
Expand All @@ -446,11 +446,11 @@ def test_none(a) -> Any:

make_global(test_none)

scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (10.6, )])
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10.6, )])
self.assertEqual(scripted_fn(30.9, ), test_none(30.9, ))

scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (10, )])
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10, )])
self.assertEqual(scripted_fn(2, ), test_none(2, ))

scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (torch.Tensor(1), )])
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (torch.Tensor(1), )])
self.assertEqual(scripted_fn(torch.ones(1), ), test_none(torch.ones(1), ))
1 change: 0 additions & 1 deletion torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)
from torch.jit._script import (
script,
_script_pdt,
Attribute,
ScriptModule,
script_method,
Expand Down
Loading