In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data 
from torch.utils.data import Subset
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import random
import cv2
import time

In [5]:
class Completion_Network(nn.Module):
    def __init__(self):
        super(Completion_Network, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(4, 64, kernel_size=5, stride=1, padding=2),
                                 nn.BatchNorm2d(64), 
                                 nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
                                 nn.BatchNorm2d(128), 
                                 nn.ReLU())
        self.conv3 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
                                 nn.BatchNorm2d(128), 
                                 nn.ReLU())
        self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
                                 nn.BatchNorm2d(256), 
                                 nn.ReLU())  
        self.conv5 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
                                 nn.BatchNorm2d(256), 
                                 nn.ReLU())
        self.conv6 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
                                 nn.BatchNorm2d(256), 
                                 nn.ReLU())
    
        self.dilated_conv7 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=2, dilation=2),
                                        nn.BatchNorm2d(256), 
                                        nn.ReLU())
        self.dilated_conv8 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=4, dilation=4),
                                        nn.BatchNorm2d(256), 
                                        nn.ReLU())  
        self.dilated_conv9 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=8, dilation=8),
                                        nn.BatchNorm2d(256), 
                                        nn.ReLU())
        self.dilated_conv10 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=16, dilation=16),
                                        nn.BatchNorm2d(256), 
                                        nn.ReLU()) 
        self.conv11 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
                                 nn.BatchNorm2d(256), 
                                 nn.ReLU())
        self.conv12 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
                                 nn.BatchNorm2d(256), 
                                 nn.ReLU())

        self.deconv13 = nn.Sequential(nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
                                 nn.BatchNorm2d(128), 
                                 nn.ReLU())    
        self.conv14 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
                                 nn.BatchNorm2d(128), 
                                 nn.ReLU())

        self.deconv15 = nn.Sequential(nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
                                 nn.BatchNorm2d(64), 
                                 nn.ReLU())
        self.conv16 = nn.Sequential(nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
                                 nn.BatchNorm2d(32), 
                                 nn.ReLU())
        self.conv17 = nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.conv6(out)
        out = self.dilated_conv7(out)
        out = self.dilated_conv8(out)
        out = self.dilated_conv9(out)
        out = self.dilated_conv10(out)
        out = self.conv11(out)
        out = self.conv12(out)
        out = self.deconv13(out)
        out = self.conv14(out)
        out = self.deconv15(out)
        out = self.conv16(out)
        out = torch.sigmoid(self.conv17(out))
        return out

In [None]:
from torchsummary import summary 
cn = Completion_Network()
cn = cn.cuda()
summary(cn, (4,160,160), batch_size=32)

In [6]:
class Local_Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Local_Discriminator, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(input_shape[0], 64, kernel_size=5, stride=2, padding=2),
                                 nn.BatchNorm2d(64), 
                                 nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
                                 nn.BatchNorm2d(128), 
                                 nn.ReLU())
        self.conv3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
                                 nn.BatchNorm2d(256), 
                                 nn.ReLU())
        self.conv4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
                                 nn.BatchNorm2d(512), 
                                 nn.ReLU())
        self.conv5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2),
                                 nn.BatchNorm2d(512), 
                                 nn.ReLU())
        out_size = 512 * (input_shape[1]//32) * (input_shape[2]//32)
        self.linear = nn.Linear(out_size, 1024)
        self.relu = nn.ReLU()


    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = out.view(out.size(0), -1)
        out = self.relu(self.linear(out))
        return out

In [None]:
ld = Local_Discriminator((3,96,96))
ld = ld.cuda()
summary(ld, (3,96,96), batch_size=32)

In [8]:
class Global_Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Global_Discriminator, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(input_shape[0], 64, kernel_size=5, stride=2, padding=2),
                                 nn.BatchNorm2d(64), 
                                 nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
                                 nn.BatchNorm2d(128), 
                                 nn.ReLU())
        self.conv3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
                                 nn.BatchNorm2d(256), 
                                 nn.ReLU())
        self.conv4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
                                 nn.BatchNorm2d(512), 
                                 nn.ReLU())
        self.conv5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2),
                                 nn.BatchNorm2d(512), 
                                 nn.ReLU())
  
        out_size = 512 * (input_shape[1]//32) * (input_shape[2]//32)
        
        self.linear = nn.Linear(out_size, 1024)
        self.relu = nn.ReLU()


    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = out.view(out.size(0), -1)
        out = self.relu(self.linear(out))
        return out

In [None]:
gd = Global_Discriminator((3,160,160))
gd = gd.cuda()
summary(gd, (3,160,160), batch_size=32)

In [9]:
class Context_Discriminators(nn.Module):
    def __init__(self, local_input_shape = (3,96,96), global_input_shape = (3,160,160)):
        super(Context_Discriminators, self).__init__()
        self.local_discriminator = Local_Discriminator(local_input_shape)
        self.global_discriminator = Global_Discriminator(global_input_shape)
        out_size = 2048
        self.linear = nn.Linear(out_size, 1)

    def forward(self, x):
        x_local, x_global = x
        out_local = self.local_discriminator(x_local)
        out_global = self.global_discriminator(x_global)
        out = torch.cat((out_local, out_global), dim=1)
        out = self.linear(out)
        out = torch.sigmoid(out)
        return out   # probability that the input image is real.