Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Update Reducesum operator for opset 13 #50532

Merged
merged 20 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
126 changes: 9 additions & 117 deletions test/onnx/test_pytorch_onnx_onnxruntime.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion test/onnx/test_utility_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.onnx import utils, OperatorExportTypes, TrainingMode
from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type
import torch.utils.cpp_extension
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion
import caffe2.python.onnx.backend as backend
from verify import verify

Expand Down Expand Up @@ -618,6 +618,8 @@ def forward(self, x):
assert next(iter).kind() == "aten::quantize_per_tensor"
assert next(iter).kind() == "aten::dequantize"

# prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11
@skipIfUnsupportedOpsetVersion([11, 12])
def test_prim_fallthrough(self):
# Test prim op
class PrimModule(torch.jit.ScriptModule):
Expand Down
23 changes: 23 additions & 0 deletions torch/csrc/jit/passes/onnx/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,28 @@ Value* addInputToBlock(Block* block) {
return block->addInput();
}

Node* createONNXUnsqueeze(
Graph* graph,
Node* n_to_insert_before,
Value* input,
int axis,
int opset_version) {
Node* unsqueeze_node = graph->create(onnx::Unsqueeze, 1);
unsqueeze_node->addInput(input);
unsqueeze_node->insertBefore(n_to_insert_before);
if (opset_version >= OPSET_VERSION_13) {
// ONNX spec sets `axes` as input for opset >= 13.
Node* unsqueeze_axes = graph->create(onnx::Constant, 1);
unsqueeze_axes->insertBefore(unsqueeze_node);
unsqueeze_axes->t_(
attr::value, at::unsqueeze(at::scalar_to_tensor(at::Scalar(axis)), 0));
unsqueeze_node->addInput(unsqueeze_axes->output());
} else {
// ONNX spec sets `axes` as attribute for opset < 13.
unsqueeze_node->is_(attr::axes, {0});
}
return unsqueeze_node;
}

} // namespace jit
} // namespace torch
9 changes: 9 additions & 0 deletions torch/csrc/jit/passes/onnx/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ static const int OPSET_VERSION_9 = 9;
static const int OPSET_VERSION_10 = 10;
static const int OPSET_VERSION_11 = 11;
static const int OPSET_VERSION_12 = 12;
static const int OPSET_VERSION_13 = 13;

using ValueToParamPairMap = std::map<Value*, std::pair<std::string, IValue>>;

Expand All @@ -33,5 +34,13 @@ Node* addNodeToBlock(Block* block, Symbol kind, ArrayRef<Value*> inputs);
Value* addInputToBlock(Block* block);

TORCH_API c10::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type);

Node* createONNXUnsqueeze(
Graph* graph,
Node* n_to_insert_before,
Value* input,
int axis,
int opset_version);

} // namespace jit
} // namespace torch
137 changes: 65 additions & 72 deletions torch/csrc/jit/passes/onnx/peephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,10 +416,8 @@ void fixDefaultRNNState(
batch_size->addInput(shape_of_input->outputs()[0]);
batch_size->addInput(gather_indices->outputs()[0]);

Node* unsqueezed_batch_size = graph->create(onnx::Unsqueeze, 1);
unsqueezed_batch_size->insertBefore(n);
unsqueezed_batch_size->addInput(batch_size->outputs()[0]);
unsqueezed_batch_size->is_(attr::axes, {0});
Node* unsqueezed_batch_size =
createONNXUnsqueeze(graph, n, batch_size->outputs()[0], 0, opset_version);

Node* hidden_size = graph->create(onnx::Constant, 1);
hidden_size->insertBefore(n);
Expand All @@ -440,10 +438,8 @@ void fixDefaultRNNState(
? 2
: 1)));

Node* unsqueezed_num_directions = graph->create(onnx::Unsqueeze, 1);
unsqueezed_num_directions->insertBefore(n);
unsqueezed_num_directions->addInput(num_directions->outputs()[0]);
unsqueezed_num_directions->is_(attr::axes, {0});
Node* unsqueezed_num_directions = createONNXUnsqueeze(
graph, n, num_directions->outputs()[0], 0, opset_version);

Node* concated_dims = graph->create(onnx::Concat, 1);
concated_dims->insertBefore(n);
Expand Down Expand Up @@ -555,6 +551,65 @@ static void replaceInputWithList(Node* node, size_t i, ArrayRef<Value*> to) {
}
}

static void eraseListConstruct(Block* block, int opset_version);

