Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Support tuples in ScriptModule inputs/outputs #20784

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 23 additions & 0 deletions test/onnx/test_pytorch_onnx_caffe2.py
Expand Up @@ -1651,6 +1651,29 @@ def forward(self, lstm_in):

self.run_model_test(MyModel(), train=False, input=lstm_in, batch_size=3, use_gpu=False)

def test_tuple_input_output(self):
class TupleModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
# type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
return a

x = (torch.randn(3, 4), torch.randn(4, 3))
self.run_model_test(TupleModel(), train=False, input=(x,), batch_size=BATCH_SIZE,
example_outputs=(x,))

def test_nested_tuple_input_output(self):
class NestedTupleModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a, b):
# type: (Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
return a + b[0] + b[1][0] + b[1][1]

x = torch.randn(4, 5)
y = (torch.randn(4, 5), (torch.randn(4, 5), torch.randn(4, 5)))
self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE,
example_outputs=x + y[0] + y[1][0] + y[1][1])

def test_topk(self):
class TopKModel(torch.nn.Module):
def forward(self, input):
Expand Down
98 changes: 71 additions & 27 deletions torch/csrc/jit/script/init.cpp
Expand Up @@ -178,11 +178,58 @@ static Self moduleSelf(
};
}

static void setInputTensorTypes(Graph& g, const Stack& stack) {
AT_ASSERT(stack.size() == g.inputs().size());
for (size_t i = 0; i < stack.size(); ++i) {
g.inputs().at(i)->setType(
DimensionedTensorType::create(stack.at(i).toTensor()));
static TypePtr getTensorType(
const at::Tensor& t,
const TypeKind type_kind) {
switch (type_kind) {
case TypeKind::DimensionedTensorType:
return DimensionedTensorType::create(t);
case TypeKind::CompleteTensorType: {
auto scalar_type = t.scalar_type();
auto sizes = t.sizes();
return CompleteTensorType::create(scalar_type, at::kCPU, sizes);
}
default:
throw std::runtime_error(
"Attempted to call getTensorType for type kind other than DimensionedTensorType or CompleteTensorType.");
}
}

static TupleTypePtr getTupleTensorType(
const Stack::const_iterator& s_iter,
const Stack::const_iterator& s_iter_end,
const TypePtr& tupleType,
const TypeKind type_kind) {
AT_ASSERT(tupleType->kind() == TupleType::Kind);
AT_ASSERT(s_iter != s_iter_end);

std::vector<TypePtr> types;
for (const auto& subType : tupleType->containedTypes()) {
if (subType->kind() == TupleType::Kind) {
types.push_back(getTupleTensorType(s_iter+1, s_iter_end, subType, type_kind));
} else {
types.push_back(getTensorType(s_iter->toTensor(), type_kind));
}
}
return TupleType::create(types);
}

static void setInputTensorTypes(
Graph& g,
const Stack& stack,
const TypeKind type_kind = TypeKind::DimensionedTensorType) {
at::ArrayRef<Value*> input_values = g.inputs();
auto s_iter = stack.begin();
for (auto v : input_values) {
AT_ASSERT(s_iter != stack.end());
if (v->type()->kind() == TupleType::Kind) {
AT_ASSERT(v->node()->kind() == prim::Param);
v->setType(
getTupleTensorType(s_iter, stack.end(), v->type(), type_kind));
} else {
v->setType(getTensorType(s_iter->toTensor(), type_kind));
s_iter++;
}
}
}

Expand All @@ -197,38 +244,32 @@ static std::shared_ptr<Graph> _propagate_shapes(
return retval;
}

static std::shared_ptr<Graph> _propagate_and_assign_input_and_output_shapes(
static std::shared_ptr<Graph> _propagate_and_assign_input_shapes(
Graph& graph,
std::vector<at::Tensor> inputs,
std::vector<at::Tensor> outputs,
const std::vector<at::Tensor>& inputs,
bool with_grad = false,
bool propagate = true) {
auto retval = graph.copy();
if (propagate) {
setInputTensorTypes(*retval, fmap<IValue>(inputs));
setInputTensorTypes(*retval, fmap<IValue>(inputs), TypeKind::DimensionedTensorType);
PropagateInputShapes(retval);
}
AT_ASSERT(retval->inputs().size() == inputs.size());
for (size_t i = 0; i < retval->inputs().size(); ++i) {
auto scalar_type = inputs[i].scalar_type();
auto sizes = inputs[i].sizes();
auto type =
torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
retval->inputs()[i]->setType(type);
}
at::ArrayRef<Value*> output_values = retval->outputs();
// patch this to still work if we are returning a tuple of multiple values
if (output_values.at(0)->type()->kind() == TupleType::Kind) {
AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
output_values = output_values.at(0)->node()->inputs();
}
AT_ASSERT(output_values.size() == outputs.size());
setInputTensorTypes(*retval, fmap<IValue>(inputs), TypeKind::CompleteTensorType);

return retval;
}

static std::shared_ptr<Graph> _assign_output_shapes(
Graph& graph,
std::vector<at::Tensor> outputs) {
auto retval = graph.copy();
AT_ASSERT(retval->outputs().size() == outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
auto scalar_type = outputs[i].scalar_type();
auto sizes = outputs[i].sizes();
auto type =
torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
output_values[i]->setType(type);
retval->outputs()[i]->setType(type);
}
return retval;
}
Expand Down Expand Up @@ -679,8 +720,11 @@ void initJitScriptBindings(PyObject* module) {
debugSetAutodiffSubgraphInlining);
m.def("_propagate_shapes", _propagate_shapes);
m.def(
"_propagate_and_assign_input_and_output_shapes",
_propagate_and_assign_input_and_output_shapes);
"_propagate_and_assign_input_shapes",
_propagate_and_assign_input_shapes);
m.def(
"_assign_output_shapes",
_assign_output_shapes);
m.def("_jit_python_print", [](py::object obj) {
std::ostringstream ss;
std::vector<at::Tensor> constants;
Expand Down
20 changes: 13 additions & 7 deletions torch/onnx/utils.py
Expand Up @@ -16,7 +16,7 @@
from torch._six import string_classes
from torch.jit import _unique_state_dict
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes
from torch._C import ListType, _propagate_and_assign_input_and_output_shapes
from torch._C import ListType, _propagate_and_assign_input_shapes, _assign_output_shapes


# the flag to tell the user whether it's in the middle of ONNX export or not
Expand Down Expand Up @@ -214,10 +214,10 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa

# onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
torch._C._jit_pass_prepare_division_for_onnx(graph)
# onnx only supports tensors, so we turn all out number types into tensors
torch._C._jit_pass_erase_number_types(graph)
# onnx does not support tuples, so try to remove them
torch._C._jit_pass_lower_all_tuples(graph)
# onnx only supports tensors, so we turn all out number types into tensors
torch._C._jit_pass_erase_number_types(graph)
torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_lint(graph)

Expand Down Expand Up @@ -288,16 +288,18 @@ def _model_to_graph(model, args, verbose=False, training=False,
assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule"
try:
method_graph, params = model.forward._lowered_graph()
graph = _propagate_and_assign_input_and_output_shapes(
method_graph, tuple(args) + tuple(params), example_outputs, False, propagate)
in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
graph = _propagate_and_assign_input_shapes(
method_graph, tuple(in_vars), False, propagate)
except AttributeError:
raise RuntimeError('\'forward\' method must be a script method')
elif isinstance(model, torch.jit.Function):
assert example_outputs is not None, "example_outputs must be provided when exporting a TorchScript Function"
method = model
params = ()
graph = _propagate_and_assign_input_and_output_shapes(
model.graph, tuple(args), example_outputs, False, propagate)
in_vars, in_desc = torch.jit._flatten(tuple(args))
graph = _propagate_and_assign_input_shapes(
model.graph, tuple(in_vars), False, propagate)
else:
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
state_dict = _unique_state_dict(model)
Expand All @@ -313,6 +315,10 @@ def _model_to_graph(model, args, verbose=False, training=False,
graph = _optimize_graph(graph, operator_export_type,
_disable_torch_constant_prop=_disable_torch_constant_prop)

if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.Function):
out_vars, _ = torch.jit._flatten(tuple(example_outputs))
graph = _assign_output_shapes(graph, out_vars)

# NB: ONNX requires complete information about output types, which might be
# erased by some optimizations, so we need to set it explicitly again.
if torch_out is not None:
Expand Down