diff --git a/test/mobile/test_lite_script_module.py b/test/mobile/test_lite_script_module.py index 033ab6a15011..cd651281b261 100644 --- a/test/mobile/test_lite_script_module.py +++ b/test/mobile/test_lite_script_module.py @@ -5,7 +5,7 @@ from typing import NamedTuple from collections import namedtuple -from torch.jit.mobile import _load_for_lite_interpreter +from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list from torch.testing._internal.common_utils import TestCase, run_tests class TestLiteScriptModule(TestCase): @@ -286,5 +286,47 @@ def forward(self): r"use a combination of list\, dictionary\, and single types\.$"): script_module._save_to_buffer_for_lite_interpreter() + def test_module_export_operator_list(self): + class Foo(torch.nn.Module): + def __init__(self): + super(Foo, self).__init__() + self.weight = torch.ones((20, 1, 5, 5)) + self.bias = torch.ones(20) + + def forward(self, input): + x1 = torch.zeros(2, 2) + x2 = torch.empty_like(torch.empty(2, 2)) + x3 = torch._convolution( + input, + self.weight, + self.bias, + [1, 1], + [0, 0], + [1, 1], + False, + [0, 0], + 1, + False, + False, + True, + True, + ) + return (x1, x2, x3) + + m = torch.jit.script(Foo()) + + buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter()) + buffer.seek(0) + mobile_module = _load_for_lite_interpreter(buffer) + + expected_ops = { + "aten::_convolution", + "aten::empty.memory_format", + "aten::empty_like", + "aten::zeros", + } + actual_ops = _export_operator_list(mobile_module) + self.assertEqual(actual_ops, expected_ops) + if __name__ == '__main__': run_tests() diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 356c91f2a03b..7cf53b6c6e11 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -639,6 +639,14 @@ py::list debugMakeNamedList(const T& list) { } return result; } +template +py::set debugMakeSet(const T& list) { + py::set result; + for (const auto& elem : list) { + result.add(py::cast(elem)); + } + return result; +} static py::dict _jit_debug_module_iterators(Module& module) { py::dict result; @@ -1544,6 +1552,9 @@ void initJitScriptBindings(PyObject* module) { } return _load_for_mobile(in, optional_device); }); + m.def("_export_operator_list", [](torch::jit::mobile::Module& sm) { + return debugMakeSet(torch::jit::mobile::_export_operator_list(sm)); + }); m.def("_jit_set_emit_hooks", setEmitHooks); m.def("_jit_get_emit_hooks", getEmitHooks); diff --git a/torch/jit/mobile/__init__.py b/torch/jit/mobile/__init__.py index 4356400cb447..a3b1aa1e4e64 100644 --- a/torch/jit/mobile/__init__.py +++ b/torch/jit/mobile/__init__.py @@ -51,7 +51,6 @@ def _load_for_lite_interpreter(f, map_location=None): return LiteScriptModule(cpp_module) - class LiteScriptModule(object): def __init__(self, cpp_module): self._c = cpp_module @@ -68,3 +67,11 @@ def forward(self, *input): def run_method(self, method_name, *input): return self._c.run_method(method_name, input) + +def _export_operator_list(module: LiteScriptModule): + r""" + return a set of root operator names (with overload name) that are used by any method + in this mobile module. + """ + # TODO fix mypy here + return torch._C._export_operator_list(module._c) # type: ignore