Skip to content

Commit

Permalink
Add support for files with periods in name
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Jun 22, 2021
1 parent 183a722 commit c32d5ac
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 37 deletions.
11 changes: 3 additions & 8 deletions test/test_datasets_utils.py
Expand Up @@ -58,6 +58,9 @@ def test_detect_file_type(self):
("foo.gz", (".gz", None, ".gz")),
("foo.zip", (".zip", ".zip", None)),
("foo.xz", (".xz", None, ".xz")),
("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.bar.gz", (".gz", None, ".gz")),
("foo.bar.zip", (".zip", ".zip", None)),
]:
with self.subTest(file=file):
self.assertSequenceEqual(utils._detect_file_type(file), expected)
Expand All @@ -66,14 +69,6 @@ def test_detect_file_type_no_ext(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo")

def test_detect_file_type_to_many_exts(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar.tar.gz")

def test_detect_file_type_unknown_archive_type(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar.gz")

def test_detect_file_type_unknown_compression(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.tar.baz")
Expand Down
52 changes: 23 additions & 29 deletions torchvision/datasets/utils.py
Expand Up @@ -281,53 +281,47 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {".tgz": (".tar", ".gz")}


def _verify_archive_type(archive_type: str) -> None:
if archive_type not in _ARCHIVE_EXTRACTORS.keys():
valid_types = "', '".join(_ARCHIVE_EXTRACTORS.keys())
raise RuntimeError(f"Unknown archive type '{archive_type}'. Known archive types are '{valid_types}'.")

def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
"""Detect the archive type and/or compression of a file.
def _verify_compression(compression: str) -> None:
if compression not in _COMPRESSED_FILE_OPENERS.keys():
valid_types = "', '".join(_COMPRESSED_FILE_OPENERS.keys())
raise RuntimeError(f"Unknown compression '{compression}'. Known compressions are '{valid_types}'.")
Args:
file (str): the filename
Returns:
(tuple): tuple of suffix, archive type, and compression
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
path = pathlib.Path(file)
suffix = path.suffix
Raises:
RuntimeError: if file has no suffix or suffix is not supported
"""
suffixes = pathlib.Path(file).suffixes
if not suffixes:
raise RuntimeError(
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
)
elif len(suffixes) > 2:
raise RuntimeError(
"Archive type and compression detection only works for 1 or 2 suffixes. " f"Got {len(suffixes)} instead."
)
elif len(suffixes) == 2:
# if we have exactly two suffixes we assume the first one is the archive type and the second on is the
# compression
archive_type, compression = suffixes
_verify_archive_type(archive_type)
_verify_compression(compression)
return "".join(suffixes), archive_type, compression
suffix = suffixes[-1]

# check if the suffix is a known alias
with contextlib.suppress(KeyError):
if suffix in _FILE_TYPE_ALIASES:
return (suffix, *_FILE_TYPE_ALIASES[suffix])

# check if the suffix is an archive type
with contextlib.suppress(RuntimeError):
_verify_archive_type(suffix)
if suffix in _ARCHIVE_EXTRACTORS:
return suffix, suffix, None

# check if the suffix is a compression
with contextlib.suppress(RuntimeError):
_verify_compression(suffix)
if suffix in _COMPRESSED_FILE_OPENERS:
# check for suffix hierarchy
if len(suffixes) > 1:
suffix2 = suffixes[-2]

# check if the suffix2 is an archive type
if suffix2 in _ARCHIVE_EXTRACTORS:
return suffix2 + suffix, suffix2, suffix

return suffix, None, suffix

raise RuntimeError(f"Suffix '{suffix}' is neither recognized as archive type nor as compression.")
valid_suffixes = set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS)
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'. Known suffixes are: '{valid_suffixes}'.")


def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
Expand Down

0 comments on commit c32d5ac

Please sign in to comment.