static void eraseListConstruct(Node* n, int opset_version) {
for (auto b : n->blocks()) {
eraseListConstruct(b, opset_version);
}
std::vector<std::tuple<size_t, std::vector<Value*>>> replacements;

auto block = n->owningBlock();
size_t i = 0;
for (auto* input : n->inputs()) {
if (input->node()->kind() == prim::ListConstruct) {
auto* lc_node = input->node();
TypePtr elem =
lc_node->output()->type()->cast<ListType>()->getElementType();
if (elem->cast<IntType>()) {
// ListConstruct Int[] output case, we need to transform to ONNX
// Concat to ensure the output is a single tensor(dynamic) type in
// order to be consumed as inputs
std::vector<Value*> unsqueezed;
Graph* g = block->owningGraph();
for (auto* input : lc_node->inputs()) {
Node* unsqueezed_node =
createONNXUnsqueeze(g, lc_node, input, 0, opset_version);
unsqueezed.emplace_back(unsqueezed_node->output());
}
Node* concat_node = g->create(onnx::Concat, 1);
concat_node->i_(attr::axis, 0);
for (auto v : unsqueezed) {
concat_node->addInput(v);
}
concat_node->insertBefore(lc_node);

// make concat node output as new input, then ListConstruct should
// become dead
replacements.emplace_back(
i, std::vector<Value*>({concat_node->output()}));

} else {
if (opset_version >= OPSET_VERSION_11) {
c10::Symbol seq_node_kind = lc_node->inputs().size() > 0
? onnx::SequenceConstruct
: onnx::SequenceEmpty;
Node* seq_node = block->owningGraph()->create(
seq_node_kind, {lc_node->inputs()}, 1);
seq_node->insertBefore(lc_node);
seq_node->output()->copyMetadata(lc_node->output());
lc_node->replaceAllUsesWith(seq_node);
}
}
}
i++;
}

for (auto ritr = replacements.rbegin(); ritr != replacements.rend(); ++ritr) {
replaceInputWithList(n, std::get<0>(*ritr), std::get<1>(*ritr));
}
}

static void eraseListConstruct(Block* block, int opset_version) {
// TODO: Fix this pass/maybe get rid of this part.
// Tensor lists might be used for meshgrid and such ops as well.
Expand All @@ -563,71 +618,9 @@ static void eraseListConstruct(Block* block, int opset_version) {
Node* n = *it;
++it;

for (auto b : n->blocks()) {
eraseListConstruct(b, opset_version);
}
std::vector<std::tuple<size_t, std::vector<Value*>>> replacements;

size_t i = 0;
for (auto* input : n->inputs()) {
if (input->node()->kind() == prim::ListConstruct) {
auto* lc_node = input->node();
TypePtr elem =
lc_node->output()->type()->cast<ListType>()->getElementType();
if (elem->cast<IntType>()) {
// ListConstruct Int[] output case, we need to transform to ONNX
// Concat to ensure the output is a single tensor(dynamic) type in
// order to be consumed as inputs
std::vector<Value*> unsqueezed;
Graph* g = block->owningGraph();
for (auto* input : lc_node->inputs()) {
Node* unsqueezed_node = g->create(onnx::Unsqueeze, 1);
unsqueezed_node->insertBefore(lc_node);
unsqueezed_node->addInput(input);
unsqueezed_node->is_(attr::axes, {0});
unsqueezed.emplace_back(unsqueezed_node->output());
}
Node* concat_node = g->create(onnx::Concat, 1);
concat_node->i_(attr::axis, 0);
for (auto v : unsqueezed) {
concat_node->addInput(v);
}
concat_node->insertBefore(lc_node);

// make concat node output as new input, then ListConstruct should
// become dead
replacements.emplace_back(
i, std::vector<Value*>({concat_node->output()}));

} else {
if (opset_version < OPSET_VERSION_11) {
// Tensor lists are used mostly for inputs to cat/stack. They are
// already handled in those symbolics, and should become dead
// afterwards.
replacements.emplace_back(
i,
std::vector<Value*>(
lc_node->inputs().begin(), lc_node->inputs().end()));
} else {
c10::Symbol seq_node_kind = lc_node->inputs().size() > 0
? onnx::SequenceConstruct
: onnx::SequenceEmpty;
Node* seq_node = block->owningGraph()->create(
seq_node_kind, {lc_node->inputs()}, 1);
seq_node->insertBefore(lc_node);
seq_node->output()->copyMetadata(lc_node->output());
lc_node->replaceAllUsesWith(seq_node);
}
}
}
i++;
}

for (auto ritr = replacements.rbegin(); ritr != replacements.rend();
++ritr) {
replaceInputWithList(n, std::get<0>(*ritr), std::get<1>(*ritr));
}
eraseListConstruct(n, opset_version);
}
eraseListConstruct(block->return_node(), opset_version);
}

// For ops such as meshgrid where output is a list of Tensors
Expand Down
79 changes: 44 additions & 35 deletions torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ bool IsSupportedNode(const Node* n) {
return true;
}

