In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Dataset

In [2]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import glob
import os
import torchvision.transforms as transforms
from PIL import Image

In [65]:
lr = 0.001
batch_size = 16
beta1 = 0
beta2 = 0.99
criterion = nn.BCELoss()

In [4]:
class CelebA(Dataset):
  
  def __init__(self, root, transform=None):
    self.files = glob.glob(os.path.join(root, "*.jpg"))
    self.transform = transform

  def __getitem__(self, index):

    image = Image.open(self.files[index]) 

    if self.transform is not None:
      return self.transform(image)
    
    return image

  def __len__(self):
    return len(self.files)

In [5]:
def loaderFunc(transform):
  train_loader = DataLoader(
    CelebA("/content/drive/MyDrive/haircolors/images/img_align_celeba/", transform=transform),
    batch_size=batch_size,
    shuffle=True,
  )

  return train_loader

In [6]:
def sample_data(image_size=4):
    
  transform = transforms.Compose([
      transforms.Resize(image_size),
      transforms.CenterCrop(image_size),
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])


  loader = loaderFunc(transform)

  for img in loader:
      yield img

<img src="https://machinelearningmastery.com/wp-content/uploads/2019/06/Tables-Showing-Generator-and-Discriminator-Configuration-for-the-Progressive-Growing-GAN.png"  width="1024"/>

In [11]:
def GeneratorBlock(in_channel, out_channel, output_size, first_block=False):
  # Growing gradually to 1024 * 1024 is done by incrementally adding blocks
  # in this function we get specification of the block and return it
  # for example input would be 4 * 4 and output_size would be 8 * 8
  if first_block:

    model = nn.Sequential(
      nn.Conv2d(in_channel, out_channel, kernel_size=4, padding=3),
      nn.LeakyReLU(0.2),

      nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
      nn.LeakyReLU(0.2),
    )
  else:

    model = nn.Sequential(
    
      nn.Upsample((output_size, output_size), mode='bilinear', align_corners=True),
      nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
      nn.LeakyReLU(0.2),

      nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
      nn.LeakyReLU(0.2),
    )

  return model  

In [12]:
class Generator(nn.Module):

  def __init__(self):
    
    super().__init__()

    self.block4_4 = GeneratorBlock(512, 512, 8, first_block=True)
    self.block8_8 = GeneratorBlock(512, 512, 8)
    self.block16_16 = GeneratorBlock(512, 512, 16)
    self.block32_32 = GeneratorBlock(512, 512, 32)
    self.block64_64 = GeneratorBlock(512, 256, 64)
    self.block128_128 = GeneratorBlock(256, 128, 128)

    self.blocks = nn.ModuleList([
        self.block4_4,
        self.block8_8,
        self.block16_16,
        self.block32_32,
        self.block64_64,
        self.block128_128
    ])

    self.to_rgbs = nn.ModuleList([
      nn.Conv2d(512, 3, 1),
      nn.Conv2d(512, 3, 1),
      nn.Conv2d(512, 3, 1),
      nn.Conv2d(512, 3, 1),
      nn.Conv2d(256, 3, 1),
      nn.Conv2d(128, 3, 1),
    ])


  def forward(self, x, step, alpha):
    # we have six steps toward progressively increase the output
    # alpha is the weight of output of new block compared to upsampled input
    if step == 1: # no need to average
      out = self.blocks[0](x)

    elif step > 1:

      for block in self.blocks[:step - 2]: # assuming all previous blocks have been trained completely
        x = block(x)

      x_small_block = self.blocks[step-2](x) # 512 * 32 * 32
      x_small_image = self.to_rgbs[step-2](x_small_block) # 3 * 32 * 32

      x_large_block = self.blocks[step-1](x_small_block) # 256 * 64 * 64
      x_large_image = self.to_rgbs[step-1](x_large_block) # 3 * 64 * 64


      x_small_upsample = F.interpolate(x_small_image, x_large_image.shape[-2:]) # 3 * 64 * 64

      out = (alpha *  x_large_image) + (1 - alpha) * (x_small_upsample)


    return out


In [13]:
gen = Generator()

In [42]:
gen(torch.randn(16, 512, 1, 1), 6, 0.2).shape

torch.Size([16, 3, 128, 128])

In [72]:
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True)
                                  + 1e-8)

In [78]:
px = PixelNorm()
px(torch.randn(16, 512))

tensor([[ 1.5158,  0.6708,  0.8473,  ...,  0.4025, -0.3423,  1.4040],
        [-0.0140, -0.9758, -1.4042,  ..., -0.5419,  1.6229,  0.8273],
        [-0.3933, -0.6400,  0.5148,  ...,  0.8503,  0.5920,  0.0959],
        ...,
        [ 0.2634,  0.3666, -1.0588,  ..., -2.2999, -1.0695, -0.8408],
        [ 0.2018, -0.3985,  0.2101,  ..., -0.8172, -0.1167, -1.4461],
        [-1.1806, -1.9183, -0.2650,  ..., -0.7813,  0.5524, -0.2943]])

