From 3a1d886d3f0ee2d3aa27c1b1cefadfb3b1e422a5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 16 Dec 2021 09:12:11 +0100 Subject: [PATCH 1/8] [PoC] separate decoding from datasets --- torchvision/prototype/datasets/__init__.py | 2 +- torchvision/prototype/datasets/_api.py | 23 +----- .../prototype/datasets/_builtin/caltech.py | 59 ++++++--------- .../prototype/datasets/_builtin/celeba.py | 66 ++++++++--------- .../prototype/datasets/_builtin/cifar.py | 34 ++------- .../prototype/datasets/_builtin/coco.py | 34 +++------ .../prototype/datasets/_builtin/imagenet.py | 35 +++------ .../prototype/datasets/_builtin/mnist.py | 69 ++++-------------- .../prototype/datasets/_builtin/sbd.py | 57 ++++----------- .../prototype/datasets/_builtin/semeion.py | 32 +++----- .../datasets/_builtin/voc.categories | 20 +++++ .../prototype/datasets/_builtin/voc.py | 73 ++++++++++++------- torchvision/prototype/datasets/_folder.py | 34 ++++----- torchvision/prototype/datasets/decoder.py | 16 ---- .../prototype/datasets/utils/__init__.py | 9 ++- .../prototype/datasets/utils/_dataset.py | 16 +--- .../prototype/datasets/utils/_decoder.py | 47 ++++++++++++ .../prototype/datasets/utils/_internal.py | 24 ++---- 18 files changed, 269 insertions(+), 381 deletions(-) create mode 100644 torchvision/prototype/datasets/_builtin/voc.categories delete mode 100644 torchvision/prototype/datasets/decoder.py create mode 100644 torchvision/prototype/datasets/utils/_decoder.py diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index 1945b5a5d9e..28840081fe7 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -7,7 +7,7 @@ "Note that you cannot install it with `pip install torchdata`, since this is another package." ) from error -from . import decoder, utils +from . import utils from ._home import home # Load this last, since some parts depend on the above being loaded first diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 8b534c85413..8b7d6c90e54 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -1,12 +1,9 @@ -import io import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List -import torch from torch.utils.data import IterDataPipe from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.decoder import raw, pil -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo from torchvision.prototype.utils._internal import add_suggestion from . import _builtin @@ -49,28 +46,14 @@ def info(name: str) -> DatasetInfo: return find(name).info -DEFAULT_DECODER = object() - -DEFAULT_DECODER_MAP: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = { - DatasetType.RAW: raw, - DatasetType.IMAGE: pil, -} - - def load( name: str, *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment] skip_integrity_check: bool = False, split: str = "train", **options: Any, ) -> IterDataPipe[Dict[str, Any]]: dataset = find(name) - - if decoder is DEFAULT_DECODER: - decoder = DEFAULT_DECODER_MAP.get(dataset.info.type) - config = dataset.info.make_config(split=split, **options) root = os.path.join(home(), dataset.name) - - return dataset.load(root, config=config, decoder=decoder, skip_integrity_check=skip_integrity_check) + return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check) diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 16dd8ec35c3..95ee3bef54a 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -1,10 +1,8 @@ -import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -18,7 +16,8 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, + DecodeableImageStreamWrapper, + DecodeableStreamWrapper, ) from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat from torchvision.prototype.features import Label, BoundingBox, Feature @@ -28,7 +27,6 @@ class Caltech101(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "caltech101", - type=DatasetType.IMAGE, dependencies=("scipy",), homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", ) @@ -81,33 +79,28 @@ def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: return category, id - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[str, str], Tuple[Tuple[str, io.IOBase], Tuple[str, io.IOBase]]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + def _decode_ann(self, data: BinaryIO) -> Dict[str, Any]: + ann = read_mat(data) + return dict( + bounding_box=BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy"), + contour=Feature(ann["obj_contour"].T), + ) + + def _prepare_sample( + self, data: Tuple[Tuple[str, str], Tuple[Tuple[str, BinaryIO], Tuple[str, BinaryIO]]] ) -> Dict[str, Any]: key, (image_data, ann_data) = data category, _ = key image_path, image_buffer = image_data ann_path, ann_buffer = ann_data - label = self.info.categories.index(category) - - image = decoder(image_buffer) if decoder else image_buffer - - ann = read_mat(ann_buffer) - bbox = BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy") - contour = Feature(ann["obj_contour"].T) - return dict( category=category, - label=label, - image=image, + label=Label(self.info.categories.index(category), category=category), image_path=image_path, - bbox=bbox, - contour=contour, + image=DecodeableImageStreamWrapper(image_buffer), ann_path=ann_path, + ann=DecodeableStreamWrapper(ann_buffer, self._decode_ann), ) def _make_datapipe( @@ -115,7 +108,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: images_dp, anns_dp = resource_dps @@ -132,7 +124,7 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, keep_key=True, ) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) @@ -144,7 +136,6 @@ class Caltech256(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "caltech256", - type=DatasetType.IMAGE, homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256", ) @@ -160,31 +151,27 @@ def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool: path = pathlib.Path(data[0]) return path.name != "RENAME2" - def _collate_and_decode_sample( - self, - data: Tuple[str, io.IOBase], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: path, buffer = data dir_name = pathlib.Path(path).parent.name label_str, category = dir_name.split(".") - label = Label(int(label_str), category=category) - - return dict(label=label, image=decoder(buffer) if decoder else buffer) + return dict( + path=path, + image=DecodeableImageStreamWrapper(buffer), + label=Label(int(label_str), category=category), + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Filter(dp, self._is_not_rogue_file) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index d8b9137ecc2..e59e84a8b47 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -1,8 +1,6 @@ import csv -import io -from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence +from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -17,9 +15,10 @@ DatasetInfo, GDriveResource, OnlineResource, - DatasetType, + DecodeableImageStreamWrapper, ) from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor +from torchvision.prototype.features import BoundingBox, Feature, Label csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) @@ -28,7 +27,7 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]): def __init__( self, - datapipe: IterDataPipe[Tuple[Any, io.IOBase]], + datapipe: IterDataPipe[Tuple[Any, BinaryIO]], *, fieldnames: Optional[Sequence[str]] = None, ) -> None: @@ -60,7 +59,6 @@ class CelebA(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "celeba", - type=DatasetType.IMAGE, homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html", ) @@ -85,7 +83,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0", file_name="list_attr_celeba.txt", ) - bboxes = GDriveResource( + bounding_boxes = GDriveResource( "0B7EVK8r0v71pbThiMVRxWXZ4dU0", sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b", file_name="list_bbox_celeba.txt", @@ -95,7 +93,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b", file_name="list_landmarks_align_celeba.txt", ) - return [splits, images, identities, attributes, bboxes, landmarks] + return [splits, images, identities, attributes, bounding_boxes, landmarks] _SPLIT_ID_TO_NAME = { "0": "train", @@ -106,38 +104,35 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool: return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split - def _collate_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[str, Dict[str, Dict[str, str]]]: - (image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data - return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks) + def _prepare_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[str, Dict[str, Any]]: + (image_id, identity), (_, attributes), (_, bounding_box), (_, landmarks) = data + return image_id, dict( + identity=Label(int(identity["identity"])), + attributes={attr: value == "1" for attr, value in attributes.items()}, + # FIXME: probe image_size from file + bounding_box=BoundingBox( + [int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")], + format="xywh", + image_size=(-1, -1), + ), + landmarks={ + landmark: Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) + for landmark in {key[:-2] for key in landmarks.keys()} + }, + ) def _collate_and_decode_sample( - self, - data: Tuple[Tuple[str, Tuple[str, List[str]], Tuple[str, io.IOBase]], Tuple[str, Dict[str, Any]]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, data: Tuple[Tuple[str, Tuple[str, List[str]], Tuple[str, BinaryIO]], Tuple[str, Dict[str, Any]]] ) -> Dict[str, Any]: split_and_image_data, ann_data = data _, _, image_data = split_and_image_data path, buffer = image_data - _, ann = ann_data - - image = decoder(buffer) if decoder else buffer - - identity = int(ann["identity"]["identity"]) - attributes = {attr: value == "1" for attr, value in ann["attributes"].items()} - bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")]) - landmarks = { - landmark: torch.tensor((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"]))) - for landmark in {key[:-2] for key in ann["landmarks"].keys()} - } + _, anns = ann_data return dict( + anns, path=path, - image=image, - identity=identity, - attributes=attributes, - bbox=bbox, - landmarks=landmarks, + image=DecodeableImageStreamWrapper(buffer), ) def _make_datapipe( @@ -145,9 +140,8 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: - splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps + splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split)) @@ -159,12 +153,12 @@ def _make_datapipe( for dp, fieldnames in ( (identities_dp, ("image_id", "identity")), (attributes_dp, None), - (bboxes_dp, None), + (bounding_boxes_dp, None), (landmarks_dp, None), ) ] ) - anns_dp = Mapper(anns_dp, self._collate_anns) + anns_dp = Mapper(anns_dp, self._prepare_anns) dp = IterKeyZipper( splits_dp, @@ -175,4 +169,4 @@ def _make_datapipe( keep_key=True, ) dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + return Mapper(dp, self._collate_and_decode_sample) diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 9695ea63e82..b221550e1b3 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -3,34 +3,28 @@ import io import pathlib import pickle -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator, cast +from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Filter, Mapper, Shuffler, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, - image_buffer_from_array, path_comparator, ) from torchvision.prototype.features import Label, Image -__all__ = ["Cifar10", "Cifar100"] - class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]): def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None: @@ -50,45 +44,29 @@ class _CifarBase(Dataset): _CATEGORIES_KEY: str @abc.abstractmethod - def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> Optional[int]: + def _is_data_file(self, data: Tuple[str, BinaryIO], *, config: DatasetConfig) -> Optional[int]: pass def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: _, file = data return cast(Dict[str, Any], pickle.load(file, encoding="latin1")) - def _collate_and_decode( - self, - data: Tuple[np.ndarray, int], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]: image_array, category_idx = data - - image: Union[Image, io.BytesIO] - if decoder is raw: - image = Image(image_array) - else: - image_buffer = image_buffer_from_array(image_array.transpose((1, 2, 0))) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - - label = Label(category_idx, category=self.categories[category_idx]) - - return dict(image=image, label=label) + return dict(image=Image(image_array), label=Label(category_idx, category=self.categories[category_idx])) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Filter(dp, functools.partial(self._is_data_file, config=config)) dp = Mapper(dp, self._unpickle) dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) - return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) @@ -109,7 +87,6 @@ def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool def _make_info(self) -> DatasetInfo: return DatasetInfo( "cifar10", - type=DatasetType.RAW, homepage="https://www.cs.toronto.edu/~kriz/cifar.html", ) @@ -134,7 +111,6 @@ def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool def _make_info(self) -> DatasetInfo: return DatasetInfo( "cifar100", - type=DatasetType.RAW, homepage="https://www.cs.toronto.edu/~kriz/cifar.html", valid_options=dict( split=("train", "test"), diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 22d47d20afa..89811f7660d 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -1,8 +1,7 @@ -import io import pathlib import re from collections import OrderedDict -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO import torch from torchdata.datapipes.iter import ( @@ -22,7 +21,7 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, + DecodeableImageStreamWrapper, ) from torchvision.prototype.datasets.utils._internal import ( MappingIterator, @@ -42,7 +41,6 @@ def _make_info(self) -> DatasetInfo: return DatasetInfo( name, - type=DatasetType.IMAGE, dependencies=("pycocotools",), categories=categories, homepage="https://cocodataset.org/", @@ -106,7 +104,7 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st ) ), areas=Feature([ann["area"] for ann in anns]), - crowds=Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool), + crowds=Feature([ann["crowd"] for ann in anns], dtype=torch.bool), bounding_boxes=BoundingBox( [ann["bbox"] for ann in anns], format="xywh", @@ -148,26 +146,21 @@ def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _collate_and_decode_image( - self, data: Tuple[str, io.IOBase], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] - ) -> Dict[str, Any]: + def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: path, buffer = data - return dict(path=path, image=decoder(buffer) if decoder else buffer) + return dict(path=path, image=DecodeableImageStreamWrapper(buffer)) - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, io.IOBase]], + data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]], *, - annotations: Optional[str], - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + annotations: str, ) -> Dict[str, Any]: ann_data, image_data = data anns, image_meta = ann_data - sample = self._collate_and_decode_image(image_data, decoder=decoder) - if annotations: - sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) - + sample = self._prepare_image(image_data) + sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) return sample def _make_datapipe( @@ -175,13 +168,12 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: images_dp, meta_dp = resource_dps if config.annotations is None: dp = Shuffler(images_dp) - return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder)) + return Mapper(dp, self._prepare_image) meta_dp = Filter( meta_dp, @@ -222,9 +214,7 @@ def _make_datapipe( ref_key_fn=path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper( - dp, self._collate_and_decode_sample, fn_kwargs=dict(annotations=config.annotations, decoder=decoder) - ) + return Mapper(dp, self._prepare_sample, fn_kwargs=dict(annotations=config.annotations)) def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: config = self.default_config diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index c34aa18398c..91f7588f81c 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,9 +1,7 @@ -import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO -import torch from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, TarArchiveReader, Filter, Shuffler from torchvision.prototype.datasets.utils import ( Dataset, @@ -11,7 +9,7 @@ DatasetInfo, OnlineResource, ManualDownloadResource, - DatasetType, + DecodeableImageStreamWrapper, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -37,7 +35,6 @@ def _make_info(self) -> DatasetInfo: return DatasetInfo( name, - type=DatasetType.IMAGE, dependencies=("scipy",), categories=categories, homepage="https://www.image-net.org/", @@ -85,7 +82,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") - def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]: + def _collate_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str, str], Tuple[str, BinaryIO]]: path = pathlib.Path(data[0]) wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr] category = self.wnid_to_category[wnid] @@ -99,40 +96,30 @@ def _val_test_image_key(self, data: Tuple[str, Any]) -> int: return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr] def _collate_val_data( - self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]] - ) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]: + self, data: Tuple[Tuple[int, int], Tuple[str, BinaryIO]] + ) -> Tuple[Tuple[Label, str, str], Tuple[str, BinaryIO]]: label_data, image_data = data _, label = label_data category = self.categories[label] wnid = self.category_to_wnid[category] return (Label(label), category, wnid), image_data - def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[None, Tuple[str, io.IOBase]]: + def _collate_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: return None, data - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Optional[Tuple[Label, str, str]], Tuple[str, io.IOBase]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + data: Tuple[Optional[Tuple[Label, str, str]], Tuple[str, BinaryIO]], ) -> Dict[str, Any]: label_data, (path, buffer) = data - sample = dict( - path=path, - image=decoder(buffer) if decoder else buffer, - ) + sample = dict(path=path, image=DecodeableImageStreamWrapper(buffer)) if label_data: sample.update(dict(zip(("label", "category", "wnid"), label_data))) - return sample def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: images_dp, devkit_dp = resource_dps @@ -160,7 +147,7 @@ def _make_datapipe( dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Mapper(dp, self._collate_test_data) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + return Mapper(dp, self._prepare_sample) # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index d7b711049b3..2eeecc02226 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -1,10 +1,9 @@ import abc import functools -import io import operator import pathlib import string -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO +from typing import Any, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO import torch from torchdata.datapipes.iter import ( @@ -14,17 +13,14 @@ Zipper, Shuffler, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, - DatasetType, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, ) from torchvision.prototype.datasets.utils._internal import ( - image_buffer_from_array, Decompressor, INFINITE_BUFFER_SIZE, fromfile, @@ -98,31 +94,15 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: return None, None - def _collate_and_decode( - self, - data: Tuple[torch.Tensor, torch.Tensor], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: image, label = data - - if decoder is raw: - image = Image(image) - else: - image_buffer = image_buffer_from_array(image.numpy()) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - - label = Label(label, dtype=torch.int64, category=self.info.categories[int(label)]) - - return dict(image=image, label=label) + return dict( + image=Image(image), + label=Label(label, dtype=torch.int64, category=self.info.categories[int(label)]), + ) def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: images_dp, labels_dp = resource_dps start, stop = self.start_and_stop(config) @@ -135,14 +115,13 @@ def _make_datapipe( dp = Zipper(images_dp, labels_dp) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) - return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder)) + return Mapper(dp, self._prepare_sample, fn_kwargs=dict(config=config)) class MNIST(_MNISTBase): def _make_info(self) -> DatasetInfo: return DatasetInfo( "mnist", - type=DatasetType.RAW, categories=10, homepage="http://yann.lecun.com/exdb/mnist", valid_options=dict( @@ -172,7 +151,6 @@ class FashionMNIST(MNIST): def _make_info(self) -> DatasetInfo: return DatasetInfo( "fashionmnist", - type=DatasetType.RAW, categories=( "T-shirt/top", "Trouser", @@ -204,7 +182,6 @@ class KMNIST(MNIST): def _make_info(self) -> DatasetInfo: return DatasetInfo( "kmnist", - type=DatasetType.RAW, categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], homepage="http://codh.rois.ac.jp/kmnist/index.html.en", valid_options=dict( @@ -225,7 +202,6 @@ class EMNIST(_MNISTBase): def _make_info(self) -> DatasetInfo: return DatasetInfo( "emnist", - type=DatasetType.RAW, categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist", valid_options=dict( @@ -280,13 +256,7 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> 46: 9, } - def _collate_and_decode( - self, - data: Tuple[torch.Tensor, torch.Tensor], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example, @@ -299,14 +269,10 @@ def _collate_and_decode( image, label = data label += self._LABEL_OFFSETS.get(int(label), 0) data = (image, label) - return super()._collate_and_decode(data, config=config, decoder=decoder) + return super()._prepare_sample(data, config=config) def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, labels_dp = Demultiplexer( @@ -316,14 +282,13 @@ def _make_datapipe( drop_none=True, buffer_size=INFINITE_BUFFER_SIZE, ) - return super()._make_datapipe([images_dp, labels_dp], config=config, decoder=decoder) + return super()._make_datapipe([images_dp, labels_dp], config=config) class QMNIST(_MNISTBase): def _make_info(self) -> DatasetInfo: return DatasetInfo( "qmnist", - type=DatasetType.RAW, categories=10, homepage="https://github.com/facebookresearch/qmnist", valid_options=dict( @@ -365,16 +330,10 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional return start, stop - def _collate_and_decode( - self, - data: Tuple[torch.Tensor, torch.Tensor], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: image, ann = data label, *extra_anns = ann - sample = super()._collate_and_decode((image, label), config=config, decoder=decoder) + sample = super()._prepare_sample((image, label), config=config) sample.update( dict( diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 971644dcecd..6d2ae4eac25 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -1,10 +1,8 @@ -import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -20,7 +18,8 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, + DecodeableStreamWrapper, + DecodeableImageStreamWrapper, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -29,19 +28,17 @@ path_accessor, path_comparator, ) +from torchvision.prototype.features import Feature class SBD(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "sbd", - type=DatasetType.IMAGE, dependencies=("scipy",), homepage="http://home.bharathh.info/pubs/codes/SBD/download.html", valid_options=dict( split=("train", "val", "train_noval"), - boundaries=(True, False), - segmentation=(False, True), ), ) @@ -72,50 +69,25 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _decode_ann( - self, data: Dict[str, Any], *, decode_boundaries: bool, decode_segmentation: bool - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - raw_anns = data["GTcls"][0] - raw_boundaries = raw_anns["Boundaries"][0] - raw_segmentation = raw_anns["Segmentation"][0] - - # the boundaries are stored in sparse CSC format, which is not supported by PyTorch - boundaries = ( - torch.as_tensor(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries])) - if decode_boundaries - else None + def _decode_ann(self, buffer: BinaryIO) -> Dict[str, Any]: + raw_anns = read_mat(buffer)["GTcls"][0] + return dict( + # the boundaries are stored in sparse CSC format, which is not supported by PyTorch + boundaries=Feature(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_anns["Boundaries"][0]])), + segmentation=Feature(raw_anns["Segmentation"][0]), ) - segmentation = torch.as_tensor(raw_segmentation) if decode_segmentation else None - - return boundaries, segmentation - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[Any, Tuple[str, io.IOBase]], Tuple[str, io.IOBase]], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]: split_and_image_data, ann_data = data _, image_data = split_and_image_data image_path, image_buffer = image_data ann_path, ann_buffer = ann_data - image = decoder(image_buffer) if decoder else image_buffer - - if config.boundaries or config.segmentation: - boundaries, segmentation = self._decode_ann( - read_mat(ann_buffer), decode_boundaries=config.boundaries, decode_segmentation=config.segmentation - ) - else: - boundaries = segmentation = None - return dict( image_path=image_path, - image=image, + image=DecodeableImageStreamWrapper(image_buffer), ann_path=ann_path, - boundaries=boundaries, - segmentation=segmentation, + ann=DecodeableStreamWrapper(ann_buffer, self._decode_ann), ) def _make_datapipe( @@ -123,7 +95,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: archive_dp, extra_split_dp = resource_dps @@ -150,7 +121,7 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index 93280f1e0b7..d5468c8d401 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -1,5 +1,4 @@ -import io -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple import torch from torchdata.datapipes.iter import ( @@ -8,23 +7,21 @@ Shuffler, CSVParser, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) -from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, image_buffer_from_array +from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE +from torchvision.prototype.features import Image, Label class SEMEION(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "semeion", - type=DatasetType.RAW, categories=10, homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", ) @@ -36,34 +33,23 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) return [data] - def _collate_and_decode_sample( - self, - data: Tuple[str, ...], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]: image_data = torch.tensor([float(pixel) for pixel in data[:256]], dtype=torch.uint8).reshape(16, 16) label_data = [int(label) for label in data[256:] if label] - if decoder is raw: - image = image_data.unsqueeze(0) - else: - image_buffer = image_buffer_from_array(image_data.numpy()) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - - label = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label)) - category = self.info.categories[label] - return dict(image=image, label=label, category=category) + label_idx = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label)) + return dict( + image=Image(image_data.unsqueeze(0)), label=Label(label_idx, category=self.info.categories[label_idx]) + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = CSVParser(dp, delimiter=" ") dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) - dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + dp = Mapper(dp, self._prepare_sample) return dp diff --git a/torchvision/prototype/datasets/_builtin/voc.categories b/torchvision/prototype/datasets/_builtin/voc.categories new file mode 100644 index 00000000000..8420ab35ede --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/voc.categories @@ -0,0 +1,20 @@ +aeroplane +bicycle +bird +boat +bottle +bus +car +cat +chair +cow +diningtable +dog +horse +motorbike +person +pottedplant +sheep +sofa +train +tvmonitor diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 9e0476f081c..953af3b4b26 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -1,10 +1,8 @@ import functools -import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast from xml.etree import ElementTree -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -21,7 +19,8 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, + DecodeableImageStreamWrapper, + DecodeableStreamWrapper, ) from torchvision.prototype.datasets.utils._internal import ( path_accessor, @@ -29,15 +28,13 @@ INFINITE_BUFFER_SIZE, path_comparator, ) - -HERE = pathlib.Path(__file__).parent +from torchvision.prototype.features import BoundingBox, Label class VOC(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "voc", - type=DatasetType.IMAGE, homepage="http://host.robots.ox.ac.uk/pascal/VOC/", valid_options=dict( split=("train", "val", "test"), @@ -82,40 +79,51 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> else: return None - def _decode_detection_ann(self, buffer: io.IOBase) -> torch.Tensor: - result = VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot()) # type: ignore[arg-type] - objects = result["annotation"]["object"] - bboxes = [obj["bndbox"] for obj in objects] - bboxes = [[int(bbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bbox in bboxes] - return torch.tensor(bboxes) + def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: + return cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"]) + + def _decode_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: + anns = self._parse_detection_ann(buffer) + instances = anns["object"] + return dict( + bounding_boxes=BoundingBox( + [ + [int(instance["bndbox"][part]) for part in ("xmin", "ymin", "xmax", "ymax")] + for instance in instances + ], + format="xyxy", + image_size=tuple(int(anns["size"][dim]) for dim in ("height", "width")), + ), + labels=[ + Label(self.info.categories.index(instance["name"]), category=instance["name"]) for instance in instances + ], + ) - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Tuple[Tuple[str, str], Tuple[str, io.IOBase]], Tuple[str, io.IOBase]], + data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> Dict[str, Any]: split_and_image_data, ann_data = data _, image_data = split_and_image_data image_path, image_buffer = image_data ann_path, ann_buffer = ann_data - image = decoder(image_buffer) if decoder else image_buffer - - if config.task == "detection": - ann = self._decode_detection_ann(ann_buffer) - else: # config.task == "segmentation": - ann = decoder(ann_buffer) if decoder else ann_buffer # type: ignore[assignment] - - return dict(image_path=image_path, image=image, ann_path=ann_path, ann=ann) + return dict( + image_path=image_path, + image=DecodeableImageStreamWrapper(image_buffer), + ann_path=ann_path, + ann=DecodeableStreamWrapper(ann_buffer, self._decode_detection_ann) + if config.task == "detection" + else DecodeableImageStreamWrapper(ann_buffer), + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] split_dp, images_dp, anns_dp = Demultiplexer( @@ -140,4 +148,17 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder)) + return Mapper(dp, self._prepare_sample, fn_kwargs=dict(config=config)) + + def _filter_detection_anns(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool: + return self._classify_archive(data, config=config) == 2 + + def _generate_categories(self, root: pathlib.Path) -> List[str]: + config = self.info.make_config(task="detection") + + resource = self.resources(config)[0] + dp = resource.load(pathlib.Path(root) / self.name) + dp = Filter(dp, self._filter_detection_anns, fn_kwargs=dict(config=config)) + dp = Mapper(dp, self._parse_detection_ann, input_col=1) + + return sorted({instance["name"] for _, anns in dp for instance in anns["object"]}) diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index 1411ee5895e..81766718c60 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -1,48 +1,46 @@ -import io import os import os.path import pathlib -from typing import Callable, Optional, Collection -from typing import Union, Tuple, List, Dict, Any +from typing import BinaryIO, Callable, Optional, Collection, Union, Tuple, List, Dict, Any -import torch -from torch.utils.data import IterDataPipe -from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter -from torchvision.prototype.datasets.decoder import pil +from torchdata.datapipes.iter import IterDataPipe, FileLister, FileLoader, Mapper, Shuffler, Filter +from torchvision.prototype.datasets.utils import DecodeableStreamWrapper, decode_image_with_pil from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE +from torchvision.prototype.features import Label __all__ = ["from_data_folder", "from_image_folder"] +def _read_bytes(buffer: BinaryIO) -> Dict[str, Any]: + return dict(data=buffer.read()) + + def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool: rel_path = pathlib.Path(path).relative_to(root) return rel_path.is_dir() or rel_path.parent != pathlib.Path(".") -def _collate_and_decode_data( - data: Tuple[str, io.IOBase], +def _prepare_sample( + data: Tuple[str, BinaryIO], *, root: pathlib.Path, categories: List[str], - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + decoder: Callable[[BinaryIO], Dict[str, Any]], ) -> Dict[str, Any]: path, buffer = data - data = decoder(buffer) if decoder else buffer category = pathlib.Path(path).relative_to(root).parts[0] - label = torch.tensor(categories.index(category)) return dict( path=path, - data=data, - label=label, - category=category, + data=DecodeableStreamWrapper(buffer, decoder), + label=Label(categories.index(category), category=category), ) def from_data_folder( root: Union[str, pathlib.Path], *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, + decoder: Callable[[BinaryIO], Dict[str, Any]] = _read_bytes, valid_extensions: Optional[Collection[str]] = None, recursive: bool = True, ) -> Tuple[IterDataPipe, List[str]]: @@ -54,7 +52,7 @@ def from_data_folder( dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = FileLoader(dp) return ( - Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)), + Mapper(dp, _prepare_sample, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)), categories, ) @@ -67,7 +65,7 @@ def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]: def from_image_folder( root: Union[str, pathlib.Path], *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, + decoder: Callable[[BinaryIO], Dict[str, Any]] = decode_image_with_pil, valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"), **kwargs: Any, ) -> Tuple[IterDataPipe, List[str]]: diff --git a/torchvision/prototype/datasets/decoder.py b/torchvision/prototype/datasets/decoder.py deleted file mode 100644 index 530a357f239..00000000000 --- a/torchvision/prototype/datasets/decoder.py +++ /dev/null @@ -1,16 +0,0 @@ -import io - -import PIL.Image -import torch -from torchvision.prototype import features -from torchvision.transforms.functional import pil_to_tensor - -__all__ = ["raw", "pil"] - - -def raw(buffer: io.IOBase) -> torch.Tensor: - raise RuntimeError("This is just a sentinel and should never be called.") - - -def pil(buffer: io.IOBase) -> features.Image: - return features.Image(pil_to_tensor(PIL.Image.open(buffer))) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 92bcffc0cdb..da3c1ab142b 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,11 @@ from . import _internal -from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset +from ._dataset import DatasetConfig, DatasetInfo, Dataset +from ._decoder import ( + DecodeableStreamWrapper, + DecodeableImageStreamWrapper, + decode_sample, + SampleDecoder, + decode_image_with_pil, +) from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 91aaf0af3fe..3a624daa813 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -1,14 +1,11 @@ import abc import csv -import enum import importlib -import io import itertools import os import pathlib -from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple +from typing import Any, Dict, List, Optional, Sequence, Union, Tuple -import torch from torch.utils.data import IterDataPipe from torchvision.prototype.utils._internal import FrozenBunch, make_repr from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str @@ -18,11 +15,6 @@ from ._resource import OnlineResource -class DatasetType(enum.Enum): - RAW = enum.auto() - IMAGE = enum.auto() - - class DatasetConfig(FrozenBunch): pass @@ -32,7 +24,6 @@ def __init__( self, name: str, *, - type: Union[str, DatasetType], dependencies: Sequence[str] = (), categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, citation: Optional[str] = None, @@ -42,7 +33,6 @@ def __init__( extra: Optional[Dict[str, Any]] = None, ) -> None: self.name = name.lower() - self.type = DatasetType[type.upper()] if isinstance(type, str) else type self.dependecies = dependencies @@ -165,7 +155,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: pass @@ -177,7 +166,6 @@ def load( root: Union[str, pathlib.Path], *, config: Optional[DatasetConfig] = None, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, skip_integrity_check: bool = False, ) -> IterDataPipe[Dict[str, Any]]: if not config: @@ -192,7 +180,7 @@ def load( resource_dps = [ resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config) ] - return self._make_datapipe(resource_dps, config=config, decoder=decoder) + return self._make_datapipe(resource_dps, config=config) def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]: raise NotImplementedError diff --git a/torchvision/prototype/datasets/utils/_decoder.py b/torchvision/prototype/datasets/utils/_decoder.py new file mode 100644 index 00000000000..257583afe09 --- /dev/null +++ b/torchvision/prototype/datasets/utils/_decoder.py @@ -0,0 +1,47 @@ +from typing import Any, Dict, Callable +from typing import BinaryIO + +import PIL.Image +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import Mapper, IterDataPipe +from torchvision.prototype import features +from torchvision.transforms.functional import pil_to_tensor + + +def decode_image_with_pil(buffer: BinaryIO) -> Dict[str, Any]: + return dict(image=features.Image(pil_to_tensor(PIL.Image.open(buffer)))) + + +class DecodeableStreamWrapper: + def __init__(self, stream: BinaryIO, decoder: Callable[[BinaryIO], Dict[str, Any]]) -> None: + self.__stream__ = stream + self.__decoder__ = decoder + + # TODO: dispatch attribute access besides `decode` to `__stream__` + + def decode(self) -> Dict[str, Any]: + return self.__decoder__(self.__stream__) + + def unwrap(self) -> BinaryIO: + return self.__stream__ + + +class DecodeableImageStreamWrapper(DecodeableStreamWrapper): + def __init__(self, stream: BinaryIO, decoder: Callable[[BinaryIO], Dict[str, Any]] = decode_image_with_pil) -> None: + super().__init__(stream, decoder) + + +def decode_sample(sample: Dict[str, Any]) -> Dict[str, Any]: + decoded_sample = dict() + for name, obj in sample.items(): + if isinstance(obj, DecodeableStreamWrapper): + decoded_sample.update(obj.decode()) + else: + decoded_sample[name] = obj + return decoded_sample + + +@functional_datapipe("decode_samples") +class SampleDecoder(Mapper[Dict[str, Any]]): + def __init__(self, datapipe: IterDataPipe[Dict[str, Any]]) -> None: + super().__init__(datapipe, decode_sample) diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index c4b91b4a14b..17df8181762 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -26,7 +26,6 @@ from typing import cast import numpy as np -import PIL.Image import torch import torch.distributed as dist import torch.utils.data @@ -39,7 +38,6 @@ "INFINITE_BUFFER_SIZE", "BUILTIN_DIR", "read_mat", - "image_buffer_from_array", "SequenceIterator", "MappingIterator", "Enumerator", @@ -60,7 +58,7 @@ BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin" -def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: +def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any: try: import scipy.io as sio except ImportError as error: @@ -72,14 +70,6 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: return sio.loadmat(buffer, **kwargs) -def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.BytesIO: - image = PIL.Image.fromarray(array) - buffer = io.BytesIO() - image.save(buffer, format=format) - buffer.seek(0) - return buffer - - class SequenceIterator(IterDataPipe[D]): def __init__(self, datapipe: IterDataPipe[Sequence[D]]): self.datapipe = datapipe @@ -146,17 +136,17 @@ class CompressionType(enum.Enum): LZMA = "lzma" -class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]): +class Decompressor(IterDataPipe[Tuple[str, BinaryIO]]): types = CompressionType - _DECOMPRESSORS = { - types.GZIP: lambda file: gzip.GzipFile(fileobj=file), - types.LZMA: lambda file: lzma.LZMAFile(file), + _DECOMPRESSORS: Dict[CompressionType, Callable[[BinaryIO], BinaryIO]] = { + types.GZIP: lambda file: cast(BinaryIO, gzip.GzipFile(fileobj=file)), + types.LZMA: lambda file: cast(BinaryIO, lzma.LZMAFile(file)), } def __init__( self, - datapipe: IterDataPipe[Tuple[str, io.IOBase]], + datapipe: IterDataPipe[Tuple[str, BinaryIO]], *, type: Optional[Union[str, CompressionType]] = None, ) -> None: @@ -178,7 +168,7 @@ def _detect_compression_type(self, path: str) -> CompressionType: else: raise RuntimeError("FIXME") - def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: + def __iter__(self) -> Iterator[Tuple[str, BinaryIO]]: for path, file in self.datapipe: type = self._detect_compression_type(path) decompressor = self._DECOMPRESSORS[type] From 8fca3a46fbc810712268ae8adb8bde0e55fa8ca8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 16 Dec 2021 09:27:39 +0100 Subject: [PATCH 2/8] cleanup --- torchvision/prototype/datasets/_builtin/caltech.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 95ee3bef54a..3e0f9ac48c1 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -95,7 +95,6 @@ def _prepare_sample( ann_path, ann_buffer = ann_data return dict( - category=category, label=Label(self.info.categories.index(category), category=category), image_path=image_path, image=DecodeableImageStreamWrapper(image_buffer), From d4507a883846cf297322df59bbdc0654794af961 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 21 Dec 2021 14:31:29 +0100 Subject: [PATCH 3/8] refactor to use tensors as base for undecoded data --- .../prototype/datasets/_builtin/caltech.py | 9 +- .../prototype/datasets/_builtin/celeba.py | 4 +- .../prototype/datasets/_builtin/cifar.py | 5 +- .../prototype/datasets/_builtin/coco.py | 7 +- .../prototype/datasets/_builtin/imagenet.py | 7 +- .../prototype/datasets/_builtin/sbd.py | 7 +- .../prototype/datasets/_builtin/semeion.py | 6 +- .../prototype/datasets/_builtin/voc.py | 11 ++- .../prototype/datasets/utils/__init__.py | 12 ++- .../prototype/datasets/utils/_decoder.py | 85 ++++++++++++------- .../prototype/datasets/utils/_internal.py | 32 ++++++- 11 files changed, 128 insertions(+), 57 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 3e0f9ac48c1..14dabeb738a 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -16,8 +16,7 @@ DatasetInfo, HttpResource, OnlineResource, - DecodeableImageStreamWrapper, - DecodeableStreamWrapper, + RawImage, ) from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat from torchvision.prototype.features import Label, BoundingBox, Feature @@ -95,11 +94,11 @@ def _prepare_sample( ann_path, ann_buffer = ann_data return dict( + self._decode_ann(ann_buffer), label=Label(self.info.categories.index(category), category=category), image_path=image_path, - image=DecodeableImageStreamWrapper(image_buffer), + image=RawImage.fromfile(image_buffer), ann_path=ann_path, - ann=DecodeableStreamWrapper(ann_buffer, self._decode_ann), ) def _make_datapipe( @@ -157,7 +156,7 @@ def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: label_str, category = dir_name.split(".") return dict( path=path, - image=DecodeableImageStreamWrapper(buffer), + image=RawImage.fromfile(buffer), label=Label(int(label_str), category=category), ) diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index e59e84a8b47..49c6948b004 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -15,7 +15,7 @@ DatasetInfo, GDriveResource, OnlineResource, - DecodeableImageStreamWrapper, + RawImage, ) from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor from torchvision.prototype.features import BoundingBox, Feature, Label @@ -132,7 +132,7 @@ def _collate_and_decode_sample( return dict( anns, path=path, - image=DecodeableImageStreamWrapper(buffer), + image=RawImage.fromfile(buffer), ) def _make_datapipe( diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index b221550e1b3..744c0264067 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -53,7 +53,10 @@ def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]: image_array, category_idx = data - return dict(image=Image(image_array), label=Label(category_idx, category=self.categories[category_idx])) + return dict( + image=Image(image_array), + label=Label(category_idx, category=self.categories[category_idx]), + ) def _make_datapipe( self, diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 89811f7660d..43d427f1e92 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -21,7 +21,7 @@ DatasetInfo, HttpResource, OnlineResource, - DecodeableImageStreamWrapper, + RawImage, ) from torchvision.prototype.datasets.utils._internal import ( MappingIterator, @@ -148,7 +148,10 @@ def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: path, buffer = data - return dict(path=path, image=DecodeableImageStreamWrapper(buffer)) + return dict( + path=path, + image=RawImage.fromfile(buffer), + ) def _prepare_sample( self, diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 91f7588f81c..be2c9dd68e4 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -9,7 +9,7 @@ DatasetInfo, OnlineResource, ManualDownloadResource, - DecodeableImageStreamWrapper, + RawImage, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -113,7 +113,10 @@ def _prepare_sample( ) -> Dict[str, Any]: label_data, (path, buffer) = data - sample = dict(path=path, image=DecodeableImageStreamWrapper(buffer)) + sample = dict( + path=path, + image=RawImage.fromfile(buffer), + ) if label_data: sample.update(dict(zip(("label", "category", "wnid"), label_data))) return sample diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 6d2ae4eac25..75740a41bc6 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -18,8 +18,7 @@ DatasetInfo, HttpResource, OnlineResource, - DecodeableStreamWrapper, - DecodeableImageStreamWrapper, + RawImage, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -84,10 +83,10 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st ann_path, ann_buffer = ann_data return dict( + self._decode_ann(ann_buffer), image_path=image_path, - image=DecodeableImageStreamWrapper(image_buffer), + image=RawImage.fromfile(image_buffer), ann_path=ann_path, - ann=DecodeableStreamWrapper(ann_buffer, self._decode_ann), ) def _make_datapipe( diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index d5468c8d401..696e44e4888 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -39,7 +39,11 @@ def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]: label_idx = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label)) return dict( - image=Image(image_data.unsqueeze(0)), label=Label(label_idx, category=self.info.categories[label_idx]) + image=Image(image_data.unsqueeze(0)), + label=Label( + label_idx, + category=self.info.categories[label_idx], + ), ) def _make_datapipe( diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 953af3b4b26..bb23e1605c3 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -19,8 +19,7 @@ DatasetInfo, HttpResource, OnlineResource, - DecodeableImageStreamWrapper, - DecodeableStreamWrapper, + RawImage, ) from torchvision.prototype.datasets.utils._internal import ( path_accessor, @@ -111,12 +110,12 @@ def _prepare_sample( ann_path, ann_buffer = ann_data return dict( + self._decode_detection_ann(ann_buffer) + if config.task == "detection" + else dict(segmentation=RawImage.fromfile(ann_buffer)), image_path=image_path, - image=DecodeableImageStreamWrapper(image_buffer), + image=RawImage.fromfile(image_buffer), ann_path=ann_path, - ann=DecodeableStreamWrapper(ann_buffer, self._decode_detection_ann) - if config.task == "detection" - else DecodeableImageStreamWrapper(ann_buffer), ) def _make_datapipe( diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index da3c1ab142b..57b957cf289 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,11 +1,15 @@ -from . import _internal +from . import _internal # usort: skip from ._dataset import DatasetConfig, DatasetInfo, Dataset from ._decoder import ( - DecodeableStreamWrapper, - DecodeableImageStreamWrapper, + decode_images, decode_sample, - SampleDecoder, decode_image_with_pil, + RawImage, + RawData, + ReadOnlyTensorBuffer, ) from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource + +DecodeableImageStreamWrapper = None +DecodeableStreamWrapper = None diff --git a/torchvision/prototype/datasets/utils/_decoder.py b/torchvision/prototype/datasets/utils/_decoder.py index 257583afe09..373a1787627 100644 --- a/torchvision/prototype/datasets/utils/_decoder.py +++ b/torchvision/prototype/datasets/utils/_decoder.py @@ -1,47 +1,74 @@ +import collections.abc +import sys from typing import Any, Dict, Callable -from typing import BinaryIO +from typing import Type, TypeVar, cast, BinaryIO import PIL.Image -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import Mapper, IterDataPipe +import torch +from torch._C import _TensorBase from torchvision.prototype import features from torchvision.transforms.functional import pil_to_tensor +from ._internal import ReadOnlyTensorBuffer, fromfile -def decode_image_with_pil(buffer: BinaryIO) -> Dict[str, Any]: - return dict(image=features.Image(pil_to_tensor(PIL.Image.open(buffer)))) +class RawData(torch.Tensor): + def __new__(cls, data: torch.Tensor) -> "RawData": + # TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8? + return torch.Tensor._make_subclass( + cast(_TensorBase, cls), + data, + False, # requires_grad + ) -class DecodeableStreamWrapper: - def __init__(self, stream: BinaryIO, decoder: Callable[[BinaryIO], Dict[str, Any]]) -> None: - self.__stream__ = stream - self.__decoder__ = decoder + @classmethod + def fromfile(cls, file: BinaryIO): + return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder)) - # TODO: dispatch attribute access besides `decode` to `__stream__` - def decode(self) -> Dict[str, Any]: - return self.__decoder__(self.__stream__) +class RawImage(RawData): + pass - def unwrap(self) -> BinaryIO: - return self.__stream__ +def decode_image_with_pil(raw_image: RawImage) -> Dict[str, Any]: + return dict(image=features.Image(pil_to_tensor(PIL.Image.open(ReadOnlyTensorBuffer(raw_image))))) -class DecodeableImageStreamWrapper(DecodeableStreamWrapper): - def __init__(self, stream: BinaryIO, decoder: Callable[[BinaryIO], Dict[str, Any]] = decode_image_with_pil) -> None: - super().__init__(stream, decoder) +D = TypeVar("D", bound=RawData) -def decode_sample(sample: Dict[str, Any]) -> Dict[str, Any]: - decoded_sample = dict() - for name, obj in sample.items(): - if isinstance(obj, DecodeableStreamWrapper): - decoded_sample.update(obj.decode()) - else: - decoded_sample[name] = obj - return decoded_sample +def decode_sample( + sample: Any, *, decoder_map: Dict[Type[D], Callable[[D], Dict[str, Any]]], inline_decoded: bool = True +) -> Any: + # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: + # "a" == "a"[0][0]... + if isinstance(sample, collections.abc.Sequence) and not isinstance(sample, str): + return [decode_sample(item, decoder_map=decoder_map, inline_decoded=inline_decoded) for item in sample] + elif isinstance(sample, collections.abc.Mapping): + decoded_sample = {} + for name, item in sample.items(): + decoded_item = decode_sample(item, decoder_map=decoder_map, inline_decoded=inline_decoded) + if inline_decoded and isinstance(item, RawData): + decoded_sample.update(decoded_item) + else: + decoded_sample[name] = decoded_item + return decoded_sample + else: + sample_type = type(sample) + if not issubclass(sample_type, RawData): + return sample -@functional_datapipe("decode_samples") -class SampleDecoder(Mapper[Dict[str, Any]]): - def __init__(self, datapipe: IterDataPipe[Dict[str, Any]]) -> None: - super().__init__(datapipe, decode_sample) + try: + return decoder_map[sample_type](cast(D, sample)) + except KeyError as error: + raise TypeError(f"Unknown type {sample_type}") from error + + +def decode_images(sample: Any, *, inline_decoded=True) -> Any: + return decode_sample( + sample, + decoder_map={ + RawImage: decode_image_with_pil, + }, + inline_decoded=inline_decoded, + ) diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 17df8181762..f888e053eae 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -47,6 +47,7 @@ "Decompressor", "fromfile", "read_flo", + "ReadOnlyTensorBuffer", ] K = TypeVar("K") @@ -301,7 +302,7 @@ def fromfile( buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :] # Reading from the memoryview does not advance the file cursor, so we have to do it manually. file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR)) - except (PermissionError, io.UnsupportedOperation): + except (AttributeError, PermissionError, io.UnsupportedOperation): buffer = _read_mutable_buffer_fallback(file, count, item_size) else: # On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state @@ -321,3 +322,32 @@ def read_flo(file: BinaryIO) -> torch.Tensor: width, height = fromfile(file, dtype=torch.int32, byte_order="little", count=2) flow = fromfile(file, dtype=torch.float32, byte_order="little", count=height * width * 2) return flow.reshape((height, width, 2)).permute((2, 0, 1)) + + +class ReadOnlyTensorBuffer: + def __init__(self, tensor: torch.Tensor) -> None: + self._memory = memoryview(tensor.numpy()) + self._cursor: int = 0 + + def tell(self) -> int: + return self._cursor + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + if whence == io.SEEK_SET: + self._cursor = offset + elif whence == io.SEEK_CUR: + self._cursor += offset + pass + elif whence == io.SEEK_END: + self._cursor = len(self._memory) + offset + else: + raise ValueError( + f"'whence' should be ``{io.SEEK_SET}``, ``{io.SEEK_CUR}``, or ``{io.SEEK_END}``, " + f"but got {repr(whence)} instead" + ) + return self.tell() + + def read(self, size=-1): + cursor = self.tell() + offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR) + return self._memory[slice(cursor, self.seek(offset, whence))].tobytes() From 26a25a5304447710edc77033b63d2b744413b4d0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 21 Dec 2021 15:01:51 +0100 Subject: [PATCH 4/8] cleanup --- .../prototype/datasets/_builtin/celeba.py | 4 ++-- .../prototype/datasets/_builtin/imagenet.py | 12 +++++----- torchvision/prototype/datasets/_folder.py | 19 +++++----------- .../prototype/datasets/utils/__init__.py | 3 --- .../prototype/datasets/utils/_decoder.py | 22 ++++++++++--------- .../prototype/datasets/utils/_internal.py | 2 +- 6 files changed, 27 insertions(+), 35 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 37fdccedde1..d2f2f0d7d81 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -127,7 +127,7 @@ def _prepare_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[s }, ) - def _collate_and_decode_sample( + def _prepare_sample( self, data: Tuple[Tuple[str, Tuple[str, List[str]], Tuple[str, BinaryIO]], Tuple[str, Dict[str, Any]]] ) -> Dict[str, Any]: split_and_image_data, ann_data = data @@ -176,4 +176,4 @@ def _make_datapipe( keep_key=True, ) dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE) - return Mapper(dp, self._collate_and_decode_sample) + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 2142f4ec4b6..8a78de429f7 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -84,7 +84,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") - def _collate_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str, str], Tuple[str, BinaryIO]]: + def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str, str], Tuple[str, BinaryIO]]: path = pathlib.Path(data[0]) wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr] category = self.wnid_to_category[wnid] @@ -97,7 +97,7 @@ def _val_test_image_key(self, data: Tuple[str, Any]) -> int: path = pathlib.Path(data[0]) return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr] - def _collate_val_data( + def _prepare_val_data( self, data: Tuple[Tuple[int, int], Tuple[str, BinaryIO]] ) -> Tuple[Tuple[Label, str, str], Tuple[str, BinaryIO]]: label_data, image_data = data @@ -106,7 +106,7 @@ def _collate_val_data( wnid = self.category_to_wnid[category] return (Label(label), category, wnid), image_data - def _collate_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: + def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: return None, data def _prepare_sample( @@ -133,7 +133,7 @@ def _make_datapipe( dp = TarArchiveReader(images_dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - dp = Mapper(dp, self._collate_train_data) + dp = Mapper(dp, self._prepare_train_data) elif config.split == "val": devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt")) devkit_dp = LineReader(devkit_dp, return_path=False) @@ -149,11 +149,11 @@ def _make_datapipe( ref_key_fn=self._val_test_image_key, buffer_size=INFINITE_BUFFER_SIZE, ) - dp = Mapper(dp, self._collate_val_data) + dp = Mapper(dp, self._prepare_val_data) else: # config.split == "test" dp = hint_sharding(images_dp) dp = hint_shuffling(dp) - dp = Mapper(dp, self._collate_test_data) + dp = Mapper(dp, self._prepare_test_data) return Mapper(dp, self._prepare_sample) # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index d7937e9b9af..7c8a22a829b 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -2,11 +2,11 @@ import os import os.path import pathlib -from typing import BinaryIO, Callable, Optional, Collection, Union, Tuple, List, Dict, Any +from typing import BinaryIO, Optional, Collection, Union, Tuple, List, Dict, Any from torch.utils.data import IterDataPipe from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Filter -from torchvision.prototype.datasets.utils import DecodeableStreamWrapper, decode_image_with_pil +from torchvision.prototype.datasets.utils import RawData, RawImage from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Label @@ -14,10 +14,6 @@ __all__ = ["from_data_folder", "from_image_folder"] -def _read_bytes(buffer: BinaryIO) -> Dict[str, Any]: - return dict(data=buffer.read()) - - def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool: rel_path = pathlib.Path(path).relative_to(root) return rel_path.is_dir() or rel_path.parent != pathlib.Path(".") @@ -28,13 +24,12 @@ def _prepare_sample( *, root: pathlib.Path, categories: List[str], - decoder: Callable[[BinaryIO], Dict[str, Any]], ) -> Dict[str, Any]: path, buffer = data category = pathlib.Path(path).relative_to(root).parts[0] return dict( path=path, - data=DecodeableStreamWrapper(buffer, decoder), + data=RawData.fromfile(buffer), label=Label(categories.index(category), category=category), ) @@ -42,7 +37,6 @@ def _prepare_sample( def from_data_folder( root: Union[str, pathlib.Path], *, - decoder: Callable[[BinaryIO], Dict[str, Any]] = _read_bytes, valid_extensions: Optional[Collection[str]] = None, recursive: bool = True, ) -> Tuple[IterDataPipe, List[str]]: @@ -54,21 +48,20 @@ def from_data_folder( dp = hint_sharding(dp) dp = hint_shuffling(dp) dp = FileLoader(dp) - return Mapper(dp, functools.partial(_prepare_sample, root=root, categories=categories, decoder=decoder)), categories + return Mapper(dp, functools.partial(_prepare_sample, root=root, categories=categories)), categories def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]: - sample["image"] = sample.pop("data") + sample["image"] = RawImage(sample.pop("data").data) return sample def from_image_folder( root: Union[str, pathlib.Path], *, - decoder: Callable[[BinaryIO], Dict[str, Any]] = decode_image_with_pil, valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"), **kwargs: Any, ) -> Tuple[IterDataPipe, List[str]]: valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())] - dp, categories = from_data_folder(root, decoder=decoder, valid_extensions=valid_extensions, **kwargs) + dp, categories = from_data_folder(root, valid_extensions=valid_extensions, **kwargs) return Mapper(dp, _data_to_image_key), categories diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 57b957cf289..6b3d07660a1 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -10,6 +10,3 @@ ) from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource - -DecodeableImageStreamWrapper = None -DecodeableStreamWrapper = None diff --git a/torchvision/prototype/datasets/utils/_decoder.py b/torchvision/prototype/datasets/utils/_decoder.py index 373a1787627..f7ee52b4cbc 100644 --- a/torchvision/prototype/datasets/utils/_decoder.py +++ b/torchvision/prototype/datasets/utils/_decoder.py @@ -11,18 +11,23 @@ from ._internal import ReadOnlyTensorBuffer, fromfile +D = TypeVar("D", bound="RawData") + class RawData(torch.Tensor): def __new__(cls, data: torch.Tensor) -> "RawData": # TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8? - return torch.Tensor._make_subclass( - cast(_TensorBase, cls), - data, - False, # requires_grad + return cast( + RawData, + torch.Tensor._make_subclass( + cast(_TensorBase, cls), + data, + False, # requires_grad + ), ) @classmethod - def fromfile(cls, file: BinaryIO): + def fromfile(cls: Type[D], file: BinaryIO) -> D: return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder)) @@ -34,9 +39,6 @@ def decode_image_with_pil(raw_image: RawImage) -> Dict[str, Any]: return dict(image=features.Image(pil_to_tensor(PIL.Image.open(ReadOnlyTensorBuffer(raw_image))))) -D = TypeVar("D", bound=RawData) - - def decode_sample( sample: Any, *, decoder_map: Dict[Type[D], Callable[[D], Dict[str, Any]]], inline_decoded: bool = True ) -> Any: @@ -59,12 +61,12 @@ def decode_sample( return sample try: - return decoder_map[sample_type](cast(D, sample)) + return decoder_map[cast(Type[D], sample_type)](cast(D, sample)) except KeyError as error: raise TypeError(f"Unknown type {sample_type}") from error -def decode_images(sample: Any, *, inline_decoded=True) -> Any: +def decode_images(sample: Any, *, inline_decoded: bool = True) -> Any: return decode_sample( sample, decoder_map={ diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 8f866b4168c..c52255e602e 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -356,7 +356,7 @@ def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: ) return self.tell() - def read(self, size=-1): + def read(self, size: int = -1) -> bytes: cursor = self.tell() offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR) return self._memory[slice(cursor, self.seek(offset, whence))].tobytes() From 74f6a09dbae35bacb6026b796e50996d6fefe703 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 21 Dec 2021 17:30:54 +0100 Subject: [PATCH 5/8] fix celeba --- .../prototype/datasets/_builtin/celeba.py | 48 ++++++++++++++----- .../prototype/datasets/utils/_decoder.py | 10 ++-- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index d2f2f0d7d81..593f8af6fdd 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -110,16 +110,25 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool: return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split - def _prepare_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[str, Dict[str, Any]]: - (image_id, identity), (_, attributes), (_, bounding_box), (_, landmarks) = data - return image_id, dict( + def _decode_anns( + self, + data: Tuple[ + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + ], + *, + image_size: Tuple[int, int], + ) -> Dict[str, Any]: + (_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = data + return dict( identity=Label(int(identity["identity"])), attributes={attr: value == "1" for attr, value in attributes.items()}, - # FIXME: probe image_size from file bounding_box=BoundingBox( [int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")], format="xywh", - image_size=(-1, -1), + image_size=image_size, ), landmarks={ landmark: Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) @@ -128,17 +137,27 @@ def _prepare_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[s ) def _prepare_sample( - self, data: Tuple[Tuple[str, Tuple[str, List[str]], Tuple[str, BinaryIO]], Tuple[str, Dict[str, Any]]] + self, + data: Tuple[ + Tuple[str, Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]], + Tuple[ + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + ], + ], ) -> Dict[str, Any]: split_and_image_data, ann_data = data - _, _, image_data = split_and_image_data + _, (_, image_data) = split_and_image_data path, buffer = image_data - _, anns = ann_data + + image = RawImage.fromfile(buffer) return dict( - anns, + self._decode_anns(ann_data, image_size=image.probe_image_size()), path=path, - image=RawImage.fromfile(buffer), + image=image, ) def _make_datapipe( @@ -165,7 +184,6 @@ def _make_datapipe( ) ] ) - anns_dp = Mapper(anns_dp, self._prepare_anns) dp = IterKeyZipper( splits_dp, @@ -175,5 +193,11 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, keep_key=True, ) - dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE) + dp = IterKeyZipper( + dp, + anns_dp, + key_fn=getitem(0), + ref_key_fn=getitem(0, 0), + buffer_size=INFINITE_BUFFER_SIZE, + ) return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/utils/_decoder.py b/torchvision/prototype/datasets/utils/_decoder.py index f7ee52b4cbc..5a28e13e7b6 100644 --- a/torchvision/prototype/datasets/utils/_decoder.py +++ b/torchvision/prototype/datasets/utils/_decoder.py @@ -1,7 +1,6 @@ import collections.abc import sys -from typing import Any, Dict, Callable -from typing import Type, TypeVar, cast, BinaryIO +from typing import Any, Dict, Callable, Type, TypeVar, cast, BinaryIO, Tuple import PIL.Image import torch @@ -32,7 +31,12 @@ def fromfile(cls: Type[D], file: BinaryIO) -> D: class RawImage(RawData): - pass + def probe_image_size(self) -> Tuple[int, int]: + if not hasattr(self, "_image_size"): + image = PIL.Image.open(ReadOnlyTensorBuffer(self)) + self._image_size = image.height, image.width + + return self._image_size def decode_image_with_pil(raw_image: RawImage) -> Dict[str, Any]: From 375fefba632e172e378a2a253b341d8fcb08c932 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 22 Dec 2021 15:50:10 +0100 Subject: [PATCH 6/8] fix tests --- test/builtin_dataset_mocks.py | 13 ++----- test/test_prototype_builtin_datasets.py | 39 ++++++++++++++++--- .../prototype/datasets/_builtin/coco.py | 2 +- 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 153094fae07..6129f1892a1 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -16,7 +16,6 @@ from torch.testing import make_tensor as _make_tensor from torchdata.datapipes.iter import IterDataPipe from torchvision.prototype import datasets -from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER from torchvision.prototype.datasets._api import find from torchvision.prototype.utils._internal import add_suggestion @@ -109,21 +108,15 @@ def _get(self, dataset, config, root): self._cache[(name, config)] = mock_resources, mock_info return mock_resources, mock_info - def load( - self, name: str, decoder=DEFAULT_DECODER, split="train", **options: Any - ) -> Tuple[IterDataPipe, Dict[str, Any]]: + def load(self, name: str, **options: Any) -> Tuple[IterDataPipe, Dict[str, Any]]: dataset = find(name) - config = dataset.info.make_config(split=split, **options) + config = dataset.info.make_config(**options) root = self._tmp_home / name root.mkdir(exist_ok=True) resources, mock_info = self._get(dataset, config, root) - datapipe = dataset._make_datapipe( - [resource.load(root) for resource in resources], - config=config, - decoder=DEFAULT_DECODER_MAP.get(dataset.info.type) if decoder is DEFAULT_DECODER else decoder, - ) + datapipe = dataset._make_datapipe([resource.load(root) for resource in resources], config=config) return datapipe, mock_info diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 4248870176f..bacafea2048 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -1,18 +1,36 @@ +import functools import io import builtin_dataset_mocks import pytest import torch +from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair, UnsupportedInputs, ErrorMeta from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.graph import traverse from torchdata.datapipes.iter import IterDataPipe, Shuffler from torchvision.prototype import datasets, transforms -from torchvision.prototype.datasets._api import DEFAULT_DECODER from torchvision.prototype.utils._internal import sequence_to_str -def to_bytes(file): - return file.read() +def patch(fn): + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except ErrorMeta as error: + if error.type is not ValueError: + raise error + + raise UnsupportedInputs() + + return wrapper + + +TensorLikePair._to_tensor = patch(TensorLikePair._to_tensor) + + +assert_samples_equal = functools.partial( + assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True +) def config_id(name, config): @@ -26,7 +44,7 @@ def config_id(name, config): return "-".join(parts) -def dataset_parametrization(*names, decoder=to_bytes): +def dataset_parametrization(*names): if not names: # TODO: Replace this with torchvision.prototype.datasets.list() as soon as all builtin datasets are supported names = ( @@ -46,7 +64,7 @@ def dataset_parametrization(*names, decoder=to_bytes): return pytest.mark.parametrize( ("dataset", "mock_info"), [ - pytest.param(*builtin_dataset_mocks.load(name, decoder=decoder, **config), id=config_id(name, config)) + pytest.param(*builtin_dataset_mocks.load(name, **config), id=config_id(name, config)) for name in names for config in datasets.info(name)._configs ], @@ -89,7 +107,7 @@ def test_decoding(self, dataset, mock_info): f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded." ) - @dataset_parametrization(decoder=DEFAULT_DECODER) + @dataset_parametrization() def test_no_vanilla_tensors(self, dataset, mock_info): vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} if vanilla_tensors: @@ -120,6 +138,15 @@ def scan(graph): else: raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.") + @dataset_parametrization() + def test_save_load(self, dataset, mock_info): + sample = next(iter(dataset)) + + with io.BytesIO() as buffer: + torch.save(sample, buffer) + buffer.seek(0) + assert_samples_equal(torch.load(buffer), sample) + class TestQMNIST: @pytest.mark.parametrize( diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 644b66e584a..7cf6cc3a3dd 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -106,7 +106,7 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st ) ), areas=Feature([ann["area"] for ann in anns]), - crowds=Feature([ann["crowd"] for ann in anns], dtype=torch.bool), + crowds=Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool), bounding_boxes=BoundingBox( [ann["bbox"] for ann in anns], format="xywh", From 667ea7e924575f6867e98ddceb424cbf6fddfef2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 22 Dec 2021 16:16:59 +0100 Subject: [PATCH 7/8] add todo --- test/test_prototype_builtin_datasets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index bacafea2048..04f2dd90b66 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -12,6 +12,7 @@ from torchvision.prototype.utils._internal import sequence_to_str +# TODO: remove this patch after https://github.com/pytorch/pytorch/pull/70304 is merged def patch(fn): def wrapper(*args, **kwargs): try: From 1406bd3d846fb061a9c171b06ae83917ad4b702e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 22 Dec 2021 16:18:42 +0100 Subject: [PATCH 8/8] fix api tests --- test/test_prototype_datasets_api.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/test/test_prototype_datasets_api.py b/test/test_prototype_datasets_api.py index ce50df123cc..33996db0cca 100644 --- a/test/test_prototype_datasets_api.py +++ b/test/test_prototype_datasets_api.py @@ -5,8 +5,8 @@ from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch -def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs): - return datasets.utils.DatasetInfo(name, type=type, categories=categories or [], **kwargs) +def make_minimal_dataset_info(name="name", categories=None, **kwargs): + return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs) class TestFrozenMapping: @@ -188,7 +188,7 @@ def resources(self, config): # This method is just defined to appease the ABC, but will be overwritten at instantiation pass - def _make_datapipe(self, resource_dps, *, config, decoder): + def _make_datapipe(self, resource_dps, *, config): # This method is just defined to appease the ABC, but will be overwritten at instantiation pass @@ -241,12 +241,3 @@ def test_resources(self, mocker): (call_args, _) = dataset._make_datapipe.call_args assert call_args[0][0] is sentinel - - def test_decoder(self): - dataset = self.DatasetMock() - - sentinel = object() - dataset.load("", decoder=sentinel) - - (_, call_kwargs) = dataset._make_datapipe.call_args - assert call_kwargs["decoder"] is sentinel