In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

from torchvision.datasets import MNIST
import torchvision.transforms as tv_transforms
import torchvision.utils as tv_utils

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm


def apply_sn(module, use_sn):
    if use_sn:
        return spectral_norm(module)
    else:
        return module

def conv3x3(in_planes, out_planes, stride=1, dilation=1, use_sn=False):
    """3x3 convolution with padding"""
    module = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, bias=False, dilation=dilation)
    return apply_sn(module, use_sn)


def conv1x1(in_planes, out_planes, stride=1, use_sn=False):
    """1x1 convolution"""
    module = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
    return apply_sn(module, use_sn)


class UpsampleConv(nn.Module):
    """ Upsample then Convolution. Better than ConvTranspose2d
            https://distill.pub/2016/deconv-checkerboard/
            https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/transformer_net.py
    """
    
    def __init__(self, in_channels, out_channels, kernel_size, stride, scale_factor, use_sn=False):
        super(UpsampleConv, self).__init__()
        self.scale_factor = scale_factor
        self.pad  = torch.nn.ReflectionPad2d(kernel_size // 2)        
        self.conv = apply_sn(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride), use_sn)
        
    def forward(self, x):
        x = F.interpolate(x, mode='nearest', scale_factor=self.scale_factor)
        x = self.pad(x)
        x = self.conv(x)
        return x
    

def deconv3x3(in_planes, out_planes, stride=1, dilation=1, use_sn=False):
    return UpsampleConv(in_planes, out_planes, kernel_size=3, stride=1, scale_factor=2, use_sn=use_sn)


def deconv1x1(in_planes, out_planes, stride=1, dilation=1, use_sn=False):
    return UpsampleConv(in_planes, out_planes, kernel_size=1, stride=1, scale_factor=2, use_sn=use_sn)


class ResidualBlock(nn.Module):
    """ Pre-activation Residual Block
            BN, nonlinearity, conv3x3, BN, nonlinearity, conv3x3

    References:
        https://arxiv.org/abs/1512.03385
        http://torch.ch/blog/2016/02/04/resnets.html

        ResBlock
            https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
        WGAN-GP resnet architecture
            https://github.com/igul222/improved_wgan_training/blob/fa66c574a54c4916d27c55441d33753dcc78f6bc/gan_cifar_resnet.py#L159
            Generator: BN, ReLU, conv3x3, Tanh -> out
            PreactivationResblock: https://github.com/igul222/improved_wgan_training/blob/master/gan_64x64.py
        SNGAN/Projection cGAN architecture
            https://github.com/pfnet-research/sngan_projection/blob/master/dis_models/resblocks.py
    """

    def __init__(self, in_channels, out_channels, 
                 resample=None,
                 norm_layer=nn.Identity,
                 nonlinearity=nn.ReLU(inplace=True),
                 resblk_1st=False,
                 use_sn=False):
        """
            resample
                \in {None, 'up', 'dn'}
            norm_layer
                \in {nn.Identity, nn.BatchNorm2d}
            nonlinearity
                either 
                    nn.ReLU(inplace=True)
                    nn.LeakyReLU(slope=0.2)
            resblk_1st
                if True, no nonlinearity before first `conv_1`
            use_sn
                Apply spectral normalization for each linear/conv layers
        """
        super(ResidualBlock, self).__init__()
        
        if   resample == 'dn':
            residual_conv_resample = conv3x3(in_channels, out_channels, 2, use_sn=use_sn)
            shortcut_conv_resample = conv1x1(in_channels, out_channels, 2, use_sn=use_sn)
        elif resample == 'up':
            residual_conv_resample = deconv3x3(in_channels, out_channels, use_sn=use_sn)
            shortcut_conv_resample = deconv1x1(in_channels, out_channels, use_sn=use_sn)
        else:
            residual_conv_resample = conv3x3(in_channels, out_channels, 1, use_sn=use_sn)
            shortcut_conv_resample = conv1x1(in_channels, out_channels, 1, use_sn=use_sn)

        self.residual = nn.Sequential()
        self.residual.add_module('Normalization_1', norm_layer(in_channels))
        self.residual.add_module('Nonlinearity_1', nn.Identity() if resblk_1st else nonlinearity)
        self.residual.add_module('Conv_1', residual_conv_resample)
        self.residual.add_module('Normalization_2', norm_layer(out_channels))
        self.residual.add_module('Nonlinearity_2', nonlinearity)
        self.residual.add_module('Conv_2', conv3x3(out_channels, out_channels, use_sn=use_sn))
        
        if in_channels == out_channels and resample == None:
            self.shortcut = nn.Identity()
        else:
            self.shortcut = nn.Sequential()
            self.shortcut.add_module('Normalization_1', norm_layer(in_channels))
            self.shortcut.add_module('Conv_1', shortcut_conv_resample)
            
        
    def forward(self, x):
        return self.residual(x) + self.shortcut(x)
    
    
