diff --git a/test/jit/test_python_bindings.py b/test/jit/test_python_bindings.py new file mode 100644 index 000000000000..090efad55edd --- /dev/null +++ b/test/jit/test_python_bindings.py @@ -0,0 +1,37 @@ +import torch +from torch.testing._internal.jit_utils import JitTestCase + +if __name__ == "__main__": + raise RuntimeError( + "This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TestPythonBindings\n\n" + "instead." + ) + + +class TestPythonBindings(JitTestCase): + def test_cu_get_functions(self): + @torch.jit.script + def test_get_python_cu_fn(x: torch.Tensor): + return 2 * x + + cu = torch.jit._state._python_cu + self.assertTrue( + "test_get_python_cu_fn" in (str(fn.name) for fn in cu.get_functions()) + ) + + def test_cu_create_function(self): + @torch.jit.script + def fn(x: torch.Tensor): + return 2 * x + + cu = torch._C.CompilationUnit() + cu.create_function("test_fn", fn.graph) + + inp = torch.randn(5) + + self.assertEqual(inp * 2, cu.find_function("test_fn")(inp)) + self.assertEqual(cu.find_function("doesnt_exist"), None) + self.assertEqual(inp * 2, cu.test_fn(inp)) + with self.assertRaises(AttributeError): + cu.doesnt_exist(inp) diff --git a/test/test_jit.py b/test/test_jit.py index b0924dd09148..4d37cd0a3ef9 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -23,6 +23,7 @@ from jit.test_peephole import TestPeephole # noqa: F401 from jit.test_save_load import TestSaveLoad # noqa: F401 from jit.test_module_containers import TestModuleContainers # noqa: F401 +from jit.test_python_bindings import TestPythonBindings # noqa: F401 from jit.test_python_ir import TestPythonIr # noqa: F401 from jit.test_functional_blocks import TestFunctionalBlocks # noqa: F401 from jit.test_remove_mutation import TestRemoveMutation # noqa: F401 diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 1f4d7a070d53..6c8c68ee7c70 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -415,10 +415,13 @@ class ErrorReport: def call_stack() -> str: ... class CompilationUnit: - def __init__(self) -> None: ... + def __init__(self, lang: str=..., _frames_up: _int=...) -> None: ... def find_function(self, name: str) -> ScriptFunction: ... - def define(self, script: str, rcb: ResolutionCallback): ... + def __getattr__(self, name: str) -> ScriptFunction: ... + def define(self, script: str, rcb: ResolutionCallback=..., _frames_up: _int=...): ... def get_interface(self, name: str) -> InterfaceType: ... + def get_functions(self) -> List[ScriptFunction]: ... + def create_function(self, name: str, graph: Graph, shouldMangle: _bool=...) -> ScriptFunction: ... class ScriptModule: def setattr(self, name: str, value: Any): ... @@ -429,6 +432,7 @@ class ScriptFunction: def __call__(self, *args, **kwargs) -> Tensor: ... def save(self, filename: str, _extra_files: Dict[str, bytes]) -> None: ... def save_to_buffer(self, _extra_files: Dict[str, bytes]) -> bytes: ... + @property def graph(self) -> Graph: ... def inlined_graph(self) -> Graph: ... def schema(self) -> FunctionSchema: ... diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index f7dd766d5da7..e29322b9a6b1 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -710,6 +710,22 @@ void extra_files_to_python(const ExtraFilesMap& m, const py::dict& pydict) { } } +void pyCompilationUnitDefine( + CompilationUnit& cu, + const std::string& src, + const ResolutionCallback* rcb, + const uint32_t _frames_up) { + if (rcb && *rcb) { + cu.define(c10::nullopt, src, pythonResolver(*rcb), nullptr); + } else { + py::object py_default_rcb = + py::module::import("torch._jit_internal") + .attr("createResolutionCallbackFromFrame")(_frames_up); + auto default_rcb = py_default_rcb.cast(); + cu.define(c10::nullopt, src, pythonResolver(default_rcb), nullptr); + } +} + void initJitScriptBindings(PyObject* module) { auto m = py::handle(module).cast(); @@ -1114,21 +1130,72 @@ void initJitScriptBindings(PyObject* module) { py::class_>( m, "CompilationUnit") - .def(py::init<>()) + .def( + py::init([](const std::string& lang, const uint32_t _frames_up) { + auto cu = std::make_shared(); + if (lang.size() > 0) { + pyCompilationUnitDefine(*cu, lang, nullptr, _frames_up); + } + return cu; + }), + py::arg("lang") = "", + py::arg("_frames_up") = 0) + .def( "find_function", [](std::shared_ptr self, const std::string& name) { - auto& fn = self->get_function(QualifiedName(name)); - return StrongFunctionPtr(std::move(self), &fn); + auto fn = self->find_function(QualifiedName(name)); + if (fn) { + return c10::optional( + StrongFunctionPtr(std::move(self), fn)); + } else { + return c10::optional(c10::nullopt); + } + }) + .def( + "__getattr__", + [](std::shared_ptr self, const std::string& name) { + auto fn = self->find_function(QualifiedName(name)); + if (fn) { + return StrongFunctionPtr(std::move(self), fn); + } else { + throw AttributeError( + "'CompilationUnit' has no attribute '%s'", name.c_str()); + } + }) + .def( + "get_functions", + [](const std::shared_ptr& self) { + auto raw_functions = self->get_functions(); + std::vector functions; + functions.reserve(raw_functions.size()); + for (auto fn : raw_functions) { + if (fn) { + functions.emplace_back(self, fn); + } + } + return functions; }) .def("set_optimized", &CompilationUnit::set_optimized) .def( "define", - [](CompilationUnit& cu, - const std::string& src, - const ResolutionCallback& rcb) { - cu.define(c10::nullopt, src, pythonResolver(rcb), nullptr); - }) + pyCompilationUnitDefine, + py::arg("src"), + py::arg("rcb") = nullptr, + py::arg("_frames_up") = 0) + .def( + "create_function", + [](std::shared_ptr& self, + const std::string& qualified_name, + std::shared_ptr graph, + bool should_mangle) { + Function* fn = self->create_function( + qualified_name, std::move(graph), should_mangle); + return StrongFunctionPtr(std::move(self), fn); + }, + py::arg("qualified_name"), + py::arg("graph"), + py::arg("should_mangle") = false) .def( "get_interface", [](const std::shared_ptr& self, diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 57b83241fa26..f9d6c33192f2 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -1095,24 +1095,8 @@ def _recursive_compile_class(obj, loc): rcb = _jit_internal.createResolutionCallbackForClassMethods(obj) _compile_and_register_class(obj, rcb, _qual_name) - -class CompilationUnit(object): - def __init__(self, lang=None, _frames_up=0): - self._c = torch._C.CompilationUnit() - if lang is not None: - self.define(lang, _frames_up=_frames_up + 1) - - def define(self, lang, rcb=None, _frames_up=0): - if not rcb: - rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) - self._c.define(lang, rcb) - - def __getattr__(self, attr): - r = self._c.find_function(attr) - if r is None: - raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr)) - return r - +CompilationUnit = torch._C.CompilationUnit +set_module(CompilationUnit, "torch.jit") def _unwrap_optional(x): assert x is not None, "Unwrapping null optional"