Skip to content

Commit

Permalink
instead of a pass use a helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
KsenijaS committed Dec 17, 2020
1 parent 2763af5 commit 723b446
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 67 deletions.
62 changes: 2 additions & 60 deletions torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <aten/src/ATen/InitialTensorOptions.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <torch/csrc/jit/passes/onnx/peephole.h>

namespace torch {
Expand All @@ -13,38 +14,6 @@ using namespace ::c10::onnx;

namespace {
const int ONNX_OPSET_13 = 13;
const int ONNX_TYPE_BOOL = 9;

Node* CreateCastToBoolNode(Value* val, Graph* graph) {
Node* cast_node = graph->create(onnx::Cast);
cast_node->addInput(val);
cast_node->i_(attr::to, ONNX_TYPE_BOOL);
cast_node->output()->setType(BoolType::get());
return cast_node;
}

Node* InsertCastForCond(Value* cond_val, Graph* graph, Node* consumer_node) {
// prev: cond_val -> consumer_node
// after: cond_val -> cast -> consumer_node
// NOTE: The cast is required because operators like PyTorch Greater/Less
// return tensor in type torch.uint8. However the type for condition
// input in ONNX Loop must be bool.
Node* cast_node = CreateCastToBoolNode(cond_val, graph);
cast_node->insertBefore(consumer_node);

consumer_node->replaceInputWith(cond_val, cast_node->output());
return cast_node;
}

bool IsCondCastRequired(Value* cond_val) {
const auto& type = cond_val->type();
if (auto tt = type->cast<TensorType>()) {
if (auto scalar_type = tt->scalarType()) {
return *scalar_type != c10::kBool;
}
}
return !type->isSubtypeOf(BoolType::get());
}

bool IsErasableSequence(const Node* loop_node, size_t i) {
TORCH_INTERNAL_ASSERT(loop_node->blocks().size() == 1);
Expand Down Expand Up @@ -192,36 +161,9 @@ void ConvertSequenceDependencies(Block* block, int opset_version) {
}
} // anonymous namespace

void FixupONNXLoopNodeInputs(Node* node) {
if (node->kind() != ::c10::onnx::Loop) {
return;
}

auto* graph = node->owningGraph();

// add cast to condition input outside the loop.
Value* cond_val = node->inputs()[1];
if (IsCondCastRequired(cond_val))
InsertCastForCond(cond_val, graph, node);

// Setup Loop input cond and i.
TORCH_INTERNAL_ASSERT(node->blocks().size() == 1);
auto* sub_block = node->blocks()[0];
Value* cond = sub_block->insertInput(1, "cond");
cond->setType(BoolType::create());

Value* i = sub_block->inputs()[0];
i->setType(TensorType::fromNumberType(IntType::get()));

// add cast to condition input inside the loop.
Value* next_cond_val = sub_block->outputs()[0];
if (IsCondCastRequired(next_cond_val))
InsertCastForCond(next_cond_val, graph, sub_block->return_node());
}

std::vector<Value*> FixupONNXLoopNode(Node* node, int opset_version) {
auto output_size = node->outputs().size();
FixupONNXLoopNodeInputs(node);
updateLoopNodeInputs(node);
auto new_outputs = ConvertSequenceDependencies(node, opset_version);
TORCH_INTERNAL_ASSERT(output_size == new_outputs.size());
return new_outputs;
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
namespace torch {
namespace jit {

void FixupONNXLoopNodeInputs(Node* node);
std::vector<Value*> FixupONNXControlflowNode(Node* n, int opset_version);

} // namespace jit
Expand Down
58 changes: 58 additions & 0 deletions torch/csrc/jit/passes/onnx/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,63 @@ Value* addInputToBlock(Block* block) {
return block->addInput();
}

Node* CreateCastToBoolNode(Value* val, Graph* graph) {
Node* cast_node = graph->create(onnx::Cast);
cast_node->addInput(val);
cast_node->i_(attr::to, ONNX_TYPE_BOOL);
cast_node->output()->setType(BoolType::get());
return cast_node;
}

Node* InsertCastForCond(Value* cond_val, Graph* graph, Node* consumer_node) {
// prev: cond_val -> consumer_node
// after: cond_val -> cast -> consumer_node
// NOTE: The cast is required because operators like PyTorch Greater/Less
// return tensor in type torch.uint8. However the type for condition
// input in ONNX Loop must be bool.
Node* cast_node = CreateCastToBoolNode(cond_val, graph);
cast_node->insertBefore(consumer_node);

consumer_node->replaceInputWith(cond_val, cast_node->output());
return cast_node;
}

bool IsCondCastRequired(Value* cond_val) {
const auto& type = cond_val->type();
if (auto tt = type->cast<TensorType>()) {
if (auto scalar_type = tt->scalarType()) {
return *scalar_type != c10::kBool;
}
}
return !type->isSubtypeOf(BoolType::get());
}

void updateLoopNodeInputs(Node* node) {
if (node->kind() != ::c10::onnx::Loop) {
return;
}

auto* graph = node->owningGraph();

// add cast to condition input outside the loop.
Value* cond_val = node->inputs()[1];
if (IsCondCastRequired(cond_val))
InsertCastForCond(cond_val, graph, node);

// Setup Loop input cond and i.
TORCH_INTERNAL_ASSERT(node->blocks().size() == 1);
auto* sub_block = node->blocks()[0];
Value* cond = sub_block->insertInput(1, "cond");
cond->setType(BoolType::create());

Value* i = sub_block->inputs()[0];
i->setType(TensorType::fromNumberType(IntType::get()));

// add cast to condition input inside the loop.
Value* next_cond_val = sub_block->outputs()[0];
if (IsCondCastRequired(next_cond_val))
InsertCastForCond(next_cond_val, graph, sub_block->return_node());
}

} // namespace jit
} // namespace torch
6 changes: 6 additions & 0 deletions torch/csrc/jit/passes/onnx/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ static const int OPSET_VERSION_10 = 10;
static const int OPSET_VERSION_11 = 11;
static const int OPSET_VERSION_12 = 12;

static const int ONNX_TYPE_BOOL = 9;

using ValueToParamPairMap = std::map<Value*, std::pair<std::string, IValue>>;

using ParamMap = std::map<std::string, IValue>;
Expand All @@ -32,6 +34,10 @@ Node* addNodeToBlock(Block* block, Symbol kind, ArrayRef<Value*> inputs);

Value* addInputToBlock(Block* block);

void updateLoopNodeInputs(Node* node);

Node* CreateCastToBoolNode(Value* val, Graph* graph);

TORCH_API c10::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type);
} // namespace jit
} // namespace torch
1 change: 0 additions & 1 deletion torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,6 @@ void initJITBindings(PyObject* module) {
})
.def("_jit_pass_onnx_block", BlockToONNX)
.def("_jit_pass_fixup_onnx_controlflow_node", FixupONNXControlflowNode)
.def("_jit_pass_fixup_onnx_loop_node_inputs", FixupONNXLoopNodeInputs)
.def("_jit_pass_canonicalize_graph_fuser_ops", CanonicalizeOps)
.def("_jit_pass_decompose_ops", DecomposeOps)
.def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/python/python_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ void initPythonIRBindings(PyObject* module_) {
"Find all nodes",
py::arg("kind"),
py::arg("recurse") = true)
.def("updateLoopInputs", [](Node& n) { return updateLoopNodeInputs(&n); })
.def("input", [](Node& n) { return n.input(); })
.def("output", [](Node& n) { return n.output(); })
.NS(addInput)
Expand Down
5 changes: 2 additions & 3 deletions torch/onnx/symbolic_opset11.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list
from torch.onnx.symbolic_opset9 import expand, unused
from torch.nn.modules.utils import _single, _pair, _triple
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block, _update_loop_inputs

# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
Expand Down Expand Up @@ -943,8 +943,7 @@ def embedding_bag(g,
_add_output_to_block(loop_block, loop_condition)
_add_output_to_block(loop_block, embeddings)
# This pass does all required type casting for loop inputs (condition and iter)
torch._C._jit_pass_fixup_onnx_loop_node_inputs(loop.node())

_update_loop_inputs(loop.node())
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
return loop.node().output(), None, None, None
Expand Down
4 changes: 2 additions & 2 deletions torch/onnx/symbolic_opset12.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block, _update_loop_inputs


# EDITING THIS FILE? READ THIS FIRST!
Expand Down Expand Up @@ -138,7 +138,7 @@ def unfold(g, input, dimension, size, step):

_add_output_to_block(loop_block, loop_condition)
_add_output_to_block(loop_block, concat)
torch._C._jit_pass_fixup_onnx_loop_node_inputs(loop.node())
_update_loop_inputs(loop.node())

loop_output = loop.node().output()
perm = [0, 1, 2, 3, 4]
Expand Down
2 changes: 2 additions & 0 deletions torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,8 @@ def _add_output_to_block(block, value):
new_output = block.registerOutput(value)
return new_output

def _update_loop_inputs(node):
node.updateLoopInputs()

# Note [Export inplace]
# ~~~~~~~~~~~~~~~~~~~~~
Expand Down

0 comments on commit 723b446

Please sign in to comment.