Skip to content

Commit

Permalink
Support tuples in ScriptModule inputs/outputs (#20784)
Browse files Browse the repository at this point in the history
Summary:
- [x] Add tests after #20256 is merged

- Support exporting ScriptModule with inputs/outputs of arbitrarily constructed tuples.

- Moved the assigning of output shapes to after graph conversion to ONNX is completed. By then all tuples in the IR has already been lowered by the pass ```_jit_pass_lower_all_tuples```. If assigning output shapes is required to happen before that, we'll need to hand parse the tuple structures in the graph, and repeat the same logic in ```_jit_pass_lower_all_tuples```. Handling inputs is easier because all tuple information is encoded within the input tensor type.

- Swap the order of ```_jit_pass_lower_all_tuples``` and ```_jit_pass_erase_number_types```. Ops like ```prim::TupleIndex``` relies on index being a scalar. ```_jit_pass_erase_number_types``` will convert these kind of scalars to tensors.
Pull Request resolved: #20784

Reviewed By: zrphercule

Differential Revision: D15484171

Pulled By: houseroad

fbshipit-source-id: 4767a84038244c929f5662758047af6cb92228d3
  • Loading branch information
BowenBao authored and facebook-github-bot committed Jun 13, 2019
1 parent 4c03ac7 commit a3db284
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 34 deletions.
23 changes: 23 additions & 0 deletions test/onnx/test_pytorch_onnx_caffe2.py
Expand Up @@ -1651,6 +1651,29 @@ def forward(self, lstm_in):

self.run_model_test(MyModel(), train=False, input=lstm_in, batch_size=3, use_gpu=False)

def test_tuple_input_output(self):
class TupleModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
# type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
return a

x = (torch.randn(3, 4), torch.randn(4, 3))
self.run_model_test(TupleModel(), train=False, input=(x,), batch_size=BATCH_SIZE,
example_outputs=(x,))

def test_nested_tuple_input_output(self):
class NestedTupleModel(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a, b):
# type: (Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
return a + b[0] + b[1][0] + b[1][1]

x = torch.randn(4, 5)
y = (torch.randn(4, 5), (torch.randn(4, 5), torch.randn(4, 5)))
self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE,
example_outputs=x + y[0] + y[1][0] + y[1][1])

def test_topk(self):
class TopKModel(torch.nn.Module):
def forward(self, input):
Expand Down
98 changes: 71 additions & 27 deletions torch/csrc/jit/script/init.cpp
Expand Up @@ -178,11 +178,58 @@ static Self moduleSelf(
};
}

