diff --git a/test/test_prototype_utils.py b/test/test_prototype_utils.py new file mode 100644 index 00000000000..712debb607a --- /dev/null +++ b/test/test_prototype_utils.py @@ -0,0 +1,17 @@ +import pytest +from torchvision.prototype.utils._internal import sequence_to_str + + +@pytest.mark.parametrize( + ("seq", "separate_last", "expected"), + [ + ([], "", ""), + (["foo"], "", "'foo'"), + (["foo", "bar"], "", "'foo', 'bar'"), + (["foo", "bar"], "and ", "'foo' and 'bar'"), + (["foo", "bar", "baz"], "", "'foo', 'bar', 'baz'"), + (["foo", "bar", "baz"], "and ", "'foo', 'bar', and 'baz'"), + ], +) +def test_sequence_to_str(seq, separate_last, expected): + assert sequence_to_str(seq, separate_last=separate_last) == expected diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 68128a2b381..9468dcf08a9 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -30,10 +30,15 @@ class StrEnum(enum.Enum, metaclass=StrEnumMeta): def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: + if not seq: + return "" if len(seq) == 1: return f"'{seq[0]}'" - return f"""'{"', '".join([str(item) for item in seq[:-1]])}', {separate_last}'{seq[-1]}'""" + head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'" + tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'" + + return head + tail def add_suggestion(