Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for files with periods in name #4099

Merged
merged 2 commits into from Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 3 additions & 8 deletions test/test_datasets_utils.py
Expand Up @@ -63,6 +63,9 @@ def test_detect_file_type(self):
("foo.gz", (".gz", None, ".gz")),
("foo.zip", (".zip", ".zip", None)),
("foo.xz", (".xz", None, ".xz")),
("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.bar.gz", (".gz", None, ".gz")),
("foo.bar.zip", (".zip", ".zip", None)),
]:
with self.subTest(file=file):
self.assertSequenceEqual(utils._detect_file_type(file), expected)
Expand All @@ -71,14 +74,6 @@ def test_detect_file_type_no_ext(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo")

def test_detect_file_type_to_many_exts(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar.tar.gz")

def test_detect_file_type_unknown_archive_type(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar.gz")

def test_detect_file_type_unknown_compression(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.tar.baz")
Expand Down
52 changes: 23 additions & 29 deletions torchvision/datasets/utils.py
Expand Up @@ -291,53 +291,47 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No
}


def _verify_archive_type(archive_type: str) -> None:
if archive_type not in _ARCHIVE_EXTRACTORS.keys():
valid_types = "', '".join(_ARCHIVE_EXTRACTORS.keys())
raise RuntimeError(f"Unknown archive type '{archive_type}'. Known archive types are '{valid_types}'.")

def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
"""Detect the archive type and/or compression of a file.

def _verify_compression(compression: str) -> None:
if compression not in _COMPRESSED_FILE_OPENERS.keys():
valid_types = "', '".join(_COMPRESSED_FILE_OPENERS.keys())
raise RuntimeError(f"Unknown compression '{compression}'. Known compressions are '{valid_types}'.")
Args:
file (str): the filename

Returns:
(tuple): tuple of suffix, archive type, and compression

def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
path = pathlib.Path(file)
suffix = path.suffix
Raises:
RuntimeError: if file has no suffix or suffix is not supported
"""
suffixes = pathlib.Path(file).suffixes
if not suffixes:
raise RuntimeError(
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
)
elif len(suffixes) > 2:
raise RuntimeError(
"Archive type and compression detection only works for 1 or 2 suffixes. " f"Got {len(suffixes)} instead."
)
elif len(suffixes) == 2:
# if we have exactly two suffixes we assume the first one is the archive type and the second on is the
# compression
archive_type, compression = suffixes
_verify_archive_type(archive_type)
_verify_compression(compression)
return "".join(suffixes), archive_type, compression
suffix = suffixes[-1]

# check if the suffix is a known alias
with contextlib.suppress(KeyError):
if suffix in _FILE_TYPE_ALIASES:
return (suffix, *_FILE_TYPE_ALIASES[suffix])

# check if the suffix is an archive type
with contextlib.suppress(RuntimeError):
_verify_archive_type(suffix)
if suffix in _ARCHIVE_EXTRACTORS:
return suffix, suffix, None

# check if the suffix is a compression
with contextlib.suppress(RuntimeError):
_verify_compression(suffix)
if suffix in _COMPRESSED_FILE_OPENERS:
# check for suffix hierarchy
if len(suffixes) > 1:
pmeier marked this conversation as resolved.
Show resolved Hide resolved
suffix2 = suffixes[-2]

# check if the suffix2 is an archive type
if suffix2 in _ARCHIVE_EXTRACTORS:
return suffix2 + suffix, suffix2, suffix

return suffix, None, suffix

raise RuntimeError(f"Suffix '{suffix}' is neither recognized as archive type nor as compression.")
valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")


def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
Expand Down