diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 5c613035e2b..8d6796b2c32 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -4,8 +4,8 @@ import torch from torch.utils.data import IterDataPipe from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.decoder import pil -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo +from torchvision.prototype.datasets.decoder import raw, pil +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType from torchvision.prototype.datasets.utils._internal import add_suggestion from . import _builtin @@ -48,15 +48,26 @@ def info(name: str) -> DatasetInfo: return find(name).info +default = object() + +DEFAULT_DECODER: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = { + DatasetType.RAW: raw, + DatasetType.IMAGE: pil, +} + + def load( name: str, *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = default, # type: ignore[assignment] split: str = "train", **options: Any, ) -> IterDataPipe[Dict[str, Any]]: dataset = find(name) + if decoder is default: + decoder = DEFAULT_DECODER.get(dataset.info.type) + config = dataset.info.make_config(split=split, **options) root = home() / name diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 7d6961fa920..9bccf1849bb 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -1 +1,2 @@ from .caltech import Caltech101, Caltech256 +from .cifar import Cifar10, Cifar100 diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index b1d9970bd94..0adebd959a2 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -19,6 +19,7 @@ DatasetInfo, HttpResource, OnlineResource, + DatasetType, ) from torchvision.prototype.datasets.utils._internal import create_categories_file, INFINITE_BUFFER_SIZE, read_mat @@ -30,6 +31,7 @@ class Caltech101(Dataset): def info(self) -> DatasetInfo: return DatasetInfo( "caltech101", + type=DatasetType.IMAGE, categories=HERE / "caltech101.categories", homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", ) @@ -146,6 +148,7 @@ class Caltech256(Dataset): def info(self) -> DatasetInfo: return DatasetInfo( "caltech256", + type=DatasetType.IMAGE, categories=HERE / "caltech256.categories", homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256", ) diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py new file mode 100644 index 00000000000..016a6fbe915 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -0,0 +1,227 @@ +import abc +import functools +import io +import pathlib +import pickle +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TypeVar + +import numpy as np +import torch +from torch.utils.data import IterDataPipe +from torch.utils.data.datapipes.iter import ( + Demultiplexer, + Filter, + Mapper, + TarArchiveReader, + Shuffler, +) +from torchdata.datapipes.iter import KeyZipper +from torchvision.prototype.datasets.decoder import raw +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, + DatasetType, +) +from torchvision.prototype.datasets.utils._internal import ( + create_categories_file, + MappingIterator, + SequenceIterator, + INFINITE_BUFFER_SIZE, + image_buffer_from_array, + Enumerator, +) + +__all__ = ["Cifar10", "Cifar100"] + +HERE = pathlib.Path(__file__).parent + +D = TypeVar("D") + + +class _CifarBase(Dataset): + @abc.abstractmethod + def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> Optional[int]: + pass + + @abc.abstractmethod + def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]: + pass + + def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: + _, file = data + return pickle.load(file, encoding="latin1") + + def _remove_data_dict_key(self, data: Tuple[str, D]) -> D: + return data[1] + + def _key_fn(self, data: Tuple[int, Any]) -> int: + return data[0] + + def _collate_and_decode( + self, + data: Tuple[Tuple[int, int], Tuple[int, np.ndarray]], + *, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> Dict[str, Any]: + (_, category_idx), (_, image_array_flat) = data + + category = self.categories[category_idx] + label = torch.tensor(category_idx) + + image_array = image_array_flat.reshape((3, 32, 32)) + image: Union[torch.Tensor, io.BytesIO] + if decoder is raw: + image = torch.from_numpy(image_array) + else: + image_buffer = image_buffer_from_array(image_array.transpose(1, 2, 0)) + image = decoder(image_buffer) if decoder else image_buffer + + return dict(label=label, category=category, image=image) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + archive_dp = resource_dps[0] + archive_dp = TarArchiveReader(archive_dp) + archive_dp: IterDataPipe = Filter(archive_dp, functools.partial(self._is_data_file, config=config)) + archive_dp: IterDataPipe = Mapper(archive_dp, self._unpickle) + archive_dp = MappingIterator(archive_dp) + images_dp, labels_dp = Demultiplexer( + archive_dp, + 2, + self._split_data_file, # type: ignore[arg-type] + drop_none=True, + buffer_size=INFINITE_BUFFER_SIZE, + ) + + labels_dp: IterDataPipe = Mapper(labels_dp, self._remove_data_dict_key) + labels_dp: IterDataPipe = SequenceIterator(labels_dp) + labels_dp = Enumerator(labels_dp) + labels_dp = Shuffler(labels_dp, buffer_size=INFINITE_BUFFER_SIZE) + + images_dp: IterDataPipe = Mapper(images_dp, self._remove_data_dict_key) + images_dp: IterDataPipe = SequenceIterator(images_dp) + images_dp = Enumerator(images_dp) + + dp = KeyZipper(labels_dp, images_dp, self._key_fn, buffer_size=INFINITE_BUFFER_SIZE) + return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) + + @property + @abc.abstractmethod + def _meta_file_name(self) -> str: + pass + + @property + @abc.abstractmethod + def _categories_key(self) -> str: + pass + + def _is_meta_file(self, data: Tuple[str, Any]) -> bool: + path = pathlib.Path(data[0]) + return path.name == self._meta_file_name + + def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: + dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) + dp = TarArchiveReader(dp) + dp: IterDataPipe = Filter(dp, self._is_meta_file) + dp: IterDataPipe = Mapper(dp, self._unpickle) + categories = next(iter(dp))[self._categories_key] + create_categories_file(HERE, self.name, categories) + + +class Cifar10(_CifarBase): + @property + def info(self) -> DatasetInfo: + return DatasetInfo( + "cifar10", + type=DatasetType.RAW, + categories=HERE / "cifar10.categories", + homepage="https://www.cs.toronto.edu/~kriz/cifar.html", + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + return [ + HttpResource( + "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz", + sha256="6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce", + ) + ] + + def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool: + path = pathlib.Path(data[0]) + return path.name.startswith("data" if config.split == "train" else "test") + + def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]: + key, _ = data + if key == "data": + return 0 + elif key == "labels": + return 1 + else: + return None + + @property + def _meta_file_name(self) -> str: + return "batches.meta" + + @property + def _categories_key(self) -> str: + return "label_names" + + +class Cifar100(_CifarBase): + @property + def info(self) -> DatasetInfo: + return DatasetInfo( + "cifar100", + type=DatasetType.RAW, + categories=HERE / "cifar100.categories", + homepage="https://www.cs.toronto.edu/~kriz/cifar.html", + valid_options=dict( + split=("train", "test"), + ), + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + return [ + HttpResource( + "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz", + sha256="85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7", + ) + ] + + def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool: + path = pathlib.Path(data[0]) + return path.name == config.split + + def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]: + key, _ = data + if key == "data": + return 0 + elif key == "fine_labels": + return 1 + else: + return None + + @property + def _meta_file_name(self) -> str: + return "meta" + + @property + def _categories_key(self) -> str: + return "fine_label_names" + + +if __name__ == "__main__": + from torchvision.prototype.datasets import home + + root = home() + Cifar10().generate_categories_file(root) + Cifar100().generate_categories_file(root) diff --git a/torchvision/prototype/datasets/_builtin/cifar10.categories b/torchvision/prototype/datasets/_builtin/cifar10.categories new file mode 100644 index 00000000000..fa30c22b95d --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/cifar10.categories @@ -0,0 +1,10 @@ +airplane +automobile +bird +cat +deer +dog +frog +horse +ship +truck diff --git a/torchvision/prototype/datasets/_builtin/cifar100.categories b/torchvision/prototype/datasets/_builtin/cifar100.categories new file mode 100644 index 00000000000..7f7bf51d1ab --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/cifar100.categories @@ -0,0 +1,100 @@ +apple +aquarium_fish +baby +bear +beaver +bed +bee +beetle +bicycle +bottle +bowl +boy +bridge +bus +butterfly +camel +can +castle +caterpillar +cattle +chair +chimpanzee +clock +cloud +cockroach +couch +crab +crocodile +cup +dinosaur +dolphin +elephant +flatfish +forest +fox +girl +hamster +house +kangaroo +keyboard +lamp +lawn_mower +leopard +lion +lizard +lobster +man +maple_tree +motorcycle +mountain +mouse +mushroom +oak_tree +orange +orchid +otter +palm_tree +pear +pickup_truck +pine_tree +plain +plate +poppy +porcupine +possum +rabbit +raccoon +ray +road +rocket +rose +sea +seal +shark +shrew +skunk +skyscraper +snail +snake +spider +squirrel +streetcar +sunflower +sweet_pepper +table +tank +telephone +television +tiger +tractor +train +trout +tulip +turtle +wardrobe +whale +willow_tree +wolf +woman +worm diff --git a/torchvision/prototype/datasets/decoder.py b/torchvision/prototype/datasets/decoder.py index 64cea43e5f0..bbe046bd0bb 100644 --- a/torchvision/prototype/datasets/decoder.py +++ b/torchvision/prototype/datasets/decoder.py @@ -4,7 +4,11 @@ import torch from torchvision.transforms.functional import pil_to_tensor -__all__ = ["pil"] +__all__ = ["raw", "pil"] + + +def raw(buffer: io.IOBase) -> torch.Tensor: + raise RuntimeError("This is just a sentinel and should never be called.") def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor: diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 48e7541eba5..018553e0908 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,3 +1,3 @@ from . import _internal -from ._dataset import DatasetConfig, DatasetInfo, Dataset +from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index b43dc3fc4c4..61e41a061e4 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -1,4 +1,5 @@ import abc +import enum import io import os import pathlib @@ -45,6 +46,11 @@ def to_str(sep: str) -> str: return f"{prefix}\n{body}\n{postfix}" +class DatasetType(enum.Enum): + RAW = enum.auto() + IMAGE = enum.auto() + + class DatasetConfig(Mapping): def __init__(self, *args, **kwargs): data = dict(*args, **kwargs) @@ -96,6 +102,7 @@ def __init__( self, name: str, *, + type: Union[str, DatasetType], categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, citation: Optional[str] = None, homepage: Optional[str] = None, @@ -103,6 +110,7 @@ def __init__( valid_options: Optional[Dict[str, Sequence]] = None, ) -> None: self.name = name.lower() + self.type = DatasetType[type.upper()] if isinstance(type, str) else type if categories is None: categories = [] @@ -111,7 +119,7 @@ def __init__( elif isinstance(categories, (str, pathlib.Path)): with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh: categories = [line.strip() for line in fh] - self.categories = categories + self.categories = tuple(categories) self.citation = citation self.homepage = homepage @@ -181,6 +189,10 @@ def name(self) -> str: def default_config(self) -> DatasetConfig: return self.info.default_config + @property + def categories(self) -> Tuple[str, ...]: + return self.info.categories + @abc.abstractmethod def resources(self, config: DatasetConfig) -> List[OnlineResource]: pass diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 7a1d34ffa0e..3db014ef4ee 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -2,10 +2,29 @@ import difflib import io import pathlib -from typing import Collection, Sequence, Callable, Union, Any +from typing import Collection, Sequence, Callable, Union, Iterator, Tuple, TypeVar, Dict, Any +import numpy as np +import PIL.Image +from torch.utils.data import IterDataPipe + + +__all__ = [ + "INFINITE_BUFFER_SIZE", + "sequence_to_str", + "add_suggestion", + "create_categories_file", + "read_mat", + "image_buffer_from_array", + "SequenceIterator", + "MappingIterator", + "Enumerator", +] + + +K = TypeVar("K") +D = TypeVar("D") -__all__ = ["INFINITE_BUFFER_SIZE", "sequence_to_str", "add_suggestion", "create_categories_file", "read_mat"] # pseudo-infinite until a true infinite buffer is supported by all datapipes INFINITE_BUFFER_SIZE = 1_000_000_000 @@ -47,3 +66,39 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error return sio.loadmat(buffer, **kwargs) + + +def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.BytesIO: + image = PIL.Image.fromarray(array) + buffer = io.BytesIO() + image.save(buffer, format=format) + buffer.seek(0) + return buffer + + +class SequenceIterator(IterDataPipe[D]): + def __init__(self, datapipe: IterDataPipe[Sequence[D]]): + self.datapipe = datapipe + + def __iter__(self) -> Iterator[D]: + for sequence in self.datapipe: + yield from iter(sequence) + + +class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]): + def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None: + self.datapipe = datapipe + self.drop_key = drop_key + + def __iter__(self) -> Iterator[Union[Tuple[K, D], D]]: + for mapping in self.datapipe: + yield from iter(mapping.values() if self.drop_key else mapping.items()) # type: ignore[call-overload] + + +class Enumerator(IterDataPipe[Tuple[int, D]]): + def __init__(self, datapipe: IterDataPipe[D], start: int = 0) -> None: + self.datapipe = datapipe + self.start = start + + def __iter__(self) -> Iterator[Tuple[int, D]]: + yield from enumerate(self.datapipe, self.start)