static void setInputTensorTypes(Graph& g, const Stack& stack) {
AT_ASSERT(stack.size() == g.inputs().size());
for (size_t i = 0; i < stack.size(); ++i) {
g.inputs().at(i)->setType(
DimensionedTensorType::create(stack.at(i).toTensor()));
static TypePtr getTensorType(
const at::Tensor& t,
const TypeKind type_kind) {
switch (type_kind) {
case TypeKind::DimensionedTensorType:
return DimensionedTensorType::create(t);
case TypeKind::CompleteTensorType: {
auto scalar_type = t.scalar_type();
auto sizes = t.sizes();
return CompleteTensorType::create(scalar_type, at::kCPU, sizes);
}
default:
throw std::runtime_error(
"Attempted to call getTensorType for type kind other than DimensionedTensorType or CompleteTensorType.");
}
}

static TupleTypePtr getTupleTensorType(
const Stack::const_iterator& s_iter,
const Stack::const_iterator& s_iter_end,
const TypePtr& tupleType,
const TypeKind type_kind) {
AT_ASSERT(tupleType->kind() == TupleType::Kind);
AT_ASSERT(s_iter != s_iter_end);

std::vector<TypePtr> types;
for (const auto& subType : tupleType->containedTypes()) {
if (subType->kind() == TupleType::Kind) {
types.push_back(getTupleTensorType(s_iter+1, s_iter_end, subType, type_kind));
} else {
types.push_back(getTensorType(s_iter->toTensor(), type_kind));
}
}
return TupleType::create(types);
}

static void setInputTensorTypes(
Graph& g,
const Stack& stack,
const TypeKind type_kind = TypeKind::DimensionedTensorType) {
at::ArrayRef<Value*> input_values = g.inputs();
auto s_iter = stack.begin();
for (auto v : input_values) {
AT_ASSERT(s_iter != stack.end());
if (v->type()->kind() == TupleType::Kind) {
AT_ASSERT(v->node()->kind() == prim::Param);
v->setType(
getTupleTensorType(s_iter, stack.end(), v->type(), type_kind));
} else {
v->setType(getTensorType(s_iter->toTensor(), type_kind));
s_iter++;
}
}
}

Expand All @@ -197,38 +244,32 @@ static std::shared_ptr<Graph> _propagate_shapes(
return retval;
}

static std::shared_ptr<Graph> _propagate_and_assign_input_and_output_shapes(
static std::shared_ptr<Graph> _propagate_and_assign_input_shapes(
Graph& graph,
std::vector<at::Tensor> inputs,
std::vector<at::Tensor> outputs,
const std::vector<at::Tensor>& inputs,
bool with_grad = false,
bool propagate = true) {
auto retval = graph.copy();
if (propagate) {
setInputTensorTypes(*retval, fmap<IValue>(inputs));
setInputTensorTypes(*retval, fmap<IValue>(inputs), TypeKind::DimensionedTensorType);
PropagateInputShapes(retval);
}
AT_ASSERT(retval->inputs().size() == inputs.size());
for (size_t i = 0; i < retval->inputs().size(); ++i) {
auto scalar_type = inputs[i].scalar_type();
auto sizes = inputs[i].sizes();
auto type =
torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
retval->inputs()[i]->setType(type);
}
at::ArrayRef<Value*> output_values = retval->outputs();
// patch this to still work if we are returning a tuple of multiple values
if (output_values.at(0)->type()->kind() == TupleType::Kind) {
AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
output_values = output_values.at(0)->node()->inputs();
}
AT_ASSERT(output_values.size() == outputs.size());
setInputTensorTypes(*retval, fmap<IValue>(inputs), TypeKind::CompleteTensorType);

return retval;
}

static std::shared_ptr<Graph> _assign_output_shapes(
Graph& graph,
std::vector<at::Tensor> outputs) {
auto retval = graph.copy();
AT_ASSERT(retval->outputs().size() == outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
auto scalar_type = outputs[i].scalar_type();
auto sizes = outputs[i].sizes();
auto type =
torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
output_values[i]->setType(type);
retval->outputs()[i]->setType(type);
}
return retval;
}
Expand Down Expand Up @@ -679,8 +720,11 @@ void initJitScriptBindings(PyObject* module) {
debugSetAutodiffSubgraphInlining);
m.def("_propagate_shapes", _propagate_shapes);
m.def(
"_propagate_and_assign_input_and_output_shapes",
_propagate_and_assign_input_and_output_shapes);
"_propagate_and_assign_input_shapes",
_propagate_and_assign_input_shapes);
m.def(
"_assign_output_shapes",
_assign_output_shapes);
m.def("_jit_python_print", [](py::object obj) {
std::ostringstream ss;
std::vector<at::Tensor> constants;
Expand Down
20 changes: 13 additions & 7 deletions torch/onnx/utils.py
Expand Up @@ -16,7 +16,7 @@
from torch._six import string_classes
from torch.jit import _unique_state_dict
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes
from torch._C import ListType, _propagate_and_assign_input_and_output_shapes
from torch._C import ListType, _propagate_and_assign_input_shapes, _assign_output_shapes


# the flag to tell the user whether it's in the middle of ONNX export or not
Expand Down Expand Up @@ -214,10 +214,10 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa

# onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
torch._C._jit_pass_prepare_division_for_onnx(graph)
# onnx only supports tensors, so we turn all out number types into tensors
torch._C._jit_pass_erase_number_types(graph)
# onnx does not support tuples, so try to remove them
torch._C._jit_pass_lower_all_tuples(graph)
# onnx only supports tensors, so we turn all out number types into tensors
torch._C._jit_pass_erase_number_types(graph)
torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_lint(graph)

Expand Down Expand Up @@ -288,16 +288,18 @@ def _model_to_graph(model, args, verbose=False, training=False,
assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule"
try:
method_graph, params = model.forward._lowered_graph()
graph = _propagate_and_assign_input_and_output_shapes(
method_graph, tuple(args) + tuple(params), example_outputs, False, propagate)
in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
graph = _propagate_and_assign_input_shapes(
method_graph, tuple(in_vars), False, propagate)
except AttributeError:
raise RuntimeError('\'forward\' method must be a script method')
elif isinstance(model, torch.jit.Function):
assert example_outputs is not None, "example_outputs must be provided when exporting a TorchScript Function"
method = model
params = ()
graph = _propagate_and_assign_input_and_output_shapes(
model.graph, tuple(args), example_outputs, False, propagate)
in_vars, in_desc = torch.jit._flatten(tuple(args))
graph = _propagate_and_assign_input_shapes(
model.graph, tuple(in_vars), False, propagate)
else:
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
state_dict = _unique_state_dict(model)
Expand All @@ -313,6 +315,10 @@ def _model_to_graph(model, args, verbose=False, training=False,
graph = _optimize_graph(graph, operator_export_type,
_disable_torch_constant_prop=_disable_torch_constant_prop)

if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.Function):
out_vars, _ = torch.jit._flatten(tuple(example_outputs))
graph = _assign_output_shapes(graph, out_vars)

# NB: ONNX requires complete information about output types, which might be
# erased by some optimizations, so we need to set it explicitly again.
if torch_out is not None:
Expand Down

0 comments on commit a3db284

Please sign in to comment.