In [None]:
from torch import nn
import torch

In [None]:
def crop(encoder_layer: torch.Tensor, decoder_layer: torch.Tensor):
    # Center-crops the encoder_layer to the size of the decoder_layer
    # only necessary for input sizes != 2**n
    # notice shape: batch-channel-height-width
    
    len_encoder = encoder_layer.shape[2]
    len_decoder = decoder_layer.shape[2]
    if len_encoder != len_decoder:
      assert len_encoder >= len_decoder

      cropped_encoder_layer = encoder_layer[
          :, 
          :,
          ((len_encoder - len_decoder) // 2):((len_encoder + len_decoder) // 2), 
          ((len_encoder - len_decoder) // 2):((len_encoder + len_decoder) // 2)]
    else:
      cropped_encoder_layer = encoder_layer

    return cropped_encoder_layer

In [None]:
# test crop
def test_crop():
  x = torch.randn(2, 128, 280, 200)
  y = torch.randn(2, 128, 280, 200)
  cropped_x = crop(x, y)
  z = torch.cat([cropped_x, y], 1)

  print(z.shape)
  print(x.shape, cropped_x.shape, y.shape)
  print(x.type())

# test_crop()


In [None]:
class DownBlock(nn.Module):
  def __init__(self, in_, out, batchnorm=True):
    super().__init__()
    self.conv = nn.Conv2d(in_channels=in_, out_channels=out, kernel_size=4, stride=2, padding=1)
    self.batchnorm = nn.BatchNorm2d(out)
    self.leaky_relu = nn.LeakyReLU(0.2, True)

    self.batchnorm_key = batchnorm
    
  def forward(self, x):
    x = self.leaky_relu(x)
    x = self.conv(x)
    if self.batchnorm_key:
      x = self.batchnorm(x)
    
    return x



class UpBlock(nn.Module):
  def __init__(self, in_, out, batchnorm=True, dropout=False):
    super().__init__()
    self.conv_transpose = nn.ConvTranspose2d(in_channels=in_, out_channels=out, kernel_size=4, stride=2, padding=1)
    self.batchnorm = nn.BatchNorm2d(out)
    self.dropout = nn.Dropout(p=0.5)
    self.relu = nn.ReLU(True)
    
    self.batchnorm_key = batchnorm
    self.dropout_key = dropout

  def forward(self, x):
    x = self.relu(x)
    x = self.conv_transpose(x)
    if self.batchnorm_key:
      x = self.batchnorm(x)
    if self.dropout_key:
      self.dropout(x)
    
    return x

In [None]:
# generator: modified Unet

class Unet(nn.Module):
  def __init__(self):
    super().__init__()

    # in_-C64-C128-C256-C512-C512-C512-C512-C512
    # self.down1 = DownBlock(3, 64, batchnorm=False)
    self.down1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1)
    self.down2 = DownBlock(64, 128)
    self.down3 = DownBlock(128, 256)
    self.down4 = DownBlock(256, 512)
    self.down5 = DownBlock(512, 512)
    self.down6 = DownBlock(512, 512)
    self.down7 = DownBlock(512, 512)
    self.down8 = DownBlock(512, 512, batchnorm=False)

    # in_-C512-C512-C512-C512-C256-C128-C64-C3
    self.up1 = UpBlock(512, 512, dropout=True)
    self.up2 = UpBlock(1024, 512, dropout=True)
    self.up3 = UpBlock(1024, 512, dropout=True)
    self.up4 = UpBlock(1024, 512)
    self.up5 = UpBlock(1024, 256)
    self.up6 = UpBlock(512, 128)
    self.up7 = UpBlock(256, 64)
    self.up8 = UpBlock(128, 3, batchnorm=False)

    # last activation
    self.tanh = nn.Tanh()


  def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)


  def forward(self, x):
    # downsampling
    d1 = self.down1(x)
    d2 = self.down2(d1)
    d3 = self.down3(d2)
    d4 = self.down4(d3)
    d5 = self.down5(d4)
    d6 = self.down6(d5)
    d7 = self.down7(d6)
    d8 = self.down8(d7)

    # upsampling
    u1 = self.up1(d8)
    cropped_d7 = crop(d7, u1)
    u1 = torch.cat([u1, cropped_d7], 1)

    u2 = self.up2(u1)
    cropped_d6 = crop(d6, u2)
    u2 = torch.cat([u2, cropped_d6], 1)

    u3 = self.up3(u2)
    cropped_d5 = crop(d5, u3)
    u3 = torch.cat([u3, cropped_d5], 1)

    u4 = self.up4(u3)
    cropped_d4 = crop(d4, u4)
    u4 = torch.cat([u4, cropped_d4], 1)

    u5 = self.up5(u4)
    cropped_d3 = crop(d3, u5)
    u5 = torch.cat([u5, cropped_d3], 1)

    u6 = self.up6(u5)
    cropped_d2 = crop(d2, u6)
    u6 = torch.cat([u6, cropped_d2], 1)

    u7 = self.up7(u6)
    cropped_d1 = crop(d1, u7)
    u7 = torch.cat([u7, cropped_d1], 1)

    u8 = self.up8(u7)
    

    return self.tanh(u8)

In [None]:
# test Unet 
def test_Unet():
  generator = Unet()
  input = torch.randn(1, 3, 256, 256)
  result = generator(input)
  print(result.shape)

# test_Unet()


In [None]:
class PatchGAN_70(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=6, out_channels=64, kernel_size=4, stride=2, padding=1)
    # self.batchnorm1 = nn.BatchNorm2d(64)
    self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1)
    self.batchnorm2 = nn.BatchNorm2d(128)
    self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1)
    self.batchnorm3 = nn.BatchNorm2d(256)
    self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1)
    self.batchnorm4 = nn.BatchNorm2d(512)
    self.conv5 = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1)

    self.leaky_relu =  nn.LeakyReLU(negative_slope=0.2, inplace=True)
    self.sigmoid = nn.Sigmoid()


  def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
            

  def forward(self, input, target):
    x = torch.cat([input, target], 1)
    x = self.leaky_relu(self.conv1(x))
    x = self.leaky_relu(self.batchnorm2(self.conv2(x)))
    x = self.leaky_relu(self.batchnorm3(self.conv3(x)))
    x = self.leaky_relu(self.batchnorm4(self.conv4(x)))
    x = self.conv5(x)

    return self.sigmoid(x)

