In [2]:
import torch
import torch.nn as nn

In [4]:
# input: (N, 100, 1, 1)
# output: (N, 1, 64, 64)
class Gen(nn.Module):
    def __init__(self, channel_noise, img_channel, feature_d) -> None:
        super(Gen, self).__init__()

        self.channel_noise = channel_noise
        self.img_channel = img_channel
        self.feature_d = feature_d

        x = self.feature_d
        # convtranspose2d's
        # inchannel, outchannel, kernel_size, stride, padding
        self.net_config = [
                [self.channel_noise, x * 8, 4, 1, 0],
                [x * 8, x * 4, 4, 2, 1],
                [x * 4, x * 2, 4, 2, 1],
                [x * 2, x, 4, 2, 1]
                # [x, self.img_channel, 4, 2, 1],
            ]
        # when feature_d == 128
        # (100, 1, 1) ->
        # (1024, 4, 4) ->
        # (512, 8, 8) ->
        # (256, 16, 16) ->
        # (128, 32, 32) ->
        # (1, 64, 64)

        conv_layers = []
        for i in range(len(self.net_config)):
            conv_layers.append( self._conv_T_block(self.net_config[i][0], 
                                self.net_config[i][1], 
                                self.net_config[i][2], 
                                self.net_config[i][3], 
                                self.net_config[i][4]))
        conv_layers.append(nn.ConvTranspose2d(x, self.img_channel, 4, 2, 1))
        conv_layers.append(nn.Tanh())
        self.net = nn.Sequential(*conv_layers)

    
    def _conv_T_block(self, inchannel, outchannel, kernel_size, stride, padding):
        block = nn.Sequential(
            nn.ConvTranspose2d(in_channels=inchannel,
                               out_channels=outchannel,
                               kernel_size=kernel_size,
                               stride=stride,
                               padding=padding,
                               bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU()
            
        )
        return block

    
    def forward(self, x):
        return self.net(x)

In [5]:
gen = Gen(100, 3, 64)

In [6]:
for m in gen.modules():
    print(m)

Gen(
  (net): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): ConvTranspose2d(64, 3, kernel_size=