Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 60 additions & 64 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pickle
import random
import tempfile
import unittest.mock
import xml.etree.ElementTree as ET
from collections import defaultdict, Counter, UserDict

Expand All @@ -21,7 +22,8 @@
from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER, find
from torchvision.prototype.datasets._api import find
from torchvision.prototype.utils._internal import sequence_to_str

make_tensor = functools.partial(_make_tensor, device="cpu")
make_scalar = functools.partial(make_tensor, ())
Expand Down Expand Up @@ -49,7 +51,7 @@ class DatasetMock:
def __init__(self, name, mock_data_fn, *, configs=None):
self.dataset = find(name)
self.root = TEST_HOME / self.dataset.name
self.mock_data_fn = self._parse_mock_data(mock_data_fn)
self.mock_data_fn = mock_data_fn
self.configs = configs or self.info._configs
self._cache = {}

Expand All @@ -61,77 +63,71 @@ def info(self):
def name(self):
return self.info.name

def _parse_mock_data(self, mock_data_fn):
def wrapper(info, root, config):
mock_infos = mock_data_fn(info, root, config)
def _parse_mock_data(self, config, mock_infos):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just GitHub not picking up on the identation change. Before this was a decorator, but I changed it into a regular method. The actual body has not changed.

if mock_infos is None:
raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an "
f"integer indicating the number of samples for the current `config`."
)

key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {}
if datasets.utils.DatasetConfig not in key_types:
mock_infos = {config: mock_infos}
elif len(key_types) > 1:
raise pytest.UsageError(
f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If "
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
)

if mock_infos is None:
for config_, mock_info in list(mock_infos.items()):
if config_ in self._cache:
raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an "
f"integer indicating the number of samples for the current `config`."
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
f"already exists in the cache."
)

key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {}
if datasets.utils.DatasetConfig not in key_types:
mock_infos = {config: mock_infos}
elif len(key_types) > 1:
if isinstance(mock_info, int):
mock_infos[config_] = dict(num_samples=mock_info)
elif not isinstance(mock_info, dict):
raise pytest.UsageError(
f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If "
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` "
f"{config_}. The returned object should be a dictionary containing at least the number of "
f"samples for the key `'num_samples'`. If no additional information is required for specific "
f"tests, the number of samples can also be returned as an integer."
)
elif "num_samples" not in mock_info:
raise pytest.UsageError(
f"The dictionary returned by the mock data function for dataset '{self.name}' and config "
f"{config_} has to contain a `'num_samples'` entry indicating the number of samples."
)

for config_, mock_info in list(mock_infos.items()):
if config_ in self._cache:
raise pytest.UsageError(
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
f"already exists in the cache."
)
if isinstance(mock_info, int):
mock_infos[config_] = dict(num_samples=mock_info)
elif not isinstance(mock_info, dict):
raise pytest.UsageError(
f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` "
f"{config_}. The returned object should be a dictionary containing at least the number of "
f"samples for the key `'num_samples'`. If no additional information is required for specific "
f"tests, the number of samples can also be returned as an integer."
)
elif "num_samples" not in mock_info:
raise pytest.UsageError(
f"The dictionary returned by the mock data function for dataset '{self.name}' and config "
f"{config_} has to contain a `'num_samples'` entry indicating the number of samples."
)

return mock_infos

return wrapper
return mock_infos

def _load_mock(self, config):
def _prepare_resources(self, config):
with contextlib.suppress(KeyError):
return self._cache[config]

self.root.mkdir(exist_ok=True)
for config_, mock_info in self.mock_data_fn(self.info, self.root, config).items():
mock_resources = [
ResourceMock(dataset_name=self.name, dataset_config=config_, file_name=resource.file_name)
for resource in self.dataset.resources(config_)
]
self._cache[config_] = (mock_resources, mock_info)
mock_infos = self._parse_mock_data(config, self.mock_data_fn(self.info, self.root, config))

available_file_names = {path.name for path in self.root.glob("*")}
for config_, mock_info in mock_infos.items():
required_file_names = {resource.file_name for resource in self.dataset.resources(config_)}
missing_file_names = required_file_names - available_file_names
if missing_file_names:
raise pytest.UsageError(
f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} "
f"for {config_}, but they were not created by the mock data function."
)

self._cache[config_] = mock_info

return self._cache[config]

def load(self, config, *, decoder=DEFAULT_DECODER):
try:
self.info.check_dependencies()
except ModuleNotFoundError as error:
pytest.skip(str(error))
Comment on lines -123 to -126
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have not been removed. The tests should be runnable even if no or not all third party dependencies are installed.