class Generator(nn.Module):
    
    def __init__(self, conv_channels = None, conv_upsample = None, dim_z = 128, im_channnels = 3):
        """
            conv_channels
                [1024, 1024, 512, 256, 128, 64]
                     c1    c2   c3   c4   c5
            conv_upsample
                4x4 -> 128x128    [True, True, True, True, True]
                4x4 -> 64x64      [True, True, True, True]
                4x4 -> 32x32      [True, True, True]
            im_channnels
                3 for color image
                1 for grayscale image
        """
        super(Generator, self).__init__()
        
        n_convs = len(conv_channels) - 1
        assert(n_convs > 0)
        assert(n_convs == len(conv_upsample))
        
        self.bottom_width = 4
        nonlinearity = nn.ReLU(inplace=True)
        
        self.Linear = nn.Linear(dim_z, (self.bottom_width**2) * conv_channels[0])
        
        self.ResidualBlocks = nn.Sequential()
        for i in range(n_convs):
            upsample = conv_upsample[i]
            self.ResidualBlocks.add_module(f'ResBlock{"Up" if upsample else ""}_{i}',
                  ResidualBlock(conv_channels[i], conv_channels[i+1],
                                resample = "up" if upsample else None,
                                norm_layer = nn.BatchNorm2d,
                                nonlinearity = nonlinearity))
        
        self.NormalizationFinal = nn.BatchNorm2d(conv_channels[-1])
        self.NonlinearityFinal = nonlinearity
        self.ConvFinal = conv3x3(conv_channels[-1], im_channnels)
        self.Tanh = nn.Tanh()

    def forward(self, x):
        # 128
        x = self.Linear(x)
        x = x.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
        # 1024x4x4
        x = self.ResidualBlocks(x)
        # 64x128x128
        x = self.NormalizationFinal(x)
        x = self.NonlinearityFinal(x)
        x = self.Tanh(self.ConvFinal(x))
        # 3x128x128
        return x

        
class Discriminator(nn.Module):
    
    def __init__(self, conv_channels = None, conv_dnsample = None, use_sn=True):
        """
            conv_channels
                [3, 64, 128, 256, 512, 1024, 1024]
                  c1  c2   c3   c4   c5    c6
            conv_dnsample
                128x128 -> 4x4    [True, True, True, True, True, False]
                64x64 -> 4x4      [True, True, True, True]
                32x32 -> 4x4      [True, True, True]
                
        Projection cGAN 
            conv_channels = [3, 64, 128, 256, 512, 1024, 1024]
            conv_dnsample = [True, True, True, True, True, False]
        """
        super(Discriminator, self).__init__()
        
        n_convs = len(conv_channels) - 1
        assert(n_convs > 0)
        assert(n_convs == len(conv_dnsample))
        
        nonlinearity = nn.LeakyReLU(0.2, inplace=True)
        
        self.ResidualBlocks = nn.Sequential()
        for i in range(n_convs):
            downsample = conv_dnsample[i]
            self.ResidualBlocks.add_module(f'ResBlock{"Dn" if downsample else ""}_{i}',
                  ResidualBlock(conv_channels[i], conv_channels[i+1],
                                resample = "dn" if downsample else None,
                                norm_layer = nn.Identity,
                                nonlinearity = nonlinearity,
                                resblk_1st = True if i == 0 else False,
                                use_sn = use_sn))
        
        self.NonlinearityFinal = nonlinearity
        self.LinearFinal = apply_sn(nn.Linear(conv_channels[-1], 1), use_sn)

    def forward(self, x):
        # 3x128x128
        x = self.ResidualBlocks(x)
        # 1024x4x4
        x = self.NonlinearityFinal(x)
        x = torch.sum(x, dim=(2,3))   # (global sum pooling)
        # 1024
        x = self.LinearFinal(x)
        # 1
        return x

In [17]:
##############################
## resblock
##############################

blk = ResidualBlock(16, 32, resblk_1st=True)

# print(blk)

x = torch.empty((32, 16, 28, 28))
out = blk(x)

print(x.shape, out.shape)

##############################
## G
##############################

