In [2]:
import torch
import torch.nn as nn

In [29]:
class Disc(nn.Module):
    """
    input: (N, 6, 256, 256) x, y concat
    output: (N, 1, 30, 30)
    """

    def __init__(self, img_channel) -> None:
        super(Disc, self).__init__()
        # conv block's: inchannel, outchannel, kernel_size, stride, padding
                            # (6, 256, 256) 
        self.config_lst = [[img_channel*2, 64, 4, 2, 1],
                            # (64, 128, 128)
                            [64, 128, 4, 2, 1],
                            # (128, 64, 64)
                            [128, 256, 4, 2, 1],
                            # (256, 32, 32)
                            [256, 512, 4, 1, 1],
                            # (512, 31, 31)
                            [512, 1, 4, 1, 1]]
                            # (1, 30, 30)
        self.conv_layers = self._create_conv_layers()
        # print(self.conv_layers)
        
    def _create_conv_layers(self):
        layers_lst = []
        for i in range(len(self.config_lst)):
            layers_lst.append(self._conv_block(self.config_lst[i][0],
                                                self.config_lst[i][1],
                                                self.config_lst[i][2],
                                                self.config_lst[i][3],
                                                self.config_lst[i][4]))
        return nn.Sequential(*layers_lst)


    def _conv_block(self, inchannel, outchannel, k_s, s, p):
        return nn.Sequential(
            nn.Conv2d(inchannel, outchannel, k_s, s, p),
            nn.BatchNorm2d(outchannel),
            nn.ReLU()

        )
    def forward(self, x):
        return self.conv_layers(x)



class Gen(nn.Module):
    def __init__(self, img_channel) -> None:
        super(Gen, self).__init__()
        # up and down block's: inchannel, outchannel, kernel_size, stride, padding
        # in down block: the first two elements is inchannel, outchannel
        # in up block: the first two elements is outchannel, inchannel
        self.config_up_down = [[img_channel, 9, 7, 1, 3],
                            [9, 18, 4, 2, 1],
                            [18, 36, 4, 2, 1]]
        # residual block's: inchannel, outchannel, kernel_size, stride, padding, repeat_time
        self.config_residual = [36, 36, 3, 1, 1, 6]
        self.down_block = self._down_block()
        # print(self.down_block)
        self.up_block = self._up_block()
        # print(self.up_block)


    def _residual_block(self, x, inchannel, outchannel, k_s, s, p):
        layer = self._conv_block(inchannel, outchannel, k_s, s, p)
        return x + layer(x)

    def _conv_block(self, inchannel, outchannel, k_s, s, p):
        return nn.Sequential(
            nn.Conv2d(inchannel, outchannel, k_s, s, p),
            nn.BatchNorm2d(outchannel),
            nn.ReLU()

        )

    def _down_block(self):
        down_layers = []
        for i in range(len(self.config_up_down)):
            down_layers.append(self._conv_block(self.config_up_down[i][0], 
                                                self.config_up_down[i][1], 
                                                self.config_up_down[i][2],
                                                self.config_up_down[i][3],
                                                self.config_up_down[i][4]))
        
        return nn.Sequential(*down_layers)

    def _deconv_block(self, inchannel, outchannel, k_s, s, p):
        return nn.Sequential(
            nn.ConvTranspose2d(inchannel, outchannel, k_s, s, p),
            nn.BatchNorm2d(outchannel),
            nn.ReLU()
        )

    def _up_block(self):
        up_layers = []
        for i in range(len(self.config_up_down) - 1, -1, -1):
            up_layers.append(self._deconv_block(self.config_up_down[i][1], 
                                                self.config_up_down[i][0], 
                                                self.config_up_down[i][2],
                                                self.config_up_down[i][3],
                                                self.config_up_down[i][4]))
        return nn.Sequential(*up_layers)
    
    def forward(self, x):
        for layer in self.down_block:
            x = layer(x)
            # print(x.shape)
        for _ in range(self.config_residual[-1]):
            x = self._residual_block(x, 
                                self.config_residual[0],
                                self.config_residual[1],
                                self.config_residual[2],
                                self.config_residual[3],
                                self.config_residual[4])
            # print(x.shape)
        for layer in self.up_block:
            x = layer(x)
            # print(x.shape)
        return x

In [None]:
x = torch.randn(16, 3, 256, 256)
gen = Gen(3)
gen(x)