-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_utils.py
57 lines (45 loc) · 1.58 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import zipfile
import gdown
import torch
from natsort import natsorted
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
def download_CelebA(data_root):
dataset_folder = f'{data_root}/img_align_celeba'
# URL for the CelebA dataset
url = 'https://drive.google.com/uc?id=1cNIac61PSA_LqDFYFUeyaQYekYPc75NH'
# Path to download the dataset to
download_path = f'{data_root}/img_align_celeba.zip'
# Create required directories
if not os.path.exists(data_root):
os.makedirs(data_root)
os.makedirs(dataset_folder)
gdown.download(url, download_path, quiet=False)
# Unzip the downloaded file
with zipfile.ZipFile(download_path, 'r') as ziphandler:
ziphandler.extractall(dataset_folder)
class CelebADataset(Dataset):
def __init__(self, root_dir, transform=None):
"""
Args:
root_dir (string): Directory with all the images
transform (callable, optional): transform to be applied to each image sample
"""
# Read names of images in the root directory
image_names = os.listdir(root_dir)
self.root_dir = root_dir
self.transform = transform
self.image_names = natsorted(image_names)
def __len__(self):
return len(self.image_names)
def __getitem__(self, idx):
# Get the path to the image
img_path = os.path.join(self.root_dir, self.image_names[idx])
# Load image and convert it to RGB
img = Image.open(img_path).convert('RGB')
# Apply transformations to the image
if self.transform:
img = self.transform(img)
return img