Skip to content

Commit

Permalink
[Pytorch Mobile] Expose _export_operator_list to python
Browse files Browse the repository at this point in the history
Follow up to D24690094 exposing the api in python. Created matching unit test.

Differential Revision: [D26112765](https://our.internmc.facebook.com/intern/diff/D26112765/)

ghstack-source-id: 120611452
Pull Request resolved: #51312
  • Loading branch information
jakeszwe committed Jan 28, 2021
1 parent 9f6e0de commit 449b147
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
44 changes: 43 additions & 1 deletion test/mobile/test_lite_script_module.py
Expand Up @@ -5,7 +5,7 @@
from typing import NamedTuple
from collections import namedtuple

from torch.jit.mobile import _load_for_lite_interpreter
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
from torch.testing._internal.common_utils import TestCase, run_tests

class TestLiteScriptModule(TestCase):
Expand Down Expand Up @@ -286,5 +286,47 @@ def forward(self):
r"use a combination of list\, dictionary\, and single types\.$"):
script_module._save_to_buffer_for_lite_interpreter()

def test_module_export_operator_list(self):
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.weight = torch.ones((20, 1, 5, 5))
self.bias = torch.ones(20)

def forward(self, input):
x1 = torch.zeros(2, 2)
x2 = torch.empty_like(torch.empty(2, 2))
x3 = torch._convolution(
input,
self.weight,
self.bias,
[1, 1],
[0, 0],
[1, 1],
False,
[0, 0],
1,
False,
False,
True,
True,
)
return (x1, x2, x3)

m = torch.jit.script(Foo())

buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)

expected_ops = {
"aten::_convolution",
"aten::empty.memory_format",
"aten::empty_like",
"aten::zeros",
}
actual_ops = _export_operator_list(mobile_module)
self.assertEqual(actual_ops, expected_ops)

if __name__ == '__main__':
run_tests()
11 changes: 11 additions & 0 deletions torch/csrc/jit/python/script_init.cpp
Expand Up @@ -639,6 +639,14 @@ py::list debugMakeNamedList(const T& list) {
}
return result;
}
template <typename T>
py::set debugMakeSet(const T& list) {
py::set result;
for (const auto& elem : list) {
result.add(py::cast(elem));
}
return result;
}

static py::dict _jit_debug_module_iterators(Module& module) {
py::dict result;
Expand Down Expand Up @@ -1544,6 +1552,9 @@ void initJitScriptBindings(PyObject* module) {
}
return _load_for_mobile(in, optional_device);
});
m.def("_export_operator_list", [](torch::jit::mobile::Module& sm) {
return debugMakeSet(torch::jit::mobile::_export_operator_list(sm));
});

m.def("_jit_set_emit_hooks", setEmitHooks);
m.def("_jit_get_emit_hooks", getEmitHooks);
Expand Down
9 changes: 8 additions & 1 deletion torch/jit/mobile/__init__.py
Expand Up @@ -51,7 +51,6 @@ def _load_for_lite_interpreter(f, map_location=None):

return LiteScriptModule(cpp_module)


class LiteScriptModule(object):
def __init__(self, cpp_module):
self._c = cpp_module
Expand All @@ -68,3 +67,11 @@ def forward(self, *input):

def run_method(self, method_name, *input):
return self._c.run_method(method_name, input)

def _export_operator_list(module: LiteScriptModule):
r"""
return a set of root operator names (with overload name) that are used by any method
in this mobile module.
"""
# TODO fix mypy here
return torch._C._export_operator_list(module._c) # type: ignore

0 comments on commit 449b147

Please sign in to comment.