diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index fc980326307..ae33f396694 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -10,6 +10,7 @@ import pickle import random import tempfile +import unittest.mock import xml.etree.ElementTree as ET from collections import defaultdict, Counter, UserDict @@ -21,7 +22,8 @@ from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor from torchvision.prototype import datasets -from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER, find +from torchvision.prototype.datasets._api import find +from torchvision.prototype.utils._internal import sequence_to_str make_tensor = functools.partial(_make_tensor, device="cpu") make_scalar = functools.partial(make_tensor, ()) @@ -49,7 +51,7 @@ class DatasetMock: def __init__(self, name, mock_data_fn, *, configs=None): self.dataset = find(name) self.root = TEST_HOME / self.dataset.name - self.mock_data_fn = self._parse_mock_data(mock_data_fn) + self.mock_data_fn = mock_data_fn self.configs = configs or self.info._configs self._cache = {} @@ -61,77 +63,71 @@ def info(self): def name(self): return self.info.name - def _parse_mock_data(self, mock_data_fn): - def wrapper(info, root, config): - mock_infos = mock_data_fn(info, root, config) + def _parse_mock_data(self, config, mock_infos): + if mock_infos is None: + raise pytest.UsageError( + f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an " + f"integer indicating the number of samples for the current `config`." + ) + + key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {} + if datasets.utils.DatasetConfig not in key_types: + mock_infos = {config: mock_infos} + elif len(key_types) > 1: + raise pytest.UsageError( + f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If " + f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type." + ) - if mock_infos is None: + for config_, mock_info in list(mock_infos.items()): + if config_ in self._cache: raise pytest.UsageError( - f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an " - f"integer indicating the number of samples for the current `config`." + f"The mock info for config {config_} of dataset {self.name} generated for config {config} " + f"already exists in the cache." ) - - key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {} - if datasets.utils.DatasetConfig not in key_types: - mock_infos = {config: mock_infos} - elif len(key_types) > 1: + if isinstance(mock_info, int): + mock_infos[config_] = dict(num_samples=mock_info) + elif not isinstance(mock_info, dict): raise pytest.UsageError( - f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If " - f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type." + f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` " + f"{config_}. The returned object should be a dictionary containing at least the number of " + f"samples for the key `'num_samples'`. If no additional information is required for specific " + f"tests, the number of samples can also be returned as an integer." + ) + elif "num_samples" not in mock_info: + raise pytest.UsageError( + f"The dictionary returned by the mock data function for dataset '{self.name}' and config " + f"{config_} has to contain a `'num_samples'` entry indicating the number of samples." ) - for config_, mock_info in list(mock_infos.items()): - if config_ in self._cache: - raise pytest.UsageError( - f"The mock info for config {config_} of dataset {self.name} generated for config {config} " - f"already exists in the cache." - ) - if isinstance(mock_info, int): - mock_infos[config_] = dict(num_samples=mock_info) - elif not isinstance(mock_info, dict): - raise pytest.UsageError( - f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` " - f"{config_}. The returned object should be a dictionary containing at least the number of " - f"samples for the key `'num_samples'`. If no additional information is required for specific " - f"tests, the number of samples can also be returned as an integer." - ) - elif "num_samples" not in mock_info: - raise pytest.UsageError( - f"The dictionary returned by the mock data function for dataset '{self.name}' and config " - f"{config_} has to contain a `'num_samples'` entry indicating the number of samples." - ) - - return mock_infos - - return wrapper + return mock_infos - def _load_mock(self, config): + def _prepare_resources(self, config): with contextlib.suppress(KeyError): return self._cache[config] self.root.mkdir(exist_ok=True) - for config_, mock_info in self.mock_data_fn(self.info, self.root, config).items(): - mock_resources = [ - ResourceMock(dataset_name=self.name, dataset_config=config_, file_name=resource.file_name) - for resource in self.dataset.resources(config_) - ] - self._cache[config_] = (mock_resources, mock_info) + mock_infos = self._parse_mock_data(config, self.mock_data_fn(self.info, self.root, config)) + + available_file_names = {path.name for path in self.root.glob("*")} + for config_, mock_info in mock_infos.items(): + required_file_names = {resource.file_name for resource in self.dataset.resources(config_)} + missing_file_names = required_file_names - available_file_names + if missing_file_names: + raise pytest.UsageError( + f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " + f"for {config_}, but they were not created by the mock data function." + ) + + self._cache[config_] = mock_info return self._cache[config] - def load(self, config, *, decoder=DEFAULT_DECODER): - try: - self.info.check_dependencies() - except ModuleNotFoundError as error: - pytest.skip(str(error)) - - mock_resources, mock_info = self._load_mock(config) - datapipe = self.dataset._make_datapipe( - [resource.load(self.root) for resource in mock_resources], - config=config, - decoder=DEFAULT_DECODER_MAP.get(self.info.type) if decoder is DEFAULT_DECODER else decoder, - ) - return datapipe, mock_info + @contextlib.contextmanager + def prepare(self, config): + mock_info = self._prepare_resources(config) + with unittest.mock.patch("torchvision.prototype.datasets._api.home", return_value=str(TEST_HOME)): + yield mock_info def config_id(name, config): @@ -1000,7 +996,7 @@ def dtd(info, root, _): def fer2013(info, root, config): num_samples = 5 if config.split == "train" else 3 - path = root / f"{config.split}.txt" + path = root / f"{config.split}.csv" with open(path, "w", newline="") as file: field_names = ["emotion"] if config.split == "train" else [] field_names.append("pixels") @@ -1061,7 +1057,7 @@ def clevr(info, root, config): file, ) - make_zip(root, f"{data_folder.name}.zip") + make_zip(root, f"{data_folder.name}.zip", data_folder) return {config_: num_samples_map[config_.split] for config_ in info._configs} @@ -1121,8 +1117,8 @@ def generate(self, root): for path in segmentation_files: path.with_name(f".{path.name}").touch() - make_tar(root, "images.tar") - make_tar(root, anns_folder.with_suffix(".tar").name) + make_tar(root, "images.tar.gz", compression="gz") + make_tar(root, anns_folder.with_suffix(".tar.gz").name, compression="gz") return num_samples_map @@ -1211,7 +1207,7 @@ def _make_segmentations(cls, root, image_files): size=[1, *make_tensor((2,), low=3, dtype=torch.int).tolist()], ) - make_tar(root, segmentations_folder.with_suffix(".tgz").name) + make_tar(root, segmentations_folder.with_suffix(".tgz").name, compression="gz") @classmethod def generate(cls, root): diff --git a/test/datasets_utils.py b/test/datasets_utils.py index b87d50ca3db..5cb43680cda 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -868,9 +868,13 @@ def _split_files_or_dirs(root, *files_or_dirs): def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True): archive = pathlib.Path(root) / name if not files_or_dirs: - dir = archive.with_suffix("") - if dir.exists() and dir.is_dir(): - files_or_dirs = (dir,) + # We need to invoke `Path.with_suffix("")`, since call only applies to the last suffix if multiple suffixes are + # present. For example, `pathlib.Path("foo.tar.gz").with_suffix("")` results in `foo.tar`. + file_or_dir = archive + for _ in range(len(archive.suffixes)): + file_or_dir = file_or_dir.with_suffix("") + if file_or_dir.exists(): + files_or_dirs = (file_or_dir,) else: raise ValueError("No file or dir provided.") diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index bebeaccaadd..5697e5ce224 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -23,13 +23,16 @@ def test_coverage(): class TestCommon: @parametrize_dataset_mocks(DATASET_MOCKS) def test_smoke(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + with dataset_mock.prepare(config): + dataset = datasets.load(dataset_mock.name, **config) + if not isinstance(dataset, IterDataPipe): raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.") @parametrize_dataset_mocks(DATASET_MOCKS) def test_sample(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + with dataset_mock.prepare(config): + dataset = datasets.load(dataset_mock.name, **config) try: sample = next(iter(dataset)) @@ -44,7 +47,8 @@ def test_sample(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS) def test_num_samples(self, dataset_mock, config): - dataset, mock_info = dataset_mock.load(config) + with dataset_mock.prepare(config) as mock_info: + dataset = datasets.load(dataset_mock.name, **config) num_samples = 0 for _ in dataset: @@ -54,7 +58,8 @@ def test_num_samples(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS) def test_decoding(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + with dataset_mock.prepare(config): + dataset = datasets.load(dataset_mock.name, **config) undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)} if undecoded_features: @@ -65,7 +70,8 @@ def test_decoding(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS) def test_no_vanilla_tensors(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + with dataset_mock.prepare(config): + dataset = datasets.load(dataset_mock.name, **config) vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} if vanilla_tensors: @@ -76,7 +82,8 @@ def test_no_vanilla_tensors(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS) def test_transformable(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + with dataset_mock.prepare(config): + dataset = datasets.load(dataset_mock.name, **config) next(iter(dataset.map(transforms.Identity()))) @@ -89,7 +96,8 @@ def test_transformable(self, dataset_mock, config): }, ) def test_traversable(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + with dataset_mock.prepare(config): + dataset = datasets.load(dataset_mock.name, **config) traverse(dataset) @@ -108,7 +116,8 @@ def scan(graph): yield node yield from scan(sub_graph) - dataset, _ = dataset_mock.load(config) + with dataset_mock.prepare(config): + dataset = datasets.load(dataset_mock.name, **config) for dp in scan(traverse(dataset)): if type(dp) is annotation_dp_type: @@ -120,7 +129,8 @@ def scan(graph): @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) class TestQMNIST: def test_extra_label(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + with dataset_mock.prepare(config): + dataset = datasets.load(dataset_mock.name, **config) sample = next(iter(dataset)) for key, type in (