Skip to content

Commit

Permalink
[WIP] adding torch.jit.isinstance
Browse files Browse the repository at this point in the history
ghstack-source-id: 79618038d10e321a2c3716c8272041ff9bfb790c
Pull Request resolved: #46062
  • Loading branch information
Lilyjjo committed Oct 13, 2020
1 parent 9fb8e33 commit ad02538
Show file tree
Hide file tree
Showing 6 changed files with 358 additions and 1 deletion.
246 changes: 246 additions & 0 deletions test/jit/test_isinstance.py
@@ -0,0 +1,246 @@
import os
import sys

import torch
from typing import List, Any, Dict, Tuple, Optional

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase

if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)

# Tests for torch.jit.isinstance
class TestIsinstance(JitTestCase):
def test_int(self):
def int_test(x: Any):
assert torch.jit.isinstance(x, int)
assert not torch.jit.isinstance(x, float)

x = 1
self.checkScript(int_test, (x,))

def test_float(self):
def float_test(x: Any):
assert torch.jit.isinstance(x, float)
assert not torch.jit.isinstance(x, int)

x = 1.0
self.checkScript(float_test, (x,))

def test_bool(self):
def bool_test(x: Any):
assert torch.jit.isinstance(x, bool)
assert not torch.jit.isinstance(x, float)

x = False
self.checkScript(bool_test, (x,))

def test_list(self):
def list_str_test(x: Any):
assert torch.jit.isinstance(x, List[str])
assert not torch.jit.isinstance(x, List[int])

x = ["1", "2", "3"]
self.checkScript(list_str_test, (x,))

def test_dict(self):
def dict_str_int_test(x: Any):
assert torch.jit.isinstance(x, Dict[str, int])
assert not torch.jit.isinstance(x, Dict[int, str])

x = {"a": 1, "b": 2}
self.checkScript(dict_str_int_test, (x,))

def test_tuple(self):
def tuple_test(x: Any):
assert torch.jit.isinstance(x, Tuple[str, int, str])
assert not torch.jit.isinstance(x, Tuple[int, str, str])
assert not torch.jit.isinstance(x, Tuple[str])

x = ("a", 1, "b")
self.checkScript(tuple_test, (x,))

def test_optional(self):
def optional_test(x: Any):
assert torch.jit.isinstance(x, Optional[torch.Tensor])
assert not torch.jit.isinstance(x, Optional[str])
# TODO: successful torch.jit.isinstance makes sets type?

x = torch.ones(3, 3)
self.checkScript(optional_test, (x,))

def test_optional_none(self):
def optional_test_none(x: Any):
assert torch.jit.isinstance(x, Optional[torch.Tensor])
# assert not torch.jit.isinstance(x, Optional[str])
# TODO: above line fails in TS interpreter need to investigate

x = None
self.checkScript(optional_test_none, (x,))

def test_list_nested(self):
def list_nested(x: Any):
assert torch.jit.isinstance(x, List[Dict[str, int]])
assert not torch.jit.isinstance(x, List[List[str]])

x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}]
self.checkScript(list_nested, (x,))

def test_dict_nested(self):
def dict_nested(x: Any):
assert torch.jit.isinstance(x, Dict[str, Tuple[str, str, str]])
assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]])

x = {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")}
self.checkScript(dict_nested, (x,))

def test_tuple_nested(self):
def tuple_nested(x: Any):
assert torch.jit.isinstance(
x, Tuple[Dict[str, Tuple[str, str, str]], List[bool], Optional[str]]
)
assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]])
assert not torch.jit.isinstance(x, Tuple[str])

x = (
{"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")},
[True, False, True],
None,
)
self.checkScript(tuple_nested, (x,))

def test_optional_nested(self):
def optional_nested(x: Any):
assert torch.jit.isinstance(x, Optional[List[str]])

x = ["a", "b", "c"]
self.checkScript(optional_nested, (x,))

def test_list_tensor_type_true(self):
def list_tensor_type_true(x: Any):
assert torch.jit.isinstance(x, List[torch.Tensor])

x = [torch.rand(3, 3), torch.rand(4, 3)]
self.checkScript(list_tensor_type_true, (x,))

def test_tensor_type_false(self):
def list_tensor_type_false(x: Any):
assert not torch.jit.isinstance(x, List[torch.Tensor])

x = [1, 2, 3]
self.checkScript(list_tensor_type_false, (x,))

def test_in_if(self):
def list_in_if(x: Any):
if torch.jit.isinstance(x, List[int]):
assert True
if torch.jit.isinstance(x, List[str]):
assert not True

x = [1, 2, 3]
self.checkScript(list_in_if, (x,))

