From 105a0fb200808e66968a0cc8bdb8f0daad2c31f0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 20 Sep 2021 12:29:00 +0200 Subject: [PATCH 1/9] add prototype for CIFAR datasets Conflicts: torchvision/prototype/datasets/_builtin/__init__.py torchvision/prototype/datasets/utils/_internal.py --- torchvision/prototype/datasets/_api.py | 11 + .../prototype/datasets/_builtin/__init__.py | 1 + .../prototype/datasets/_builtin/cifar.py | 232 ++++++++++++++++++ .../datasets/_builtin/cifar10.categories | 10 + .../datasets/_builtin/cifar100.categories | 100 ++++++++ torchvision/prototype/datasets/_folder.py | 40 ++- .../prototype/datasets/utils/_dataset.py | 8 +- .../prototype/datasets/utils/_internal.py | 61 ++++- 8 files changed, 454 insertions(+), 9 deletions(-) create mode 100644 torchvision/prototype/datasets/_builtin/__init__.py create mode 100644 torchvision/prototype/datasets/_builtin/cifar.py create mode 100644 torchvision/prototype/datasets/_builtin/cifar10.categories create mode 100644 torchvision/prototype/datasets/_builtin/cifar100.categories diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 29dce26dd0c..d0bf2455f36 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -8,6 +8,7 @@ from torchvision.prototype.datasets.decoder import pil from torchvision.prototype.datasets.utils import Dataset, DatasetInfo from torchvision.prototype.datasets.utils._internal import add_suggestion +from . import _builtin DATASETS: Dict[str, Dataset] = {} @@ -17,6 +18,16 @@ def register(dataset: Dataset) -> None: DATASETS[dataset.name] = dataset +for name, obj in _builtin.__dict__.items(): + if ( + not name.startswith("_") + and isinstance(obj, type) + and issubclass(obj, Dataset) + and obj is not Dataset + ): + register(obj()) + + # This is exposed as 'list', but we avoid that here to not shadow the built-in 'list' def _list() -> List[str]: return sorted(DATASETS.keys()) diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py new file mode 100644 index 00000000000..13983748a60 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -0,0 +1 @@ +from .cifar import Cifar10, Cifar100 diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py new file mode 100644 index 00000000000..3cc0e47b75f --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -0,0 +1,232 @@ +import abc +import functools +import io +import os.path +import pathlib +import pickle +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TypeVar + +import numpy as np + +import torch +from torch.utils.data import IterDataPipe +from torch.utils.data.datapipes.iter import ( + Demultiplexer, + Filter, + Mapper, + TarArchiveReader, + Shuffler, +) +from torchdata.datapipes.iter import KeyZipper + +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, +) +from torchvision.prototype.datasets.utils._internal import ( + create_categories_file, + MappingIterator, + SequenceIterator, + INFINITE_BUFFER_SIZE, + image_buffer_from_array, + Enumerator, +) + +__all__ = ["Cifar10", "Cifar100"] + +HERE = pathlib.Path(__file__).parent + +D = TypeVar("D") + + +class _CifarBase(Dataset): + @abc.abstractmethod + def _is_data_file( + self, data: Tuple[str, io.IOBase], *, config: DatasetConfig + ) -> Optional[int]: + pass + + @abc.abstractmethod + def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]: + pass + + def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: + _, file = data + return pickle.load(file, encoding="latin1") + + def _remove_data_dict_key(self, data: Tuple[str, D]) -> D: + return data[1] + + def _key_fn(self, data: Tuple[int, Any]) -> int: + return data[0] + + def _collate_and_decode( + self, data: Tuple[Tuple[int, int], Tuple[int, np.ndarray]], + *, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> Dict[str, Any]: + (_, category_idx), (_, image_array_flat) = data + + image_array = image_array_flat.reshape((3, 32, 32)).transpose(1, 2, 0) + image_buffer = image_buffer_from_array(image_array) + + category = self.categories[category_idx] + label = torch.tensor(category_idx) + + return dict(image=decoder(image_buffer) if decoder else image_buffer, label=label, category=category) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + archive_dp = resource_dps[0] + archive_dp = TarArchiveReader(archive_dp) + archive_dp = Filter( + archive_dp, functools.partial(self._is_data_file, config=config) + ) + archive_dp = Mapper(archive_dp, self._unpickle) + archive_dp = MappingIterator(archive_dp) + images_dp, labels_dp = Demultiplexer( + archive_dp, + 2, + self._split_data_file, # type: ignore[arg-type] + drop_none=True, + buffer_size=INFINITE_BUFFER_SIZE, + ) + + labels_dp = Mapper(labels_dp, self._remove_data_dict_key) + labels_dp = SequenceIterator(labels_dp) + labels_dp = Enumerator(labels_dp) + labels_dp = Shuffler(labels_dp, buffer_size=INFINITE_BUFFER_SIZE) + + images_dp = Mapper(images_dp, self._remove_data_dict_key) + images_dp = SequenceIterator(images_dp) + images_dp = Enumerator(images_dp) + + dp = KeyZipper(labels_dp, images_dp, self._key_fn, buffer_size=INFINITE_BUFFER_SIZE) + return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) + + @property + @abc.abstractmethod + def _meta_file_name(self) -> str: + pass + + @property + @abc.abstractmethod + def _categories_key(self) -> str: + pass + + def _is_meta_file(self, data: Tuple[str, Any]) -> bool: + path = pathlib.Path(data[0]) + return path.name == self._meta_file_name + + def generate_categories_file( + self, root: Union[str, pathlib.Path] + ) -> None: + dp = self.resources(self.default_config)[0].to_datapipe( + pathlib.Path(root) / self.name + ) + dp = TarArchiveReader(dp) + dp = Filter(dp, self._is_meta_file) + dp = Mapper(dp, self._unpickle) + categories = next(iter(dp))[self._categories_key] + create_categories_file(HERE, self.name, categories) + + +class Cifar10(_CifarBase): + @property + def info(self) -> DatasetInfo: + return DatasetInfo( + "cifar10", + categories=HERE / "cifar10.categories", + homepage="https://www.cs.toronto.edu/~kriz/cifar.html", + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + return [ + HttpResource( + "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz", + sha256="6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce", + ) + ] + + def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool: + path, _ = data + name = os.path.basename(path) + return name.startswith("data" if config.split == "train" else "test") + + def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]: + key, _ = data + if key == "data": + return 0 + elif key == "labels": + return 1 + else: + return None + + @property + def _meta_file_name(self) -> str: + return "batches.meta" + + @property + def _categories_key(self) -> str: + return "label_names" + + +class Cifar100(_CifarBase): + @property + def info(self) -> DatasetInfo: + return DatasetInfo( + "cifar100", + categories=HERE / "cifar100.categories", + homepage="https://www.cs.toronto.edu/~kriz/cifar.html", + valid_options=dict( + split=("train", "test"), + ), + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + return [ + HttpResource( + "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz", + sha256="85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7", + ) + ] + + def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool: + path, _ = data + name = os.path.basename(path) + return name == config.split + + def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]: + key, _ = data + if key == "data": + return 0 + elif key == "fine_labels": + return 1 + else: + return None + + @property + def _meta_file_name(self) -> str: + return "meta" + + @property + def _categories_key(self) -> str: + return "fine_label_names" + + +if __name__ == "__main__": + from torchvision.prototype.datasets import home + + home("~/datasets") + + root = home() + Cifar10().generate_categories_file(root) + Cifar100().generate_categories_file(root) diff --git a/torchvision/prototype/datasets/_builtin/cifar10.categories b/torchvision/prototype/datasets/_builtin/cifar10.categories new file mode 100644 index 00000000000..fa30c22b95d --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/cifar10.categories @@ -0,0 +1,10 @@ +airplane +automobile +bird +cat +deer +dog +frog +horse +ship +truck diff --git a/torchvision/prototype/datasets/_builtin/cifar100.categories b/torchvision/prototype/datasets/_builtin/cifar100.categories new file mode 100644 index 00000000000..7f7bf51d1ab --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/cifar100.categories @@ -0,0 +1,100 @@ +apple +aquarium_fish +baby +bear +beaver +bed +bee +beetle +bicycle +bottle +bowl +boy +bridge +bus +butterfly +camel +can +castle +caterpillar +cattle +chair +chimpanzee +clock +cloud +cockroach +couch +crab +crocodile +cup +dinosaur +dolphin +elephant +flatfish +forest +fox +girl +hamster +house +kangaroo +keyboard +lamp +lawn_mower +leopard +lion +lizard +lobster +man +maple_tree +motorcycle +mountain +mouse +mushroom +oak_tree +orange +orchid +otter +palm_tree +pear +pickup_truck +pine_tree +plain +plate +poppy +porcupine +possum +rabbit +raccoon +ray +road +rocket +rose +sea +seal +shark +shrew +skunk +skyscraper +snail +snake +spider +squirrel +streetcar +sunflower +sweet_pepper +table +tank +telephone +television +tiger +tractor +train +trout +tulip +turtle +wardrobe +whale +willow_tree +wolf +woman +worm diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index 5626f68650f..240484e839f 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -7,7 +7,13 @@ import torch from torch.utils.data import IterDataPipe -from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter +from torch.utils.data.datapipes.iter import ( + FileLister, + FileLoader, + Mapper, + Shuffler, + Filter, +) from torchvision.prototype.datasets.decoder import pil from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE @@ -49,13 +55,19 @@ def from_data_folder( ) -> Tuple[IterDataPipe, List[str]]: root = pathlib.Path(root).expanduser().resolve() categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) - masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" + masks: Union[List[str], str] = ( + [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" + ) dp: IterDataPipe = FileLister(str(root), recursive=recursive, masks=masks) dp = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root)) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = FileLoader(dp) return ( - Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)), + Mapper( + dp, + _collate_and_decode_data, + fn_kwargs=dict(root=root, categories=categories, decoder=decoder), + ), categories, ) @@ -69,9 +81,25 @@ def from_image_folder( root: Union[str, pathlib.Path], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, - valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"), + valid_extensions: Collection[str] = ( + "jpg", + "jpeg", + "png", + "ppm", + "bmp", + "pgm", + "tif", + "tiff", + "webp", + ), **kwargs: Any, ) -> Tuple[IterDataPipe, List[str]]: - valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())] - dp, categories = from_data_folder(root, decoder=decoder, valid_extensions=valid_extensions, **kwargs) + valid_extensions = [ + valid_extension + for ext in valid_extensions + for valid_extension in (ext.lower(), ext.upper()) + ] + dp, categories = from_data_folder( + root, decoder=decoder, valid_extensions=valid_extensions, **kwargs + ) return Mapper(dp, _data_to_image_key), categories diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index b3cf53afc8d..04c71be2983 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -112,8 +112,8 @@ def __init__( categories = [str(label) for label in range(categories)] elif isinstance(categories, (str, pathlib.Path)): with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh: - categories = fh.readlines() - self.categories = categories + categories = [line.strip() for line in fh] + self.categories = tuple(categories) self.citation = citation self.homepage = homepage @@ -190,6 +190,10 @@ def name(self) -> str: def default_config(self) -> DatasetConfig: return self.info.default_config + @property + def categories(self) -> Tuple[str, ...]: + return self.info.categories + @abc.abstractmethod def resources(self, config: DatasetConfig) -> List[OnlineResource]: pass diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index ad4f70145d5..f25ee67f881 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -1,14 +1,30 @@ import collections.abc import difflib -from typing import Collection, Sequence, Callable +import io +import pathlib +from typing import Collection, Sequence, Callable, Union, Iterator, Tuple, TypeVar, Dict + +import numpy as np +import PIL.Image +from torch.utils.data import IterDataPipe __all__ = [ "INFINITE_BUFFER_SIZE", "sequence_to_str", "add_suggestion", + "create_categories_file", + "image_buffer_from_array", + "SequenceIterator", + "MappingIterator", + "Enumerator", ] + +K = TypeVar("K") +D = TypeVar("D") + + # pseudo-infinite until a true infinite buffer is supported by all datapipes INFINITE_BUFFER_SIZE = 1_000_000_000 @@ -44,3 +60,46 @@ def add_suggestion( else alternative_hint(possibilities) ) return f"{msg.strip()} {hint}" + + +def create_categories_file( + root: Union[str, pathlib.Path], name: str, categories: Sequence[str] +) -> None: + with open(pathlib.Path(root) / f"{name}.categories", "w") as fh: + fh.write("\n".join(categories) + "\n") + + +def image_buffer_from_array(array: np.array, *, format: str = "png") -> io.BytesIO: + image = PIL.Image.fromarray(array) + buffer = io.BytesIO() + image.save(buffer, format=format) + buffer.seek(0) + return buffer + + +class SequenceIterator(IterDataPipe[D]): + def __init__(self, datapipe: IterDataPipe[Sequence[D]]): + self.datapipe = datapipe + + def __iter__(self) -> Iterator[D]: + for sequence in self.datapipe: + yield from iter(sequence) + + +class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]): + def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None: + self.datapipe = datapipe + self.drop_key = drop_key + + def __iter__(self) -> Iterator[Union[Tuple[K, D], D]]: + for mapping in self.datapipe: + yield from iter(mapping.values() if self.drop_key else mapping.items()) # type: ignore[call-overload] + + +class Enumerator(IterDataPipe[Tuple[int, D]]): + def __init__(self, datapipe: IterDataPipe[D], start: int = 0) -> None: + self.datapipe = datapipe + self.start = start + + def __iter__(self) -> Iterator[Tuple[int, D]]: + yield from enumerate(self.datapipe, self.start) From 07d28d81b59861e2259ae6164597901a4157d8d2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 30 Sep 2021 14:22:59 +0200 Subject: [PATCH 2/9] fix mypy --- torchvision/prototype/datasets/utils/_internal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index f25ee67f881..647f91c9655 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -69,7 +69,7 @@ def create_categories_file( fh.write("\n".join(categories) + "\n") -def image_buffer_from_array(array: np.array, *, format: str = "png") -> io.BytesIO: +def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.BytesIO: image = PIL.Image.fromarray(array) buffer = io.BytesIO() image.save(buffer, format=format) From 3341792fefbd4a173e914ddd35090142f84800a3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 30 Sep 2021 14:41:17 +0200 Subject: [PATCH 3/9] cleanup --- main.py | 9 --------- torchvision/prototype/datasets/_builtin/cifar.py | 2 -- 2 files changed, 11 deletions(-) delete mode 100644 main.py diff --git a/main.py b/main.py deleted file mode 100644 index 2b1e4f34cf7..00000000000 --- a/main.py +++ /dev/null @@ -1,9 +0,0 @@ -from torchvision.prototype import datasets -import tqdm - -datasets.home("~/datasets") - -dataset = datasets.load("caltech101", decoder=None) - -for sample in tqdm.tqdm(dataset): - pass diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 3cc0e47b75f..94e5c41c45b 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -225,8 +225,6 @@ def _categories_key(self) -> str: if __name__ == "__main__": from torchvision.prototype.datasets import home - home("~/datasets") - root = home() Cifar10().generate_categories_file(root) Cifar100().generate_categories_file(root) From 58b37b128c26471b3792b47d7155cbe0b524f097 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 30 Sep 2021 14:57:30 +0200 Subject: [PATCH 4/9] more cleanup --- torchvision/prototype/datasets/_builtin/cifar.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 94e5c41c45b..e3d825d2988 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -1,7 +1,6 @@ import abc import functools import io -import os.path import pathlib import pickle from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TypeVar @@ -156,10 +155,9 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) ] - def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool: - path, _ = data - name = os.path.basename(path) - return name.startswith("data" if config.split == "train" else "test") + def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool: + path = pathlib.Path(data[0]) + return path.name.startswith("data" if config.split == "train" else "test") def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]: key, _ = data @@ -200,9 +198,8 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ] def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool: - path, _ = data - name = os.path.basename(path) - return name == config.split + path = pathlib.Path(data[0]) + return path.name == config.split def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]: key, _ = data From 341130ca60c0b4354445a2254f2d183fcc2ec6dc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 1 Oct 2021 19:33:19 +0200 Subject: [PATCH 5/9] revert unrelated changes --- torchvision/prototype/datasets/_folder.py | 40 ++++------------------- 1 file changed, 6 insertions(+), 34 deletions(-) diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index 240484e839f..5626f68650f 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -7,13 +7,7 @@ import torch from torch.utils.data import IterDataPipe -from torch.utils.data.datapipes.iter import ( - FileLister, - FileLoader, - Mapper, - Shuffler, - Filter, -) +from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter from torchvision.prototype.datasets.decoder import pil from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE @@ -55,19 +49,13 @@ def from_data_folder( ) -> Tuple[IterDataPipe, List[str]]: root = pathlib.Path(root).expanduser().resolve() categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) - masks: Union[List[str], str] = ( - [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" - ) + masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" dp: IterDataPipe = FileLister(str(root), recursive=recursive, masks=masks) dp = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root)) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = FileLoader(dp) return ( - Mapper( - dp, - _collate_and_decode_data, - fn_kwargs=dict(root=root, categories=categories, decoder=decoder), - ), + Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)), categories, ) @@ -81,25 +69,9 @@ def from_image_folder( root: Union[str, pathlib.Path], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, - valid_extensions: Collection[str] = ( - "jpg", - "jpeg", - "png", - "ppm", - "bmp", - "pgm", - "tif", - "tiff", - "webp", - ), + valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"), **kwargs: Any, ) -> Tuple[IterDataPipe, List[str]]: - valid_extensions = [ - valid_extension - for ext in valid_extensions - for valid_extension in (ext.lower(), ext.upper()) - ] - dp, categories = from_data_folder( - root, decoder=decoder, valid_extensions=valid_extensions, **kwargs - ) + valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())] + dp, categories = from_data_folder(root, decoder=decoder, valid_extensions=valid_extensions, **kwargs) return Mapper(dp, _data_to_image_key), categories From df455c5e2faf3eebe451cd3149b61b08fcdaad37 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Oct 2021 08:25:48 +0200 Subject: [PATCH 6/9] fix code format --- .../prototype/datasets/_builtin/cifar.py | 21 ++++++------------- .../prototype/datasets/utils/_internal.py | 1 - 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index e3d825d2988..c9fe70463de 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -6,7 +6,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TypeVar import numpy as np - import torch from torch.utils.data import IterDataPipe from torch.utils.data.datapipes.iter import ( @@ -17,7 +16,6 @@ Shuffler, ) from torchdata.datapipes.iter import KeyZipper - from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, @@ -43,9 +41,7 @@ class _CifarBase(Dataset): @abc.abstractmethod - def _is_data_file( - self, data: Tuple[str, io.IOBase], *, config: DatasetConfig - ) -> Optional[int]: + def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> Optional[int]: pass @abc.abstractmethod @@ -63,7 +59,8 @@ def _key_fn(self, data: Tuple[int, Any]) -> int: return data[0] def _collate_and_decode( - self, data: Tuple[Tuple[int, int], Tuple[int, np.ndarray]], + self, + data: Tuple[Tuple[int, int], Tuple[int, np.ndarray]], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> Dict[str, Any]: @@ -86,9 +83,7 @@ def _make_datapipe( ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] archive_dp = TarArchiveReader(archive_dp) - archive_dp = Filter( - archive_dp, functools.partial(self._is_data_file, config=config) - ) + archive_dp = Filter(archive_dp, functools.partial(self._is_data_file, config=config)) archive_dp = Mapper(archive_dp, self._unpickle) archive_dp = MappingIterator(archive_dp) images_dp, labels_dp = Demultiplexer( @@ -125,12 +120,8 @@ def _is_meta_file(self, data: Tuple[str, Any]) -> bool: path = pathlib.Path(data[0]) return path.name == self._meta_file_name - def generate_categories_file( - self, root: Union[str, pathlib.Path] - ) -> None: - dp = self.resources(self.default_config)[0].to_datapipe( - pathlib.Path(root) / self.name - ) + def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: + dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = TarArchiveReader(dp) dp = Filter(dp, self._is_meta_file) dp = Mapper(dp, self._unpickle) diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index c61f577af2f..3db014ef4ee 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -6,7 +6,6 @@ import numpy as np import PIL.Image - from torch.utils.data import IterDataPipe From fab6cd62fdbcac392942263787c204f845997281 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Oct 2021 17:09:29 +0200 Subject: [PATCH 7/9] avoid decoding twice by default --- torchvision/prototype/datasets/__init__.py | 5 ++--- torchvision/prototype/datasets/_api.py | 14 ++++++++++++-- torchvision/prototype/datasets/_builtin/caltech.py | 3 +++ torchvision/prototype/datasets/_builtin/cifar.py | 12 +++++++++--- torchvision/prototype/datasets/utils/__init__.py | 2 +- torchvision/prototype/datasets/utils/_dataset.py | 8 ++++++++ 6 files changed, 35 insertions(+), 9 deletions(-) diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index c677cff0878..b08ddb93c31 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -7,10 +7,9 @@ "Note that you cannot install it with `pip install torchdata`, since this is another package." ) from error - from . import decoder, utils +from ._home import home # Load this last, since some parts depend on the above being loaded first -from ._api import register, _list as list, info, load +from ._api import register, _list as list, info, load # usort: skip from ._folder import from_data_folder, from_image_folder -from ._home import home diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 5c613035e2b..5ad8215e25f 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -5,7 +5,7 @@ from torch.utils.data import IterDataPipe from torchvision.prototype.datasets import home from torchvision.prototype.datasets.decoder import pil -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType from torchvision.prototype.datasets.utils._internal import add_suggestion from . import _builtin @@ -48,15 +48,25 @@ def info(name: str) -> DatasetInfo: return find(name).info +default = object() + +DEFAULT_DECODER: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = { + DatasetType.IMAGE: pil, +} + + def load( name: str, *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = default, # type: ignore[assignment] split: str = "train", **options: Any, ) -> IterDataPipe[Dict[str, Any]]: dataset = find(name) + if decoder is default: + decoder = DEFAULT_DECODER.get(dataset.info.type) + config = dataset.info.make_config(split=split, **options) root = home() / name diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index d2ce41c0d0f..c0e65b3c745 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -19,6 +19,7 @@ DatasetInfo, HttpResource, OnlineResource, + DatasetType, ) from torchvision.prototype.datasets.utils._internal import create_categories_file, INFINITE_BUFFER_SIZE, read_mat @@ -30,6 +31,7 @@ class Caltech101(Dataset): def info(self) -> DatasetInfo: return DatasetInfo( "caltech101", + type=DatasetType.IMAGE, categories=HERE / "caltech101.categories", homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", ) @@ -146,6 +148,7 @@ class Caltech256(Dataset): def info(self) -> DatasetInfo: return DatasetInfo( "caltech256", + type=DatasetType.IMAGE, categories=HERE / "caltech256.categories", homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256", ) diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index c9fe70463de..0e945e4e379 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -22,6 +22,7 @@ DatasetInfo, HttpResource, OnlineResource, + DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( create_categories_file, @@ -66,13 +67,16 @@ def _collate_and_decode( ) -> Dict[str, Any]: (_, category_idx), (_, image_array_flat) = data - image_array = image_array_flat.reshape((3, 32, 32)).transpose(1, 2, 0) - image_buffer = image_buffer_from_array(image_array) + image_array = image_array_flat.reshape((3, 32, 32)) + if decoder: + image = decoder(image_buffer_from_array(image_array.transpose(1, 2, 0))) + else: + image = torch.from_numpy(image_array) category = self.categories[category_idx] label = torch.tensor(category_idx) - return dict(image=decoder(image_buffer) if decoder else image_buffer, label=label, category=category) + return dict(image=image, label=label, category=category) def _make_datapipe( self, @@ -134,6 +138,7 @@ class Cifar10(_CifarBase): def info(self) -> DatasetInfo: return DatasetInfo( "cifar10", + type=DatasetType.PRE_DECODED, categories=HERE / "cifar10.categories", homepage="https://www.cs.toronto.edu/~kriz/cifar.html", ) @@ -173,6 +178,7 @@ class Cifar100(_CifarBase): def info(self) -> DatasetInfo: return DatasetInfo( "cifar100", + type=DatasetType.PRE_DECODED, categories=HERE / "cifar100.categories", homepage="https://www.cs.toronto.edu/~kriz/cifar.html", valid_options=dict( diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 48e7541eba5..018553e0908 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,3 +1,3 @@ from . import _internal -from ._dataset import DatasetConfig, DatasetInfo, Dataset +from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index f3d022dcebc..0998635938e 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -1,4 +1,5 @@ import abc +import enum import io import os import pathlib @@ -45,6 +46,11 @@ def to_str(sep: str) -> str: return f"{prefix}\n{body}\n{postfix}" +class DatasetType(enum.Enum): + PRE_DECODED = enum.auto() + IMAGE = enum.auto() + + class DatasetConfig(Mapping): def __init__(self, *args, **kwargs): data = dict(*args, **kwargs) @@ -96,6 +102,7 @@ def __init__( self, name: str, *, + type: Union[str, DatasetType], categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, citation: Optional[str] = None, homepage: Optional[str] = None, @@ -103,6 +110,7 @@ def __init__( valid_options: Optional[Dict[str, Sequence]] = None, ) -> None: self.name = name.lower() + self.type = DatasetType[type.upper()] if isinstance(type, str) else type if categories is None: categories = [] From bb91402db64aeac896618dbe941fad56129a50eb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Oct 2021 17:10:43 +0200 Subject: [PATCH 8/9] revert unrelated change --- torchvision/prototype/datasets/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index b08ddb93c31..c677cff0878 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -7,9 +7,10 @@ "Note that you cannot install it with `pip install torchdata`, since this is another package." ) from error + from . import decoder, utils -from ._home import home # Load this last, since some parts depend on the above being loaded first -from ._api import register, _list as list, info, load # usort: skip +from ._api import register, _list as list, info, load from ._folder import from_data_folder, from_image_folder +from ._home import home From 50fc909df611b3e2108a088b596f3e3fbe6838fa Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Oct 2021 18:10:31 +0200 Subject: [PATCH 9/9] cleanup --- torchvision/prototype/datasets/_api.py | 3 +- .../prototype/datasets/_builtin/cifar.py | 37 ++++++++++--------- torchvision/prototype/datasets/decoder.py | 6 ++- .../prototype/datasets/utils/_dataset.py | 2 +- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 5ad8215e25f..8d6796b2c32 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -4,7 +4,7 @@ import torch from torch.utils.data import IterDataPipe from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.decoder import pil +from torchvision.prototype.datasets.decoder import raw, pil from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType from torchvision.prototype.datasets.utils._internal import add_suggestion @@ -51,6 +51,7 @@ def info(name: str) -> DatasetInfo: default = object() DEFAULT_DECODER: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = { + DatasetType.RAW: raw, DatasetType.IMAGE: pil, } diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 0e945e4e379..016a6fbe915 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -16,6 +16,7 @@ Shuffler, ) from torchdata.datapipes.iter import KeyZipper +from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, @@ -67,16 +68,18 @@ def _collate_and_decode( ) -> Dict[str, Any]: (_, category_idx), (_, image_array_flat) = data - image_array = image_array_flat.reshape((3, 32, 32)) - if decoder: - image = decoder(image_buffer_from_array(image_array.transpose(1, 2, 0))) - else: - image = torch.from_numpy(image_array) - category = self.categories[category_idx] label = torch.tensor(category_idx) - return dict(image=image, label=label, category=category) + image_array = image_array_flat.reshape((3, 32, 32)) + image: Union[torch.Tensor, io.BytesIO] + if decoder is raw: + image = torch.from_numpy(image_array) + else: + image_buffer = image_buffer_from_array(image_array.transpose(1, 2, 0)) + image = decoder(image_buffer) if decoder else image_buffer + + return dict(label=label, category=category, image=image) def _make_datapipe( self, @@ -87,8 +90,8 @@ def _make_datapipe( ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] archive_dp = TarArchiveReader(archive_dp) - archive_dp = Filter(archive_dp, functools.partial(self._is_data_file, config=config)) - archive_dp = Mapper(archive_dp, self._unpickle) + archive_dp: IterDataPipe = Filter(archive_dp, functools.partial(self._is_data_file, config=config)) + archive_dp: IterDataPipe = Mapper(archive_dp, self._unpickle) archive_dp = MappingIterator(archive_dp) images_dp, labels_dp = Demultiplexer( archive_dp, @@ -98,13 +101,13 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, ) - labels_dp = Mapper(labels_dp, self._remove_data_dict_key) - labels_dp = SequenceIterator(labels_dp) + labels_dp: IterDataPipe = Mapper(labels_dp, self._remove_data_dict_key) + labels_dp: IterDataPipe = SequenceIterator(labels_dp) labels_dp = Enumerator(labels_dp) labels_dp = Shuffler(labels_dp, buffer_size=INFINITE_BUFFER_SIZE) - images_dp = Mapper(images_dp, self._remove_data_dict_key) - images_dp = SequenceIterator(images_dp) + images_dp: IterDataPipe = Mapper(images_dp, self._remove_data_dict_key) + images_dp: IterDataPipe = SequenceIterator(images_dp) images_dp = Enumerator(images_dp) dp = KeyZipper(labels_dp, images_dp, self._key_fn, buffer_size=INFINITE_BUFFER_SIZE) @@ -127,8 +130,8 @@ def _is_meta_file(self, data: Tuple[str, Any]) -> bool: def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = TarArchiveReader(dp) - dp = Filter(dp, self._is_meta_file) - dp = Mapper(dp, self._unpickle) + dp: IterDataPipe = Filter(dp, self._is_meta_file) + dp: IterDataPipe = Mapper(dp, self._unpickle) categories = next(iter(dp))[self._categories_key] create_categories_file(HERE, self.name, categories) @@ -138,7 +141,7 @@ class Cifar10(_CifarBase): def info(self) -> DatasetInfo: return DatasetInfo( "cifar10", - type=DatasetType.PRE_DECODED, + type=DatasetType.RAW, categories=HERE / "cifar10.categories", homepage="https://www.cs.toronto.edu/~kriz/cifar.html", ) @@ -178,7 +181,7 @@ class Cifar100(_CifarBase): def info(self) -> DatasetInfo: return DatasetInfo( "cifar100", - type=DatasetType.PRE_DECODED, + type=DatasetType.RAW, categories=HERE / "cifar100.categories", homepage="https://www.cs.toronto.edu/~kriz/cifar.html", valid_options=dict( diff --git a/torchvision/prototype/datasets/decoder.py b/torchvision/prototype/datasets/decoder.py index 64cea43e5f0..bbe046bd0bb 100644 --- a/torchvision/prototype/datasets/decoder.py +++ b/torchvision/prototype/datasets/decoder.py @@ -4,7 +4,11 @@ import torch from torchvision.transforms.functional import pil_to_tensor -__all__ = ["pil"] +__all__ = ["raw", "pil"] + + +def raw(buffer: io.IOBase) -> torch.Tensor: + raise RuntimeError("This is just a sentinel and should never be called.") def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor: diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 0998635938e..61e41a061e4 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -47,7 +47,7 @@ def to_str(sep: str) -> str: class DatasetType(enum.Enum): - PRE_DECODED = enum.auto() + RAW = enum.auto() IMAGE = enum.auto()