diff --git a/test/test_operations.py b/test/test_operations.py index 00deac6320f5..f24fc798e92d 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -874,8 +874,8 @@ def checkGrad(self, ############################################################## # Run forward and backwarg graphs via jit interpreter - exec_f = torch._C.GraphExecutor(gradient.f, False) - exec_df = torch._C.GraphExecutor(gradient.df, False) + exec_f = torch_xla._XLAC.GraphExecutor(gradient.f, False) + exec_df = torch_xla._XLAC.GraphExecutor(gradient.df, False) # forward function raw_outputs = exec_f(*inputs_params_buffers) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4f8d68adaacc..02b53f270054 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -12,6 +12,9 @@ #include "tensorflow/compiler/xla/xla_client/util.h" #include "torch/csrc/autograd/utils/wrap_outputs.h" #include "torch/csrc/autograd/variable.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" +#include "torch/csrc/jit/passes/lower_tuples.h" +#include "torch/csrc/jit/python_tracer.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/aten_xla_type.h" #include "torch_xla/csrc/device.h" @@ -543,12 +546,91 @@ void InitXlaTensorBindings(py::module m) { py::arg("count_include_pad") = true); } +// FIXME: copied from torch/csrc/jit/python_tracer.cpp since it's only available +// in libtorch_python.so. +std::shared_ptr CreateGraphByTracing( + const py::function& func, torch::jit::Stack trace_inputs, + const py::function& var_name_lookup_fn, bool force_outplace, + const c10::optional& num_real_inputs = c10::nullopt) { + size_t num_func_inputs = num_real_inputs.value_or(trace_inputs.size()); + auto enter_info = torch::jit::tracer::enter(std::move(trace_inputs)); + torch::jit::tracer::getTracingState()->lookup_var_name_fn = + [var_name_lookup_fn]( + const torch::autograd::Variable& var) -> std::string { + AutoGIL ag; + return py::cast(var_name_lookup_fn(var)); + }; + torch::jit::tracer::getTracingState()->force_outplace = force_outplace; + try { + py::tuple py_inputs(num_func_inputs); + for (size_t i = 0; i < num_func_inputs; ++i) { + py_inputs[i] = py::cast(enter_info.second[i]); + } + auto out = func(*py_inputs); + if (out.ptr() == Py_None) { + AT_ERROR( + "The traced function didn't return any values! Side-effects are not " + "captured in traces, so it would be a no-op."); + } + torch::jit::tracer::exit({torch::jit::toIValue(out)}); + auto graph = enter_info.first->graph; + EliminateDeadCode(graph); + LowerSimpleTuples(graph); + + return graph; + } catch (...) { + torch::jit::tracer::abandon(); + throw; + } +} + +void InitGraphExecutorBindings(py::module m) { + py::class_(m, "GraphExecutor", py::dynamic_attr()) + .def(py::init([](py::function func, py::tuple inputs, + py::function var_name_lookup_fn, bool optimize, + bool _force_outplace) { + auto graph = + CreateGraphByTracing(func, torch::jit::toStack(inputs), + var_name_lookup_fn, _force_outplace); + return torch::jit::GraphExecutor(graph, optimize); + }), + py::arg("func"), py::arg("inputs"), py::arg("var_name_lookup_fn"), + py::arg("optimize") = true, py::arg("_force_outplace") = false) + .def( + py::init([](std::shared_ptr graph, bool optimize) { + return torch::jit::GraphExecutor(std::move(graph), optimize); + }), + py::arg("graph"), py::arg("optimize") = true) + .def( + "graph_for", + [](torch::jit::GraphExecutor& ge, py::args args) { + return ge.graphFor(torch::jit::evilDeprecatedBadCreateStackDoNotUse( + args, ge.graph()->inputs())); + }) + .def_property_readonly( + "graph", [](torch::jit::GraphExecutor& ge) { return ge.graph(); }) + .def("get_debug_state", + [](torch::jit::GraphExecutor& ge) { return ge.getDebugState(); }) + .def("__call__", + [](torch::jit::GraphExecutor& ge, py::args args) -> py::object { + const auto& graph = ge.graph(); + auto stack = torch::jit::evilDeprecatedBadCreateStackDoNotUse( + args, graph->inputs()); + { + AutoNoGIL no_gil_guard; + ge.run(stack); + } + return torch::jit::createPyObjectForStack(std::move(stack)); + }); +} + } // namespace void InitXlaBindings(py::module m) { InitXlaModuleBindings(m); InitXlaPassesBindings(m); InitXlaTensorBindings(m); + InitGraphExecutorBindings(m); } } // namespace torch_xla