From c99e44a4783ef31bccfcaafafa29828bf904c3cc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 19 Jan 2022 12:14:39 +0100 Subject: [PATCH 1/2] fix and add test for sequence_to_str --- test/test_prototype_utils.py | 17 +++++++++++++++++ torchvision/prototype/utils/_internal.py | 7 ++++++- 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 test/test_prototype_utils.py diff --git a/test/test_prototype_utils.py b/test/test_prototype_utils.py new file mode 100644 index 00000000000..80e494a9c03 --- /dev/null +++ b/test/test_prototype_utils.py @@ -0,0 +1,17 @@ +import pytest +from torchvision.prototype.utils._internal import sequence_to_str + + +@pytest.mark.parametrize( + ("seq", "separate_last", "expected"), + [ + pytest.param([], "", "", id="empty"), + pytest.param(["foo"], "", "'foo'", id="single"), + pytest.param(["foo", "bar"], "", "'foo', 'bar'", id="double"), + pytest.param(["foo", "bar"], "and ", "'foo' and 'bar'", id="double-separate_last"), + pytest.param(["foo", "bar", "baz"], "", "'foo', 'bar', 'baz'", id="multi"), + pytest.param(["foo", "bar", "baz"], "and ", "'foo', 'bar', and 'baz'", id="multi-separate_last"), + ], +) +def test_sequence_to_str(seq, separate_last, expected): + assert sequence_to_str(seq, separate_last=separate_last) == expected diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 68128a2b381..9468dcf08a9 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -30,10 +30,15 @@ class StrEnum(enum.Enum, metaclass=StrEnumMeta): def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: + if not seq: + return "" if len(seq) == 1: return f"'{seq[0]}'" - return f"""'{"', '".join([str(item) for item in seq[:-1]])}', {separate_last}'{seq[-1]}'""" + head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'" + tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'" + + return head + tail def add_suggestion( From c9d82cfb87ef11314e87d22a5bd0f8a5248cb5d9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 19 Jan 2022 20:30:50 +0100 Subject: [PATCH 2/2] remove manual ids --- test/test_prototype_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_prototype_utils.py b/test/test_prototype_utils.py index 80e494a9c03..712debb607a 100644 --- a/test/test_prototype_utils.py +++ b/test/test_prototype_utils.py @@ -5,12 +5,12 @@ @pytest.mark.parametrize( ("seq", "separate_last", "expected"), [ - pytest.param([], "", "", id="empty"), - pytest.param(["foo"], "", "'foo'", id="single"), - pytest.param(["foo", "bar"], "", "'foo', 'bar'", id="double"), - pytest.param(["foo", "bar"], "and ", "'foo' and 'bar'", id="double-separate_last"), - pytest.param(["foo", "bar", "baz"], "", "'foo', 'bar', 'baz'", id="multi"), - pytest.param(["foo", "bar", "baz"], "and ", "'foo', 'bar', and 'baz'", id="multi-separate_last"), + ([], "", ""), + (["foo"], "", "'foo'"), + (["foo", "bar"], "", "'foo', 'bar'"), + (["foo", "bar"], "and ", "'foo' and 'bar'"), + (["foo", "bar", "baz"], "", "'foo', 'bar', 'baz'"), + (["foo", "bar", "baz"], "and ", "'foo', 'bar', and 'baz'"), ], ) def test_sequence_to_str(seq, separate_last, expected):