Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
from torchvision.prototype.datasets.decoder import raw, pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType
from torchvision.prototype.datasets.utils._internal import add_suggestion

from . import _builtin
Expand Down Expand Up @@ -48,15 +48,26 @@ def info(name: str) -> DatasetInfo:
return find(name).info


default = object()

DEFAULT_DECODER: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = {
DatasetType.RAW: raw,
DatasetType.IMAGE: pil,
}


def load(
name: str,
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = default, # type: ignore[assignment]
split: str = "train",
**options: Any,
) -> IterDataPipe[Dict[str, Any]]:
dataset = find(name)

if decoder is default:
decoder = DEFAULT_DECODER.get(dataset.info.type)

config = dataset.info.make_config(split=split, **options)
root = home() / name

Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .caltech import Caltech101, Caltech256
from .cifar import Cifar10, Cifar100
3 changes: 3 additions & 0 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import create_categories_file, INFINITE_BUFFER_SIZE, read_mat

Expand All @@ -30,6 +31,7 @@ class Caltech101(Dataset):
def info(self) -> DatasetInfo:
return DatasetInfo(
"caltech101",
type=DatasetType.IMAGE,
categories=HERE / "caltech101.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
)
Expand Down Expand Up @@ -146,6 +148,7 @@ class Caltech256(Dataset):
def info(self) -> DatasetInfo:
return DatasetInfo(
"caltech256",
type=DatasetType.IMAGE,
categories=HERE / "caltech256.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
)
Expand Down
227 changes: 227 additions & 0 deletions torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import abc
import functools
import io
import pathlib
import pickle
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TypeVar

import numpy as np
import torch
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import (
Demultiplexer,
Filter,
Mapper,
TarArchiveReader,
Shuffler,
)
from torchdata.datapipes.iter import KeyZipper
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
create_categories_file,
MappingIterator,
SequenceIterator,
INFINITE_BUFFER_SIZE,
image_buffer_from_array,
Enumerator,
)

__all__ = ["Cifar10", "Cifar100"]

HERE = pathlib.Path(__file__).parent

D = TypeVar("D")


class _CifarBase(Dataset):
@abc.abstractmethod
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> Optional[int]:
pass

@abc.abstractmethod
def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]:
pass

def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
return pickle.load(file, encoding="latin1")

def _remove_data_dict_key(self, data: Tuple[str, D]) -> D:
return data[1]

def _key_fn(self, data: Tuple[int, Any]) -> int:
return data[0]

def _collate_and_decode(
self,
data: Tuple[Tuple[int, int], Tuple[int, np.ndarray]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
(_, category_idx), (_, image_array_flat) = data

category = self.categories[category_idx]
label = torch.tensor(category_idx)

image_array = image_array_flat.reshape((3, 32, 32))
image: Union[torch.Tensor, io.BytesIO]
if decoder is raw:
image = torch.from_numpy(image_array)
else:
image_buffer = image_buffer_from_array(image_array.transpose(1, 2, 0))
image = decoder(image_buffer) if decoder else image_buffer

return dict(label=label, category=category, image=image)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
archive_dp = TarArchiveReader(archive_dp)
archive_dp: IterDataPipe = Filter(archive_dp, functools.partial(self._is_data_file, config=config))
archive_dp: IterDataPipe = Mapper(archive_dp, self._unpickle)
Comment on lines +92 to +94
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ejguan any idea how to appease mypy here, without slapping : IterDataPipe everywhere? Otherwise I'm inclined to blanket ignore var-annotated here.

Copy link
Contributor

@ejguan ejguan Oct 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The easiest way should be adding annotation to the variable at the beginning:

Suggested change
archive_dp = TarArchiveReader(archive_dp)
archive_dp: IterDataPipe = Filter(archive_dp, functools.partial(self._is_data_file, config=config))
archive_dp: IterDataPipe = Mapper(archive_dp, self._unpickle)
archive_dp: IterDataPipe
archive_dp = resource_dps[0]
archive_dp = TarArchiveReader(archive_dp)
archive_dp = Filter(archive_dp, functools.partial(self._is_data_file, config=config))
archive_dp = Mapper(archive_dp, self._unpickle)

archive_dp = MappingIterator(archive_dp)
images_dp, labels_dp = Demultiplexer(
archive_dp,
2,
self._split_data_file, # type: ignore[arg-type]
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)

labels_dp: IterDataPipe = Mapper(labels_dp, self._remove_data_dict_key)
labels_dp: IterDataPipe = SequenceIterator(labels_dp)
labels_dp = Enumerator(labels_dp)
labels_dp = Shuffler(labels_dp, buffer_size=INFINITE_BUFFER_SIZE)

images_dp: IterDataPipe = Mapper(images_dp, self._remove_data_dict_key)
images_dp: IterDataPipe = SequenceIterator(images_dp)
images_dp = Enumerator(images_dp)

dp = KeyZipper(labels_dp, images_dp, self._key_fn, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))

@property
@abc.abstractmethod
def _meta_file_name(self) -> str:
pass

@property
@abc.abstractmethod
def _categories_key(self) -> str:
pass

def _is_meta_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name == self._meta_file_name

def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_meta_file)
dp: IterDataPipe = Mapper(dp, self._unpickle)
categories = next(iter(dp))[self._categories_key]
create_categories_file(HERE, self.name, categories)


class Cifar10(_CifarBase):
@property
def info(self) -> DatasetInfo:
return DatasetInfo(
"cifar10",
type=DatasetType.RAW,
categories=HERE / "cifar10.categories",
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
)

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
"https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz",
sha256="6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce",
)
]

def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name.startswith("data" if config.split == "train" else "test")

def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data
if key == "data":
return 0
elif key == "labels":
return 1
else:
return None

@property
def _meta_file_name(self) -> str:
return "batches.meta"

@property
def _categories_key(self) -> str:
return "label_names"


class Cifar100(_CifarBase):
@property
def info(self) -> DatasetInfo:
return DatasetInfo(
"cifar100",
type=DatasetType.RAW,
categories=HERE / "cifar100.categories",
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
valid_options=dict(
split=("train", "test"),
),
)

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
"https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz",
sha256="85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7",
)
]

def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name == config.split

def _split_data_file(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data
if key == "data":
return 0
elif key == "fine_labels":
return 1
else:
return None

@property
def _meta_file_name(self) -> str:
return "meta"

@property
def _categories_key(self) -> str:
return "fine_label_names"


if __name__ == "__main__":
from torchvision.prototype.datasets import home

root = home()
Cifar10().generate_categories_file(root)
Cifar100().generate_categories_file(root)
10 changes: 10 additions & 0 deletions torchvision/prototype/datasets/_builtin/cifar10.categories
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
Loading