def test_if_else(self):
def list_in_if_else(x: Any):
if torch.jit.isinstance(x, Tuple[str, str, str]):
assert True
else:
assert not True

x = ("a", "b", "c")
self.checkScript(list_in_if_else, (x,))

def test_in_while_loop(self):
def list_in_while_loop(x: Any):
count = 0
while torch.jit.isinstance(x, List[Dict[str, int]]) and count <= 0:
count = count + 1
assert count == 1

x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}]
self.checkScript(list_in_while_loop, (x,))

def test_type_refinement(self):
def type_refinement(obj: Any):
hit = False
if torch.jit.isinstance(obj, List[torch.Tensor]):
hit = not hit
for el in obj:
# perform some tensor operation
y = el.clamp(0, 0.5)
if torch.jit.isinstance(obj, Dict[str, str]):
hit = not hit
str_cat = ""
for val in obj.values():
str_cat = str_cat + val
assert "111222" == str_cat
assert hit

x = [torch.rand(3, 3), torch.rand(4, 3)]
self.checkScript(type_refinement, (x,))
x = {"1": "111", "2": "222"}
self.checkScript(type_refinement, (x,))

def test_list_no_contained_type(self):
def list_no_contained_type(x: Any):
assert torch.jit.isinstance(x, List)

x = ["1", "2", "3"]

with self.assertRaisesRegex(
RuntimeError,
"Attempted to use List without a "
"contained type. Please add a contained type, e.g. "
r"List\[int\]",
):
torch.jit.script(list_no_contained_type)


def test_tuple_no_contained_type(self):
def tuple_no_contained_type(x: Any):
assert torch.jit.isinstance(x, Tuple)

x = ("1", "2", "3")

with self.assertRaisesRegex(
RuntimeError,
"Attempted to use Tuple without a "
"contained type. Please add a contained type, e.g. "
r"Tuple\[int\]"
):
torch.jit.script(tuple_no_contained_type)

def test_optional_no_contained_type(self):
def optional_no_contained_type(x: Any):
assert torch.jit.isinstance(x, Optional)

x = ("1", "2", "3")

with self.assertRaisesRegex(
RuntimeError,
"Attempted to use Optional without a "
"contained type. Please add a contained type, e.g. "
r"Optional\[int\]",
):
torch.jit.script(optional_no_contained_type)

def test_dict_no_contained_type(self):
def dict_no_contained_type(x: Any):
assert torch.jit.isinstance(x, Dict)

x = {"a": "aa"}

with self.assertRaisesRegex(
RuntimeError,
"Attempted to use Dict without "
"contained types. Please add contained type, e.g. "
r"Dict\[int, int\]",
):
torch.jit.script(dict_no_contained_type)
1 change: 1 addition & 0 deletions test/test_jit.py
Expand Up @@ -33,6 +33,7 @@
from jit.test_profiler import TestProfiler # noqa: F401
from jit.test_slice import TestSlice # noqa: F401
from jit.test_warn import TestWarn # noqa: F401
from jit.test_isinstance import TestIsinstance # noqa: F401

# Torch
from torch import Tensor
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/frontend/ir_emitter.cpp
Expand Up @@ -1196,6 +1196,14 @@ struct to_ir {
return emitHasAttr(apply.inputs()[0], apply.inputs()[1]);
}
}
auto sv = emitSugaredExpr(apply.callee(), 1);
auto loc = apply.callee().range();
if (auto special_form = dynamic_cast<SpecialFormValue*>(sv.get())) {
if (special_form->form() == prim::isinstance) {
checkApplyNumInputs(apply, 2);
return emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
}
}
}
auto expr_out = emitToBool(expr.range(), emitExpr(expr));
c10::optional<bool> static_if = c10::nullopt;
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/python/python_sugared_value.cpp
Expand Up @@ -918,6 +918,9 @@ std::shared_ptr<SugaredValue> toSugaredValue(
} else if (
obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) {
return SpecialFormValue::create(prim::annotate);
} else if (
obj.ptr() == py::module::import("torch.jit").attr("isinstance").ptr()) {
return SpecialFormValue::create(prim::isinstance);
#ifdef USE_RPC
// RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on.
} else if (
Expand Down
5 changes: 5 additions & 0 deletions torch/jit/__init__.py
Expand Up @@ -21,6 +21,7 @@
RecursiveScriptModule,
ScriptWarning,
interface,
_isinstance,
CompilationUnit,
ScriptFunction,
_unwrap_optional,
Expand Down Expand Up @@ -70,5 +71,9 @@ def annotate(the_type, the_value):
return the_value


# for torch.jit.isinstance
isinstance = _isinstance


if not torch._C._jit_init():
raise RuntimeError("JIT initialization failed")

0 comments on commit ad02538

Please sign in to comment.