mock_resources, mock_info = self._load_mock(config)
datapipe = self.dataset._make_datapipe(
[resource.load(self.root) for resource in mock_resources],
config=config,
decoder=DEFAULT_DECODER_MAP.get(self.info.type) if decoder is DEFAULT_DECODER else decoder,
)
return datapipe, mock_info
@contextlib.contextmanager
def prepare(self, config):
mock_info = self._prepare_resources(config)
with unittest.mock.patch("torchvision.prototype.datasets._api.home", return_value=str(TEST_HOME)):
yield mock_info


def config_id(name, config):
Expand Down Expand Up @@ -1000,7 +996,7 @@ def dtd(info, root, _):
def fer2013(info, root, config):
num_samples = 5 if config.split == "train" else 3

path = root / f"{config.split}.txt"
path = root / f"{config.split}.csv"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change and everything below in this file are actual bugs in our mock data generation that were hidden by our custom loading logic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why the tests were passing before despite the files had the wrong name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before the resources where not collected by the regular logic, but rather "hand-fed" to _make_datapipe. So as long as the data in mock_resources[0] corresponded to Dataset.resources(...)[0] the test suite didn't notice.

with open(path, "w", newline="") as file:
field_names = ["emotion"] if config.split == "train" else []
field_names.append("pixels")
Expand Down Expand Up @@ -1061,7 +1057,7 @@ def clevr(info, root, config):
file,
)

make_zip(root, f"{data_folder.name}.zip")
make_zip(root, f"{data_folder.name}.zip", data_folder)

return {config_: num_samples_map[config_.split] for config_ in info._configs}

Expand Down Expand Up @@ -1121,8 +1117,8 @@ def generate(self, root):
for path in segmentation_files:
path.with_name(f".{path.name}").touch()

make_tar(root, "images.tar")
make_tar(root, anns_folder.with_suffix(".tar").name)
make_tar(root, "images.tar.gz", compression="gz")
make_tar(root, anns_folder.with_suffix(".tar.gz").name, compression="gz")

return num_samples_map

Expand Down Expand Up @@ -1211,7 +1207,7 @@ def _make_segmentations(cls, root, image_files):
size=[1, *make_tensor((2,), low=3, dtype=torch.int).tolist()],
)

make_tar(root, segmentations_folder.with_suffix(".tgz").name)
make_tar(root, segmentations_folder.with_suffix(".tgz").name, compression="gz")

@classmethod
def generate(cls, root):
Expand Down
10 changes: 7 additions & 3 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,9 +868,13 @@ def _split_files_or_dirs(root, *files_or_dirs):
def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
archive = pathlib.Path(root) / name
if not files_or_dirs:
dir = archive.with_suffix("")
if dir.exists() and dir.is_dir():
files_or_dirs = (dir,)
# We need to invoke `Path.with_suffix("")`, since call only applies to the last suffix if multiple suffixes are
# present. For example, `pathlib.Path("foo.tar.gz").with_suffix("")` results in `foo.tar`.
file_or_dir = archive
for _ in range(len(archive.suffixes)):
file_or_dir = file_or_dir.with_suffix("")
if file_or_dir.exists():
files_or_dirs = (file_or_dir,)
else:
raise ValueError("No file or dir provided.")

Expand Down
28 changes: 19 additions & 9 deletions test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ def test_coverage():
class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)

if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)

try:
sample = next(iter(dataset))
Expand All @@ -44,7 +47,8 @@ def test_sample(self, dataset_mock, config):

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_num_samples(self, dataset_mock, config):
dataset, mock_info = dataset_mock.load(config)
with dataset_mock.prepare(config) as mock_info:
dataset = datasets.load(dataset_mock.name, **config)

num_samples = 0
for _ in dataset:
Expand All @@ -54,7 +58,8 @@ def test_num_samples(self, dataset_mock, config):

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_decoding(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)

undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
if undecoded_features:
Expand All @@ -65,7 +70,8 @@ def test_decoding(self, dataset_mock, config):

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)

vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
if vanilla_tensors:
Expand All @@ -76,7 +82,8 @@ def test_no_vanilla_tensors(self, dataset_mock, config):

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)

next(iter(dataset.map(transforms.Identity())))

Expand All @@ -89,7 +96,8 @@ def test_transformable(self, dataset_mock, config):
},
)
def test_traversable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)

traverse(dataset)

Expand All @@ -108,7 +116,8 @@ def scan(graph):
yield node
yield from scan(sub_graph)

dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)

for dp in scan(traverse(dataset)):
if type(dp) is annotation_dp_type:
Expand All @@ -120,7 +129,8 @@ def scan(graph):
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
def test_extra_label(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
with dataset_mock.prepare(config):
dataset = datasets.load(dataset_mock.name, **config)

sample = next(iter(dataset))
for key, type in (
Expand Down