From 3e170255eeb353f2a8b8eb16df96dbbb73360afe Mon Sep 17 00:00:00 2001 From: Lillian Johnson Date: Fri, 9 Oct 2020 14:03:15 -0700 Subject: [PATCH] [WIP] adding torch.jit.isinstance ghstack-source-id: 5c165eb874b6015036e110ee5f3668b7d47125f2 Pull Request resolved: https://github.com/pytorch/pytorch/pull/46062 --- test/jit/test_isinstance.py | 293 ++++++++++++++++++ test/test_jit.py | 1 + torch/csrc/jit/frontend/ir_emitter.cpp | 8 + .../csrc/jit/python/python_sugared_value.cpp | 3 + torch/jit/__init__.py | 8 + torch/jit/_isinstance.py | 109 +++++++ 6 files changed, 422 insertions(+) create mode 100644 test/jit/test_isinstance.py create mode 100644 torch/jit/_isinstance.py diff --git a/test/jit/test_isinstance.py b/test/jit/test_isinstance.py new file mode 100644 index 000000000000..c301fcb2eb91 --- /dev/null +++ b/test/jit/test_isinstance.py @@ -0,0 +1,293 @@ +import os +import sys + +import torch +from typing import List, Any, Dict, Tuple, Optional + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from torch.testing._internal.jit_utils import JitTestCase + +if __name__ == "__main__": + raise RuntimeError( + "This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead." + ) + +# Tests for torch.jit.isinstance +class TestIsinstance(JitTestCase): + def test_int(self): + def int_test(x: Any): + assert torch.jit.isinstance(x, int) + assert not torch.jit.isinstance(x, float) + + x = 1 + self.checkScript(int_test, (x,)) + + def test_float(self): + def float_test(x: Any): + assert torch.jit.isinstance(x, float) + assert not torch.jit.isinstance(x, int) + + x = 1.0 + self.checkScript(float_test, (x,)) + + def test_bool(self): + def bool_test(x: Any): + assert torch.jit.isinstance(x, bool) + assert not torch.jit.isinstance(x, float) + + x = False + self.checkScript(bool_test, (x,)) + + def test_list(self): + def list_str_test(x: Any): + assert torch.jit.isinstance(x, List[str]) + assert not torch.jit.isinstance(x, List[int]) + + x = ["1", "2", "3"] + self.checkScript(list_str_test, (x,)) + + def test_dict(self): + def dict_str_int_test(x: Any): + assert torch.jit.isinstance(x, Dict[str, int]) + assert not torch.jit.isinstance(x, Dict[int, str]) + + x = {"a": 1, "b": 2} + self.checkScript(dict_str_int_test, (x,)) + + def test_tuple(self): + def tuple_test(x: Any): + assert torch.jit.isinstance(x, Tuple[str, int, str]) + assert not torch.jit.isinstance(x, Tuple[int, str, str]) + assert not torch.jit.isinstance(x, Tuple[str]) + + x = ("a", 1, "b") + self.checkScript(tuple_test, (x,)) + + def test_optional(self): + def optional_test(x: Any): + assert torch.jit.isinstance(x, Optional[torch.Tensor]) + assert not torch.jit.isinstance(x, Optional[str]) + # TODO: successful torch.jit.isinstance makes sets type? + + x = torch.ones(3, 3) + self.checkScript(optional_test, (x,)) + + def test_optional_none(self): + def optional_test_none(x: Any): + assert torch.jit.isinstance(x, Optional[torch.Tensor]) + # assert not torch.jit.isinstance(x, Optional[str]) + # TODO: above line fails in TS interpreter need to investigate + + x = None + self.checkScript(optional_test_none, (x,)) + + def test_list_nested(self): + def list_nested(x: Any): + assert torch.jit.isinstance(x, List[Dict[str, int]]) + assert not torch.jit.isinstance(x, List[List[str]]) + + x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] + self.checkScript(list_nested, (x,)) + + def test_dict_nested(self): + def dict_nested(x: Any): + assert torch.jit.isinstance(x, Dict[str, Tuple[str, str, str]]) + assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) + + x = {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")} + self.checkScript(dict_nested, (x,)) + + def test_tuple_nested(self): + def tuple_nested(x: Any): + assert torch.jit.isinstance( + x, Tuple[Dict[str, Tuple[str, str, str]], List[bool], Optional[str]] + ) + assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) + assert not torch.jit.isinstance(x, Tuple[str]) + + x = ( + {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")}, + [True, False, True], + None, + ) + self.checkScript(tuple_nested, (x,)) + + def test_optional_nested(self): + def optional_nested(x: Any): + assert torch.jit.isinstance(x, Optional[List[str]]) + + x = ["a", "b", "c"] + self.checkScript(optional_nested, (x,)) + + def test_list_tensor_type_true(self): + def list_tensor_type_true(x: Any): + assert torch.jit.isinstance(x, List[torch.Tensor]) + + x = [torch.rand(3, 3), torch.rand(4, 3)] + self.checkScript(list_tensor_type_true, (x,)) + + def test_tensor_type_false(self): + def list_tensor_type_false(x: Any): + assert not torch.jit.isinstance(x, List[torch.Tensor]) + + x = [1, 2, 3] + self.checkScript(list_tensor_type_false, (x,)) + + def test_in_if(self): + def list_in_if(x: Any): + if torch.jit.isinstance(x, List[int]): + assert True + if torch.jit.isinstance(x, List[str]): + assert not True + + x = [1, 2, 3] + self.checkScript(list_in_if, (x,)) + + def test_if_else(self): + def list_in_if_else(x: Any): + if torch.jit.isinstance(x, Tuple[str, str, str]): + assert True + else: + assert not True + + x = ("a", "b", "c") + self.checkScript(list_in_if_else, (x,)) + + def test_in_while_loop(self): + def list_in_while_loop(x: Any): + count = 0 + while torch.jit.isinstance(x, List[Dict[str, int]]) and count <= 0: + count = count + 1 + assert count == 1 + + x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] + self.checkScript(list_in_while_loop, (x,)) + + def test_switch_on_type(self): + def list_switch_on_type(obj: Any): + hit = False + if torch.jit.isinstance(obj, List[torch.Tensor]): + hit = not hit + for el in obj: + # perform some tensor operation + y = el.clamp(0, 0.5) + if torch.jit.isinstance(obj, Dict[str, str]): + hit = not hit + str_cat = "" + for val in obj.values(): + str_cat = str_cat + val + assert "111222" == str_cat + assert hit + + x = [torch.rand(3, 3), torch.rand(4, 3)] + self.checkScript(list_switch_on_type, (x,)) + x = {"1": "111", "2": "222"} + self.checkScript(list_switch_on_type, (x,)) + + def test_list_no_contained_type(self): + def list_no_contained_type(x: Any): + assert torch.jit.isinstance(x, List) + + x = ["1", "2", "3"] + + try: + torch.jit.script(list_no_contained_type) + except RuntimeError as e: + self.assertEqual( + str(e), + "Attempted to use List without a " + "contained type. Please add a contained type, e.g. " + "List[int]", + ) + + try: + list_no_contained_type(x) + except RuntimeError as e: + self.assertEqual( + str(e), + "Attempted to use List without a " + "contained type. Please add a contained type, e.g. " + "List[int]", + ) + + def test_tuple_no_contained_type(self): + def tuple_no_contained_type(x: Any): + assert torch.jit.isinstance(x, Tuple) + + x = ("1", "2", "3") + + try: + torch.jit.script(tuple_no_contained_type) + except RuntimeError as e: + self.assertEqual( + str(e), + "Attempted to use Tuple without a " + "contained type. Please add a contained type, e.g. " + "Tuple[int]", + ) + + try: + tuple_no_contained_type(x) + except RuntimeError as e: + self.assertEqual( + str(e), + "Attempted to use Tuple without a " + "contained type. Please add a contained type, e.g. " + "Tuple[int]", + ) + + def test_optional_no_contained_type(self): + def optional_no_contained_type(x: Any): + assert torch.jit.isinstance(x, Optional) + + x = ("1", "2", "3") + + try: + torch.jit.script(optional_no_contained_type) + except RuntimeError as e: + self.assertEqual( + str(e), + "Attempted to use Optional without a " + "contained type. Please add a contained type, e.g. " + "Optional[int]", + ) + + try: + optional_no_contained_type(x) + except RuntimeError as e: + self.assertEqual( + str(e), + "Attempted to use Optional without a " + "contained type. Please add a contained type, e.g. " + "Optional[int]", + ) + + def test_dict_no_contained_type(self): + def dict_no_contained_type(x: Any): + assert torch.jit.isinstance(x, Dict) + + x = {"a": "aa"} + + try: + torch.jit.script(dict_no_contained_type) + except RuntimeError as e: + self.assertEqual( + str(e), + "Attempted to use Dict without " + "contained types. Please add contained type, e.g. " + "Dict[int, int]", + ) + + try: + dict_no_contained_type(x) + except RuntimeError as e: + self.assertEqual( + str(e), + "Attempted to use Dict without " + "contained types. Please add contained type, e.g. " + "Dict[int, int]", + ) diff --git a/test/test_jit.py b/test/test_jit.py index 01bf1339bcf7..b9d37821186b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -33,6 +33,7 @@ from jit.test_profiler import TestProfiler # noqa: F401 from jit.test_slice import TestSlice # noqa: F401 from jit.test_warn import TestWarn # noqa: F401 +from jit.test_isinstance import TestIsinstance # noqa: F401 # Torch from torch import Tensor diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index fce8cd314c49..9aaa90538018 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -1196,6 +1196,14 @@ struct to_ir { return emitHasAttr(apply.inputs()[0], apply.inputs()[1]); } } + auto sv = emitSugaredExpr(apply.callee(), 1); + auto loc = apply.callee().range(); + if (auto special_form = dynamic_cast(sv.get())) { + if (special_form->form() == prim::isinstance) { + checkApplyNumInputs(apply, 2); + return emitIsInstance(apply.inputs()[0], apply.inputs()[1]); + } + } } auto expr_out = emitToBool(expr.range(), emitExpr(expr)); c10::optional static_if = c10::nullopt; diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index a99f706bce68..15d151b761ef 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -918,6 +918,9 @@ std::shared_ptr toSugaredValue( } else if ( obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) { return SpecialFormValue::create(prim::annotate); + } else if ( + obj.ptr() == py::module::import("torch.jit").attr("isinstance").ptr()) { + return SpecialFormValue::create(prim::isinstance); #ifdef USE_RPC // RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on. } else if ( diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index fd61228f3379..f03824ac399e 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1,4 +1,8 @@ import torch._C +from torch.jit._isinstance import _isinstance +from typing import List, Dict, Optional, Tuple, Union +import typing +from sys import version_info from torch.utils import set_module @@ -70,5 +74,9 @@ def annotate(the_type, the_value): return the_value +def isinstance(the_obj, the_type) -> bool: + return _isinstance(the_obj, the_type) + + if not torch._C._jit_init(): raise RuntimeError("JIT initialization failed") diff --git a/torch/jit/_isinstance.py b/torch/jit/_isinstance.py new file mode 100644 index 000000000000..c18f1c0f8979 --- /dev/null +++ b/torch/jit/_isinstance.py @@ -0,0 +1,109 @@ +from typing import List, Dict, Tuple, Union, Optional + + +def get_origin(the_type): + return getattr(the_type, "__origin__", None) + + +def get_args(the_type): + return getattr(the_type, "__args__", None) + + +def check_args_exist(the_type): + if the_type is List or the_type is list: + raise RuntimeError( + "Attempted to use List without a " + "contained type. Please add a contained type, e.g. " + "List[int]" + ) + elif the_type is Tuple or the_type is tuple: + raise RuntimeError( + "Attempted to use Tuple without a " + "contained type. Please add a contained type, e.g. " + "Tuple[int]" + ) + elif the_type is Dict or the_type is dict: + raise RuntimeError( + "Attempted to use Dict without " + "contained types. Please add contained type, e.g. " + "Dict[int, int]" + ) + elif the_type is None or the_type is Optional: + raise RuntimeError( + "Attempted to use Optional without a " + "contained type. Please add a contained type, e.g. " + "Optional[int]" + ) + + +def generics_checker(the_obj, the_type): + origin_type = get_origin(the_type) + check_args_exist(the_type) + if origin_type is None: + pass + elif origin_type is list or origin_type is List: + if isinstance(the_obj, list): + for el in the_obj: + # check if nested generics, ex: List[List[str]] + arg_type = get_args(the_type)[0] + arg_origin = get_origin(arg_type) + if arg_origin: # processes nested generics, ex: List[List[str]] + if not generics_checker(el, arg_type): + return False + elif not isinstance(el, arg_type): + return False + else: + return False + elif origin_type is dict or origin_type is Dict: + if isinstance(the_obj, dict): + key_type = get_args(the_type)[0] + val_type = get_args(the_type)[1] + for key, val in the_obj.items(): + # check if keys are of right type + if not isinstance(key, key_type): + return False + val_origin = get_origin(val_type) + if val_origin: + if not generics_checker(val, val_type): + return False + elif not isinstance(val, val_type): + return False + else: + return False + elif origin_type is Union: # TODO actually handles Optional Case + if the_obj is None: # check before recursion because None is always fine + return True + optional_type = get_args(the_type)[0] + optional_origin = get_origin(optional_type) + if optional_origin: + return generics_checker(the_obj, optional_type) + elif isinstance(the_obj, optional_type): + return True + else: + return False + elif origin_type is tuple or Tuple: + if isinstance(the_obj, tuple): + arg_types = get_args(the_type) + if len(the_obj) != len(arg_types): + return False + for el, el_type in zip(the_obj, arg_types): + el_origin = get_origin(el_type) + if el_origin: + if not generics_checker(el, el_type): + return False + elif not isinstance(el, el_type): + return False + else: + return False + return True + + +def _isinstance(the_obj, the_type) -> bool: + origin_type = get_origin(the_type) + if origin_type: + return generics_checker(the_obj, the_type) + # handle odd case of non typed optional origin returning as none + if origin_type is None and the_type is Optional: + check_args_exist(the_type) + # handle non-generics + return isinstance(the_obj, the_type)