In [2]:
import torch
import torch.nn as nn

# Discriminator

In [3]:
class Discriminator(nn.Module):

  def __init__(self, channels_img, features_d):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
        # input: N * channels_img* 64*64
        nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.2),
        # N*feature_dim, 32*32
        self._block(features_d,features_d*2, 4 , 2 , 1), # N*feature_dim, 16*16
        self._block(features_d*2,features_d*4, 4 , 2 , 1), # N*feature_dim, 8*8
        self._block(features_d*4,features_d*8, 4 , 2 , 1), # N*feature_dim, 4*4
        nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # N*feature_dim, 1*1
        nn.Sigmoid(),
    )



  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False
        ),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2),
    )



  def forward(self,x):
    return self.disc(x)

In [4]:
class Generator(nn.Module):

  def __init__(self, z_dim, channel_img, features_g):
    super(Generator, self).__init__()
    self.net = nn.Sequential(
            # input: N * z_dim * 1*1
            self._block(z_dim, features_g*16, 4, 1, 0), # N * f_g*16 * 4*4
            self.block(features_g*16, features_g*8, 4, 2, 1), # 8*8
            self.block(features_g*8, features_g*4, 4, 2, 1), # 16*16
            self.block(features_g*4, features_g*2, 4, 2, 1), # 32*32
            nn.ConvTranspose2D(features_g*2, channel_img, kernel_size=4, stride=2, padding=1), # 64*64
            nn.Tanh(), # [-1,1]
    )

  def forward(self,x):
    return self.net(x)


  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2D(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )