diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index b3c5f9e9a6a..e9dd883b92e 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -1,13 +1,12 @@ import csv import os -import warnings from collections import namedtuple from typing import Any, Callable, List, Optional, Union, Tuple import PIL import torch -from .utils import check_integrity, verify_str_arg +from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive from .vision import VisionDataset CSV = namedtuple("CSV", ["header", "index", "data"]) @@ -36,17 +35,9 @@ class CelebA(VisionDataset): and returns a transformed version. E.g, ``transforms.PILToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. - download (bool, optional): Deprecated. - - .. warning:: - - Downloading CelebA is not supported anymore as of 0.13 and this - parameter will be removed in 0.15. See - `this issue `__ - for more details. - Please download the files from - https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract - them in ``root/celeba``. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. """ base_folder = "celeba" @@ -73,7 +64,7 @@ def __init__( target_type: Union[List[str], str] = "attr", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, - download: bool = None, + download: bool = False, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self.split = split @@ -85,15 +76,6 @@ def __init__( if not self.target_type and self.target_transform is not None: raise RuntimeError("target_transform is specified but target_type is empty") - if download is not None: - warnings.warn( - "Downloading CelebA is not supported anymore as of 0.13, and the " - "download parameter will be removed in 0.15. See " - "https://github.com/pytorch/vision/issues/5705 for more details. " - "Please download the files from " - "https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them " - "in ``root/celeba``." - ) if download: self.download() @@ -164,14 +146,10 @@ def download(self) -> None: print("Files already downloaded and verified") return - raise ValueError( - "Downloading CelebA is not supported anymore as of 0.13, and the " - "download parameter will be removed in 0.15. See " - "https://github.com/pytorch/vision/issues/5705 for more details. " - "Please download the files from " - "https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them " - "in ``root/celeba``." - ) + for (file_id, md5, filename) in self.file_list: + download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) + + extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip")) def __getitem__(self, index: int) -> Tuple[Any, Any]: X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index db90aa057ff..46ccf8de6f7 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -11,7 +11,7 @@ ) from torchvision.prototype.datasets.utils import ( Dataset, - ManualDownloadResource, + GDriveResource, OnlineResource, ) from torchvision.prototype.datasets.utils._internal import ( @@ -85,34 +85,33 @@ def __init__( super().__init__(root, skip_integrity_check=skip_integrity_check) def _resources(self) -> List[OnlineResource]: - instructions = "Please download the file from https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html." - splits = ManualDownloadResource( - instructions=instructions, + splits = GDriveResource( + "0B7EVK8r0v71pY0NSMzRuSXJEVkk", sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7", file_name="list_eval_partition.txt", ) - images = ManualDownloadResource( - instructions=instructions, + images = GDriveResource( + "0B7EVK8r0v71pZjFTYXZWM3FlRnM", sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74", file_name="img_align_celeba.zip", ) - identities = ManualDownloadResource( - instructions=instructions, + identities = GDriveResource( + "1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", sha256="c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0", file_name="identity_CelebA.txt", ) - attributes = ManualDownloadResource( - instructions=instructions, + attributes = GDriveResource( + "0B7EVK8r0v71pblRyaVFSWGxPY0U", sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0", file_name="list_attr_celeba.txt", ) - bounding_boxes = ManualDownloadResource( - instructions=instructions, + bounding_boxes = GDriveResource( + "0B7EVK8r0v71pbThiMVRxWXZ4dU0", sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b", file_name="list_bbox_celeba.txt", ) - landmarks = ManualDownloadResource( - instructions=instructions, + landmarks = GDriveResource( + "0B7EVK8r0v71pd0FJY3Blby1HUTQ", sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b", file_name="list_landmarks_align_celeba.txt", ) diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 9222c6e30a0..3c9b95cb498 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -216,9 +216,9 @@ def __init__(self, instructions: str, **kwargs: Any) -> None: def _download(self, root: pathlib.Path) -> NoReturn: raise RuntimeError( - f"The file {self.file_name} was not found, and cannot be downloaded automatically.\n\n" - f"{self.instructions.strip()}\n\n" - f"Once it is downloaded, please place the file in {root}." + f"The file {self.file_name} cannot be downloaded automatically. " + f"Please follow the instructions below and place it in {root}\n\n" + f"{self.instructions}" )