From ad02538f94987b38da3b00223e8924001b19364f Mon Sep 17 00:00:00 2001 From: Lillian Johnson Date: Tue, 13 Oct 2020 13:51:02 -0700 Subject: [PATCH] [WIP] adding torch.jit.isinstance ghstack-source-id: 79618038d10e321a2c3716c8272041ff9bfb790c Pull Request resolved: https://github.com/pytorch/pytorch/pull/46062 --- test/jit/test_isinstance.py | 246 ++++++++++++++++++ test/test_jit.py | 1 + torch/csrc/jit/frontend/ir_emitter.cpp | 8 + .../csrc/jit/python/python_sugared_value.cpp | 3 + torch/jit/__init__.py | 5 + torch/jit/_script.py | 96 ++++++- 6 files changed, 358 insertions(+), 1 deletion(-) create mode 100644 test/jit/test_isinstance.py diff --git a/test/jit/test_isinstance.py b/test/jit/test_isinstance.py new file mode 100644 index 000000000000..f9d60b5d5dc2 --- /dev/null +++ b/test/jit/test_isinstance.py @@ -0,0 +1,246 @@ +import os +import sys + +import torch +from typing import List, Any, Dict, Tuple, Optional + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from torch.testing._internal.jit_utils import JitTestCase + +if __name__ == "__main__": + raise RuntimeError( + "This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead." + ) + +# Tests for torch.jit.isinstance +class TestIsinstance(JitTestCase): + def test_int(self): + def int_test(x: Any): + assert torch.jit.isinstance(x, int) + assert not torch.jit.isinstance(x, float) + + x = 1 + self.checkScript(int_test, (x,)) + + def test_float(self): + def float_test(x: Any): + assert torch.jit.isinstance(x, float) + assert not torch.jit.isinstance(x, int) + + x = 1.0 + self.checkScript(float_test, (x,)) + + def test_bool(self): + def bool_test(x: Any): + assert torch.jit.isinstance(x, bool) + assert not torch.jit.isinstance(x, float) + + x = False + self.checkScript(bool_test, (x,)) + + def test_list(self): + def list_str_test(x: Any): + assert torch.jit.isinstance(x, List[str]) + assert not torch.jit.isinstance(x, List[int]) + + x = ["1", "2", "3"] + self.checkScript(list_str_test, (x,)) + + def test_dict(self): + def dict_str_int_test(x: Any): + assert torch.jit.isinstance(x, Dict[str, int]) + assert not torch.jit.isinstance(x, Dict[int, str]) + + x = {"a": 1, "b": 2} + self.checkScript(dict_str_int_test, (x,)) + + def test_tuple(self): + def tuple_test(x: Any): + assert torch.jit.isinstance(x, Tuple[str, int, str]) + assert not torch.jit.isinstance(x, Tuple[int, str, str]) + assert not torch.jit.isinstance(x, Tuple[str]) + + x = ("a", 1, "b") + self.checkScript(tuple_test, (x,)) + + def test_optional(self): + def optional_test(x: Any): + assert torch.jit.isinstance(x, Optional[torch.Tensor]) + assert not torch.jit.isinstance(x, Optional[str]) + # TODO: successful torch.jit.isinstance makes sets type? + + x = torch.ones(3, 3) + self.checkScript(optional_test, (x,)) + + def test_optional_none(self): + def optional_test_none(x: Any): + assert torch.jit.isinstance(x, Optional[torch.Tensor]) + # assert not torch.jit.isinstance(x, Optional[str]) + # TODO: above line fails in TS interpreter need to investigate + + x = None + self.checkScript(optional_test_none, (x,)) + + def test_list_nested(self): + def list_nested(x: Any): + assert torch.jit.isinstance(x, List[Dict[str, int]]) + assert not torch.jit.isinstance(x, List[List[str]]) + + x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] + self.checkScript(list_nested, (x,)) + + def test_dict_nested(self): + def dict_nested(x: Any): + assert torch.jit.isinstance(x, Dict[str, Tuple[str, str, str]]) + assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) + + x = {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")} + self.checkScript(dict_nested, (x,)) + + def test_tuple_nested(self): + def tuple_nested(x: Any): + assert torch.jit.isinstance( + x, Tuple[Dict[str, Tuple[str, str, str]], List[bool], Optional[str]] + ) + assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) + assert not torch.jit.isinstance(x, Tuple[str]) + + x = ( + {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")}, + [True, False, True], + None, + ) + self.checkScript(tuple_nested, (x,)) + + def test_optional_nested(self): + def optional_nested(x: Any): + assert torch.jit.isinstance(x, Optional[List[str]]) + + x = ["a", "b", "c"] + self.checkScript(optional_nested, (x,)) + + def test_list_tensor_type_true(self): + def list_tensor_type_true(x: Any): + assert torch.jit.isinstance(x, List[torch.Tensor]) + + x = [torch.rand(3, 3), torch.rand(4, 3)] + self.checkScript(list_tensor_type_true, (x,)) + + def test_tensor_type_false(self): + def list_tensor_type_false(x: Any): + assert not torch.jit.isinstance(x, List[torch.Tensor]) + + x = [1, 2, 3] + self.checkScript(list_tensor_type_false, (x,)) + + def test_in_if(self): + def list_in_if(x: Any): + if torch.jit.isinstance(x, List[int]): + assert True + if torch.jit.isinstance(x, List[str]): + assert not True + + x = [1, 2, 3] + self.checkScript(list_in_if, (x,)) + + def test_if_else(self): + def list_in_if_else(x: Any): + if torch.jit.isinstance(x, Tuple[str, str, str]): + assert True + else: + assert not True + + x = ("a", "b", "c") + self.checkScript(list_in_if_else, (x,)) + + def test_in_while_loop(self): + def list_in_while_loop(x: Any): + count = 0 + while torch.jit.isinstance(x, List[Dict[str, int]]) and count <= 0: + count = count + 1 + assert count == 1 + + x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] + self.checkScript(list_in_while_loop, (x,)) + + def test_type_refinement(self): + def type_refinement(obj: Any): + hit = False + if torch.jit.isinstance(obj, List[torch.Tensor]): + hit = not hit + for el in obj: + # perform some tensor operation + y = el.clamp(0, 0.5) + if torch.jit.isinstance(obj, Dict[str, str]): + hit = not hit + str_cat = "" + for val in obj.values(): + str_cat = str_cat + val + assert "111222" == str_cat + assert hit + + x = [torch.rand(3, 3), torch.rand(4, 3)] + self.checkScript(type_refinement, (x,)) + x = {"1": "111", "2": "222"} + self.checkScript(type_refinement, (x,)) + + def test_list_no_contained_type(self): + def list_no_contained_type(x: Any): + assert torch.jit.isinstance(x, List) + + x = ["1", "2", "3"] + + with self.assertRaisesRegex( + RuntimeError, + "Attempted to use List without a " + "contained type. Please add a contained type, e.g. " + r"List\[int\]", + ): + torch.jit.script(list_no_contained_type) + + + def test_tuple_no_contained_type(self): + def tuple_no_contained_type(x: Any): + assert torch.jit.isinstance(x, Tuple) + + x = ("1", "2", "3") + + with self.assertRaisesRegex( + RuntimeError, + "Attempted to use Tuple without a " + "contained type. Please add a contained type, e.g. " + r"Tuple\[int\]" + ): + torch.jit.script(tuple_no_contained_type) + + def test_optional_no_contained_type(self): + def optional_no_contained_type(x: Any): + assert torch.jit.isinstance(x, Optional) + + x = ("1", "2", "3") + + with self.assertRaisesRegex( + RuntimeError, + "Attempted to use Optional without a " + "contained type. Please add a contained type, e.g. " + r"Optional\[int\]", + ): + torch.jit.script(optional_no_contained_type) + + def test_dict_no_contained_type(self): + def dict_no_contained_type(x: Any): + assert torch.jit.isinstance(x, Dict) + + x = {"a": "aa"} + + with self.assertRaisesRegex( + RuntimeError, + "Attempted to use Dict without " + "contained types. Please add contained type, e.g. " + r"Dict\[int, int\]", + ): + torch.jit.script(dict_no_contained_type) diff --git a/test/test_jit.py b/test/test_jit.py index 01bf1339bcf7..b9d37821186b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -33,6 +33,7 @@ from jit.test_profiler import TestProfiler # noqa: F401 from jit.test_slice import TestSlice # noqa: F401 from jit.test_warn import TestWarn # noqa: F401 +from jit.test_isinstance import TestIsinstance # noqa: F401 # Torch from torch import Tensor diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index fce8cd314c49..9aaa90538018 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -1196,6 +1196,14 @@ struct to_ir { return emitHasAttr(apply.inputs()[0], apply.inputs()[1]); } } + auto sv = emitSugaredExpr(apply.callee(), 1); + auto loc = apply.callee().range(); + if (auto special_form = dynamic_cast(sv.get())) { + if (special_form->form() == prim::isinstance) { + checkApplyNumInputs(apply, 2); + return emitIsInstance(apply.inputs()[0], apply.inputs()[1]); + } + } } auto expr_out = emitToBool(expr.range(), emitExpr(expr)); c10::optional static_if = c10::nullopt; diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index a99f706bce68..15d151b761ef 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -918,6 +918,9 @@ std::shared_ptr toSugaredValue( } else if ( obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) { return SpecialFormValue::create(prim::annotate); + } else if ( + obj.ptr() == py::module::import("torch.jit").attr("isinstance").ptr()) { + return SpecialFormValue::create(prim::isinstance); #ifdef USE_RPC // RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on. } else if ( diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index fd61228f3379..bb728883f874 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -21,6 +21,7 @@ RecursiveScriptModule, ScriptWarning, interface, + _isinstance, CompilationUnit, ScriptFunction, _unwrap_optional, @@ -70,5 +71,9 @@ def annotate(the_type, the_value): return the_value +# for torch.jit.isinstance +isinstance = _isinstance + + if not torch._C._jit_init(): raise RuntimeError("JIT initialization failed") diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 19cce3a86945..1531c9d54f8d 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -12,7 +12,7 @@ import copy import pickle import warnings -from typing import Any, Dict +from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -1079,5 +1079,99 @@ def _unwrap_optional(x): return x +def get_origin(target_type): + return getattr(target_type, "__origin__", None) + + +def get_args(target_type): + return getattr(target_type, "__args__", None) + + +def check_args_exist(target_type): + if target_type is List or target_type is list: + _jit_internal.is_list(target_type) + elif target_type is Tuple or target_type is tuple: + _jit_internal.is_tuple(target_type) + elif target_type is Dict or target_type is dict: + _jit_internal.is_dict(target_type) + elif target_type is None or target_type is Optional: + _jit_internal.is_optional(target_type) + +# supports List/Dict/Tuple and Optional types +# TODO support future +def generics_checker(obj, target_type): + origin_type = get_origin(target_type) + check_args_exist(target_type) + if origin_type is list or origin_type is List: + if not isinstance(obj, list): + return False + arg_type = get_args(target_type)[0] + arg_origin = get_origin(arg_type) + for el in obj: + # check if nested generics, ex: List[List[str]] + if arg_origin: # processes nested generics, ex: List[List[str]] + if not generics_checker(el, arg_type): + return False + elif not isinstance(el, arg_type): + return False + return True + elif origin_type is Dict or origin_type is dict: + if not isinstance(obj, dict): + return False + key_type = get_args(target_type)[0] + val_type = get_args(target_type)[1] + for key, val in obj.items(): + # check if keys are of right type + if not isinstance(key, key_type): + return False + val_origin = get_origin(val_type) + if val_origin: + if not generics_checker(val, val_type): + return False + elif not isinstance(val, val_type): + return False + return True + elif origin_type is Tuple or origin_type is tuple: + if not isinstance(obj, tuple): + return False + arg_types = get_args(target_type) + if len(obj) != len(arg_types): + return False + for el, el_type in zip(obj, arg_types): + el_origin = get_origin(el_type) + if el_origin: + if not generics_checker(el, el_type): + return False + elif not isinstance(el, el_type): + return False + return True + elif origin_type is Union: # actually handles Optional Case + if obj is None: # check before recursion because None is always fine + return True + optional_type = get_args(target_type)[0] + optional_origin = get_origin(optional_type) + if optional_origin: + return generics_checker(obj, optional_type) + elif isinstance(obj, optional_type): + return True + return False + + +def _isinstance(obj, target_type) -> bool: + origin_type = get_origin(target_type) + if origin_type: + return generics_checker(obj, target_type) + + # Check to handle weird python type behaviors + # 1. python 3.6 returns None for origin of containers without + # contained type (intead of returning outer container type) + # 2. non-typed optional origin returns as none instead + # of as optional in 3.6-3.8 + check_args_exist(target_type) + + # handle non-generics + return isinstance(obj, target_type) + + _register_builtin(_unwrap_optional, "aten::_unwrap_optional") _register_builtin(_jit_internal.is_scripting, "aten::is_scripting")