diff --git a/test/test_jit.py b/test/test_jit.py index cadaafc0e8eb..eb6175c4634b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2244,6 +2244,32 @@ def forward(self, input, other=four): t = Test() self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4) + def test_union_to_optional(self): + def test1(u: Union[int, None]) -> int: + if u is not None: + return u + else: + return 0 + scripted = torch.jit.script(test1) + self.assertEqual(scripted(10), test1(10)) + + def test2(u: Union[None, int]) -> int: + if u is not None: + return u + else: + return 0 + scripted = torch.jit.script(test2) + self.assertEqual(scripted(40), test2(40)) + + def test3(u: Union[float, int]) -> int: + if u is not None: + return u + else: + return 0 + expected_result = "General Union types are not currently supported" + with self.assertRaisesRegex(RuntimeError, expected_result): + torch.jit.script(test3) + def test_mutable_default_values(self): with self.assertRaisesRegex(Exception, "Mutable default parameters"): @torch.jit.script diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index 64837e68e881..6eed7f376d92 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -70,6 +70,30 @@ TypePtr ScriptTypeParser::subscriptToType( auto elem_type = parseTypeFromExprImpl(*subscript.subscript_exprs().begin()); return RRefType::create(elem_type); + } else if (typeName == "Union") { + // In Python 3.9+, Union[NoneType, T] or Union[T, NoneType] are + // treated as Optional[T]. Adding the same support for Union in Torchscript. + const char* const err = + "General Union types are not currently supported." + " Only Union[T, NoneType] (i.e. Optional[T]) is " + "supported."; + if (subscript.subscript_exprs().size() != 2) { + throw ErrorReport(subscript) << (err); + } + auto first_type = parseTypeFromExprImpl(subscript.subscript_exprs()[0]); + auto second_type = parseTypeFromExprImpl(subscript.subscript_exprs()[1]); + + bool first_none = first_type == NoneType::get(); + bool second_none = second_type == NoneType::get(); + + if (first_none && !second_none) { + return OptionalType::create(second_type); + } else if (!first_none && second_none) { + return OptionalType::create(first_type); + } else { + throw ErrorReport(subscript.range()) << err; + } + } else if (typeName == "Dict") { if (subscript.subscript_exprs().size() != 2) { throw ErrorReport(subscript)