diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index 302f75087b7..2c954c4d719 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -6,7 +6,7 @@ import PIL import torch -from .utils import download_file_from_google_drive, check_integrity, verify_str_arg +from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive from .vision import VisionDataset CSV = namedtuple("CSV", ["header", "index", "data"]) @@ -142,8 +142,6 @@ def _check_integrity(self) -> bool: return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) def download(self) -> None: - import zipfile - if self._check_integrity(): print("Files already downloaded and verified") return @@ -151,8 +149,7 @@ def download(self) -> None: for (file_id, md5, filename) in self.file_list: download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) - with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: - f.extractall(os.path.join(self.root, self.base_folder)) + extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip")) def __getitem__(self, index: int) -> Tuple[Any, Any]: X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))