In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [0]:
class Down(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, dropout=0):
        super(Down, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=kernel_size, stride=stride, 
            padding=padding)
        self.norm = nn.BatchNorm2d(out_channels)
        self.activation = nn.LeakyReLU(0.2)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.activation(x)
        return x

In [0]:
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, dropout=0):
        super(Up, self).__init__()
        self.conv = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size=kernel_size, stride=stride, 
            padding=padding)
        self.norm = nn.BatchNorm2d(out_channels)
        self.activation = nn.Relu()
    
    def forward(self, x1, x2):
        x1 = self.conv(x1)
        x1 = self.norm(x1)
        x1 = self.activation(x1)

        x1 = torch.cat([x1, x2], 1)
        return x1

In [0]:
class Final(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dropout=0, activation_fn=nn.Tanh):
        super(Final, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=kernel_size, stride=stride, 
            padding=padding)
        self.norm = nn.BatchNorm2d(out_channels)
        self.activation = activation_fn()
    
    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.activation(x)
        return x

In [0]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.down1 = Down(1, 64, kernel_size=4, stride=1, padding=0, dropout=0)
        self.down2 = Down(64, 128, kernel_size=4, stride=2, padding=1, dropout=0)
        self.down3 = Down(128, 256, kernel_size=4, stride=2, padding=1, dropout=0)
        self.down4 = Down(256, 512, kernel_size=4, stride=2, padding=1, dropout=0)
        self.down5 = Down(512, 512, kernel_size=4, stride=2, padding=1, dropout=0)

        self.up1 = Up(512, 512, kernel_size=4, stride=2, padding=1, dropout=0.5)
        self.up2 = Up(1024, 256, kernel_size=4, stride=2, padding=1, dropout=0.5)
        self.up3 = Up(512, 128, kernel_size=4, stride=2, padding=1, dropout=0)
        self.up4 = Up(256, 64, kernel_size=4, stride=2, padding=1, dropout=0)

        self.final = Final(128, 2)
        

    def forward(self, x):
        x = F.interpolate(x, size=(35, 35), mode='bilinear', align_corners=True)
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)

        u1 = self.up1(d5, d4)
        u2 = self.up2(u1, d3)
        u3 = self.up3(u2, d2)
        u4 = self.up4(u3, d1)

        x = self.final(u4)
        return x

In [0]:
class Discriminator(nn.Module):
    def __init__():
        super(Discriminator, self).__init__()
        self.down1 = Down(3, 64, kernel_size=4, stride=1, padding=0, dropout=0)
        self.down2 = Down(64, 128, kernel_size=4, stride=2, padding=1, dropout=0)
        self.down3 = Down(128, 256, kernel_size=4, stride=2, padding=1, dropout=0)
        self.down4 = Down(256, 512, kernel_size=4, stride=2, padding=1, dropout=0)
        self.final = Final(
            512, 1, kernel_size=4, stride=1, padding=0, dropout=0, 
            activation_fn=nn.Sigmoid()
        )
    
    def forward(self, x):
        x = F.interpolate(x, size=(35, 35), mode='bilinear', align_corners=True)
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)

        x = self.final(d4)
        x = x.view(x.size()[0], -1)
        return x