Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions test/jit/test_schema_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,13 @@ def test_schema_check_mode_functionality_nested_training_op(self):
actual = batch(actual)
self.assertEqual(expected, actual)

# Tests that SchemaCheckMode wraps Torch.tensor with empty list input
def test_schema_check_mode_empty_list_input(self):
expected = torch.atleast_1d([])
with enable_torch_dispatch_mode(SchemaCheckMode()):
actual = torch.atleast_1d([])
self.assertEqual(expected, actual)

# Tests that an exception is raised for a mismatching mutation
def test_mutation_check_fail(self):
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
Expand Down Expand Up @@ -368,15 +375,22 @@ def test_alias_check_fail_outputs_unexpectedly_aliasing(self):
IncorrectAliasTensor(x).aminmax(dim=0)

# Tests that is_alias_of returns as expected
def test_is_alias_of(self):
def test_is_alias_of_basic(self):
x = torch.rand((3, 3), requires_grad=True)
y = torch.rand((3, 3), requires_grad=True)
y = x.add(x, alpha=2)
self.assertTrue(torch._C._is_alias_of(x, x))
self.assertFalse(torch._C._is_alias_of(x, y))

# Tests that is_alias_of returns as expected with empty containers
def test_is_alias_of_empty_container(self):
x = []
y = torch.rand((3, 3), requires_grad=True)
self.assertFalse(torch._C._is_alias_of(x, x))
self.assertFalse(torch._C._is_alias_of(x, y))

# Tests that overlaps returns as expected
def test_overlaps(self):
def test_overlaps_basic(self):
x = torch.rand((3, 3), requires_grad=True)
y = torch.rand((3, 3), requires_grad=True)
z = [x, y]
Expand All @@ -385,8 +399,16 @@ def test_overlaps(self):
self.assertTrue(torch._C._overlaps(z, x))
self.assertTrue(torch._C._overlaps(z, y))

# Tests that overlaps returns correctly with empty containers
def test_overlaps_empty_container(self):
x = []
y = [torch.rand((3, 3), requires_grad=True)]
# Anything overlaps nothing
self.assertTrue(torch._C._overlaps(y, x))
self.assertTrue(torch._C._overlaps(y, y))

# Tests that SchemaInfo Bindings work as expected
def test_schema_info_bind(self):
def test_schema_info_bind_basic(self):
class SchemaInfoBindTestMode(TorchDispatchMode):
def __init__(self, test_self):
self.test_self = test_self
Expand Down
18 changes: 18 additions & 0 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ bool loadPythonClasses() {

return true;
}

bool isEmptyContainer(const py::handle self) {
bool is_empty_list =
PySequence_Check(self.ptr()) && !PySequence_Size(self.ptr());
return is_empty_list;
}
} // anonymous namespace

#if !defined(USE_ROCM)
Expand Down Expand Up @@ -1701,6 +1707,9 @@ void initJITBindings(PyObject* module) {
[](SchemaInfo& self,
const std::string& name,
const py::object& value) {
if (isEmptyContainer(value)) {
return;
}
// For normalization purposes there is an inconsistency within
// torch.fx that turns all arguments named "self" into "input". Thus
// this check ensures that those arguments are checked correctly.
Expand All @@ -1714,6 +1723,9 @@ void initJITBindings(PyObject* module) {
std::unordered_map<std::string, IValue> value_map;
for (const auto& key_pair : values) {
IValue key = toTypeInferredIValue(key_pair.first);
if (isEmptyContainer(key_pair.second)) {
continue;
}
IValue value = toTypeInferredIValue(key_pair.second);
TORCH_INTERNAL_ASSERT(
key.isString(),
Expand Down Expand Up @@ -1898,9 +1910,15 @@ void initJITBindings(PyObject* module) {
}),
py::call_guard<py::gil_scoped_release>());
m.def("_is_alias_of", [](const py::object& self, const py::object& other) {
if (isEmptyContainer(self) || isEmptyContainer(other)) {
return false;
}
return toTypeInferredIValue(self).isAliasOf(toTypeInferredIValue(other));
});
m.def("_overlaps", [](const py::object& self, const py::object& other) {
if (isEmptyContainer(self) || isEmptyContainer(other)) {
return true;
}
return toTypeInferredIValue(self).overlaps(toTypeInferredIValue(other));
});
m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {
Expand Down