In [None]:
# discriminator: 16x16 patchGAN

class PatchGAN_16(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=6, out_channels=64, kernel_size=4, stride=2, padding=1)
    # self.batchnorm1 = nn.BatchNorm2d(64)
    self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=1, padding=1)
    self.batchnorm2 = nn.BatchNorm2d(128)
    self.conv3 = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=4, stride=1, padding=1)

    self.leaky_relu =  nn.LeakyReLU(negative_slope=0.2, inplace=True)
    self.sigmoid = nn.Sigmoid()


  def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
  

  def forward(self, input, target):
    x = torch.cat([input, target], 1)
    x = self.conv1(x)
    # x = self.batchnorm1(x)
    x = self.leaky_relu(x)
    x = self.conv2(x)
    x = self.batchnorm2(x)
    x = self.leaky_relu(x)
    x = self.conv3(x)
    x = self.sigmoid(x)

    return x

In [None]:
# discriminator: 1x1 pixelGAN

class PixelGAN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=6, out_channels=64, kernel_size=1, stride=1, padding=0)
    # self.batchnorm1 = nn.BatchNorm2d(64)
    self.leaky_relu =  nn.LeakyReLU(negative_slope=0.2, inplace=True)
    self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, stride=1, padding=0, bias=nn.InstanceNorm2d)
    self.batchnorm2 = nn.BatchNorm2d(128)
    self.conv3 = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0, bias=nn.InstanceNorm2d)

    self.sigmoid = nn.Sigmoid()


  def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
            

  def forward(self, input, target):
    x = torch.cat([input, target], 1)
    x = self.conv1(x)
    # x = self.batchnorm1(x)
    x = self.leaky_relu(x)
    x = self.conv2(x)
    x = self.batchnorm2(x)
    x = self.leaky_relu(x)
    x = self.conv3(x)
    x = self.sigmoid(x)

    return x

