diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index e78ab88da27..88f3277de58 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -1,3 +1,4 @@ +import enum import io import pathlib from typing import Any, Callable, Dict, List, Optional, Tuple @@ -30,6 +31,12 @@ from torchvision.prototype.features import Label +class DTDDemux(enum.IntEnum): + SPLIT = 0 + JOINT_CATEGORIES = 1 + IMAGES = 2 + + class DTD(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( @@ -54,11 +61,11 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: path = pathlib.Path(data[0]) if path.parent.name == "labels": if path.name == "labels_joint_anno.txt": - return 1 + return DTDDemux.JOINT_CATEGORIES - return 0 + return DTDDemux.SPLIT elif path.parents[1].name == "images": - return 2 + return DTDDemux.IMAGES else: return None @@ -122,7 +129,7 @@ def _make_datapipe( return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) def _filter_images(self, data: Tuple[str, Any]) -> bool: - return self._classify_archive(data) == 2 + return self._classify_archive(data) == DTDDemux.IMAGES def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index 4e43613715e..99b1f643b61 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -1,3 +1,4 @@ +import enum import functools import io import pathlib @@ -24,6 +25,11 @@ from torchvision.prototype.features import Label +class OxfordIITPetDemux(enum.IntEnum): + SPLIT_AND_CLASSIFICATION = 0 + SEGMENTATIONS = 1 + + class OxfordIITPet(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( @@ -51,8 +57,8 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]: return { - "annotations": 0, - "trimaps": 1, + "annotations": OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION, + "trimaps": OxfordIITPetDemux.SEGMENTATIONS, }.get(pathlib.Path(data[0]).parent.name) def _filter_images(self, data: Tuple[str, Any]) -> bool: @@ -135,7 +141,7 @@ def _make_datapipe( return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool: - return self._classify_anns(data) == 0 + return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION def _generate_categories(self, root: pathlib.Path) -> List[str]: config = self.default_config