-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Use public API for loading in prototype datasets tests #5212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
97ade0f
aeabd33
2c111aa
e4957de
38aee72
cbf128b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| import pickle | ||
| import random | ||
| import tempfile | ||
| import unittest.mock | ||
| import xml.etree.ElementTree as ET | ||
| from collections import defaultdict, Counter, UserDict | ||
|
|
||
|
|
@@ -21,7 +22,8 @@ | |
| from torch.nn.functional import one_hot | ||
| from torch.testing import make_tensor as _make_tensor | ||
| from torchvision.prototype import datasets | ||
| from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER, find | ||
| from torchvision.prototype.datasets._api import find | ||
| from torchvision.prototype.utils._internal import sequence_to_str | ||
|
|
||
| make_tensor = functools.partial(_make_tensor, device="cpu") | ||
| make_scalar = functools.partial(make_tensor, ()) | ||
|
|
@@ -49,7 +51,7 @@ class DatasetMock: | |
| def __init__(self, name, mock_data_fn, *, configs=None): | ||
| self.dataset = find(name) | ||
| self.root = TEST_HOME / self.dataset.name | ||
| self.mock_data_fn = self._parse_mock_data(mock_data_fn) | ||
| self.mock_data_fn = mock_data_fn | ||
| self.configs = configs or self.info._configs | ||
| self._cache = {} | ||
|
|
||
|
|
@@ -61,77 +63,71 @@ def info(self): | |
| def name(self): | ||
| return self.info.name | ||
|
|
||
| def _parse_mock_data(self, mock_data_fn): | ||
| def wrapper(info, root, config): | ||
| mock_infos = mock_data_fn(info, root, config) | ||
| def _parse_mock_data(self, config, mock_infos): | ||
| if mock_infos is None: | ||
| raise pytest.UsageError( | ||
| f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an " | ||
| f"integer indicating the number of samples for the current `config`." | ||
| ) | ||
|
|
||
| key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {} | ||
| if datasets.utils.DatasetConfig not in key_types: | ||
| mock_infos = {config: mock_infos} | ||
| elif len(key_types) > 1: | ||
| raise pytest.UsageError( | ||
| f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If " | ||
| f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type." | ||
| ) | ||
|
|
||
| if mock_infos is None: | ||
| for config_, mock_info in list(mock_infos.items()): | ||
| if config_ in self._cache: | ||
| raise pytest.UsageError( | ||
| f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an " | ||
| f"integer indicating the number of samples for the current `config`." | ||
| f"The mock info for config {config_} of dataset {self.name} generated for config {config} " | ||
| f"already exists in the cache." | ||
| ) | ||
|
|
||
| key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {} | ||
| if datasets.utils.DatasetConfig not in key_types: | ||
| mock_infos = {config: mock_infos} | ||
| elif len(key_types) > 1: | ||
| if isinstance(mock_info, int): | ||
| mock_infos[config_] = dict(num_samples=mock_info) | ||
| elif not isinstance(mock_info, dict): | ||
| raise pytest.UsageError( | ||
| f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If " | ||
| f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type." | ||
| f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` " | ||
| f"{config_}. The returned object should be a dictionary containing at least the number of " | ||
| f"samples for the key `'num_samples'`. If no additional information is required for specific " | ||
| f"tests, the number of samples can also be returned as an integer." | ||
| ) | ||
| elif "num_samples" not in mock_info: | ||
| raise pytest.UsageError( | ||
| f"The dictionary returned by the mock data function for dataset '{self.name}' and config " | ||
| f"{config_} has to contain a `'num_samples'` entry indicating the number of samples." | ||
| ) | ||
|
|
||
| for config_, mock_info in list(mock_infos.items()): | ||
| if config_ in self._cache: | ||
| raise pytest.UsageError( | ||
| f"The mock info for config {config_} of dataset {self.name} generated for config {config} " | ||
| f"already exists in the cache." | ||
| ) | ||
| if isinstance(mock_info, int): | ||
| mock_infos[config_] = dict(num_samples=mock_info) | ||
| elif not isinstance(mock_info, dict): | ||
| raise pytest.UsageError( | ||
| f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` " | ||
| f"{config_}. The returned object should be a dictionary containing at least the number of " | ||
| f"samples for the key `'num_samples'`. If no additional information is required for specific " | ||
| f"tests, the number of samples can also be returned as an integer." | ||
| ) | ||
| elif "num_samples" not in mock_info: | ||
| raise pytest.UsageError( | ||
| f"The dictionary returned by the mock data function for dataset '{self.name}' and config " | ||
| f"{config_} has to contain a `'num_samples'` entry indicating the number of samples." | ||
| ) | ||
|
|
||
| return mock_infos | ||
|
|
||
| return wrapper | ||
| return mock_infos | ||
|
|
||
| def _load_mock(self, config): | ||
| def _prepare_resources(self, config): | ||
| with contextlib.suppress(KeyError): | ||
| return self._cache[config] | ||
|
|
||
| self.root.mkdir(exist_ok=True) | ||
| for config_, mock_info in self.mock_data_fn(self.info, self.root, config).items(): | ||
| mock_resources = [ | ||
| ResourceMock(dataset_name=self.name, dataset_config=config_, file_name=resource.file_name) | ||
| for resource in self.dataset.resources(config_) | ||
| ] | ||
| self._cache[config_] = (mock_resources, mock_info) | ||
| mock_infos = self._parse_mock_data(config, self.mock_data_fn(self.info, self.root, config)) | ||
|
|
||
| available_file_names = {path.name for path in self.root.glob("*")} | ||
| for config_, mock_info in mock_infos.items(): | ||
| required_file_names = {resource.file_name for resource in self.dataset.resources(config_)} | ||
| missing_file_names = required_file_names - available_file_names | ||
| if missing_file_names: | ||
| raise pytest.UsageError( | ||
| f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " | ||
| f"for {config_}, but they were not created by the mock data function." | ||
| ) | ||
|
|
||
| self._cache[config_] = mock_info | ||
|
|
||
| return self._cache[config] | ||
|
|
||
| def load(self, config, *, decoder=DEFAULT_DECODER): | ||
| try: | ||
| self.info.check_dependencies() | ||
| except ModuleNotFoundError as error: | ||
| pytest.skip(str(error)) | ||
|
Comment on lines
-123
to
-126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should have not been removed. The tests should be runnable even if no or not all third party dependencies are installed. |
||
|
|
||
| mock_resources, mock_info = self._load_mock(config) | ||
| datapipe = self.dataset._make_datapipe( | ||
| [resource.load(self.root) for resource in mock_resources], | ||
| config=config, | ||
| decoder=DEFAULT_DECODER_MAP.get(self.info.type) if decoder is DEFAULT_DECODER else decoder, | ||
| ) | ||
| return datapipe, mock_info | ||
| @contextlib.contextmanager | ||
| def prepare(self, config): | ||
| mock_info = self._prepare_resources(config) | ||
| with unittest.mock.patch("torchvision.prototype.datasets._api.home", return_value=str(TEST_HOME)): | ||
| yield mock_info | ||
|
|
||
|
|
||
| def config_id(name, config): | ||
|
|
@@ -1000,7 +996,7 @@ def dtd(info, root, _): | |
| def fer2013(info, root, config): | ||
| num_samples = 5 if config.split == "train" else 3 | ||
|
|
||
| path = root / f"{config.split}.txt" | ||
| path = root / f"{config.split}.csv" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change and everything below in this file are actual bugs in our mock data generation that were hidden by our custom loading logic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know why the tests were passing before despite the files had the wrong name? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before the resources where not collected by the regular logic, but rather "hand-fed" to |
||
| with open(path, "w", newline="") as file: | ||
| field_names = ["emotion"] if config.split == "train" else [] | ||
| field_names.append("pixels") | ||
|
|
@@ -1061,7 +1057,7 @@ def clevr(info, root, config): | |
| file, | ||
| ) | ||
|
|
||
| make_zip(root, f"{data_folder.name}.zip") | ||
| make_zip(root, f"{data_folder.name}.zip", data_folder) | ||
|
|
||
| return {config_: num_samples_map[config_.split] for config_ in info._configs} | ||
|
|
||
|
|
@@ -1121,8 +1117,8 @@ def generate(self, root): | |
| for path in segmentation_files: | ||
| path.with_name(f".{path.name}").touch() | ||
|
|
||
| make_tar(root, "images.tar") | ||
| make_tar(root, anns_folder.with_suffix(".tar").name) | ||
| make_tar(root, "images.tar.gz", compression="gz") | ||
| make_tar(root, anns_folder.with_suffix(".tar.gz").name, compression="gz") | ||
|
|
||
| return num_samples_map | ||
|
|
||
|
|
@@ -1211,7 +1207,7 @@ def _make_segmentations(cls, root, image_files): | |
| size=[1, *make_tensor((2,), low=3, dtype=torch.int).tolist()], | ||
| ) | ||
|
|
||
| make_tar(root, segmentations_folder.with_suffix(".tgz").name) | ||
| make_tar(root, segmentations_folder.with_suffix(".tgz").name, compression="gz") | ||
|
|
||
| @classmethod | ||
| def generate(cls, root): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just GitHub not picking up on the identation change. Before this was a decorator, but I changed it into a regular method. The actual body has not changed.