In [None]:
# test patchGAN

def test_patchGAN():
  G = Unet().cuda()
  input = torch.randn(1, 3, 572, 572).cuda()
  G_output = G(input)
  print(G_output.shape)

  D = PatchGAN_70().cuda()
  D_output = D(G_output, torch.randn(1, 3, G_output.shape[2], G_output.shape[2]).cuda())
  print(D_output.shape)

  D_output_1 = D(torch.randn(1, 3, 256, 256).cuda(), torch.randn(1, 3, 256, 256).cuda())
  print(D_output_1.shape)

# test_patchGAN()

In [None]:
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [None]:
# original Unet

class Block_original(nn.Module):
  def __init__(self, in_, middle, out, key):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=in_, out_channels=out, kernel_size=3)
    self.conv2 = nn.Conv2d(in_channels=out, out_channels=out, kernel_size=3)
    self.batchnorm = nn.BatchNorm2d(out)
    self.activation = nn.ReLU()
    

  def forward(self, x):
    x = self.conv1(x)
    x = self.batchnorm(x)
    x = self.activation(x)
    x = self.conv2(x)
    x = self.batchnorm(x)
    x = self.activation(x)
    
    return x



class Unet_original(nn.Module):
  def __init__(self, in_):
    super().__init__()
    self.max_pooling = nn.MaxPool2d(2, 2)
    # debug: stride=2
    self.conv_transpose1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
    self.conv_transpose2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
    self.conv_transpose3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
    self.conv_transpose4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)

    self.down_block1 = Block_original(in_, 64)
    self.down_block2 = Block_original(64, 128)
    self.down_block3 = Block_original(128, 256)
    self.down_block4 = Block_original(256, 512)

    self.up_block1 = Block_original(512, 1024)
    self.up_block2 = Block_original(1024, 512)
    self.up_block3 = Block_original(512, 256)
    self.up_block4 = Block_original(256, 128)

    self.last_block = Block_original(128, 64)
    # out_channels=3 
    self.last_conv = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1)

  def forward(self, x):
    # downsampling
    # b1 - b4 are for skip connection
    b1 = self.down_block1(x)
    down_block1 = self.max_pooling(b1)
    b2 = self.down_block2(down_block1)
    down_block2 = self.max_pooling(b2)
    b3 = self.down_block3(down_block2)
    down_block3 = self.max_pooling(b3)
    b4 = self.down_block4(down_block3)
    down_block4 = self.max_pooling(b4)

    # upsampling
    u1 = self.up_block1(down_block4)
    up_block1 = self.conv_transpose1(u1)
    cropped_b4 = crop(b4, up_block1)
    skip_connect1 = torch.cat([cropped_b4, up_block1], 1)

    u2 = self.up_block2(skip_connect1)
    up_block2 = self.conv_transpose2(u2)
    cropped_b3 = crop(b3, up_block2)
    skip_connect2 = torch.cat([cropped_b3, up_block2], 1)

    u3 = self.up_block3(skip_connect2)
    up_block3 = self.conv_transpose3(u3)
    cropped_b2 = crop(b2, up_block3)
    skip_connect3 = torch.cat([cropped_b2, up_block3], 1)

    u4 = self.up_block4(skip_connect3)
    up_block4 = self.conv_transpose4(u4)
    cropped_b1 = crop(b1, up_block4)
    skip_connect4 = torch.cat([cropped_b1, up_block4], 1)

    # last convolution
    last = self.last_block(skip_connect4)
    result = self.last_conv(last)

    return result