Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jit] Fix recusive method compilation #21862

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 2 additions & 8 deletions test/jit_utils.py
Expand Up @@ -53,12 +53,6 @@ def tearDown(self):
self.clearHooks()
torch._C._jit_clear_class_registry()

@contextmanager
def disableEmitHook(self):
self.clearHooks()
yield None
self.setHooks()

def _isHookExceptionOk(self, e):
se = str(e)
allowed = ("Could not export Python function",
Expand All @@ -73,7 +67,7 @@ def emitFunctionHook(self, func):
if func.name == "<lambda>" or "aten::" in func.name or not _inline_everything:
return
# disable the hook while we parse code, otherwise we will re-enter the hook
with self.disableEmitHook():
with torch.jit._disable_emit_hooks():
try:
src, constants = _jit_python_print(func)
cu = torch.jit.CompilationUnit()._import(src, constants)
Expand All @@ -98,7 +92,7 @@ def copy_structure_and_params(m):
return c

# disable the hook while we parse code, otherwise we will re-enter the hook
with self.disableEmitHook():
with torch.jit._disable_emit_hooks():
try:
if len(module.code) == 0:
# short-circuit if this is an empty module
Expand Down
22 changes: 16 additions & 6 deletions test/test_jit.py
Expand Up @@ -3176,7 +3176,7 @@ def test_annoying_doubles(self):
mod.ninf = float("-inf")
mod.nan = float("nan")

with self.disableEmitHook():
with torch.jit._disable_emit_hooks():
@torch.jit.script
def foo():
return math.pi, 0.1, mod.inf, mod.ninf, 2.225073858507201e-308, mod.nan
Expand Down Expand Up @@ -11966,7 +11966,7 @@ def foo(a, b):
self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))

def test_bool_dispatch(self):
with self.disableEmitHook(): # TODO: Python print broadcasting list
with torch.jit._disable_emit_hooks(): # TODO: Python print broadcasting list
def kwarg_false(x):
# type: (Tensor) -> Tensor
return F.max_pool1d(x, 1, 1, return_indices=False)
Expand Down Expand Up @@ -12843,7 +12843,7 @@ def forward(self, key):
# type: (str) -> Tensor
return self.table[key] + self.x

with self.disableEmitHook():
with torch.jit._disable_emit_hooks():
# TODO: re-enable module hook when Python printing of attributes is
# supported
m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
Expand Down Expand Up @@ -13299,6 +13299,16 @@ def forward(self, t):

self.checkModule(M(), (torch.randn(2, 2),))

def test_method_call(self):
class M(nn.Module):
def test(self, x):
return x

def forward(self, z):
y = self.test(z)
return z + 20 + y

self.checkModule(M(), (torch.randn(2, 2),))

def test_script_basic(self):
def a_python_fn(a, b, c):
Expand Down Expand Up @@ -13857,7 +13867,7 @@ def forward(self, input):
return self.seq.forward(input)

# disabled due to a jitter issues that will be fixed by using load/store in the compiler
with self.disableEmitHook():
with torch.jit._disable_emit_hooks():
# TODO: toggle export_import once above issues are fixed
self.checkTrace(Traced(), (torch.rand(3, 4),),
export_import=False)
Expand Down Expand Up @@ -14988,7 +14998,7 @@ def run_test():
self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)

if test_name in EXCLUDE_PYTHON_PRINT:
with self.disableEmitHook():
with torch.jit._disable_emit_hooks():
run_test()
else:
run_test()
Expand Down Expand Up @@ -15069,7 +15079,7 @@ def make_module(script):

# module cannot be imported / exported
if module_name in EXCLUDE_MODULE_EXPORT_IMPORT:
with self.disableEmitHook():
with torch.jit._disable_emit_hooks():
module = make_module(script)
create_script_module.last_graph = module.graph
mod = module(*args)
Expand Down
13 changes: 7 additions & 6 deletions torch/csrc/jit/hooks_for_testing.cpp
Expand Up @@ -4,25 +4,26 @@
namespace torch {
namespace jit {

static std::function<void(std::shared_ptr<script::Module> module)>
emit_module_callback;
static ModuleHook emit_module_callback;
void didFinishEmitModule(std::shared_ptr<script::Module> module) {
if (emit_module_callback) {
emit_module_callback(std::move(module));
}
}
static std::function<void(std::shared_ptr<Function> fn)> emit_function_callback;
static FunctionHook emit_function_callback;
void didFinishEmitFunction(std::shared_ptr<Function> fn) {
if (emit_function_callback) {
emit_function_callback(fn);
}
}
void setEmitHooks(
std::function<void(std::shared_ptr<script::Module> module)> for_mod,
std::function<void(std::shared_ptr<Function> for_fn)> for_fn) {
void setEmitHooks(ModuleHook for_mod, FunctionHook for_fn) {
emit_module_callback = std::move(for_mod);
emit_function_callback = std::move(for_fn);
}

std::pair<ModuleHook, FunctionHook> getEmitHooks() {
return std::make_pair(emit_module_callback, emit_function_callback);
}

} // namespace jit
} // namespace torch
9 changes: 7 additions & 2 deletions torch/csrc/jit/hooks_for_testing.h
Expand Up @@ -9,10 +9,15 @@ struct Function;
namespace script {
struct Module;
}

using ModuleHook = std::function<void(std::shared_ptr<script::Module> module)>;
using FunctionHook = std::function<void(std::shared_ptr<Function> function)>;

TORCH_API void didFinishEmitModule(std::shared_ptr<script::Module> module);
TORCH_API void didFinishEmitFunction(std::shared_ptr<Function> defined);
TORCH_API void setEmitHooks(
std::function<void(std::shared_ptr<script::Module> module)> for_module,
std::function<void(std::shared_ptr<Function> fn)> for_fn);
ModuleHook for_module,
FunctionHook for_fn);
driazati marked this conversation as resolved.
Show resolved Hide resolved
TORCH_API std::pair<ModuleHook, FunctionHook> getEmitHooks();
} // namespace jit
} // namespace torch
1 change: 1 addition & 0 deletions torch/csrc/jit/script/init.cpp
Expand Up @@ -725,6 +725,7 @@ void initJitScriptBindings(PyObject* module) {
});

m.def("_jit_set_emit_hooks", setEmitHooks);
m.def("_jit_get_emit_hooks", getEmitHooks);
m.def("_jit_clear_class_registry", CompilationUnit::_clear_python_cu);
m.def(
"_debug_set_autodiff_subgraph_inlining",
Expand Down
14 changes: 13 additions & 1 deletion torch/jit/__init__.py
Expand Up @@ -1007,11 +1007,23 @@ def _try_compile_fn(fn):
return torch.jit.script(fn, _rcb=rcb)


@contextlib.contextmanager
def _disable_emit_hooks():
driazati marked this conversation as resolved.
Show resolved Hide resolved
hooks = torch._C._jit_get_emit_hooks()
torch._C._jit_set_emit_hooks(None, None)
yield
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])


def _create_method_from_fn(module, fn):
if _jit_internal.is_ignored_fn(fn):
return None
stub = script_method(fn, createResolutionCallbackFromClosure(fn))
_create_methods_from_stubs(self, (stub,))
with _disable_emit_hooks():
# We don't want to call the hooks here since the graph that is calling
# this function is not yet complete
_create_methods_from_stubs(module, (stub,))
return stub


# ScriptClasses must be new-style classes because we construct them using their
Expand Down