Skip to content

Commit

Permalink
Merge CompilationUnit from torch._C and torch.jit (#50614)
Browse files Browse the repository at this point in the history
Summary:
This simplifies our handling and allows passing CompilationUnits from Python to C++ defined functions via PyBind easily.

Discussed on Slack with SplitInfinity

Pull Request resolved: #50614

Reviewed By: anjali411

Differential Revision: D25938005

Pulled By: SplitInfinity

fbshipit-source-id: 94aadf0c063ddfef7ca9ea17bfa998d8e7b367ad
  • Loading branch information
t-vi authored and facebook-github-bot committed Jan 25, 2021
1 parent 5e79b8e commit ac0a3cc
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 28 deletions.
37 changes: 37 additions & 0 deletions test/jit/test_python_bindings.py
@@ -0,0 +1,37 @@
import torch
from torch.testing._internal.jit_utils import JitTestCase

if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TestPythonBindings\n\n"
"instead."
)


class TestPythonBindings(JitTestCase):
def test_cu_get_functions(self):
@torch.jit.script
def test_get_python_cu_fn(x: torch.Tensor):
return 2 * x

cu = torch.jit._state._python_cu
self.assertTrue(
"test_get_python_cu_fn" in (str(fn.name) for fn in cu.get_functions())
)

def test_cu_create_function(self):
@torch.jit.script
def fn(x: torch.Tensor):
return 2 * x

cu = torch._C.CompilationUnit()
cu.create_function("test_fn", fn.graph)

inp = torch.randn(5)

self.assertEqual(inp * 2, cu.find_function("test_fn")(inp))
self.assertEqual(cu.find_function("doesnt_exist"), None)
self.assertEqual(inp * 2, cu.test_fn(inp))
with self.assertRaises(AttributeError):
cu.doesnt_exist(inp)
1 change: 1 addition & 0 deletions test/test_jit.py
Expand Up @@ -23,6 +23,7 @@
from jit.test_peephole import TestPeephole # noqa: F401
from jit.test_save_load import TestSaveLoad # noqa: F401
from jit.test_module_containers import TestModuleContainers # noqa: F401
from jit.test_python_bindings import TestPythonBindings # noqa: F401
from jit.test_python_ir import TestPythonIr # noqa: F401
from jit.test_functional_blocks import TestFunctionalBlocks # noqa: F401
from jit.test_remove_mutation import TestRemoveMutation # noqa: F401
Expand Down
8 changes: 6 additions & 2 deletions torch/_C/__init__.pyi.in
Expand Up @@ -415,10 +415,13 @@ class ErrorReport:
def call_stack() -> str: ...

class CompilationUnit:
def __init__(self) -> None: ...
def __init__(self, lang: str=..., _frames_up: _int=...) -> None: ...
def find_function(self, name: str) -> ScriptFunction: ...
def define(self, script: str, rcb: ResolutionCallback): ...
def __getattr__(self, name: str) -> ScriptFunction: ...
def define(self, script: str, rcb: ResolutionCallback=..., _frames_up: _int=...): ...
def get_interface(self, name: str) -> InterfaceType: ...
def get_functions(self) -> List[ScriptFunction]: ...
def create_function(self, name: str, graph: Graph, shouldMangle: _bool=...) -> ScriptFunction: ...

class ScriptModule:
def setattr(self, name: str, value: Any): ...
Expand All @@ -429,6 +432,7 @@ class ScriptFunction:
def __call__(self, *args, **kwargs) -> Tensor: ...
def save(self, filename: str, _extra_files: Dict[str, bytes]) -> None: ...
def save_to_buffer(self, _extra_files: Dict[str, bytes]) -> bytes: ...
@property
def graph(self) -> Graph: ...
def inlined_graph(self) -> Graph: ...
def schema(self) -> FunctionSchema: ...
Expand Down
83 changes: 75 additions & 8 deletions torch/csrc/jit/python/script_init.cpp
Expand Up @@ -710,6 +710,22 @@ void extra_files_to_python(const ExtraFilesMap& m, const py::dict& pydict) {
}
}

