forked from facebookresearch/torchbeast
/
datasets.py
139 lines (111 loc) · 4.47 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright urw7rs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import os
import PIL
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import download_file_from_google_drive
from torchvision.datasets.utils import check_integrity, verify_str_arg
class CelebAHQ(VisionDataset):
base_folder = "celeba-hq"
# yapf: disable
file_list = [
# File ID MD5 Hash Filename
("1badu11NqxGf6qM3PTTooQDJvQbejgbTv", "b08032b342a8e0cf84c273db2b52eef3", "CelebAMask-HQ.zip"),
("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
]
# yapf: enable
def __init__(
self, root, split="train", target_type="attr", transform=None, download=False,
):
import pandas
super(CelebAHQ, self).__init__(root, transform=transform, target_transform=None)
self.split = split
if isinstance(target_type, list):
self.target_type = target_type
else:
self.target_type = [target_type]
if not self.target_type and self.target_transform is not None:
raise RuntimeError("target_transform is specified but target_type is empty")
if download:
self.download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted."
+ " You can use download=True to download it"
)
split_map = {
"train": 0,
"valid": 1,
"test": 2,
"all": None,
}
split = split_map[
verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))
]
fn = partial(os.path.join, self.root, self.base_folder)
splits = pandas.read_csv(
fn("list_eval_partition.txt"),
delim_whitespace=True,
header=None,
index_col=0,
)
index = pandas.read_csv(
fn("CelebAMask-HQ", "CelebA-HQ-to-CelebA-mapping.txt"),
delim_whitespace=True,
header=0,
usecols=["idx", "orig_idx"],
)
splits = index["orig_idx"].apply(lambda i: splits.iloc[i])
index = index["idx"]
mask = slice(None) if split is None else (splits[1] == split)
self.filename = index[mask].apply(lambda s: str(s) + ".jpg").values
def _check_integrity(self):
for (_, md5, filename) in self.file_list:
fpath = os.path.join(self.root, self.base_folder, filename)
_, ext = os.path.splitext(filename)
# Allow original archive to be deleted (zip and 7z)
# Only need the extracted images
if ext != ".zip" and not check_integrity(fpath, md5):
return False
# Should check a hash of the images
return os.path.isdir(os.path.join(self.root, self.base_folder, "CelebAMask-HQ"))
def download(self):
import zipfile
if self._check_integrity():
print("Files already downloaded and verified")
return
try:
os.mkdir(os.path.join(self.root, self.base_folder))
except OSError as error:
print(error)
for (file_id, md5, filename) in self.file_list:
download_file_from_google_drive(
file_id, os.path.join(self.root, self.base_folder), filename, md5
)
with zipfile.ZipFile(
os.path.join(self.root, self.base_folder, "CelebAMask-HQ.zip"), "r"
) as f:
f.extractall(os.path.join(self.root, self.base_folder))
def __getitem__(self, index):
X = PIL.Image.open(
os.path.join(
self.root, self.base_folder, "CelebAMask-HQ", "CelebA-HQ-img", self.filename[index]
)
)
if self.transform is not None:
X = self.transform(X)
return X, 0
def __len__(self):
return len(self.filename)