Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,7 @@ libtorch_python_core_sources = [
"torch/csrc/jit/passes/onnx/function_extraction.cpp",
"torch/csrc/jit/passes/onnx/onnx_log.cpp",
"torch/csrc/jit/python/pybind_utils.cpp",
"torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.cpp",
"torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp",
"torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp",
"torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp",
Expand Down
28 changes: 27 additions & 1 deletion docs/source/onnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ See the ``symbolic_opset*.py`` files for more examples.
torch.autograd.Functions
^^^^^^^^^^^^^^^^^^^^^^^^

If the operator is a sub-class of :class:`torch.autograd.Function`, there are two ways
If the operator is a sub-class of :class:`torch.autograd.Function`, there are three ways
to export it.

Static Symbolic Method
Expand Down Expand Up @@ -488,6 +488,32 @@ The example below shows how you can access ``requires_grad`` via the ``Node`` ob
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)

Inline Autograd Function
~~~~~~~~~~~~~~~~~~~~~~~~
In cases where a static symbolic method is not provided for its subsequent autograd.Function
or where a function to register prim::PythonOp as custom symbolic functions is not provided,
torch.onnx.export tries to inline the graph that corresponds to that autograd.Function such that
this function is broken down into individual operators that were used within the function.
The export should be successful as long as these individual operators are supported. For example::

class MyLogExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
h = input.exp()
return h.log().log()

There is no static symbolic method present for this model, yet it is exported as follows::

graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
%1 : float = onnx::Exp[](%input)
%2 : float = onnx::Log[](%1)
%3 : float = onnx::Log[](%2)
return (%3)

In order to avoid inlining of autograd.Functions, model should be exported with
operator_export_type set to ONNX_FALLTHROUGH or ONNX_ATEN_FALLBACK mode

Custom operators
^^^^^^^^^^^^^^^^

Expand Down
181 changes: 181 additions & 0 deletions test/onnx/test_autograd_funs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Owner(s): ["module: onnx"]

import unittest

import torch

from onnx_test_common import run_model_test
from torch.onnx import OperatorExportTypes
from torch.onnx._globals import GLOBALS
from torch.onnx.utils import _model_to_graph


class TestAutogradFuns(unittest.TestCase):
opset_version = GLOBALS.export_onnx_opset_version
keep_initializers_as_inputs = False
onnx_shape_inference = True

