In [19]:
import os
from glob import glob
import torch
from torch.utils.data import random_split, DataLoader, Dataset
from torchvision import transforms
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from PIL import Image
import numpy as np

In [20]:
hparams = {
    'batch_size': 2,
    'image_size': 64,
    'n_channels': 3,
    'latent_vector_size': 100,
    'generator_features': 64,
    'discriminator_features': 64,
    'epochs': 5,
    'learning_rate': 0.0002,
    'beta1': 0.5,
    'gpus': 0
}

In [21]:
image_files = glob('../data/raw/img_align_celeba/*.jpg')
image_files[0]

'../data/raw/img_align_celeba/052628.jpg'

In [22]:
class CelebDataset(Dataset):
    
    def __init__(self, data_dir: str, image_size: int = 64):
        super().__init__()
        
        self.data_dir = data_dir
        assert os.path.isdir(self.data_dir), f'self.data_dir is not a directory: {self.data_dir}'
        self.transformation_stack = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        self.image_files = glob(os.path.join(self.data_dir, '*.jpg'))
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        image_file = self.image_files[idx]
        image = Image.open(image_file)
        image = self.transformation_stack(image)
        image = np.asarray(image).astype('float')
        image = torch.from_numpy(image)
                
        return image

In [23]:
dataset = CelebDataset(data_dir='../data/raw/img_align_celeba')
batch = next(iter(dataset))
batch.shape

torch.Size([3, 64, 64])

In [31]:
class CelebDataModule(LightningDataModule):
    
    def __init__(
        self,
        data_dir: str,
        batch_size: int = 4,
        num_workers: int = 1,
        image_size: int = 64,
        
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.image_size = image_size

        
    def prepare_data(self):
        assert os.path.isdir(self.data_dir), f'self.data_dir is not a directory: {self.data_dir}'

    def setup(self):
        
        dataset = CelebDataset(self.data_dir, self.image_size)
        number_of_images_in_dataset = len(dataset)
        number_of_images_for_training = int(number_of_images_in_dataset*0.8)
        number_of_images_for_testing = int(number_of_images_in_dataset*0.15)
        number_of_images_for_validation = number_of_images_in_dataset - \
                                          number_of_images_for_training - \
                                          number_of_images_for_testing
        
        self.train_dataset, self.test_dataset, self.val_dataset = random_split(
            dataset,
            (
                number_of_images_for_training, 
                number_of_images_for_testing, 
                number_of_images_for_validation
            )
        )
        

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size, 
            num_workers=self.num_workers
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, 
            batch_size=self.batch_size, 
            num_workers=self.num_workers
        )

In [37]:
data_module = CelebDataModule(data_dir='../data/raw/img_align_celeba', image_size=64, batch_size=1)
data_module.setup()

In [38]:
data_module.train_dataset[0].shape

torch.Size([3, 64, 64])