From 7cc7ed1322405ba3c627b9c5661a330f92c4183d Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sat, 23 Mar 2019 22:54:36 -0700 Subject: [PATCH] Specialize optional tensor inputs to graphs in the JIT (#18360) Summary: This specializes optional tensor inputs to either a DimensionedTensorType or, when None is passed, UndefinedTensor (aka AutogradZeroTensorType). This works because we already have different specs and thus separate plans for the two cases. It enhances the shape analysis - because now unwrapped optional tensors will have DimensionedTensorType with appropriate shape and required grad etc. Also, when combined with "if-pruning" (which I understand #18259 works towards), we actually get much nicer concrete graphs, too. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18360 Differential Revision: D14590577 Pulled By: soumith fbshipit-source-id: cac204a506d1d38b15703cbcc67a6b75fd4979f4 --- test/test_jit.py | 17 +++++++++++++++++ torch/csrc/jit/argument_spec.h | 3 ++- torch/csrc/jit/passes/shape_analysis.cpp | 17 +++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/test/test_jit.py b/test/test_jit.py index 8145c8819da5..17965a20043d 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4986,6 +4986,23 @@ def named_var_and(x, y): if y is not None and x_none: print(x + y) # noqa: T484 + def test_optional_tensor(self): + @torch.jit.script + def fn(x): + # type: (Optional[Tensor]) -> int + if x is None: + return 1 + else: + return 0 + + fn(None) + g = fn.graph_for(None) + self.assertEqual(list(g.inputs())[0].type().str(), 'UndefinedTensor') + t = torch.ones(1) + fn(t) + g = fn.graph_for(t) + self.assertEqual(list(g.inputs())[0].type().kind(), 'DimensionedTensorType') + def test_while_write_outer_then_read(self): def func(a, b): while bool(a < 10): diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index a345c6bdc0c4..74e9f07bb7a9 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -153,7 +153,8 @@ struct ArgumentSpec { private: TypePtr fillType(TypePtr original, size_t& offset) const { - if (original->isSubtypeOf(TensorType::get())) { + if (original->isSubtypeOf(TensorType::get()) + || original->isSubtypeOf(OptionalType::ofTensor())) { auto& arg = args.at(offset++); if (!arg.defined()) return AutogradZeroTensorType::get(); diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 8ac159b6e063..2ada059a1c41 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -513,6 +513,16 @@ class ShapePropagator { } return; } + case prim::unchecked_unwrap_optional: { + // we know we cannot have None as input, so we can always pass + // on the type. + if(auto ot = node->input()->type()->cast()) { + node->output()->setType(ot->getElementType()); + } else { + node->output()->setType(node->input()->type()); + } + return; + } case prim::ConstantChunk: { Value* tensor = node->input(); if (auto type = tensor->type()->cast()) { @@ -529,10 +539,17 @@ class ShapePropagator { return; } case aten::_unwrap_optional: { + // if we have None as input, we need to leave the output alone auto input_ivalue = toIValue(node->input()); if (input_ivalue && input_ivalue->isNone()) { return; } + if(auto ot = node->input()->type()->cast()) { + node->output()->setType(ot->getElementType()); + } else { + node->output()->setType(node->input()->type()); + } + return; } default: break; // fall-through