In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

<img src="https://machinelearningmastery.com/wp-content/uploads/2019/06/Tables-Showing-Generator-and-Discriminator-Configuration-for-the-Progressive-Growing-GAN.png"  width="1024"/>

In [63]:
class Generator(nn.Module):

  def __init__(self):
    
    super().__init__()

    self.block4_4 = GeneratorBlock(512, 512, 8, first_block=True)
    self.block8_8 = GeneratorBlock(512, 512, 8)
    self.block16_16 = GeneratorBlock(512, 512, 16)
    self.block32_32 = GeneratorBlock(512, 512, 32)
    self.block64_64 = GeneratorBlock(512, 256, 64)
    self.block128_128 = GeneratorBlock(256, 128, 128)

    self.blocks = nn.ModuleList([
        self.block4_4,
        self.block8_8,
        self.block16_16,
        self.block32_32,
        self.block64_64,
        self.block128_128
    ])

    self.to_rgbs = nn.ModuleList([
      nn.Conv2d(512, 3, 1),
      nn.Conv2d(512, 3, 1),
      nn.Conv2d(512, 3, 1),
      nn.Conv2d(512, 3, 1),
      nn.Conv2d(256, 3, 1),
      nn.Conv2d(128, 3, 1),
    ])


  def forward(self, x, step, alpha):
    # we have six steps toward progressively increase the output
    # alpha is the weight of output of new block compared to upsampled input
    if step == 1: # no need to average
      out = self.blocks[0](x)

    elif step > 1:

      for block in self.blocks[:step - 2]: # assuming all previous blocks have been trained completely
        x = block(x)

      x_small_block = self.blocks[step-2](x) # 512 * 32 * 32
      x_small_image = self.to_rgbs[step-2](x_small_block) # 3 * 32 * 32

      x_large_block = self.blocks[step-1](x_small_block) # 256 * 64 * 64
      x_large_image = self.to_rgbs[step-1](x_large_block) # 3 * 64 * 64


      x_small_upsample = F.interpolate(x_small_image, x_large_image.shape[-2:]) # 3 * 64 * 64

      out = (alpha *  x_large_image) + (1 - alpha) * (x_small_upsample)


    return out


In [65]:
gen = Generator()

In [67]:
gen(torch.randn(16, 512, 1, 1), 6, 0.2).shape

torch.Size([16, 3, 128, 128])

In [64]:
def GeneratorBlock(in_channel, out_channel, output_size, first_block=False):
  # Growing gradually to 1024 * 1024 is done by incrementally adding blocks
  # in this function we get specification of the block and return it
  # for example input would be 4 * 4 and output_size would be 8 * 8
  if first_block:

    model = nn.Sequential(
      nn.Conv2d(in_channel, out_channel, kernel_size=4, padding=3),
      nn.LeakyReLU(0.2),

      nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
      nn.LeakyReLU(0.2),
    )
  else:

    model = nn.Sequential(
    
      nn.Upsample((output_size, output_size), mode='bilinear', align_corners=True),
      nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
      nn.LeakyReLU(0.2),

      nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
      nn.LeakyReLU(0.2),
    )

  return model


  

In [None]:
!git clone "https://github.com/rosinality/progressive-gan-pytorch.git"

Cloning into 'progressive-gan-pytorch'...
remote: Enumerating objects: 22, done.[K
remote: Total 22 (delta 0), reused 0 (delta 0), pack-reused 22[K
Unpacking objects: 100% (22/22), done.


In [None]:
%cd /content/progressive-gan-pytorch

/content/progressive-gan-pytorch


In [None]:
from model import Generator, Discriminator

In [None]:
n_label = 1
code_size = 512 - n_label
generator = Generator(code_size, n_label)
batch_size = 16

In [None]:
generator(torch.randn(batch_size, code_size), torch.zeros(batch_size).int()).shape

torch.Size([16, 1])


torch.Size([16, 3, 4, 4])

In [None]:
from model import ConvBlock

In [None]:
cb = ConvBlock(512, 512, 4, 3, 3, 1)

In [None]:
cb

ConvBlock(
  (conv): Sequential(
    (0): EqualConv2d(
      (conv): Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1), padding=(3, 3))
    )
    (1): PixelNorm()
    (2): LeakyReLU(negative_slope=0.2)
    (3): EqualConv2d(
      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (4): PixelNorm()
    (5): LeakyReLU(negative_slope=0.2)
  )
)

In [None]:
cb.conv[0](torch.randn(16, 512, 1, 1)).shape

torch.Size([16, 512, 4, 4])

In [None]:
cb(torch.randn(16, 512, 1, 1)).shape

torch.Size([16, 512, 4, 4])

In [None]:
con2 = torch.nn.Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1), padding=(3, 3))

In [None]:
con2(torch.randn(16, 512, 1, 1)).shape

torch.Size([16, 512, 4, 4])