From 36902b15a8094d4fb852fad714e082d233a8831c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 29 Sep 2021 13:26:37 +0200 Subject: [PATCH 1/2] add prototype for `Caltech256` dataset --- setup.py | 2 +- torchvision/prototype/datasets/_api.py | 12 +- .../prototype/datasets/_builtin/__init__.py | 1 + .../prototype/datasets/_builtin/caltech.py | 209 ++++++++++++++ .../datasets/_builtin/caltech101.categories | 101 +++++++ .../datasets/_builtin/caltech256.categories | 257 ++++++++++++++++++ torchvision/prototype/datasets/decoder.py | 4 +- .../prototype/datasets/utils/_dataset.py | 2 +- .../prototype/datasets/utils/_internal.py | 24 +- 9 files changed, 606 insertions(+), 6 deletions(-) create mode 100644 torchvision/prototype/datasets/_builtin/__init__.py create mode 100644 torchvision/prototype/datasets/_builtin/caltech.py create mode 100644 torchvision/prototype/datasets/_builtin/caltech101.categories create mode 100644 torchvision/prototype/datasets/_builtin/caltech256.categories diff --git a/setup.py b/setup.py index 981c7e33563..4c9e734f31b 100644 --- a/setup.py +++ b/setup.py @@ -495,7 +495,7 @@ def run(self): # Package info packages=find_packages(exclude=('test',)), package_data={ - package_name: ['*.dll', '*.dylib', '*.so'] + package_name: ['*.dll', '*.dylib', '*.so', '*.categories'] }, zip_safe=False, install_requires=requirements, diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 29dce26dd0c..97d95eef6c1 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -8,7 +8,7 @@ from torchvision.prototype.datasets.decoder import pil from torchvision.prototype.datasets.utils import Dataset, DatasetInfo from torchvision.prototype.datasets.utils._internal import add_suggestion - +from . import _builtin DATASETS: Dict[str, Dataset] = {} @@ -17,6 +17,16 @@ def register(dataset: Dataset) -> None: DATASETS[dataset.name] = dataset +for name, obj in _builtin.__dict__.items(): + if ( + not name.startswith("_") + and isinstance(obj, type) + and issubclass(obj, Dataset) + and obj is not Dataset + ): + register(obj()) + + # This is exposed as 'list', but we avoid that here to not shadow the built-in 'list' def _list() -> List[str]: return sorted(DATASETS.keys()) diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py new file mode 100644 index 00000000000..7d6961fa920 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -0,0 +1 @@ +from .caltech import Caltech101, Caltech256 diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py new file mode 100644 index 00000000000..3d747e88fef --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -0,0 +1,209 @@ +import io +import pathlib +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import re + +import numpy as np + +import torch +from torch.utils.data import IterDataPipe +from torch.utils.data.datapipes.iter import ( + Mapper, + TarArchiveReader, + Shuffler, + Filter, +) + +from torchdata.datapipes.iter import KeyZipper +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, +) +from torchvision.prototype.datasets.utils._internal import create_categories_file, INFINITE_BUFFER_SIZE, read_mat + +HERE = pathlib.Path(__file__).parent + + +class Caltech101(Dataset): + @property + def info(self) -> DatasetInfo: + return DatasetInfo( + "caltech101", + categories=HERE / "caltech101.categories", + homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + images = HttpResource( + "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", + sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926", + ) + anns = HttpResource( + "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", + sha256="1717f4e10aa837b05956e3f4c94456527b143eec0d95e935028b30aff40663d8", + ) + return [images, anns] + + _IMAGES_NAME_PATTERN = re.compile(r"image_(?P\d+)[.]jpg") + _ANNS_NAME_PATTERN = re.compile(r"annotation_(?P\d+)[.]mat") + _ANNS_CATEGORY_MAP = { + "Faces_2": "Faces", + "Faces_3": "Faces_easy", + "Motorbikes_16": "Motorbikes", + "Airplanes_Side_2": "airplanes", + } + + def _is_not_background_image(self, data: Tuple[str, Any]) -> bool: + path = pathlib.Path(data[0]) + return path.parent.name != "BACKGROUND_Google" + + def _is_ann(self, data: Tuple[str, Any]) -> bool: + path = pathlib.Path(data[0]) + return bool(self._ANNS_NAME_PATTERN.match(path.name)) + + def _images_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: + path = pathlib.Path(data[0]) + + category = path.parent.name + id = self._IMAGES_NAME_PATTERN.match(path.name).group("id") + + return category, id + + def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: + path = pathlib.Path(data[0]) + + category = path.parent.name + if category in self._ANNS_CATEGORY_MAP: + category = self._ANNS_CATEGORY_MAP[category] + + id = self._ANNS_NAME_PATTERN.match(path.name).group("id") + + return category, id + + def _collate_and_decode_sample( + self, data, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] + ) -> Dict[str, Any]: + key, image_data, ann_data = data + category, _ = key + image_path, image_buffer = image_data + ann_path, ann_buffer = ann_data + + label = self.info.categories.index(category) + + image = decoder(image_buffer) if decoder else image_buffer + + ann = read_mat(ann_buffer) + bbox = torch.as_tensor(ann["box_coord"].astype(np.int64)) + contour = torch.as_tensor(ann["obj_contour"]) + + return dict( + category=category, + label=label, + image=image, + image_path=image_path, + bbox=bbox, + contour=contour, + ann_path=ann_path, + ) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + images_dp, anns_dp = resource_dps + + images_dp = TarArchiveReader(images_dp) + images_dp = Filter(images_dp, self._is_not_background_image) + # FIXME: add this after https://github.com/pytorch/pytorch/issues/65808 is resolved + # images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE) + + anns_dp = TarArchiveReader(anns_dp) + anns_dp = Filter(anns_dp, self._is_ann) + + dp = KeyZipper( + images_dp, + anns_dp, + key_fn=self._images_key_fn, + ref_key_fn=self._anns_key_fn, + buffer_size=INFINITE_BUFFER_SIZE, + keep_key=True, + ) + return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + + def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: + dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) + dp = TarArchiveReader(dp) + dp = Filter(dp, self._is_not_background_image) + dir_names = {pathlib.Path(path).parent.name for path, _ in dp} + create_categories_file(HERE, self.name, sorted(dir_names)) + + +class Caltech256(Dataset): + @property + def info(self) -> DatasetInfo: + return DatasetInfo( + "caltech256", + categories=HERE / "caltech256.categories", + homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256", + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + return [ + HttpResource( + "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", + sha256="08ff01b03c65566014ae88eb0490dbe4419fc7ac4de726ee1163e39fd809543e", + ) + ] + + def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool: + path = pathlib.Path(data[0]) + return path.name != "RENAME2" + + def _collate_and_decode_sample( + self, + data: Tuple[str, io.IOBase], + *, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> Dict[str, Any]: + path, buffer = data + + dir_name = pathlib.Path(path).parent.name + label_str, category = dir_name.split(".") + label = torch.tensor(int(label_str)) + + return dict(label=label, category=category, image=decoder(buffer) if decoder else buffer) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + dp = resource_dps[0] + dp = TarArchiveReader(dp) + dp = Filter(dp, self._is_not_rogue_file) + # FIXME: add this after https://github.com/pytorch/pytorch/issues/65808 is resolved + # dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) + return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) + + def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: + dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) + dp = TarArchiveReader(dp) + dir_names = {pathlib.Path(path).parent.name for path, _ in dp} + categories = [name.split(".")[1] for name in sorted(dir_names)] + create_categories_file(HERE, self.name, categories) + + +if __name__ == "__main__": + from torchvision.prototype.datasets import home + + root = home() + Caltech101().generate_categories_file(root) + Caltech256().generate_categories_file(root) diff --git a/torchvision/prototype/datasets/_builtin/caltech101.categories b/torchvision/prototype/datasets/_builtin/caltech101.categories new file mode 100644 index 00000000000..d5c18654b4e --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/caltech101.categories @@ -0,0 +1,101 @@ +Faces +Faces_easy +Leopards +Motorbikes +accordion +airplanes +anchor +ant +barrel +bass +beaver +binocular +bonsai +brain +brontosaurus +buddha +butterfly +camera +cannon +car_side +ceiling_fan +cellphone +chair +chandelier +cougar_body +cougar_face +crab +crayfish +crocodile +crocodile_head +cup +dalmatian +dollar_bill +dolphin +dragonfly +electric_guitar +elephant +emu +euphonium +ewer +ferry +flamingo +flamingo_head +garfield +gerenuk +gramophone +grand_piano +hawksbill +headphone +hedgehog +helicopter +ibis +inline_skate +joshua_tree +kangaroo +ketch +lamp +laptop +llama +lobster +lotus +mandolin +mayfly +menorah +metronome +minaret +nautilus +octopus +okapi +pagoda +panda +pigeon +pizza +platypus +pyramid +revolver +rhino +rooster +saxophone +schooner +scissors +scorpion +sea_horse +snoopy +soccer_ball +stapler +starfish +stegosaurus +stop_sign +strawberry +sunflower +tick +trilobite +umbrella +watch +water_lilly +wheelchair +wild_cat +windsor_chair +wrench +yin_yang diff --git a/torchvision/prototype/datasets/_builtin/caltech256.categories b/torchvision/prototype/datasets/_builtin/caltech256.categories new file mode 100644 index 00000000000..82128efba97 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/caltech256.categories @@ -0,0 +1,257 @@ +ak47 +american-flag +backpack +baseball-bat +baseball-glove +basketball-hoop +bat +bathtub +bear +beer-mug +billiards +binoculars +birdbath +blimp +bonsai-101 +boom-box +bowling-ball +bowling-pin +boxing-glove +brain-101 +breadmaker +buddha-101 +bulldozer +butterfly +cactus +cake +calculator +camel +cannon +canoe +car-tire +cartman +cd +centipede +cereal-box +chandelier-101 +chess-board +chimp +chopsticks +cockroach +coffee-mug +coffin +coin +comet +computer-keyboard +computer-monitor +computer-mouse +conch +cormorant +covered-wagon +cowboy-hat +crab-101 +desk-globe +diamond-ring +dice +dog +dolphin-101 +doorknob +drinking-straw +duck +dumb-bell +eiffel-tower +electric-guitar-101 +elephant-101 +elk +ewer-101 +eyeglasses +fern +fighter-jet +fire-extinguisher +fire-hydrant +fire-truck +fireworks +flashlight +floppy-disk +football-helmet +french-horn +fried-egg +frisbee +frog +frying-pan +galaxy +gas-pump +giraffe +goat +golden-gate-bridge +goldfish +golf-ball +goose +gorilla +grand-piano-101 +grapes +grasshopper +guitar-pick +hamburger +hammock +harmonica +harp +harpsichord +hawksbill-101 +head-phones +helicopter-101 +hibiscus +homer-simpson +horse +horseshoe-crab +hot-air-balloon +hot-dog +hot-tub +hourglass +house-fly +human-skeleton +hummingbird +ibis-101 +ice-cream-cone +iguana +ipod +iris +jesus-christ +joy-stick +kangaroo-101 +kayak +ketch-101 +killer-whale +knife +ladder +laptop-101 +lathe +leopards-101 +license-plate +lightbulb +light-house +lightning +llama-101 +mailbox +mandolin +mars +mattress +megaphone +menorah-101 +microscope +microwave +minaret +minotaur +motorbikes-101 +mountain-bike +mushroom +mussels +necktie +octopus +ostrich +owl +palm-pilot +palm-tree +paperclip +paper-shredder +pci-card +penguin +people +pez-dispenser +photocopier +picnic-table +playing-card +porcupine +pram +praying-mantis +pyramid +raccoon +radio-telescope +rainbow +refrigerator +revolver-101 +rifle +rotary-phone +roulette-wheel +saddle +saturn +school-bus +scorpion-101 +screwdriver +segway +self-propelled-lawn-mower +sextant +sheet-music +skateboard +skunk +skyscraper +smokestack +snail +snake +sneaker +snowmobile +soccer-ball +socks +soda-can +spaghetti +speed-boat +spider +spoon +stained-glass +starfish-101 +steering-wheel +stirrups +sunflower-101 +superman +sushi +swan +swiss-army-knife +sword +syringe +tambourine +teapot +teddy-bear +teepee +telephone-box +tennis-ball +tennis-court +tennis-racket +theodolite +toaster +tomato +tombstone +top-hat +touring-bike +tower-pisa +traffic-light +treadmill +triceratops +tricycle +trilobite-101 +tripod +t-shirt +tuning-fork +tweezer +umbrella-101 +unicorn +vcr +video-projector +washing-machine +watch-101 +waterfall +watermelon +welding-mask +wheelbarrow +windmill +wine-bottle +xylophone +yarmulke +yo-yo +zebra +airplanes-101 +car-side-101 +faces-easy-101 +greyhound +tennis-shoes +toad +clutter diff --git a/torchvision/prototype/datasets/decoder.py b/torchvision/prototype/datasets/decoder.py index 4c10cff1035..d4897bebc91 100644 --- a/torchvision/prototype/datasets/decoder.py +++ b/torchvision/prototype/datasets/decoder.py @@ -8,5 +8,5 @@ __all__ = ["pil"] -def pil(file: io.IOBase, mode: str = "RGB") -> torch.Tensor: - return pil_to_tensor(PIL.Image.open(file).convert(mode.upper())) +def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor: + return pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper())) diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index b3cf53afc8d..19fb3b1d596 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -112,7 +112,7 @@ def __init__( categories = [str(label) for label in range(categories)] elif isinstance(categories, (str, pathlib.Path)): with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh: - categories = fh.readlines() + categories = [line.strip() for line in fh] self.categories = categories self.citation = citation diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index ad4f70145d5..56c9a2d8c07 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -1,12 +1,16 @@ import collections.abc import difflib -from typing import Collection, Sequence, Callable +import io +import pathlib +from typing import Collection, Sequence, Callable, Union, Any __all__ = [ "INFINITE_BUFFER_SIZE", "sequence_to_str", "add_suggestion", + "create_categories_file", + "read_mat" ] # pseudo-infinite until a true infinite buffer is supported by all datapipes @@ -44,3 +48,21 @@ def add_suggestion( else alternative_hint(possibilities) ) return f"{msg.strip()} {hint}" + + +def create_categories_file( + root: Union[str, pathlib.Path], name: str, categories: Sequence[str] +) -> None: + with open(pathlib.Path(root) / f"{name}.categories", "w") as fh: + fh.write("\n".join(categories) + "\n") + + +def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: + try: + import scipy.io as sio + except ImportError as error: + raise ModuleNotFoundError( + "Package `scipy` is required to be installed to read .mat files." + ) from error + + return sio.loadmat(buffer, **kwargs) From cd32de60ec362363a547c3942328dd756ecd8595 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 30 Sep 2021 14:16:56 +0200 Subject: [PATCH 2/2] silence mypy --- torchvision/prototype/datasets/_builtin/caltech.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 3d747e88fef..7f6021522c8 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -68,7 +68,7 @@ def _images_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: path = pathlib.Path(data[0]) category = path.parent.name - id = self._IMAGES_NAME_PATTERN.match(path.name).group("id") + id = self._IMAGES_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr] return category, id @@ -79,7 +79,7 @@ def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: if category in self._ANNS_CATEGORY_MAP: category = self._ANNS_CATEGORY_MAP[category] - id = self._ANNS_NAME_PATTERN.match(path.name).group("id") + id = self._ANNS_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr] return category, id