diff --git a/torchvision/prototype/__init__.py b/torchvision/prototype/__init__.py deleted file mode 100644 index 62a23dfafae..00000000000 --- a/torchvision/prototype/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import datasets diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py deleted file mode 100644 index 6fa4d8dbc8f..00000000000 --- a/torchvision/prototype/datasets/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -try: - import torchdata -except (ModuleNotFoundError, TypeError) as error: - raise ModuleNotFoundError( - "`torchvision.prototype.datasets` depends on PyTorch's `torchdata` (https://github.com/pytorch/data). " - "You can install it with `pip install git+https://github.com/pytorch/data.git`. " - "Note that you cannot install it with `pip install torchdata`, since this is another package." - ) from error - - -from ._home import home -from . import decoder, utils - -# Load this last, since some parts depend on the above being loaded first -from ._api import register, _list as list, info, load -from ._folder import from_data_folder, from_image_folder diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py deleted file mode 100644 index 97d95eef6c1..00000000000 --- a/torchvision/prototype/datasets/_api.py +++ /dev/null @@ -1,68 +0,0 @@ -import io -from typing import Any, Callable, Dict, List, Optional - -import torch -from torch.utils.data import IterDataPipe - -from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.decoder import pil -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo -from torchvision.prototype.datasets.utils._internal import add_suggestion -from . import _builtin - -DATASETS: Dict[str, Dataset] = {} - - -def register(dataset: Dataset) -> None: - DATASETS[dataset.name] = dataset - - -for name, obj in _builtin.__dict__.items(): - if ( - not name.startswith("_") - and isinstance(obj, type) - and issubclass(obj, Dataset) - and obj is not Dataset - ): - register(obj()) - - -# This is exposed as 'list', but we avoid that here to not shadow the built-in 'list' -def _list() -> List[str]: - return sorted(DATASETS.keys()) - - -def find(name: str) -> Dataset: - name = name.lower() - try: - return DATASETS[name] - except KeyError as error: - raise ValueError( - add_suggestion( - f"Unknown dataset '{name}'.", - word=name, - possibilities=DATASETS.keys(), - alternative_hint=lambda _: ( - "You can use torchvision.datasets.list() to get a list of all available datasets." - ), - ) - ) from error - - -def info(name: str) -> DatasetInfo: - return find(name).info - - -def load( - name: str, - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, - split: str = "train", - **options: Any, -) -> IterDataPipe[Dict[str, Any]]: - dataset = find(name) - - config = dataset.info.make_config(split=split, **options) - root = home() / name - - return dataset.to_datapipe(root, config=config, decoder=decoder) diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py deleted file mode 100644 index 7d6961fa920..00000000000 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .caltech import Caltech101, Caltech256 diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py deleted file mode 100644 index 7f6021522c8..00000000000 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ /dev/null @@ -1,209 +0,0 @@ -import io -import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import re - -import numpy as np - -import torch -from torch.utils.data import IterDataPipe -from torch.utils.data.datapipes.iter import ( - Mapper, - TarArchiveReader, - Shuffler, - Filter, -) - -from torchdata.datapipes.iter import KeyZipper -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) -from torchvision.prototype.datasets.utils._internal import create_categories_file, INFINITE_BUFFER_SIZE, read_mat - -HERE = pathlib.Path(__file__).parent - - -class Caltech101(Dataset): - @property - def info(self) -> DatasetInfo: - return DatasetInfo( - "caltech101", - categories=HERE / "caltech101.categories", - homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", - ) - - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - images = HttpResource( - "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", - sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926", - ) - anns = HttpResource( - "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", - sha256="1717f4e10aa837b05956e3f4c94456527b143eec0d95e935028b30aff40663d8", - ) - return [images, anns] - - _IMAGES_NAME_PATTERN = re.compile(r"image_(?P\d+)[.]jpg") - _ANNS_NAME_PATTERN = re.compile(r"annotation_(?P\d+)[.]mat") - _ANNS_CATEGORY_MAP = { - "Faces_2": "Faces", - "Faces_3": "Faces_easy", - "Motorbikes_16": "Motorbikes", - "Airplanes_Side_2": "airplanes", - } - - def _is_not_background_image(self, data: Tuple[str, Any]) -> bool: - path = pathlib.Path(data[0]) - return path.parent.name != "BACKGROUND_Google" - - def _is_ann(self, data: Tuple[str, Any]) -> bool: - path = pathlib.Path(data[0]) - return bool(self._ANNS_NAME_PATTERN.match(path.name)) - - def _images_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: - path = pathlib.Path(data[0]) - - category = path.parent.name - id = self._IMAGES_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr] - - return category, id - - def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: - path = pathlib.Path(data[0]) - - category = path.parent.name - if category in self._ANNS_CATEGORY_MAP: - category = self._ANNS_CATEGORY_MAP[category] - - id = self._ANNS_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr] - - return category, id - - def _collate_and_decode_sample( - self, data, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] - ) -> Dict[str, Any]: - key, image_data, ann_data = data - category, _ = key - image_path, image_buffer = image_data - ann_path, ann_buffer = ann_data - - label = self.info.categories.index(category) - - image = decoder(image_buffer) if decoder else image_buffer - - ann = read_mat(ann_buffer) - bbox = torch.as_tensor(ann["box_coord"].astype(np.int64)) - contour = torch.as_tensor(ann["obj_contour"]) - - return dict( - category=category, - label=label, - image=image, - image_path=image_path, - bbox=bbox, - contour=contour, - ann_path=ann_path, - ) - - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> IterDataPipe[Dict[str, Any]]: - images_dp, anns_dp = resource_dps - - images_dp = TarArchiveReader(images_dp) - images_dp = Filter(images_dp, self._is_not_background_image) - # FIXME: add this after https://github.com/pytorch/pytorch/issues/65808 is resolved - # images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE) - - anns_dp = TarArchiveReader(anns_dp) - anns_dp = Filter(anns_dp, self._is_ann) - - dp = KeyZipper( - images_dp, - anns_dp, - key_fn=self._images_key_fn, - ref_key_fn=self._anns_key_fn, - buffer_size=INFINITE_BUFFER_SIZE, - keep_key=True, - ) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) - - def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: - dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) - dp = TarArchiveReader(dp) - dp = Filter(dp, self._is_not_background_image) - dir_names = {pathlib.Path(path).parent.name for path, _ in dp} - create_categories_file(HERE, self.name, sorted(dir_names)) - - -class Caltech256(Dataset): - @property - def info(self) -> DatasetInfo: - return DatasetInfo( - "caltech256", - categories=HERE / "caltech256.categories", - homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256", - ) - - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - return [ - HttpResource( - "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", - sha256="08ff01b03c65566014ae88eb0490dbe4419fc7ac4de726ee1163e39fd809543e", - ) - ] - - def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool: - path = pathlib.Path(data[0]) - return path.name != "RENAME2" - - def _collate_and_decode_sample( - self, - data: Tuple[str, io.IOBase], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: - path, buffer = data - - dir_name = pathlib.Path(path).parent.name - label_str, category = dir_name.split(".") - label = torch.tensor(int(label_str)) - - return dict(label=label, category=category, image=decoder(buffer) if decoder else buffer) - - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> IterDataPipe[Dict[str, Any]]: - dp = resource_dps[0] - dp = TarArchiveReader(dp) - dp = Filter(dp, self._is_not_rogue_file) - # FIXME: add this after https://github.com/pytorch/pytorch/issues/65808 is resolved - # dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) - return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) - - def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: - dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) - dp = TarArchiveReader(dp) - dir_names = {pathlib.Path(path).parent.name for path, _ in dp} - categories = [name.split(".")[1] for name in sorted(dir_names)] - create_categories_file(HERE, self.name, categories) - - -if __name__ == "__main__": - from torchvision.prototype.datasets import home - - root = home() - Caltech101().generate_categories_file(root) - Caltech256().generate_categories_file(root) diff --git a/torchvision/prototype/datasets/_builtin/caltech101.categories b/torchvision/prototype/datasets/_builtin/caltech101.categories deleted file mode 100644 index d5c18654b4e..00000000000 --- a/torchvision/prototype/datasets/_builtin/caltech101.categories +++ /dev/null @@ -1,101 +0,0 @@ -Faces -Faces_easy -Leopards -Motorbikes -accordion -airplanes -anchor -ant -barrel -bass -beaver -binocular -bonsai -brain -brontosaurus -buddha -butterfly -camera -cannon -car_side -ceiling_fan -cellphone -chair -chandelier -cougar_body -cougar_face -crab -crayfish -crocodile -crocodile_head -cup -dalmatian -dollar_bill -dolphin -dragonfly -electric_guitar -elephant -emu -euphonium -ewer -ferry -flamingo -flamingo_head -garfield -gerenuk -gramophone -grand_piano -hawksbill -headphone -hedgehog -helicopter -ibis -inline_skate -joshua_tree -kangaroo -ketch -lamp -laptop -llama -lobster -lotus -mandolin -mayfly -menorah -metronome -minaret -nautilus -octopus -okapi -pagoda -panda -pigeon -pizza -platypus -pyramid -revolver -rhino -rooster -saxophone -schooner -scissors -scorpion -sea_horse -snoopy -soccer_ball -stapler -starfish -stegosaurus -stop_sign -strawberry -sunflower -tick -trilobite -umbrella -watch -water_lilly -wheelchair -wild_cat -windsor_chair -wrench -yin_yang diff --git a/torchvision/prototype/datasets/_builtin/caltech256.categories b/torchvision/prototype/datasets/_builtin/caltech256.categories deleted file mode 100644 index 82128efba97..00000000000 --- a/torchvision/prototype/datasets/_builtin/caltech256.categories +++ /dev/null @@ -1,257 +0,0 @@ -ak47 -american-flag -backpack -baseball-bat -baseball-glove -basketball-hoop -bat -bathtub -bear -beer-mug -billiards -binoculars -birdbath -blimp -bonsai-101 -boom-box -bowling-ball -bowling-pin -boxing-glove -brain-101 -breadmaker -buddha-101 -bulldozer -butterfly -cactus -cake -calculator -camel -cannon -canoe -car-tire -cartman -cd -centipede -cereal-box -chandelier-101 -chess-board -chimp -chopsticks -cockroach -coffee-mug -coffin -coin -comet -computer-keyboard -computer-monitor -computer-mouse -conch -cormorant -covered-wagon -cowboy-hat -crab-101 -desk-globe -diamond-ring -dice -dog -dolphin-101 -doorknob -drinking-straw -duck -dumb-bell -eiffel-tower -electric-guitar-101 -elephant-101 -elk -ewer-101 -eyeglasses -fern -fighter-jet -fire-extinguisher -fire-hydrant -fire-truck -fireworks -flashlight -floppy-disk -football-helmet -french-horn -fried-egg -frisbee -frog -frying-pan -galaxy -gas-pump -giraffe -goat -golden-gate-bridge -goldfish -golf-ball -goose -gorilla -grand-piano-101 -grapes -grasshopper -guitar-pick -hamburger -hammock -harmonica -harp -harpsichord -hawksbill-101 -head-phones -helicopter-101 -hibiscus -homer-simpson -horse -horseshoe-crab -hot-air-balloon -hot-dog -hot-tub -hourglass -house-fly -human-skeleton -hummingbird -ibis-101 -ice-cream-cone -iguana -ipod -iris -jesus-christ -joy-stick -kangaroo-101 -kayak -ketch-101 -killer-whale -knife -ladder -laptop-101 -lathe -leopards-101 -license-plate -lightbulb -light-house -lightning -llama-101 -mailbox -mandolin -mars -mattress -megaphone -menorah-101 -microscope -microwave -minaret -minotaur -motorbikes-101 -mountain-bike -mushroom -mussels -necktie -octopus -ostrich -owl -palm-pilot -palm-tree -paperclip -paper-shredder -pci-card -penguin -people -pez-dispenser -photocopier -picnic-table -playing-card -porcupine -pram -praying-mantis -pyramid -raccoon -radio-telescope -rainbow -refrigerator -revolver-101 -rifle -rotary-phone -roulette-wheel -saddle -saturn -school-bus -scorpion-101 -screwdriver -segway -self-propelled-lawn-mower -sextant -sheet-music -skateboard -skunk -skyscraper -smokestack -snail -snake -sneaker -snowmobile -soccer-ball -socks -soda-can -spaghetti -speed-boat -spider -spoon -stained-glass -starfish-101 -steering-wheel -stirrups -sunflower-101 -superman -sushi -swan -swiss-army-knife -sword -syringe -tambourine -teapot -teddy-bear -teepee -telephone-box -tennis-ball -tennis-court -tennis-racket -theodolite -toaster -tomato -tombstone -top-hat -touring-bike -tower-pisa -traffic-light -treadmill -triceratops -tricycle -trilobite-101 -tripod -t-shirt -tuning-fork -tweezer -umbrella-101 -unicorn -vcr -video-projector -washing-machine -watch-101 -waterfall -watermelon -welding-mask -wheelbarrow -windmill -wine-bottle -xylophone -yarmulke -yo-yo -zebra -airplanes-101 -car-side-101 -faces-easy-101 -greyhound -tennis-shoes -toad -clutter diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py deleted file mode 100644 index 5626f68650f..00000000000 --- a/torchvision/prototype/datasets/_folder.py +++ /dev/null @@ -1,77 +0,0 @@ -import io -import os -import os.path -import pathlib -from typing import Callable, Optional, Collection -from typing import Union, Tuple, List, Dict, Any - -import torch -from torch.utils.data import IterDataPipe -from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter - -from torchvision.prototype.datasets.decoder import pil -from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE - - -__all__ = ["from_data_folder", "from_image_folder"] - - -def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool: - rel_path = pathlib.Path(path).relative_to(root) - return rel_path.is_dir() or rel_path.parent != pathlib.Path(".") - - -def _collate_and_decode_data( - data: Tuple[str, io.IOBase], - *, - root: pathlib.Path, - categories: List[str], - decoder, -) -> Dict[str, Any]: - path, buffer = data - data = decoder(buffer) if decoder else buffer - category = pathlib.Path(path).relative_to(root).parts[0] - label = torch.tensor(categories.index(category)) - return dict( - path=path, - data=data, - label=label, - category=category, - ) - - -def from_data_folder( - root: Union[str, pathlib.Path], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, - valid_extensions: Optional[Collection[str]] = None, - recursive: bool = True, -) -> Tuple[IterDataPipe, List[str]]: - root = pathlib.Path(root).expanduser().resolve() - categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) - masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" - dp: IterDataPipe = FileLister(str(root), recursive=recursive, masks=masks) - dp = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root)) - dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) - dp = FileLoader(dp) - return ( - Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)), - categories, - ) - - -def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]: - sample["image"] = sample.pop("data") - return sample - - -def from_image_folder( - root: Union[str, pathlib.Path], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, - valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"), - **kwargs: Any, -) -> Tuple[IterDataPipe, List[str]]: - valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())] - dp, categories = from_data_folder(root, decoder=decoder, valid_extensions=valid_extensions, **kwargs) - return Mapper(dp, _data_to_image_key), categories diff --git a/torchvision/prototype/datasets/_home.py b/torchvision/prototype/datasets/_home.py deleted file mode 100644 index 535d35294b9..00000000000 --- a/torchvision/prototype/datasets/_home.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -import pathlib -from typing import Optional, Union - -from torch.hub import _get_torch_home - -HOME = pathlib.Path(_get_torch_home()) / "datasets" / "vision" - - -def home(root: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path: - global HOME - if root is not None: - HOME = pathlib.Path(root).expanduser().resolve() - return HOME - - root = os.getenv("TORCHVISION_DATASETS_HOME") - if root is not None: - return pathlib.Path(root) - - return HOME diff --git a/torchvision/prototype/datasets/decoder.py b/torchvision/prototype/datasets/decoder.py deleted file mode 100644 index d4897bebc91..00000000000 --- a/torchvision/prototype/datasets/decoder.py +++ /dev/null @@ -1,12 +0,0 @@ -import io - -import PIL.Image -import torch - -from torchvision.transforms.functional import pil_to_tensor - -__all__ = ["pil"] - - -def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor: - return pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper())) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py deleted file mode 100644 index 48e7541eba5..00000000000 --- a/torchvision/prototype/datasets/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . import _internal -from ._dataset import DatasetConfig, DatasetInfo, Dataset -from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py deleted file mode 100644 index 19fb3b1d596..00000000000 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ /dev/null @@ -1,220 +0,0 @@ -import abc -import io -import os -import pathlib -import textwrap -from collections import Mapping -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Sequence, - Union, - NoReturn, - Iterable, - Tuple, -) - -import torch -from torch.utils.data import IterDataPipe - -from torchvision.prototype.datasets.utils._internal import ( - add_suggestion, - sequence_to_str, -) -from ._resource import OnlineResource - - -def make_repr(name: str, items: Iterable[Tuple[str, Any]]): - def to_str(sep: str) -> str: - return sep.join([f"{key}={value}" for key, value in items]) - - prefix = f"{name}(" - postfix = ")" - body = to_str(", ") - - line_length = int(os.environ.get("COLUMNS", 80)) - body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length - multiline_body = len(str(body).splitlines()) > 1 - if not (body_too_long or multiline_body): - return prefix + body + postfix - - body = textwrap.indent(to_str(",\n"), " " * 2) - return f"{prefix}\n{body}\n{postfix}" - - -class DatasetConfig(Mapping): - def __init__(self, *args, **kwargs): - data = dict(*args, **kwargs) - self.__dict__["__data__"] = data - self.__dict__["__final_hash__"] = hash(tuple(data.items())) - - def __getitem__(self, name: str) -> Any: - return self.__dict__["__data__"][name] - - def __iter__(self): - return iter(self.__dict__["__data__"].keys()) - - def __len__(self): - return len(self.__dict__["__data__"]) - - def __getattr__(self, name: str) -> Any: - try: - return self[name] - except KeyError as error: - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'" - ) from error - - def __setitem__(self, key: Any, value: Any) -> NoReturn: - raise RuntimeError(f"'{type(self).__name__}' object is immutable") - - def __setattr__(self, key: Any, value: Any) -> NoReturn: - raise RuntimeError(f"'{type(self).__name__}' object is immutable") - - def __delitem__(self, key: Any) -> NoReturn: - raise RuntimeError(f"'{type(self).__name__}' object is immutable") - - def __delattr__(self, item: Any) -> NoReturn: - raise RuntimeError(f"'{type(self).__name__}' object is immutable") - - def __hash__(self) -> int: - return self.__dict__["__final_hash__"] - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, DatasetConfig): - return NotImplemented - - return hash(self) == hash(other) - - def __repr__(self) -> str: - return make_repr(type(self).__name__, self.items()) - - -class DatasetInfo: - def __init__( - self, - name: str, - *, - categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, - citation: Optional[str] = None, - homepage: Optional[str] = None, - license: Optional[str] = None, - valid_options: Optional[Dict[str, Sequence]] = None, - ) -> None: - self.name = name.lower() - - if categories is None: - categories = [] - elif isinstance(categories, int): - categories = [str(label) for label in range(categories)] - elif isinstance(categories, (str, pathlib.Path)): - with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh: - categories = [line.strip() for line in fh] - self.categories = categories - - self.citation = citation - self.homepage = homepage - self.license = license - - valid_split: Dict[str, Sequence] = dict(split=["train"]) - if valid_options is None: - valid_options = valid_split - elif "split" not in valid_options: - valid_options.update(valid_split) - elif "train" not in valid_options["split"]: - raise ValueError( - f"'train' has to be a valid argument for option 'split', " - f"but found only {sequence_to_str(valid_options['split'], separate_last='and ')}." - ) - self._valid_options: Dict[str, Sequence] = valid_options - - @property - def default_config(self) -> DatasetConfig: - return DatasetConfig( - {name: valid_args[0] for name, valid_args in self._valid_options.items()} - ) - - def make_config(self, **options: Any) -> DatasetConfig: - for name, arg in options.items(): - if name not in self._valid_options: - raise ValueError( - add_suggestion( - f"Unknown option '{name}' of dataset {self.name}.", - word=name, - possibilities=sorted(self._valid_options.keys()), - ) - ) - - valid_args = self._valid_options[name] - - if arg not in valid_args: - raise ValueError( - add_suggestion( - f"Invalid argument '{arg}' for option '{name}' of dataset {self.name}.", - word=arg, - possibilities=valid_args, - ) - ) - - return DatasetConfig(self.default_config, **options) - - def __repr__(self) -> str: - items = [("name", self.name)] - for key in ("citation", "homepage", "license"): - value = getattr(self, key) - if value is not None: - items.append((key, value)) - items.extend( - sorted( - (key, sequence_to_str(value)) - for key, value in self._valid_options.items() - ) - ) - return make_repr(type(self).__name__, items) - - -class Dataset(abc.ABC): - @property - @abc.abstractmethod - def info(self) -> DatasetInfo: - pass - - @property - def name(self) -> str: - return self.info.name - - @property - def default_config(self) -> DatasetConfig: - return self.info.default_config - - @abc.abstractmethod - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - pass - - @abc.abstractmethod - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> IterDataPipe[Dict[str, Any]]: - pass - - def to_datapipe( - self, - root: Union[str, pathlib.Path], - *, - config: Optional[DatasetConfig] = None, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, - ) -> IterDataPipe[Dict[str, Any]]: - if not config: - config = self.info.default_config - - resource_dps = [ - resource.to_datapipe(root) for resource in self.resources(config) - ] - return self._make_datapipe(resource_dps, config=config, decoder=decoder) diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py deleted file mode 100644 index 56c9a2d8c07..00000000000 --- a/torchvision/prototype/datasets/utils/_internal.py +++ /dev/null @@ -1,68 +0,0 @@ -import collections.abc -import difflib -import io -import pathlib -from typing import Collection, Sequence, Callable, Union, Any - - -__all__ = [ - "INFINITE_BUFFER_SIZE", - "sequence_to_str", - "add_suggestion", - "create_categories_file", - "read_mat" -] - -# pseudo-infinite until a true infinite buffer is supported by all datapipes -INFINITE_BUFFER_SIZE = 1_000_000_000 - - -def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: - if len(seq) == 1: - return f"'{seq[0]}'" - - return ( - f"""'{"', '".join([str(item) for item in seq[:-1]])}', """ - f"""{separate_last}'{seq[-1]}'.""" - ) - - -def add_suggestion( - msg: str, - *, - word: str, - possibilities: Collection[str], - close_match_hint: Callable[ - [str], str - ] = lambda close_match: f"Did you mean '{close_match}'?", - alternative_hint: Callable[ - [Sequence[str]], str - ] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.", -) -> str: - if not isinstance(possibilities, collections.abc.Sequence): - possibilities = sorted(possibilities) - suggestions = difflib.get_close_matches(word, possibilities, 1) - hint = ( - close_match_hint(suggestions[0]) - if suggestions - else alternative_hint(possibilities) - ) - return f"{msg.strip()} {hint}" - - -def create_categories_file( - root: Union[str, pathlib.Path], name: str, categories: Sequence[str] -) -> None: - with open(pathlib.Path(root) / f"{name}.categories", "w") as fh: - fh.write("\n".join(categories) + "\n") - - -def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: - try: - import scipy.io as sio - except ImportError as error: - raise ModuleNotFoundError( - "Package `scipy` is required to be installed to read .mat files." - ) from error - - return sio.loadmat(buffer, **kwargs) diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py deleted file mode 100644 index 3f372d0f5b7..00000000000 --- a/torchvision/prototype/datasets/utils/_resource.py +++ /dev/null @@ -1,54 +0,0 @@ -import os.path -import pathlib -from typing import Optional, Union -from urllib.parse import urlparse - -from torch.utils.data import IterDataPipe -from torch.utils.data.datapipes.iter import FileLoader, IterableWrapper - - -# FIXME -def compute_sha256(_) -> str: - return "" - - -class LocalResource: - def __init__( - self, path: Union[str, pathlib.Path], *, sha256: Optional[str] = None - ) -> None: - self.path = pathlib.Path(path).expanduser().resolve() - self.file_name = self.path.name - self.sha256 = sha256 or compute_sha256(self.path) - - def to_datapipe(self) -> IterDataPipe: - return FileLoader(IterableWrapper((str(self.path),))) - - -class OnlineResource: - def __init__(self, url: str, *, sha256: str, file_name: str) -> None: - self.url = url - self.sha256 = sha256 - self.file_name = file_name - - def to_datapipe(self, root: Union[str, pathlib.Path]) -> IterDataPipe: - path = (pathlib.Path(root) / self.file_name).expanduser().resolve() - # FIXME - return FileLoader(IterableWrapper((str(path),))) - - -# TODO: add support for mirrors -# TODO: add support for http -> https -class HttpResource(OnlineResource): - def __init__( - self, url: str, *, sha256: str, file_name: Optional[str] = None - ) -> None: - if not file_name: - file_name = os.path.basename(urlparse(url).path) - super().__init__(url, sha256=sha256, file_name=file_name) - - -class GDriveResource(OnlineResource): - def __init__(self, id: str, *, sha256: str, file_name: str) -> None: - # TODO: can we maybe do a head request to extract the file name? - url = f"https://drive.google.com/file/d/{id}/view" - super().__init__(url, sha256=sha256, file_name=file_name)