def test_single_output(self):
class SingleOut(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i.exp()
result = result.log()
ctx.save_for_backward(result)
return result

@staticmethod
def backward(ctx, grad_output):
(result,) = ctx.saved_tensors
return grad_output * result

class Caller(torch.nn.Module):
def forward(self, input):
result = input + 5
return SingleOut.apply(result) + 3

model = Caller()
input = torch.ones(1)
run_model_test(self, model, input_args=(input,))

def test_multi_output(self):
class MultiOut(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result_exp = i.exp()
result_log = result_exp.log()
ctx.save_for_backward(result_exp, result_log)
return result_exp, result_log

@staticmethod
def backward(ctx, grad_output):
(result,) = ctx.saved_tensors
return grad_output * result

class Caller(torch.nn.Module):
def forward(self, input):
return MultiOut.apply(input)

model = Caller()
input = torch.ones(1, 5)
run_model_test(self, model, input_args=(input,))

def test_partial_output(self):
class PartialOut(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
values, indices = torch.topk(input, 3)
return values

class Caller(torch.nn.Module):
def forward(self, input):
return PartialOut.apply(input)

model = Caller()
input = torch.ones(1, 5)
run_model_test(self, model, input_args=(input,))

def test_nested_autograd(self):
class Child(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i.log()
result_log = result.log()
ctx.save_for_backward(result_log)
return result_log

@staticmethod
def backward(ctx, grad_output):
(result,) = ctx.saved_tensors
return grad_output * result

class Parent(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result_exp = i.exp()
result_log = Child.apply(result_exp)
ctx.save_for_backward(result_exp, result_log)
return result_exp, result_log

@staticmethod
def backward(ctx, grad_output):
(result,) = ctx.saved_tensors
return grad_output * result

class Caller(torch.nn.Module):
def forward(self, input):
return Parent.apply(input)

model = Caller()
input = torch.ones(1, 5)
run_model_test(self, model, input_args=(input,))

# Run export in ONNX_FALLTHROUGH mode as torch.erf() is not supported
def test_aten_unsupported(self):
class Erf(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
erf_out = torch.special.erf(x)
ctx.save_for_backward(erf_out)
return erf_out

@staticmethod
def backward(ctx, grad_output):
result = ctx.saved_tensors
return torch.special.erfinv(result), None

class Caller(torch.nn.Module):
def forward(self, input):
return Erf.apply(input)

model = Caller()
input = torch.ones(1, 5)

# Test ONNX_FALLTHROUGH_MODE
graph, _, _ = _model_to_graph(
model,
(input,),
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
)
iter = graph.nodes()
self.assertEqual(next(iter).kind(), "prim::PythonOp")

# Test ATEN_FALLBACK_MODE
graph, _, _ = _model_to_graph(
model,
(input,),
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
)
iter = graph.nodes()
self.assertEqual(next(iter).kind(), "prim::PythonOp")

def test_inline_and_symbolic(self):
class Exp(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(input)
return i.exp()

@staticmethod
def symbolic(g, input):
return g.op("Exp", input)

class LogLog(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(input)
return i.log().log()

class Caller(torch.nn.Module):
def forward(self, input):
exp_result = Exp.apply(input)
return LogLog.apply(exp_result)

model = Caller()
input = torch.ones(1)
run_model_test(self, model, input_args=(input,))


if __name__ == "__main__":
unittest.main()
6 changes: 5 additions & 1 deletion test/onnx/test_utility_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,11 @@ def forward(self, input):
batch = torch.FloatTensor(1, 3)

graph, _, _ = self._model_to_graph(
model, batch, input_names=["batch"], dynamic_axes={"batch": [0, 1]}
model,
batch,
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
input_names=["batch"],
dynamic_axes={"batch": [0, 1]},
)
iter = graph.nodes()
autograd1 = next(iter)
Expand Down
1 change: 1 addition & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph, module: Module) ->
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
def _jit_pass_peephole(graph: Graph, disable_shape_peepholes: _bool = False) -> None: ...
def _jit_pass_onnx_autograd_function_process(graph: Graph) -> None: ...
def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
def _jit_pass_onnx_preprocess(graph: Graph) -> None: ...
def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ...
Expand Down
55 changes: 55 additions & 0 deletions torch/csrc/autograd/python_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,52 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject* args) {
return std::make_pair(std::move(unpacked), std::move(flags));
}

// Given a prim::PythonOp node, _append_subgraph creates a subgraph such that:
// (1) It has the same inputs as the prim::PythonOp node
// (2) The intermediate nodes used in the PythonOp are cloned and stored in the
// subgraph (3) trace_outputs stores the Value* objects, before a new trace
// value is assigned by the prim::PythonOp node and helps to eventually route
// the outputs of the subgraph correctly This newly created subgraph is then
// added to the prim::PythonOp node as a subgraph attribute
static void _append_subgraph(
torch::jit::Node* node,
torch::jit::Graph* graph,
std::vector<torch::jit::Value*> trace_outputs,
bool unpack_output) {
node->g_(
torch::jit::attr::Subgraph,
std::make_shared<torch::jit::Graph>(graph->current_scope()));
auto subgraph = node->g(torch::jit::attr::Subgraph);

std::unordered_map<Value*, Value*> value_map;
auto value_map_func = [&](Value* v) { return value_map.at(v); };
for (size_t i = 0; i < node->inputs().size(); ++i) {
auto subgraph_input = subgraph->addInput();
subgraph_input->copyMetadata(node->inputs().at(i));
value_map[node->inputs().at(i)] = subgraph_input;
}
// Find node position in graph, all subsequent nodes after are added to
// subgraph
auto it = std::find(graph->nodes().begin(), graph->nodes().end(), node);
// Skip TupleUnpack node if created
if (!unpack_output) {
it++;
}
for (it++; it != graph->nodes().end(); ++it) {
torch::jit::Node* node = *it;
auto* clone_node =
subgraph->insertNode(subgraph->createClone(node, value_map_func));
for (size_t i = 0; i < node->outputs().size(); ++i) {
value_map[node->outputs()[i]] = clone_node->outputs()[i];
auto trace_it = std::find(
trace_outputs.begin(), trace_outputs.end(), node->outputs()[i]);
if (trace_it != trace_outputs.end()) {
subgraph->registerOutput(clone_node->outputs()[i]);
}
}
}
}

static torch::jit::Node* _trace_pre_record(
PyObject* op_obj,
PyObject* input_objects,
Expand Down Expand Up @@ -650,17 +696,26 @@ static void _trace_post_record(
auto unpacked = graph->createTupleUnpack(node->output())->insertAfter(node);
node = unpacked;
}

std::vector<torch::jit::Value*> trace_outputs;
for (const auto i : c10::irange(num_outputs)) {
PyObject* obj = PyTuple_GET_ITEM(output_objects, i);
if (THPVariable_Check(obj)) {
Value* value = node->outputs()[i];
const auto& tensor = THPVariable_Unpack(obj);
if (tensor.defined()) {
value->inferTypeFrom(tensor);
trace_outputs.push_back(jit::tracer::getValueTrace(tensor));
jit::tracer::setValueTrace(tensor, value);
}
}
}
py::bool_ is_in_onnx_export =
py::module::import("torch.onnx.__init__").attr("is_in_onnx_export");
if (py::cast<bool>(is_in_onnx_export)) {
_append_subgraph(old_node, graph, trace_outputs, unpack_output);
}

// If TupleUnpack operator is created, we copy its output type back
// to the original tuple type.
if (!unpack_output) {
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ std::ostream& Node::print(
if (kind() == prim::PythonOp) {
auto* pyOp = static_cast<const ::torch::jit::PythonOp*>(this);
out << "^" << pyOp->name();
printAttributes(out, /*ignore_subgraph=*/false);
pyOp->writeScalars(out);
} else if (hasAttribute(attr::Subgraph) && groups) {
out << kind().toQualString() << "_" << groups->size();
Expand Down
38 changes: 36 additions & 2 deletions torch/csrc/jit/passes/onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,21 @@ void NodeToONNX(
}
};

// Inline the prim::PythonOp sub-block nodes and append them to the onnx graph
auto inlineAutograd = [&](Node* PythonOpNode) {
for (auto subblock : PythonOpNode->blocks()) {
for (const auto i : c10::irange(PythonOpNode->inputs().size())) {
env[subblock->inputs()[i]] = env[PythonOpNode->inputs()[i]];
}
for (auto* node : subblock->nodes()) {
NodeToONNX(node, new_block, operator_export_type, env);
}
for (const auto i : c10::irange(PythonOpNode->outputs().size())) {
env[PythonOpNode->outputs()[i]] = env[subblock->outputs()[i]];
}
}
};

// Cast output of symbolic() python implementation
auto processSymbolicOutput = [&](const std::string& op_name,
Node* n,
Expand Down Expand Up @@ -441,11 +456,30 @@ void NodeToONNX(
"PythonOp", "prim", opset_version);
if (!py::hasattr(pyobj, "symbolic") &&
(!PyObject_IsTrue(is_registered_op.ptr()))) {
// Simply clone the node, unless either
// Inline the subgraph within the prim::PythonOp unless
// either of these conditions are satisfied
// 1. The torch.autograd.Function class of this node object has `symbolic`
// method defined.
// 2. Custom export symbolic is registered for prim::PythonOp.
cloneNode(op);
if (operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX) {
try {
inlineAutograd(op);
} catch (const std::exception& ex) {
TORCH_WARN(
"Unable to inline PythonOp: ",
op->name(),
" due to the following exception\n",
ex.what(),
"prim::PythonOp will be exported as is and without being inlined\n",
"Try exporting with the following alternatives: \n",
"1) Set operator_export_type to ONNX_FALLTHROUGH mode\n",
"2) Register a symbolic method for the prim::PythonOp ",
op->name());
cloneNode(op);
}
} else {
cloneNode(op);
}
return;
}

Expand Down
Loading