In [1]:
import torch
import torch.nn as nn
from torchsummary import summary

In [2]:
def weights_init(m):
    if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
      nn.init.normal_(m.weight.data, 0, 0.02)
      # nn.init.constant_(m.bias.data,0)

In [3]:
class DownConv(nn.Module):
  def __init__(self,in_channels,out_channels,kernel,stride, activation='relu'):
    super(DownConv,self).__init__()

    self.activation = {
        'relu': nn.ReLU(),
        'leakyRelu': nn.LeakyReLU(0.2)
    }
    self.layers = nn.Sequential(
        nn.Conv2d(in_channels = in_channels, out_channels=out_channels, kernel_size=kernel, stride=stride, padding=1),
        nn.InstanceNorm2d(out_channels),
        self.activation[activation],
    )
    self.apply(weights_init)
  def forward(self,x):
    x = self.layers(x)
    return x

In [4]:
class Residual_block(nn.Module):
  def __init__(self, in_channels, out_channels, kernel, stride, padding):
    super(Residual_block,self).__init__()
    self.block = nn.Sequential(
        nn.Conv2d(in_channels = in_channels, out_channels=out_channels, kernel_size=kernel, stride=stride, padding=padding),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(),  
        nn.Conv2d(in_channels = out_channels, out_channels=out_channels, kernel_size=kernel, stride=stride, padding=padding),
        nn.InstanceNorm2d(out_channels),    
    )      
    self.apply(weights_init)
  def forward(self,x):
    identity = x
    out = self.block(x)
    out = torch.add(out, identity)
    return out


In [5]:
class UpConv(nn.Module):
  def __init__(self,in_channels, out_channels, kernel, stride):
    super(UpConv,self).__init__()
    self.layers = nn.Sequential(
        nn.ConvTranspose2d(in_channels,out_channels,kernel_size=kernel, stride=stride, padding=1, output_padding=1),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(),
    )
    self.apply(weights_init)
  def forward(self,x):
    x = self.layers(x)
    return x

In [6]:
class Generator(nn.Module):
  def __init__(self, in_channels):
    super(Generator,self).__init__()

    self.downconv1 = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=7, stride=1, padding='same')
    self.downconv2 = DownConv(in_channels=128, out_channels=128, kernel=3, stride=2, activation='relu')
    self.downconv3 = DownConv(in_channels=128, out_channels=256, kernel=3, stride=2, activation='relu')

    self.residual_block1 = Residual_block(in_channels=256, out_channels=256, kernel=3, stride=1, padding=1)
    self.residual_block2 = Residual_block(in_channels=256, out_channels=256, kernel=3, stride=1, padding=1)
    self.residual_block3 = Residual_block(in_channels=256, out_channels=256, kernel=3, stride=1, padding=1)
    self.residual_block4 = Residual_block(in_channels=256, out_channels=256, kernel=3, stride=1, padding=1)
    self.residual_block5 = Residual_block(in_channels=256, out_channels=256, kernel=3, stride=1, padding=1)
    self.residual_block6 = Residual_block(in_channels=256, out_channels=256, kernel=3, stride=1, padding=1)
    self.residual_block7 = Residual_block(in_channels=256, out_channels=256, kernel=3, stride=1, padding=1)
    self.residual_block8 = Residual_block(in_channels=256, out_channels=256, kernel=3, stride=1, padding=1)
    self.residual_block9 = Residual_block(in_channels=256, out_channels=256, kernel=3, stride=1, padding=1)

    self.upconv1 = UpConv(in_channels=256, out_channels=128, kernel=3, stride=2)
    self.upconv2 = UpConv(in_channels=128, out_channels=256, kernel=3, stride=2)
    self.downconv4 = nn.Conv2d(in_channels=256, out_channels=3, kernel_size=7, stride=1, padding='same')
    self.tanh = nn.Tanh()
    self.apply(weights_init)
  def forward(self,x):

    x = self.downconv1(x)
    x = self.downconv2(x)
    x = self.downconv3(x)
  
    x = self.residual_block1(x)
    x = self.residual_block2(x)
    x = self.residual_block3(x)
    x = self.residual_block4(x)
    x = self.residual_block5(x)
    x = self.residual_block6(x)
    x = self.residual_block7(x)
    x = self.residual_block8(x)
    x = self.residual_block9(x)
    
    x = self.upconv1(x)
    x = self.upconv2(x)
    x = self.downconv4(x)
    x = self.tanh(x)

    return x