void pyCompilationUnitDefine(
CompilationUnit& cu,
const std::string& src,
const ResolutionCallback* rcb,
const uint32_t _frames_up) {
if (rcb && *rcb) {
cu.define(c10::nullopt, src, pythonResolver(*rcb), nullptr);
} else {
py::object py_default_rcb =
py::module::import("torch._jit_internal")
.attr("createResolutionCallbackFromFrame")(_frames_up);
auto default_rcb = py_default_rcb.cast<ResolutionCallback>();
cu.define(c10::nullopt, src, pythonResolver(default_rcb), nullptr);
}
}

void initJitScriptBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();

Expand Down Expand Up @@ -1114,21 +1130,72 @@ void initJitScriptBindings(PyObject* module) {

py::class_<CompilationUnit, std::shared_ptr<CompilationUnit>>(
m, "CompilationUnit")
.def(py::init<>())
.def(
py::init([](const std::string& lang, const uint32_t _frames_up) {
auto cu = std::make_shared<CompilationUnit>();
if (lang.size() > 0) {
pyCompilationUnitDefine(*cu, lang, nullptr, _frames_up);
}
return cu;
}),
py::arg("lang") = "",
py::arg("_frames_up") = 0)

.def(
"find_function",
[](std::shared_ptr<CompilationUnit> self, const std::string& name) {
auto& fn = self->get_function(QualifiedName(name));
return StrongFunctionPtr(std::move(self), &fn);
auto fn = self->find_function(QualifiedName(name));
if (fn) {
return c10::optional<StrongFunctionPtr>(
StrongFunctionPtr(std::move(self), fn));
} else {
return c10::optional<StrongFunctionPtr>(c10::nullopt);
}
})
.def(
"__getattr__",
[](std::shared_ptr<CompilationUnit> self, const std::string& name) {
auto fn = self->find_function(QualifiedName(name));
if (fn) {
return StrongFunctionPtr(std::move(self), fn);
} else {
throw AttributeError(
"'CompilationUnit' has no attribute '%s'", name.c_str());
}
})
.def(
"get_functions",
[](const std::shared_ptr<CompilationUnit>& self) {
auto raw_functions = self->get_functions();
std::vector<StrongFunctionPtr> functions;
functions.reserve(raw_functions.size());
for (auto fn : raw_functions) {
if (fn) {
functions.emplace_back(self, fn);
}
}
return functions;
})
.def("set_optimized", &CompilationUnit::set_optimized)
.def(
"define",
[](CompilationUnit& cu,
const std::string& src,
const ResolutionCallback& rcb) {
cu.define(c10::nullopt, src, pythonResolver(rcb), nullptr);
})
pyCompilationUnitDefine,
py::arg("src"),
py::arg("rcb") = nullptr,
py::arg("_frames_up") = 0)
.def(
"create_function",
[](std::shared_ptr<CompilationUnit>& self,
const std::string& qualified_name,
std::shared_ptr<Graph> graph,
bool should_mangle) {
Function* fn = self->create_function(
qualified_name, std::move(graph), should_mangle);
return StrongFunctionPtr(std::move(self), fn);
},
py::arg("qualified_name"),
py::arg("graph"),
py::arg("should_mangle") = false)
.def(
"get_interface",
[](const std::shared_ptr<CompilationUnit>& self,
Expand Down
20 changes: 2 additions & 18 deletions torch/jit/_script.py
Expand Up @@ -1095,24 +1095,8 @@ def _recursive_compile_class(obj, loc):
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
_compile_and_register_class(obj, rcb, _qual_name)


class CompilationUnit(object):
def __init__(self, lang=None, _frames_up=0):
self._c = torch._C.CompilationUnit()
if lang is not None:
self.define(lang, _frames_up=_frames_up + 1)

def define(self, lang, rcb=None, _frames_up=0):
if not rcb:
rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
self._c.define(lang, rcb)

def __getattr__(self, attr):
r = self._c.find_function(attr)
if r is None:
raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
return r

CompilationUnit = torch._C.CompilationUnit
set_module(CompilationUnit, "torch.jit")

def _unwrap_optional(x):
assert x is not None, "Unwrapping null optional"
Expand Down

0 comments on commit ac0a3cc

Please sign in to comment.