In [1]:
from torchsummary import summary

#!/usr/bin/env python
# coding: utf-8

# In[2]:


import random
import time
import datetime
import sys
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torchvision.utils import save_image



def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


##############################
#           RESNET
##############################


class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)


class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_blocks):
        super(GeneratorResNet, self).__init__()

        channels = input_shape[0]

        # Initial convolution block
        out_features = 64
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features

        # Downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(out_features)]

        # Upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Output layer
        model += [nn.ReflectionPad2d(3), nn.Conv2d(out_features, channels, 7), nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


##############################
#        Discriminator
##############################


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        channels, height, width = input_shape

        # Calculate output shape of image discriminator (PatchGAN)
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img):
        return self.model(img)



In [3]:
input_shape = (3, 224, 224)
G_AB = GeneratorResNet(input_shape, 7).cuda()
G_BA = GeneratorResNet(input_shape, 7).cuda()
D_A = Discriminator(input_shape).cuda()
D_B = Discriminator(input_shape).cuda()

In [4]:
summary(G_AB, input_shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ReflectionPad2d-1          [-1, 3, 230, 230]               0
            Conv2d-2         [-1, 64, 224, 224]           9,472
    InstanceNorm2d-3         [-1, 64, 224, 224]               0
              ReLU-4         [-1, 64, 224, 224]               0
            Conv2d-5        [-1, 128, 112, 112]          73,856
    InstanceNorm2d-6        [-1, 128, 112, 112]               0
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8          [-1, 256, 56, 56]         295,168
    InstanceNorm2d-9          [-1, 256, 56, 56]               0
             ReLU-10          [-1, 256, 56, 56]               0
  ReflectionPad2d-11          [-1, 256, 58, 58]               0
           Conv2d-12          [-1, 256, 56, 56]         590,080
   InstanceNorm2d-13          [-1, 256, 56, 56]               0
             ReLU-14          [-1, 256,

In [5]:
summary(D_A, input_shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           3,136
         LeakyReLU-2         [-1, 64, 112, 112]               0
            Conv2d-3          [-1, 128, 56, 56]         131,200
    InstanceNorm2d-4          [-1, 128, 56, 56]               0
         LeakyReLU-5          [-1, 128, 56, 56]               0
            Conv2d-6          [-1, 256, 28, 28]         524,544
    InstanceNorm2d-7          [-1, 256, 28, 28]               0
         LeakyReLU-8          [-1, 256, 28, 28]               0
            Conv2d-9          [-1, 512, 14, 14]       2,097,664
   InstanceNorm2d-10          [-1, 512, 14, 14]               0
        LeakyReLU-11          [-1, 512, 14, 14]               0
        ZeroPad2d-12          [-1, 512, 15, 15]               0
           Conv2d-13            [-1, 1, 14, 14]           8,193
Total params: 2,764,737
Trainable param