In [12]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.7.2-py3-none-any.whl (22 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.7.2


In [13]:
import torch
from torchinfo import summary

In [2]:
class ConvBlock(torch.nn.Module):
    def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, activation=True, batch_norm=True):
        super(ConvBlock, self).__init__()
        self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding)
        self.activation = activation
        self.lrelu = torch.nn.LeakyReLU(0.2, True)
        self.batch_norm = batch_norm
        self.bn = torch.nn.BatchNorm2d(output_size)

    def forward(self, x):
        if self.activation:
            out = self.conv(self.lrelu(x))
        else:
            out = self.conv(x)

        if self.batch_norm:
            return self.bn(out)
        else:
            return out

In [28]:
class Discriminator(torch.nn.Module):
    def __init__(self, input_dim, num_filter, output_dim):
        super(Discriminator, self).__init__()

        self.conv1 = ConvBlock(input_dim, num_filter, activation=False, batch_norm=False)
        self.conv2 = ConvBlock(num_filter, num_filter * 2)
        self.conv3 = ConvBlock(num_filter * 2, num_filter * 4)
        self.conv4 = ConvBlock(num_filter * 4, num_filter * 8, stride=1)
        self.conv5 = ConvBlock(num_filter * 8, output_dim, stride=1, batch_norm=False)

    def forward(self, x, label):
        x = torch.cat([x, label], 1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        out = torch.nn.Sigmoid()(x)
        return out

    def normal_weight_init(self, mean=0.0, std=0.02):
        for m in self.children():
            if isinstance(m, ConvBlock):
                torch.nn.init.normal(m.conv.weight, mean, std)

In [29]:
new_dis = Discriminator(1,64,6)

In [31]:
summary(new_dis,input_size=(64,1,25,25))

Layer (type:depth-idx)                   Output Shape              Param #
Discriminator                            [64, 6, 1, 1]             --
├─ConvBlock: 1-1                         [64, 64, 12, 12]          128
│    └─Conv2d: 2-1                       [64, 64, 12, 12]          1,088
├─ConvBlock: 1-2                         [64, 128, 6, 6]           --
│    └─LeakyReLU: 2-2                    [64, 64, 12, 12]          --
│    └─Conv2d: 2-3                       [64, 128, 6, 6]           131,200
│    └─BatchNorm2d: 2-4                  [64, 128, 6, 6]           256
├─ConvBlock: 1-3                         [64, 256, 3, 3]           --
│    └─LeakyReLU: 2-5                    [64, 128, 6, 6]           --
│    └─Conv2d: 2-6                       [64, 256, 3, 3]           524,544
│    └─BatchNorm2d: 2-7                  [64, 256, 3, 3]           512
├─ConvBlock: 1-4                         [64, 512, 2, 2]           --
│    └─LeakyReLU: 2-8                    [64, 256, 3, 3]           --