In [2]:
import torch
import torch.nn as nn

In [3]:
from torchinfo import summary

In [4]:
class discriminator_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.LeakyReLU = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.LeakyReLU(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.disc_conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            discriminator_block(64, 64, kernel_size=3, stride=2, padding=1),
            discriminator_block(64, 128, kernel_size=3, stride=1, padding=1),
            discriminator_block(128, 128, kernel_size=3, stride=2, padding=1),
            discriminator_block(128, 256, kernel_size=3, stride=1, padding=1),
            discriminator_block(256, 256, kernel_size=3, stride=2, padding=1),
            discriminator_block(256, 512, kernel_size=3, stride=1, padding=1),
            discriminator_block(512, 512, kernel_size=3, stride=2, padding=1),
        )

        self.fc = nn.Sequential(
            nn.Linear(512 * 16 * 16, 512 * 4 * 4),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Linear(512 * 4 * 4, 1024),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.disc_conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
model = Discriminator()
summary(model, input_size=(1, 3, 256, 256))

Layer (type:depth-idx)                   Output Shape              Param #
Discriminator                            [64, 1]                   --
├─Sequential: 1-1                        [64, 512, 16, 16]         --
│    └─Conv2d: 2-1                       [64, 64, 256, 256]        1,792
│    └─LeakyReLU: 2-2                    [64, 64, 256, 256]        --
│    └─discriminator_block: 2-3          [64, 64, 128, 128]        --
│    │    └─Conv2d: 3-1                  [64, 64, 128, 128]        36,928
│    │    └─BatchNorm2d: 3-2             [64, 64, 128, 128]        128
│    │    └─LeakyReLU: 3-3               [64, 64, 128, 128]        --
│    └─discriminator_block: 2-4          [64, 128, 128, 128]       --
│    │    └─Conv2d: 3-4                  [64, 128, 128, 128]       73,856
│    │    └─BatchNorm2d: 3-5             [64, 128, 128, 128]       256
│    │    └─LeakyReLU: 3-6               [64, 128, 128, 128]       --
│    └─discriminator_block: 2-5          [64, 128, 64, 64]         --
│ 