Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 0 additions & 82 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/lower_tuples.h"
#include "torch/csrc/jit/python_tracer.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/aten_xla_type.h"
#include "torch_xla/csrc/device.h"
Expand Down Expand Up @@ -546,91 +543,12 @@ void InitXlaTensorBindings(py::module m) {
py::arg("count_include_pad") = true);
}

// FIXME: copied from torch/csrc/jit/python_tracer.cpp since it's only available
// in libtorch_python.so.
std::shared_ptr<torch::jit::Graph> CreateGraphByTracing(
const py::function& func, torch::jit::Stack trace_inputs,
const py::function& var_name_lookup_fn, bool force_outplace,
const c10::optional<size_t>& num_real_inputs = c10::nullopt) {
size_t num_func_inputs = num_real_inputs.value_or(trace_inputs.size());
auto enter_info = torch::jit::tracer::enter(std::move(trace_inputs));
torch::jit::tracer::getTracingState()->lookup_var_name_fn =
[var_name_lookup_fn](
const torch::autograd::Variable& var) -> std::string {
AutoGIL ag;
return py::cast<std::string>(var_name_lookup_fn(var));
};
torch::jit::tracer::getTracingState()->force_outplace = force_outplace;
try {
py::tuple py_inputs(num_func_inputs);
for (size_t i = 0; i < num_func_inputs; ++i) {
py_inputs[i] = py::cast(enter_info.second[i]);
}
auto out = func(*py_inputs);
if (out.ptr() == Py_None) {
AT_ERROR(
"The traced function didn't return any values! Side-effects are not "
"captured in traces, so it would be a no-op.");
}
torch::jit::tracer::exit({torch::jit::toIValue(out)});
auto graph = enter_info.first->graph;
EliminateDeadCode(graph);
LowerSimpleTuples(graph);

return graph;
} catch (...) {
torch::jit::tracer::abandon();
throw;
}
}

void InitGraphExecutorBindings(py::module m) {
py::class_<torch::jit::GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
.def(py::init([](py::function func, py::tuple inputs,
py::function var_name_lookup_fn, bool optimize,
bool _force_outplace) {
auto graph =
CreateGraphByTracing(func, torch::jit::toStack(inputs),
var_name_lookup_fn, _force_outplace);
return torch::jit::GraphExecutor(graph, optimize);
}),
py::arg("func"), py::arg("inputs"), py::arg("var_name_lookup_fn"),
py::arg("optimize") = true, py::arg("_force_outplace") = false)
.def(
py::init([](std::shared_ptr<torch::jit::Graph> graph, bool optimize) {
return torch::jit::GraphExecutor(std::move(graph), optimize);
}),
py::arg("graph"), py::arg("optimize") = true)
.def(
"graph_for",
[](torch::jit::GraphExecutor& ge, py::args args) {
return ge.graphFor(torch::jit::evilDeprecatedBadCreateStackDoNotUse(
args, ge.graph()->inputs()));
})
.def_property_readonly(
"graph", [](torch::jit::GraphExecutor& ge) { return ge.graph(); })
.def("get_debug_state",
[](torch::jit::GraphExecutor& ge) { return ge.getDebugState(); })
.def("__call__",
[](torch::jit::GraphExecutor& ge, py::args args) -> py::object {
const auto& graph = ge.graph();
auto stack = torch::jit::evilDeprecatedBadCreateStackDoNotUse(
args, graph->inputs());
{
AutoNoGIL no_gil_guard;
ge.run(stack);
}
return torch::jit::createPyObjectForStack(std::move(stack));
});
}

} // namespace

void InitXlaBindings(py::module m) {
InitXlaModuleBindings(m);
InitXlaPassesBindings(m);
InitXlaTensorBindings(m);
InitGraphExecutorBindings(m);
}

} // namespace torch_xla
Expand Down