Value* CloneValueFromListConstruct(Value* v, std::shared_ptr<Graph> n_graph) {
Value* CloneValueFromListConstruct(
Value* v,
std::shared_ptr<Graph> n_graph,
int opset_version) {
auto lc_node = v->node();
TORCH_INTERNAL_ASSERT(lc_node->kind() == ::c10::prim::ListConstruct);
// In jit/passes/onnx/peephole.cpp::eraseListConstruct,
Expand All @@ -220,12 +223,10 @@ Value* CloneValueFromListConstruct(Value* v, std::shared_ptr<Graph> n_graph) {
// order to be consumed as inputs
std::vector<Value*> unsqueezed;
for (auto* input : lc_node->inputs()) {
Node* unsqueezed_node =
n_graph->insertNode(n_graph->create(::c10::onnx::Unsqueeze, 1));
auto new_input = n_graph->addInput();
new_input->copyMetadata(input);
unsqueezed_node->addInput(new_input);
unsqueezed_node->is_(attr::axes, {0});
Node* unsqueezed_node = createONNXUnsqueeze(
n_graph.get(), n_graph->return_node(), new_input, 0, opset_version);
unsqueezed.emplace_back(unsqueezed_node->output());
}
Node* concat_node =
Expand Down Expand Up @@ -257,34 +258,38 @@ Value* CloneValueFromListConstruct(Value* v, std::shared_ptr<Graph> n_graph) {
}

// Clone the node n for the new graph.
Node* CloneNodeToGraph(Node* n, std::shared_ptr<Graph> n_graph) {
auto clone_node = n_graph->createClone(n, [&n_graph](Value* v) {
auto v_n = v->node();
switch (v_n->kind()) {
case ::c10::onnx::Constant: {
// Clone the input if it is constant.
auto constant_n = n_graph->insertNode(
n_graph->createClone(v_n, [](Value* v) { return v; }));
return constant_n->output();
}
case ::c10::prim::ListConstruct: {
return CloneValueFromListConstruct(v, n_graph);
}
case ::c10::prim::PackPadded: {
auto input = n_graph->addInput();
input->copyMetadata(v_n->input(0));
return input;
}
default: {
// If the input is not constant, we cannot depend on its value
// in shape inference. Set it to graph input in the new graph,
// and copy over metadata, such as datatype and shape.
auto input = n_graph->addInput();
input->copyMetadata(v);
return input;
}
}
});
Node* CloneNodeToGraph(
Node* n,
std::shared_ptr<Graph> n_graph,
int opset_version) {
auto clone_node =
n_graph->createClone(n, [&n_graph, opset_version](Value* v) {
auto v_n = v->node();
switch (v_n->kind()) {
case ::c10::onnx::Constant: {
// Clone the input if it is constant.
auto constant_n = n_graph->insertNode(
n_graph->createClone(v_n, [](Value* v) { return v; }));
return constant_n->output();
}
case ::c10::prim::ListConstruct: {
return CloneValueFromListConstruct(v, n_graph, opset_version);
}
case ::c10::prim::PackPadded: {
auto input = n_graph->addInput();
input->copyMetadata(v_n->input(0));
return input;
}
default: {
// If the input is not constant, we cannot depend on its value
// in shape inference. Set it to graph input in the new graph,
// and copy over metadata, such as datatype and shape.
auto input = n_graph->addInput();
input->copyMetadata(v);
return input;
}
}
});
return clone_node;
}

Expand Down Expand Up @@ -453,7 +458,7 @@ void ONNXShapeTypeInference(Node* n, int opset_version) {
// Create a Graph containing only the single node n.
// This graph is later converted to ONNX to run shape inference.
auto n_graph = std::make_shared<Graph>();
auto clone_node = CloneNodeToGraph(n, n_graph);
auto clone_node = CloneNodeToGraph(n, n_graph, opset_version);
n_graph->insertNode(clone_node);

// Register all node outputs as graph outputs.
Expand Down Expand Up @@ -484,12 +489,16 @@ void ONNXShapeTypeInference(Node* n, int opset_version) {
} catch (std::runtime_error& ex) {
// TODO: include this as warning once we have a more consolidated warning
// system.
GRAPH_DEBUG(
"ONNX shape inference fails with: ",
ex.what(),
" on graph: ",
n_graph->toString());
const char shape_err[] = "ShapeInferenceError";
const char type_err[] = "TypeInferenceError";
if ((strstr(ex.what(), shape_err) == NULL) &&
(strstr(ex.what(), type_err) == NULL))
throw;
GRAPH_DEBUG("ONNX shape inference fails with: ", ex.what());
}
GRAPH_DEBUG(
"ONNX graph after shape inference: ", prettyPrint(*model_proto));
Expand Down