Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Union[NoneType, T] as input type #51606

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,6 +2244,32 @@ def forward(self, input, other=four):
t = Test()
self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)

def test_union_to_optional(self):
def test1(u: Union[int, None]) -> int:
if u is not None:
return u
else:
return 0
scripted = torch.jit.script(test1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional:

Use self.checkScript

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was told self.checkScript is the old way in one of my earlier PR's.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, didn't know about that. What's wrong with it and what's the alternative?

self.assertEqual(scripted(10), test1(10))

def test2(u: Union[None, int]) -> int:
if u is not None:
return u
else:
return 0
scripted = torch.jit.script(test2)
self.assertEqual(scripted(40), test2(40))

def test3(u: Union[float, int]) -> int:
if u is not None:
return u
else:
return 0
expected_result = "General Union types are not currently supported"
with self.assertRaisesRegex(RuntimeError, expected_result):
torch.jit.script(test3)

def test_mutable_default_values(self):
with self.assertRaisesRegex(Exception, "Mutable default parameters"):
@torch.jit.script
Expand Down
24 changes: 24 additions & 0 deletions torch/csrc/jit/frontend/script_type_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,30 @@ TypePtr ScriptTypeParser::subscriptToType(
auto elem_type =
parseTypeFromExprImpl(*subscript.subscript_exprs().begin());
return RRefType::create(elem_type);
} else if (typeName == "Union") {
// In Python 3.9+, Union[NoneType, T] or Union[T, NoneType] are
// treated as Optional[T]. Adding the same support for Union in Torchscript.
const char* const err =
"General Union types are not currently supported."
" Only Union[T, NoneType] (i.e. Optional[T]) is "
"supported.";
if (subscript.subscript_exprs().size() != 2) {
throw ErrorReport(subscript) << (err);
}
auto first_type = parseTypeFromExprImpl(subscript.subscript_exprs()[0]);
auto second_type = parseTypeFromExprImpl(subscript.subscript_exprs()[1]);

bool first_none = first_type == NoneType::get();
bool second_none = second_type == NoneType::get();

if (first_none && !second_none) {
return OptionalType::create(second_type);
} else if (!first_none && second_none) {
return OptionalType::create(first_type);
} else {
throw ErrorReport(subscript.range()) << err;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be clear, Union[None, None] would trigger this line and error out, is this intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That should be handled during the regular Union.

}

} else if (typeName == "Dict") {
if (subscript.subscript_exprs().size() != 2) {
throw ErrorReport(subscript)
Expand Down