From be181258bc437cae5cf62107ec7113662b701be9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 13:01:52 +0000 Subject: [PATCH 1/3] Moving `sequence_to_str` to `torchvision._utils` --- test/builtin_dataset_mocks.py | 2 +- test/test_prototype_builtin_datasets.py | 2 +- test/test_prototype_utils.py | 2 +- torchvision/_utils.py | 14 +++++++++++++- torchvision/prototype/datasets/utils/_dataset.py | 3 ++- torchvision/prototype/utils/_internal.py | 15 ++------------- 6 files changed, 20 insertions(+), 18 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 82bbea5494b..478474fef9d 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -20,7 +20,7 @@ from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor from torchvision.prototype.datasets._api import find -from torchvision.prototype.utils._internal import sequence_to_str +from torchvision._utils import sequence_to_str make_tensor = functools.partial(_make_tensor, device="cpu") make_scalar = functools.partial(make_tensor, ()) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 673158b00cd..94c46104317 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -11,7 +11,7 @@ from torch.utils.data.graph import traverse from torchdata.datapipes.iter import IterDataPipe, Shuffler from torchvision.prototype import transforms, datasets -from torchvision.prototype.utils._internal import sequence_to_str +from torchvision._utils import sequence_to_str assert_samples_equal = functools.partial( diff --git a/test/test_prototype_utils.py b/test/test_prototype_utils.py index 712debb607a..f5f8a040db9 100644 --- a/test/test_prototype_utils.py +++ b/test/test_prototype_utils.py @@ -1,5 +1,5 @@ import pytest -from torchvision.prototype.utils._internal import sequence_to_str +from torchvision._utils import sequence_to_str @pytest.mark.parametrize( diff --git a/torchvision/_utils.py b/torchvision/_utils.py index da0eb923f75..8e8fe1b8a83 100644 --- a/torchvision/_utils.py +++ b/torchvision/_utils.py @@ -1,5 +1,5 @@ import enum -from typing import TypeVar, Type +from typing import Sequence, TypeVar, Type T = TypeVar("T", bound=enum.Enum) @@ -18,3 +18,15 @@ def from_str(self: Type[T], member: str) -> T: # type: ignore[misc] class StrEnum(enum.Enum, metaclass=StrEnumMeta): pass + + +def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: + if not seq: + return "" + if len(seq) == 1: + return f"'{seq[0]}'" + + head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'" + tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'" + + return head + tail diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 5ee7c5ccc60..b5c6d7acb60 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -7,7 +7,8 @@ from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection from torch.utils.data import IterDataPipe -from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str +from torchvision._utils import sequence_to_str +from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion from .._home import use_sharded_dataset from ._internal import BUILTIN_DIR, _make_sharded_datapipe diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 864bff9ce02..147a7f0ff4c 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -28,9 +28,10 @@ import numpy as np import torch +from torchvision._utils import sequence_to_str + __all__ = [ - "sequence_to_str", "add_suggestion", "FrozenMapping", "make_repr", @@ -43,18 +44,6 @@ ] -def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: - if not seq: - return "" - if len(seq) == 1: - return f"'{seq[0]}'" - - head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'" - tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'" - - return head + tail - - def add_suggestion( msg: str, *, From 57280e7909e99c981d3b1f7dbe9c8d5d4638fa4b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 13:13:04 +0000 Subject: [PATCH 2/3] Fix linter --- test/builtin_dataset_mocks.py | 2 +- test/test_prototype_builtin_datasets.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 478474fef9d..62259a604a0 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -19,8 +19,8 @@ from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor -from torchvision.prototype.datasets._api import find from torchvision._utils import sequence_to_str +from torchvision.prototype.datasets._api import find make_tensor = functools.partial(_make_tensor, device="cpu") make_scalar = functools.partial(make_tensor, ()) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 94c46104317..f7c40d432a4 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -10,8 +10,8 @@ from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.graph import traverse from torchdata.datapipes.iter import IterDataPipe, Shuffler -from torchvision.prototype import transforms, datasets from torchvision._utils import sequence_to_str +from torchvision.prototype import transforms, datasets assert_samples_equal = functools.partial( From 4ccc7f1c923c38ab99efb961016be3df9bc7bb7f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 13:14:06 +0000 Subject: [PATCH 3/3] Rename test_prototype_utils test to test_internal_utils --- test/{test_prototype_utils.py => test_internal_utils.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/{test_prototype_utils.py => test_internal_utils.py} (100%) diff --git a/test/test_prototype_utils.py b/test/test_internal_utils.py similarity index 100% rename from test/test_prototype_utils.py rename to test/test_internal_utils.py