In [2]:
import torch.nn as nn
from torch.autograd import Variable
import torch
import torch.nn.functional as F

In [7]:
class Generator(nn.Module):

    def __init__(self, kernel_size=3):
        super(Generator, self).__init__()
        #input size (6,256,256)
        self.conv1 = nn.Sequential(
            nn.Conv2d(6,32,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(32,32,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.AvgPool2d(3,stride=2,padding=1))
        #(3,128,128)
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32,64,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64,64,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.AvgPool2d(3,stride=2,padding=1))
        #(64,64,64)
        self.conv3 = nn.Sequential(
            nn.Conv2d(64,128,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(128,128,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.AvgPool2d(3,stride=2,padding=1))
        #(128,32,32)
        self.conv4 = nn.Sequential(
            nn.Conv2d(128,256,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(256,256,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.AvgPool2d(3,stride=2,padding=1))
        #(256,16,16)
        self.deconv4 = nn.Sequential(
            nn.ConvTranspose2d(256,128,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Conv2d(128,128,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear'))
        #(128,32,32)
        self.deconv3 = nn.Sequential(
            nn.ConvTranspose2d(256,64,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Conv2d(64,64,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear'))
        #(64,64,64)
        self.deconv2 = nn.Sequential(
            nn.ConvTranspose2d(128,32,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Conv2d(32,32,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear'))
        #(3,256,256)
        self.deconv1 = nn.Sequential(
            nn.ConvTranspose2d(64,3,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Conv2d(3,3,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear'))
        
    def forward(self,inputs):
        e1 = self.conv1(inputs)
        e2 = self.conv2(e1)
        e3 = self.conv3(e2)
        e4 = self.conv4(e3)
        d4 = self.deconv4(e4)
        d3_in = torch.cat((e3,d4),1)
        d3 = self.deconv3(d3_in)
        d2_in = torch.cat((e2,d3),1)
        d2 = self.deconv2(d2_in)
        d1_in = torch.cat((e1,d2),1)
        d1 = self.deconv1(d1_in)
        return d1

In [9]:
#test Generator
inputs = torch.rand((10,6,256,256))
model = Generator()
outputs = model.forward(inputs)
print(outputs.size())

torch.Size([10, 3, 256, 256])


  "See the documentation of nn.Upsample for details.".format(mode))


In [None]:
class Discriminator(nn.Module):

    def __init__(self, kernel_size=5, dim=64):
        super(Discriminator, self).__init__()
        self.kernel_size = kernel_size
        self.dim = dim
        
    def forward(self, x):
        batch_size, channel, height, width = x.shape
        output = nn.Conv2d(channel, self.dim, self.kernel_size, stride=1, padding=2)
        output = nn.LeakyReLU(output)
        
      