# WGAN-GP on CelebA Dataset

**The notebook has been adapted from the notebook provided in David Foster's Generative Deep Learning, 2nd Edition.**

- Book: [Amazon](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1098134184/ref=sr_1_1?keywords=generative+deep+learning%2C+2nd+edition&qid=1684708209&sprefix=generative+de%2Caps%2C93&sr=8-1)
- Original notebook (tensorflow and keras): [Github](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/04_gan/02_wgan_gp/wgan_gp.ipynb)
- Dataset: [Kaggle](https://www.kaggle.com/datasets/jessicali9530/celeba-dataset)

In [5]:
import os
from PIL import Image

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as Transforms

## 0. Train parameters

In [2]:
DATA_DIR = '../../data/CelebFaces/img_align_celeba/img_align_celeba/'
IMAGE_SIZE = 64
CHANNELS = 3
BATCH_SIZE = 512
Z_DIM = 128
LR = 2e-4
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.9
EPOCHS = 200
CRITIC_STEPS = 3
GP_WEIGHT = 10.0

## 1. Prepare dataset

In [6]:
class CelebA(Dataset):
    def __init__(self, image_dir):
        super().__init__(self)

        self.transform = Transforms.Compose([
                Transforms.ToTensor(),
                Transforms.Resize(size=(IMAGE_SIZE, IMAGE_SIZE)),
                Transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        self.dir = image_dir
        self.imgs = os.listdir(self.dir)
        self.length = len(self.imgs)

    def __len__(self):
        return self.length

    def __getitem(self, index):
        output_img = self.transform(Image.open(self.dir + self.imgs[i]))
        return output_img