In [3]:
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from natsort import natsorted

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
input_size = 100
n_filters = 32
lr_generator = 3e-4
lr_discriminator = 2e-4
mode_z = 'normal'
image_size = (64, 64)
batch_size = 32
num_epochs = 100

current_path = os.path.dirname(os.path.abspath(__name__))
data_path = os.path.normpath(os.path.join(current_path, 'data/celeba/img_align_celeba/'))

In [7]:
class CelebADataset(Dataset):

  def __init__(self, root_dir, transform=None):
    image_names = os.listdir(root_dir)
    self.root_dir = root_dir
    self.transform = transform 
    self.image_names = natsorted(image_names)

  def __len__(self): 
    return len(self.image_names)

  def __getitem__(self, idx):
    img_path = os.path.join(self.root_dir, self.image_names[idx])
    img = Image.open(img_path).convert('RGB')
    if self.transform:
      img = self.transform(img)
    return img

transform = transforms.Compose([
    transforms.Resize(image_size[0]),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True)
])

dataset = CelebADataset(root_dir=data_path, transform=transform)
dataset_dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)

In [41]:
class Generator(nn.Module):
    def __init__(self, input_size, n_filters):
        super().__init__()
        self.network = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=input_size,
                out_channels=n_filters*4,
                kernel_size=4,
                stride=1,
                padding=0,
                bias=False
            ),
            nn.BatchNorm2d(
                num_features=n_filters*4,
                track_running_stats=False
            ),
            nn.LeakyReLU(
                negative_slope=0.2,
                inplace=True
            ),
            nn.ConvTranspose2d(
                in_channels=n_filters*4,
                out_channels=n_filters*2,
                kernel_size=3,
                stride=2,
                padding=1,
                bias=False
            ),
            nn.BatchNorm2d(
                num_features=n_filters*2,
                track_running_stats=False
            ),
            nn.LeakyReLU(
                negative_slope=0.2,
                inplace=True
            ),
            nn.ConvTranspose2d(
                in_channels=n_filters*2,
                out_channels=n_filters,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False
            ),
            nn.BatchNorm2d(
                num_features=n_filters,
                track_running_stats=False
            ),
            nn.LeakyReLU(
                negative_slope=0.2,
                inplace=True
            ),
            nn.ConvTranspose2d(
                in_channels=n_filters,
                out_channels=1,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False
            ),
            nn.Tanh()          
        )
    def forward(self, input):
        output = self.network(input)
        return output

In [None]:
generator = Generator(
    input_size=input_size,
    n_filters=n_filters
).to(device=device)

print(generator.network)

Sequential(
  (0): ConvTranspose2d(100, 128, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (2): LeakyReLU(negative_slope=0.2, inplace=True)
  (3): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (5): LeakyReLU(negative_slope=0.2, inplace=True)
  (6): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (8): LeakyReLU(negative_slope=0.2, inplace=True)
  (9): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (10): Tanh()
)


In [None]:
class Discriminator(nn.Module):
    def __init__(self, n_filters):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=n_filters,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False
            ),
            nn.LeakyReLU(
                negative_slope=0.2,
                inplace=True
            ),
            nn.Conv2d(
                in_channels=n_filters,
                out_channels=n_filters*2,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False
            ),
            nn.BatchNorm2d(
                num_features=n_filters*2,
                track_running_stats=False
            ),
            nn.LeakyReLU(
                negative_slope=0.2,
                inplace=True
            ),
            nn.Conv2d(
                in_channels=n_filters*2,
                out_channels=n_filters*4,
                kernel_size=3,
                stride=2,
                padding=1,
                bias=False
            ),
            nn.BatchNorm2d(
                num_features=n_filters*4,
                track_running_stats=False
            ),
            nn.LeakyReLU(
                negative_slope=0.2,
                inplace=True
            ),
            nn.Conv2d(
                in_channels=n_filters*4,
                out_channels=1,
                kernel_size=4,
                stride=1,
                padding=0,
                bias=False
            ),
            nn.Sigmoid()          
        )
    def forward(self, input):
        output = self.network(input)
        return output

In [None]:
discriminator = Discriminator(
    n_filters=n_filters
).to(device=device)

print(discriminator.network)

Sequential(
  (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (1): LeakyReLU(negative_slope=0.2, inplace=True)
  (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (4): LeakyReLU(negative_slope=0.2, inplace=True)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (7): LeakyReLU(negative_slope=0.2, inplace=True)
  (8): Conv2d(128, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (9): Sigmoid()
)


In [None]:
loss_func = nn.BCELoss()
generator_optimizer = torch.optim.Adam(
    params=generator.parameters(), 
    lr=lr_generator
)
discriminator_optimizer = torch.optim.Adam(
    params=discriminator.parameters(),
    lr=lr_discriminator
)

In [None]:
def create_noise(batch_size, input_size, mode_z):
    if mode_z == 'uniform':
        input_z = torch.rand(batch_size, input_size, 1, 1) * 2 - 1
    elif mode_z == 'normal':
        input_z = torch.rand(batch_size, input_size, 1, 1)
    return input_z

def train_discriminator(input: torch.Tensor):
    discriminator.zero_grad()
    batch_size = input.size(0)
    input = input.to(device)
    discriminator_labels_real = torch.ones(batch_size, 1, device=device)
    discriminator_prob_real = discriminator(input)
    discriminator_loss_real = loss_func(discriminator_prob_real, discriminator_labels_real)
    input_z = create_noise(batch_size, input_size, mode_z)
    generator_output = generator(input_z)
    discriminator_prob_fake = discriminator(generator_output)
    discriminator_labels_fake = torch.zeros(batch_size, 1, device=device)
    discriminator_loss_fake = loss_func(discriminator_prob_fake, discriminator_labels_fake)
    discriminator_loss = discriminator_loss_real + discriminator_loss_fake
    discriminator_loss.backward()
    discriminator_optimizer.step()
    return (
        discriminator_loss.data.item(),
        discriminator_prob_real.detach(),
        discriminator_prob_fake.detach()
    )

def train_generator(input: torch.Tensor):
    generator.zero_grad()
    batch_size = input.size(0)
    input_z = create_noise(batch_size, input_size, mode_z).to(device)
    generator_labels_real = torch.ones(batch_size, 1, device=device)
    generator_output = generator(input_z)
    discriminator_prob_fake = discriminator(generator_output)
    generator_loss = loss_func(discriminator_prob_fake, generator_labels_real)
    generator_loss.backward()
    generator_optimizer.step()
    return generator_loss.data.item()

In [None]:
def create_samples(generator, input_z):
    generator_output = generator(input_z)
    images = torch.reshape(generator_output, (batch_size, *image_size))
    return (images + 1) / 2.0


fixed_z = create_noise(batch_size, input_size, mode_z).to(device)
torch.manual_seed(1)

discriminator_losses = []
generator_losses = []
epoch_samples = []

for epoch in range(1, num_epochs+1):
    generator.train()
    for i, (input, _) in enumerate(dataset_dataloader):
        discriminator_loss, discriminator_prob_real, discriminator_prob_fake = train_discriminator(input)
        discriminator_losses.append(discriminator_loss)
        generator_loss = train_generator(input)
        generator_losses.append(generator_loss)
    print(  f'Epoch {epoch:03d} | Avg Losses >>'
            f' G/D {torch.FloatTensor(generator_losses).mean():.4f}'
            f'/{torch.FloatTensor(discriminator_losses).mean():.4f}')
    generator.eval()
    epoch_samples.append(
        create_samples(
            generator, fixed_z
        ).detach().cpu().numpy()
    )