diff --git a/test/test_jit.py b/test/test_jit.py index 8145c8819da5..17965a20043d 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4986,6 +4986,23 @@ def named_var_and(x, y): if y is not None and x_none: print(x + y) # noqa: T484 + def test_optional_tensor(self): + @torch.jit.script + def fn(x): + # type: (Optional[Tensor]) -> int + if x is None: + return 1 + else: + return 0 + + fn(None) + g = fn.graph_for(None) + self.assertEqual(list(g.inputs())[0].type().str(), 'UndefinedTensor') + t = torch.ones(1) + fn(t) + g = fn.graph_for(t) + self.assertEqual(list(g.inputs())[0].type().kind(), 'DimensionedTensorType') + def test_while_write_outer_then_read(self): def func(a, b): while bool(a < 10): diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index a345c6bdc0c4..74e9f07bb7a9 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -153,7 +153,8 @@ struct ArgumentSpec { private: TypePtr fillType(TypePtr original, size_t& offset) const { - if (original->isSubtypeOf(TensorType::get())) { + if (original->isSubtypeOf(TensorType::get()) + || original->isSubtypeOf(OptionalType::ofTensor())) { auto& arg = args.at(offset++); if (!arg.defined()) return AutogradZeroTensorType::get(); diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 8ac159b6e063..2ada059a1c41 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -513,6 +513,16 @@ class ShapePropagator { } return; } + case prim::unchecked_unwrap_optional: { + // we know we cannot have None as input, so we can always pass + // on the type. + if(auto ot = node->input()->type()->cast()) { + node->output()->setType(ot->getElementType()); + } else { + node->output()->setType(node->input()->type()); + } + return; + } case prim::ConstantChunk: { Value* tensor = node->input(); if (auto type = tensor->type()->cast()) { @@ -529,10 +539,17 @@ class ShapePropagator { return; } case aten::_unwrap_optional: { + // if we have None as input, we need to leave the output alone auto input_ivalue = toIValue(node->input()); if (input_ivalue && input_ivalue->isNone()) { return; } + if(auto ot = node->input()->type()->cast()) { + node->output()->setType(ot->getElementType()); + } else { + node->output()->setType(node->input()->type()); + } + return; } default: break; // fall-through