In [73]:
def DiscriminatorBlock(in_channel, out_channel, output_size, last_block=False):
  # Growing gradually to 1024 * 1024 is done by incrementally adding blocks
  # in this function we get specification of the block and return it
  # for example input would be 4 * 4 and output_size would be 8 * 8
  if last_block:

    model = nn.Sequential(
      nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
      nn.LeakyReLU(0.2),

      nn.Conv2d(out_channel, out_channel, kernel_size=4, padding=0),
      nn.LeakyReLU(0.2),
      nn.Flatten(start_dim=1),
      nn.Linear(512, 1),
    )
  else:

    model = nn.Sequential(
    
      nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
      PixelNorm(),
      nn.LeakyReLU(0.2),

      nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
      PixelNorm(),
      nn.LeakyReLU(0.2),

      nn.AvgPool2d(kernel_size=2),
    )

  return model

In [74]:
class Discriminator(nn.Module):

  def __init__(self):
    
    super().__init__()

    self.block64_64 = DiscriminatorBlock(128, 256, 64)
    self.block32_32 = DiscriminatorBlock(256, 512, 32)
    self.blcok16_16 = DiscriminatorBlock(512, 512, 16)
    self.block8_8 = DiscriminatorBlock(512, 512, 8)
    self.block4_4 = DiscriminatorBlock(512, 512, 4)
    self.block1_1 = DiscriminatorBlock(512, 512, 1, last_block=True)

    self.blocks = nn.ModuleList([
      self.block64_64,
      self.block32_32,
      self.blcok16_16,
      self.block8_8,
      self.block4_4,
      self.block1_1
    ])

    self.from_rgbs = nn.ModuleList([
      nn.Conv2d(3, 128, 1),
      nn.Conv2d(3, 256, 1),
      nn.Conv2d(3, 512, 1),
      nn.Conv2d(3, 512, 1),
      nn.Conv2d(3, 512, 1),
      nn.Conv2d(3, 512, 1),
    ])
  

  def forward(self, x_large, step, alpha):
    
    if step == 1:
      print(self.blocks[-step])
      out = self.blocks[-step](x_large)
    
    elif step > 1:
      
      x_large_feature_map = self.from_rgbs[(6-step)](x_large) # input : 3 * 8 * 8 and output is 512 * 8 * 8
      x_small = self.blocks[-step](x_large_feature_map) # last layer output is 512 * 4 * 4
      
      x_large_downsampled = F.avg_pool2d(self.from_rgbs[(6 - step) + 1](x_large), kernel_size=2) # input 3 * 8 * 8 => output is 512 * 8 * 8 => downsampled to 512 * 4 * 64

      out = (1 - alpha) * x_large_downsampled + (alpha * x_small)

      for block in self.blocks[(6 - step) + 1:]:
        out = block(out)
    
    return out

In [75]:
disc = Discriminator()

In [76]:
gen_optim = torch.optim.Adam(gen.parameters(), lr=lr , betas=(beta1, beta2))
disc_optim = torch.optim.Adam(disc.parameters(), lr=lr , betas=(beta1, beta2))

In [77]:
# step and alpha
step = 1
step_size = 0.001
alpha = 0
latent_space_size = 512

for real in sample_data(2 ** (step + 1)):

  alpha += step_size

  if alpha == 1 and step < 6:
    step += 1
    alpha = 0

  # Training Generator
  gen_optim.zero_grad()

  fake = gen(torch.randn(batch_size, latent_space_size, 1, 1), step, alpha)
  pred_fake = disc(fake, step, alpha).reshape(-1)
  gen_loss = criterion(pred_fake, torch.ones_like(pred_fake).float())

  gen_loss.backward()

  gen_optim.step()


  # Training Discriminator
  disc_optim.zero_grad()

  fake = gen(torch.randn(batch_size, latent_space_size, 1, 1), step, alpha).detach()
  pred_fake = disc(fake, step, alpha)

  pred_real = disc(real, step, alpha)

  disc_loss = (1/2) * (criterion(pred_fake, torch.zeros(pred_fake)) + criterion(pred_real, torch.ones(pred_real)))

  disc_loss.backward()

  disc_optim.step()

Sequential(
  (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): LeakyReLU(negative_slope=0.2)
  (2): Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1))
  (3): LeakyReLU(negative_slope=0.2)
  (4): Flatten(start_dim=1, end_dim=-1)
  (5): Linear(in_features=512, out_features=1, bias=True)
)


RuntimeError: ignored