Skip to content

Commit

Permalink
add Imagenette dataset (#8139)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Dec 4, 2023
1 parent 3feb502 commit 30397d9
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Image classification
GTSRB
INaturalist
ImageNet
Imagenette
KMNIST
LFWPeople
LSUN
Expand Down
35 changes: 35 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3377,6 +3377,41 @@ def test_bad_input(self):
pass


class ImagenetteTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Imagenette
ADDITIONAL_CONFIGS = combinations_grid(split=["train", "val"], size=["full", "320px", "160px"])

_WNIDS = [
"n01440764",
"n02102040",
"n02979186",
"n03000684",
"n03028079",
"n03394916",
"n03417042",
"n03425413",
"n03445777",
"n03888257",
]

def inject_fake_data(self, tmpdir, config):
archive_root = "imagenette2"
if config["size"] != "full":
archive_root += f"-{config['size'].replace('px', '')}"
image_root = pathlib.Path(tmpdir) / archive_root / config["split"]

num_images_per_class = 3
for wnid in self._WNIDS:
datasets_utils.create_image_folder(
root=image_root,
name=wnid,
file_name_fn=lambda idx: f"{wnid}_{idx}.JPEG",
num_examples=num_images_per_class,
)

return num_images_per_class * len(self._WNIDS)


class TestDatasetWrapper:
def test_unknown_type(self):
unknown_object = object()
Expand Down
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .gtsrb import GTSRB
from .hmdb51 import HMDB51
from .imagenet import ImageNet
from .imagenette import Imagenette
from .inaturalist import INaturalist
from .kinetics import Kinetics
from .kitti import Kitti
Expand Down Expand Up @@ -128,6 +129,7 @@
"InStereo2k",
"ETH3DStereo",
"wrap_dataset_for_transforms_v2",
"Imagenette",
)


Expand Down
104 changes: 104 additions & 0 deletions torchvision/datasets/imagenette.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from pathlib import Path
from typing import Any, Callable, Optional, Tuple

from PIL import Image

from .folder import find_classes, make_dataset
from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset


class Imagenette(VisionDataset):
"""`Imagenette <https://github.com/fastai/imagenette#imagenette-1>`_ image classification dataset.
Args:
root (string): Root directory of the Imagenette dataset.
split (string, optional): The dataset split. Supports ``"train"`` (default), and ``"val"``.
size (string, optional): The image size. Supports ``"full"`` (default), ``"320px"``, and ``"160px"``.
download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
downloaded archives are not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version, e.g. ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
Attributes:
classes (list): List of the class name tuples.
class_to_idx (dict): Dict with items (class name, class index).
wnids (list): List of the WordNet IDs.
wnid_to_idx (dict): Dict with items (WordNet ID, class index).
"""

_ARCHIVES = {
"full": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz", "fe2fc210e6bb7c5664d602c3cd71e612"),
"320px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz", "3df6f0d01a2c9592104656642f5e78a3"),
"160px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz", "e793b78cc4c9e9a4ccc0c1155377a412"),
}
_WNID_TO_CLASS = {
"n01440764": ("tench", "Tinca tinca"),
"n02102040": ("English springer", "English springer spaniel"),
"n02979186": ("cassette player",),
"n03000684": ("chain saw", "chainsaw"),
"n03028079": ("church", "church building"),
"n03394916": ("French horn", "horn"),
"n03417042": ("garbage truck", "dustcart"),
"n03425413": ("gas pump", "gasoline pump", "petrol pump", "island dispenser"),
"n03445777": ("golf ball",),
"n03888257": ("parachute", "chute"),
}

def __init__(
self,
root: str,
split: str = "train",
size: str = "full",
download=False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)

self._split = verify_str_arg(split, "split", ["train", "val"])
self._size = verify_str_arg(size, "size", ["full", "320px", "160px"])

self._url, self._md5 = self._ARCHIVES[self._size]
self._size_root = Path(self.root) / Path(self._url).stem
self._image_root = str(self._size_root / self._split)

if download:
self._download()
elif not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it.")

self.wnids, self.wnid_to_idx = find_classes(self._image_root)
self.classes = [self._WNID_TO_CLASS[wnid] for wnid in self.wnids]
self.class_to_idx = {
class_name: idx for wnid, idx in self.wnid_to_idx.items() for class_name in self._WNID_TO_CLASS[wnid]
}
self._samples = make_dataset(self._image_root, self.wnid_to_idx, extensions=".jpeg")

def _check_exists(self) -> bool:
return self._size_root.exists()

def _download(self):
if self._check_exists():
raise RuntimeError(
f"The directory {self._size_root} already exists. "
f"If you want to re-download or re-extract the images, delete the directory."
)

download_and_extract_archive(self._url, self.root, md5=self._md5)

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
path, label = self._samples[idx]
image = Image.open(path).convert("RGB")

if self.transform is not None:
image = self.transform(image)

if self.target_transform is not None:
label = self.target_transform(label)

return image, label

def __len__(self) -> int:
return len(self._samples)
1 change: 1 addition & 0 deletions torchvision/tv_tensors/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def classification_wrapper_factory(dataset, target_keys):
datasets.GTSRB,
datasets.DatasetFolder,
datasets.ImageFolder,
datasets.Imagenette,
]:
WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory)

Expand Down

0 comments on commit 30397d9

Please sign in to comment.