In [7]:
class Discriminator(nn.Module):

  def __init__(self, in_channels):
    super(Discriminator,self).__init__()

    self.downconv1 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1)
    self.activation = nn.LeakyReLU(0.2)
    self.downconv2 = DownConv(in_channels=64, out_channels=128, kernel=4, stride=2, activation='leakyRelu')
    self.downconv3 = DownConv(in_channels=128, out_channels=256, kernel=4, stride=2, activation='leakyRelu')
    self.downconv4 = DownConv(in_channels=256, out_channels=512, kernel=4, stride=1, activation='leakyRelu')
    self.downconv5 = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1)
    self.apply(weights_init)
  def forward(self,x):
    x = self.downconv1(x)
    x = self.activation(x)
    x = self.downconv2(x)
    x = self.downconv3(x)
    x = self.downconv4(x)
    x = self.downconv5(x)
    return x

In [8]:
if __name__ == "__main__":
  num_channels, num_images, H, W = 3, 4, 256, 256
  inpput_generator = torch.zeros([num_channels*num_images, H, W], dtype=torch.float32)
  print(f"input data shape: {inpput_generator.shape}")

  
  generator = Generator(inpput_generator.shape[0])
  generator_output = generator(inpput_generator)
  print(f"shape of generator's output: {generator_output.shape}")


  input_discriminator = torch.zeros([num_channels, H, W], dtype=torch.float32)
  discriminator = Discriminator(input_discriminator.shape[0])
  discriminator_output = discriminator(input_discriminator)
  print(f"shape of discriminator output: {discriminator_output.shape}")



input data shape: torch.Size([12, 256, 256])
shape of generator's output: torch.Size([3, 256, 256])
shape of discriminator output: torch.Size([1, 30, 30])


In [9]:
if __name__ == "__main__":
  summary(generator,inpput_generator.shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1        [-1, 128, 256, 256]          75,392
            Conv2d-2        [-1, 128, 128, 128]         147,584
    InstanceNorm2d-3        [-1, 128, 128, 128]               0
              ReLU-4        [-1, 128, 128, 128]               0
          DownConv-5        [-1, 128, 128, 128]               0
            Conv2d-6          [-1, 256, 64, 64]         295,168
    InstanceNorm2d-7          [-1, 256, 64, 64]               0
              ReLU-8          [-1, 256, 64, 64]               0
          DownConv-9          [-1, 256, 64, 64]               0
           Conv2d-10          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-11          [-1, 256, 64, 64]               0
             ReLU-12          [-1, 256, 64, 64]               0
           Conv2d-13          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-14          [-1, 256,

In [10]:
if __name__ == "__main__":
   summary(discriminator,input_discriminator.shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           3,136
         LeakyReLU-2         [-1, 64, 128, 128]               0
            Conv2d-3          [-1, 128, 64, 64]         131,200
    InstanceNorm2d-4          [-1, 128, 64, 64]               0
         LeakyReLU-5          [-1, 128, 64, 64]               0
          DownConv-6          [-1, 128, 64, 64]               0
            Conv2d-7          [-1, 256, 32, 32]         524,544
    InstanceNorm2d-8          [-1, 256, 32, 32]               0
         LeakyReLU-9          [-1, 256, 32, 32]               0
         DownConv-10          [-1, 256, 32, 32]               0
           Conv2d-11          [-1, 512, 31, 31]       2,097,664
   InstanceNorm2d-12          [-1, 512, 31, 31]               0
        LeakyReLU-13          [-1, 512, 31, 31]               0
         DownConv-14          [-1, 512,