<a href="https://colab.research.google.com/github/taravatp/Multi_Spectral_Image_Segmentation/blob/main/networks/Discriminator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import math
import numpy as np

In [None]:
def convolution_module(in_channels,out_channels):
  conv = nn.Sequential(
      nn.Conv2d(in_channels,out_channels,kernel_size=3, padding=1, bias=False),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channels,out_channels,kernel_size=3, padding=1, bias=False),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True)
  )
  return conv

In [None]:
class discriminator(nn.Module):

  def __init__(self, in_channels=13, feat_channels=[32, 64, 128, 256]):

    super(discriminator, self).__init__()

    # Encoder convolutions
    self.down_conv1 = convolution_module(in_channels,32)
    self.down_conv2 = convolution_module(32,64)
    self.down_conv3 = convolution_module(64,128)
    self.down_conv4 = convolution_module(128,256)
    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.flat = nn.Flatten()
    self.mlp1 = nn.Linear(16384,128)
    self.mlp2 = nn.Linear(128,1)
    self.activation = nn.Sigmoid()

  def forward(self,img_A,img_B):
    x = torch.cat((img_A, img_B), 1)
    x = self.max_pool(self.down_conv1(x))
    x = self.max_pool(self.down_conv2(x))
    x = self.max_pool(self.down_conv3(x))
    x = self.down_conv4(x)
    x = self.flat(x)
    x = self.mlp1(x)
    x = self.mlp2(x)
    x = self.activation(x)

    return x

In [None]:
if __name__ == "__main__":

  img_A = torch.rand((8,12,64,64))
  img_B = torch.rand((8,1,64,64))
  model = discriminator()
  output = model(img_A,img_B)
  print('output:',output.shape)

output: torch.Size([8, 1])