dim_z = 50
conv_channels = [256, 256, 128, 64]
conv_upsample = [True, True, True]
G = Generator(conv_channels, conv_upsample, dim_z, im_channnels = 1)

# print(G)

x = torch.empty((32, dim_z))
out = G(x)

print(x.shape, out.shape)

##############################
## D
##############################

conv_channels = [3, 64, 128, 256]
conv_dnsample = [True, True, True]
D = Discriminator(conv_channels, conv_dnsample)

# print(D)

x = torch.empty((50, 3, 32, 32))
out = D(x)

print(x.shape, out.shape)

torch.Size([32, 16, 28, 28]) torch.Size([32, 32, 28, 28])
torch.Size([32, 50]) torch.Size([32, 1, 32, 32])
torch.Size([50, 3, 32, 32]) torch.Size([50, 1])


In [None]:
import os

In [18]:
model_name = 'ResGAN'
seed = 0
gpu_id = '2'
image_size = 32
dim_z = 50
batch_size = 32
lr = 0.0002
beta1 = 0.5
n_epochs = 20
figure_root = 'figures/'
log_interval = 20
use_sn = True

In [None]:
os.makedirs(figure_root, exist_ok=True)

torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader = torch.utils.data.DataLoader(
    MNIST(
        root='./data', download=True, train=True, transform=tv_transforms.Compose([
            tv_transforms.Resize(image_size),
            tv_transforms.ToTensor(),
            tv_transforms.Normalize((0.5,), (0.5,)),
        ])),
    batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

In [None]:
conv_channels = [256, 256, 128, 64]
conv_upsample = [True, True, True]
G = Generator(conv_channels, conv_upsample, dim_z, im_channnels = 1).to(device)

conv_channels = [1, 64, 128, 256]
conv_dnsample = [True, True, True]
D = Discriminator(conv_channels, conv_dnsample, use_sn=use_sn).to(device)

In [None]:
criterion = nn.BCEWithLogitsLoss()

fixed_z = torch.randn(64, dim_z, device=device)

# label flipping helps with training G!
real_label = 0
fake_label = 1

optimizerD = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:

for epoch in range(n_epochs):

    for it, (x_real, _) in enumerate(train_loader):

        # batch_size for last batch might be different ...
        batch_size = x_real.size(0)
        real_labels = torch.full((batch_size, 1), real_label, device=device)
        fake_labels = torch.full((batch_size, 1), fake_label, device=device)

        ##############################################################
        # Update Discriminator: Maximize E[log(D(x))] + E[log(1 - D(G(z)))]
        ##############################################################

        D.zero_grad()

        # a minibatch of samples from data distribution
        x_real = x_real.to(device)

        y = D(x_real)
        loss_D_real = criterion(y, real_labels)
        loss_D_real.backward()

        D_x = y.mean().item()

        # a minibatch of samples from the model distribution
        z = torch.randn(batch_size, dim_z, device=device)

        x_fake = G(z)
        # https://github.com/pytorch/examples/issues/116
        # If we do not detach, then, although x_fake is not needed for gradient update of D,
        #   as a consequence of backward pass which clears all the variables in the graph
        #   graph for G will not be available for gradient update of G
        # Also for performance considerations, detaching x_fake will prevent computing 
        #   gradients for parameters in G
        y = D(x_fake.detach())
        loss_D_fake = criterion(y, fake_labels)
        loss_D_fake.backward()

        loss_D = loss_D_real + loss_D_fake

        optimizerD.step()

        ##############################################################
        # Update Generator: Minimize E[log(1 - D(G(z)))] => Maximize E[log(D(G(z))))]
        ##############################################################

        G.zero_grad()

        y = D(x_fake)
        loss_G = criterion(y, real_labels)
        loss_G.backward()

        optimizerG.step()

        ##############################################################
        # write/print
        ##############################################################
        
        loss_D = loss_D.item()
        loss_G = loss_G.item()
        loss_total = loss_D + loss_G
        
        global_step = epoch*len(train_loader)+it
        
        if it % log_interval == log_interval-1:
            print(f'[{epoch+1}/{n_epochs}][{it+1}/{len(train_loader)}]'
                f'loss: {loss_total:.4}\t'
                f'loss_D: {loss_D:.4}\t'
                f'loss_G: {loss_G:.4}')
            x_fake = G(fixed_z)
            tv_utils.save_image(x_fake.detach(),
                os.path.join(figure_root,
                    f'{model_name}_fake_samples_epoch={epoch}_it={it}.png'))

In [None]:
device