From 4a44223991b8fdae87d6fd3bc45371f3b4a13947 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 7 Oct 2021 10:51:58 +0200 Subject: [PATCH] use helper function to extract archive in CelebA --- torchvision/datasets/celeba.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index 302f75087b7..2c954c4d719 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -6,7 +6,7 @@ import PIL import torch -from .utils import download_file_from_google_drive, check_integrity, verify_str_arg +from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive from .vision import VisionDataset CSV = namedtuple("CSV", ["header", "index", "data"]) @@ -142,8 +142,6 @@ def _check_integrity(self) -> bool: return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) def download(self) -> None: - import zipfile - if self._check_integrity(): print("Files already downloaded and verified") return @@ -151,8 +149,7 @@ def download(self) -> None: for (file_id, md5, filename) in self.file_list: download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) - with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: - f.extractall(os.path.join(self.root, self.base_folder)) + extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip")) def __getitem__(self, index: int) -> Tuple[Any, Any]: X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))