diff --git a/docs/source/jit.rst b/docs/source/jit.rst index 45ba6fa18d80..e201929f9263 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -54,6 +54,7 @@ Creating TorchScript Code load ignore unused + isinstance Mixing Tracing and Scripting ---------------------------- diff --git a/test/jit/test_isinstance.py b/test/jit/test_isinstance.py new file mode 100644 index 000000000000..f0d6df66ee5c --- /dev/null +++ b/test/jit/test_isinstance.py @@ -0,0 +1,251 @@ +import os +import sys + +import torch +from typing import List, Any, Dict, Tuple, Optional + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from torch.testing._internal.jit_utils import JitTestCase + +if __name__ == "__main__": + raise RuntimeError( + "This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead." + ) + +# Tests for torch.jit.isinstance +class TestIsinstance(JitTestCase): + def test_int(self): + def int_test(x: Any): + assert torch.jit.isinstance(x, int) + assert not torch.jit.isinstance(x, float) + + x = 1 + self.checkScript(int_test, (x,)) + + def test_float(self): + def float_test(x: Any): + assert torch.jit.isinstance(x, float) + assert not torch.jit.isinstance(x, int) + + x = 1.0 + self.checkScript(float_test, (x,)) + + def test_bool(self): + def bool_test(x: Any): + assert torch.jit.isinstance(x, bool) + assert not torch.jit.isinstance(x, float) + + x = False + self.checkScript(bool_test, (x,)) + + def test_list(self): + def list_str_test(x: Any): + assert torch.jit.isinstance(x, List[str]) + assert not torch.jit.isinstance(x, List[int]) + assert not torch.jit.isinstance(x, Tuple[int]) + + x = ["1", "2", "3"] + self.checkScript(list_str_test, (x,)) + + def test_dict(self): + def dict_str_int_test(x: Any): + assert torch.jit.isinstance(x, Dict[str, int]) + assert not torch.jit.isinstance(x, Dict[int, str]) + assert not torch.jit.isinstance(x, Dict[str, str]) + + x = {"a": 1, "b": 2} + self.checkScript(dict_str_int_test, (x,)) + + def test_tuple(self): + def tuple_test(x: Any): + assert torch.jit.isinstance(x, Tuple[str, int, str]) + assert not torch.jit.isinstance(x, Tuple[int, str, str]) + assert not torch.jit.isinstance(x, Tuple[str]) + + x = ("a", 1, "b") + self.checkScript(tuple_test, (x,)) + + def test_optional(self): + def optional_test(x: Any): + assert torch.jit.isinstance(x, Optional[torch.Tensor]) + assert not torch.jit.isinstance(x, Optional[str]) + + x = torch.ones(3, 3) + self.checkScript(optional_test, (x,)) + + def test_optional_none(self): + def optional_test_none(x: Any): + assert torch.jit.isinstance(x, Optional[torch.Tensor]) + # assert torch.jit.isinstance(x, Optional[str]) + # TODO: above line in eager will evaluate to True while in + # the TS interpreter will evaluate to False as the + # first torch.jit.isinstance refines the 'None' type + + x = None + self.checkScript(optional_test_none, (x,)) + + def test_list_nested(self): + def list_nested(x: Any): + assert torch.jit.isinstance(x, List[Dict[str, int]]) + assert not torch.jit.isinstance(x, List[List[str]]) + + x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] + self.checkScript(list_nested, (x,)) + + def test_dict_nested(self): + def dict_nested(x: Any): + assert torch.jit.isinstance(x, Dict[str, Tuple[str, str, str]]) + assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) + + x = {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")} + self.checkScript(dict_nested, (x,)) + + def test_tuple_nested(self): + def tuple_nested(x: Any): + assert torch.jit.isinstance( + x, Tuple[Dict[str, Tuple[str, str, str]], List[bool], Optional[str]] + ) + assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) + assert not torch.jit.isinstance(x, Tuple[str]) + assert not torch.jit.isinstance(x, Tuple[List[bool], List[str], List[int]]) + + x = ( + {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")}, + [True, False, True], + None, + ) + self.checkScript(tuple_nested, (x,)) + + def test_optional_nested(self): + def optional_nested(x: Any): + assert torch.jit.isinstance(x, Optional[List[str]]) + + x = ["a", "b", "c"] + self.checkScript(optional_nested, (x,)) + + def test_list_tensor_type_true(self): + def list_tensor_type_true(x: Any): + assert torch.jit.isinstance(x, List[torch.Tensor]) + + x = [torch.rand(3, 3), torch.rand(4, 3)] + self.checkScript(list_tensor_type_true, (x,)) + + def test_tensor_type_false(self): + def list_tensor_type_false(x: Any): + assert not torch.jit.isinstance(x, List[torch.Tensor]) + + x = [1, 2, 3] + self.checkScript(list_tensor_type_false, (x,)) + + def test_in_if(self): + def list_in_if(x: Any): + if torch.jit.isinstance(x, List[int]): + assert True + if torch.jit.isinstance(x, List[str]): + assert not True + + x = [1, 2, 3] + self.checkScript(list_in_if, (x,)) + + def test_if_else(self): + def list_in_if_else(x: Any): + if torch.jit.isinstance(x, Tuple[str, str, str]): + assert True + else: + assert not True + + x = ("a", "b", "c") + self.checkScript(list_in_if_else, (x,)) + + def test_in_while_loop(self): + def list_in_while_loop(x: Any): + count = 0 + while torch.jit.isinstance(x, List[Dict[str, int]]) and count <= 0: + count = count + 1 + assert count == 1 + + x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] + self.checkScript(list_in_while_loop, (x,)) + + def test_type_refinement(self): + def type_refinement(obj: Any): + hit = False + if torch.jit.isinstance(obj, List[torch.Tensor]): + hit = not hit + for el in obj: + # perform some tensor operation + y = el.clamp(0, 0.5) + if torch.jit.isinstance(obj, Dict[str, str]): + hit = not hit + str_cat = "" + for val in obj.values(): + str_cat = str_cat + val + assert "111222" == str_cat + assert hit + + x = [torch.rand(3, 3), torch.rand(4, 3)] + self.checkScript(type_refinement, (x,)) + x = {"1": "111", "2": "222"} + self.checkScript(type_refinement, (x,)) + + def test_list_no_contained_type(self): + def list_no_contained_type(x: Any): + assert torch.jit.isinstance(x, List) + + x = ["1", "2", "3"] + + err_msg = "Attempted to use List without a contained type. " \ + r"Please add a contained type, e.g. List\[int\]" + + with self.assertRaisesRegex(RuntimeError, err_msg,): + torch.jit.script(list_no_contained_type) + with self.assertRaisesRegex(RuntimeError, err_msg,): + list_no_contained_type(x) + + + + def test_tuple_no_contained_type(self): + def tuple_no_contained_type(x: Any): + assert torch.jit.isinstance(x, Tuple) + + x = ("1", "2", "3") + + err_msg = "Attempted to use Tuple without a contained type. " \ + r"Please add a contained type, e.g. Tuple\[int\]" + + with self.assertRaisesRegex(RuntimeError, err_msg,): + torch.jit.script(tuple_no_contained_type) + with self.assertRaisesRegex(RuntimeError, err_msg,): + tuple_no_contained_type(x) + + def test_optional_no_contained_type(self): + def optional_no_contained_type(x: Any): + assert torch.jit.isinstance(x, Optional) + + x = ("1", "2", "3") + + err_msg = "Attempted to use Optional without a contained type. " \ + r"Please add a contained type, e.g. Optional\[int\]" + + with self.assertRaisesRegex(RuntimeError, err_msg,): + torch.jit.script(optional_no_contained_type) + with self.assertRaisesRegex(RuntimeError, err_msg,): + optional_no_contained_type(x) + + def test_dict_no_contained_type(self): + def dict_no_contained_type(x: Any): + assert torch.jit.isinstance(x, Dict) + + x = {"a": "aa"} + + err_msg = "Attempted to use Dict without contained types. " \ + r"Please add contained type, e.g. Dict\[int, int\]" + + with self.assertRaisesRegex(RuntimeError, err_msg,): + torch.jit.script(dict_no_contained_type) + with self.assertRaisesRegex(RuntimeError, err_msg,): + dict_no_contained_type(x) diff --git a/test/test_jit.py b/test/test_jit.py index 01bf1339bcf7..b9d37821186b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -33,6 +33,7 @@ from jit.test_profiler import TestProfiler # noqa: F401 from jit.test_slice import TestSlice # noqa: F401 from jit.test_warn import TestWarn # noqa: F401 +from jit.test_isinstance import TestIsinstance # noqa: F401 # Torch from torch import Tensor diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index e9fb21c5e854..62783086b645 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -637,11 +637,7 @@ def _get_overloaded_methods(method, mod_class): def is_tuple(ann): if ann is Tuple: - raise RuntimeError( - "Attempted to use Tuple without a " - "contained type. Please add a contained type, e.g. " - "Tuple[int]" - ) + raise_error_container_parameter_missing("Tuple") # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule if not hasattr(ann, '__module__'): @@ -652,11 +648,7 @@ def is_tuple(ann): def is_list(ann): if ann is List: - raise RuntimeError( - "Attempted to use List without a " - "contained type. Please add a contained type, e.g. " - "List[int]" - ) + raise_error_container_parameter_missing("List") if not hasattr(ann, '__module__'): return False @@ -666,11 +658,7 @@ def is_list(ann): def is_dict(ann): if ann is Dict: - raise RuntimeError( - "Attempted to use Dict without " - "contained types. Please add contained type, e.g. " - "Dict[int, int]" - ) + raise_error_container_parameter_missing("Dict") if not hasattr(ann, '__module__'): return False @@ -680,11 +668,7 @@ def is_dict(ann): def is_optional(ann): if ann is Optional: - raise RuntimeError( - "Attempted to use Optional without a " - "contained type. Please add a contained type, e.g. " - "Optional[int]" - ) + raise_error_container_parameter_missing("Optional") # Optional[T] is just shorthand for Union[T, None], so check for both def safe_is_subclass(the_type, super_type): @@ -885,3 +869,110 @@ def _is_exception(obj): if not inspect.isclass(obj): return False return issubclass(obj, Exception) + +def raise_error_container_parameter_missing(target_type): + if target_type == 'Dict': + raise RuntimeError( + "Attempted to use Dict without " + "contained types. Please add contained type, e.g. " + "Dict[int, int]" + ) + raise RuntimeError( + f"Attempted to use {target_type} without a " + "contained type. Please add a contained type, e.g. " + f"{target_type}[int]" + ) + + +def get_origin(target_type): + return getattr(target_type, "__origin__", None) + + +def get_args(target_type): + return getattr(target_type, "__args__", None) + + +def check_args_exist(target_type): + if target_type is List or target_type is list: + raise_error_container_parameter_missing("List") + elif target_type is Tuple or target_type is tuple: + raise_error_container_parameter_missing("Tuple") + elif target_type is Dict or target_type is dict: + raise_error_container_parameter_missing("Dict") + elif target_type is None or target_type is Optional: + raise_error_container_parameter_missing("Optional") + +# supports List/Dict/Tuple and Optional types +# TODO support future +def container_checker(obj, target_type): + origin_type = get_origin(target_type) + check_args_exist(target_type) + if origin_type is list or origin_type is List: + if not isinstance(obj, list): + return False + arg_type = get_args(target_type)[0] + arg_origin = get_origin(arg_type) + for el in obj: + # check if nested container, ex: List[List[str]] + if arg_origin: # processes nested container, ex: List[List[str]] + if not container_checker(el, arg_type): + return False + elif not isinstance(el, arg_type): + return False + return True + elif origin_type is Dict or origin_type is dict: + if not isinstance(obj, dict): + return False + key_type = get_args(target_type)[0] + val_type = get_args(target_type)[1] + for key, val in obj.items(): + # check if keys are of right type + if not isinstance(key, key_type): + return False + val_origin = get_origin(val_type) + if val_origin: + if not container_checker(val, val_type): + return False + elif not isinstance(val, val_type): + return False + return True + elif origin_type is Tuple or origin_type is tuple: + if not isinstance(obj, tuple): + return False + arg_types = get_args(target_type) + if len(obj) != len(arg_types): + return False + for el, el_type in zip(obj, arg_types): + el_origin = get_origin(el_type) + if el_origin: + if not container_checker(el, el_type): + return False + elif not isinstance(el, el_type): + return False + return True + elif origin_type is Union: # actually handles Optional Case + if obj is None: # check before recursion because None is always fine + return True + optional_type = get_args(target_type)[0] + optional_origin = get_origin(optional_type) + if optional_origin: + return container_checker(obj, optional_type) + elif isinstance(obj, optional_type): + return True + return False + + +def _isinstance(obj, target_type) -> bool: + origin_type = get_origin(target_type) + if origin_type: + return container_checker(obj, target_type) + + # Check to handle weird python type behaviors + # 1. python 3.6 returns None for origin of containers without + # contained type (intead of returning outer container type) + # 2. non-typed optional origin returns as none instead + # of as optional in 3.6-3.8 + check_args_exist(target_type) + + # handle non-containers + return isinstance(obj, target_type) diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index fce8cd314c49..9aaa90538018 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -1196,6 +1196,14 @@ struct to_ir { return emitHasAttr(apply.inputs()[0], apply.inputs()[1]); } } + auto sv = emitSugaredExpr(apply.callee(), 1); + auto loc = apply.callee().range(); + if (auto special_form = dynamic_cast(sv.get())) { + if (special_form->form() == prim::isinstance) { + checkApplyNumInputs(apply, 2); + return emitIsInstance(apply.inputs()[0], apply.inputs()[1]); + } + } } auto expr_out = emitToBool(expr.range(), emitExpr(expr)); c10::optional static_if = c10::nullopt; diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index a99f706bce68..15d151b761ef 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -918,6 +918,9 @@ std::shared_ptr toSugaredValue( } else if ( obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) { return SpecialFormValue::create(prim::annotate); + } else if ( + obj.ptr() == py::module::import("torch.jit").attr("isinstance").ptr()) { + return SpecialFormValue::create(prim::isinstance); #ifdef USE_RPC // RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on. } else if ( diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index fd61228f3379..4e87c12023f3 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -9,6 +9,7 @@ _overload, _overload_method, ignore, + _isinstance, is_scripting, export, unused, @@ -70,5 +71,48 @@ def annotate(the_type, the_value): return the_value +# for torch.jit.isinstance +def isinstance(obj, target_type): + """ + This function provides for conatiner type refinement in TorchScript. It can refine + parameterized containers of the List, Dict, Tuple, and Optional types. E.g. ``List[str]``, + ``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also + refine basic types such as bools and ints that are available in TorchScript. + + Arguments: + obj: object to refine the type of + target_type: type to try to refine obj to + Returns: + ``bool``: True if obj was successfully refined to the type of target_type, + False otherwise with no new type refinement + + + Example (using ``torch.jit.isinstance`` for type refinement): + .. testcode:: + + import torch + from typing import Any, Dict, List + + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + + def forward(self, input: Any): # note the Any type + if torch.jit.isinstance(input, List[torch.Tensor]): + for t in input: + y = t.clamp(0, 0.5) + elif torch.jit.isinstance(input, Dict[str, str]): + for val in input.values(): + print(val) + + m = torch.jit.script(MyModule()) + x = [torch.rand(3,3), torch.rand(4,3)] + m(x) + y = {"key1":"val1","key2":"val2"} + m(y) + """ + return _isinstance(obj, target_type) + + if not torch._C._jit_init(): raise RuntimeError("JIT initialization failed")