
1. GAN with resblock backbone 
2. cGAN with projection


## todo

+ does projection cGAN work with continuous conditioned variable


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

from torchvision.datasets import MNIST
import torchvision.transforms as tv_transforms
import torchvision.utils as tv_utils

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm


class ConditionalBatchNorm2d(nn.Module):
    """ https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775
    """
    
    def __init__(self, num_features, num_classes):
        super(ConditionalBatchNorm2d, self).__init__()
        self.num_features = num_features
        self.bn = nn.BatchNorm2d(num_features, affine=False)
        self.embed = nn.Embedding(num_classes, num_features * 2)
        self.embed.weight.data[:, :num_features].normal_(1, 0.02)  # Initialise scale at N(1, 0.02)
        self.embed.weight.data[:, num_features:].zero_()  # Initialise bias at 0

    def forward(self, x, c):
        out = self.bn(x)
        gamma, beta = self.embed(c.view(-1)).chunk(2, 1)
        out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
        return out


def apply_sn(module, use_sn):
    if use_sn:
        return spectral_norm(module)
    else:
        return module

def conv3x3(in_planes, out_planes, stride=1, dilation=1, use_sn=False):
    """3x3 convolution with padding"""
    module = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, bias=False, dilation=dilation)
    return apply_sn(module, use_sn)


def conv1x1(in_planes, out_planes, stride=1, use_sn=False):
    """1x1 convolution"""
    module = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
    return apply_sn(module, use_sn)


class UpsampleConv(nn.Module):
    """ Upsample then Convolution. Better than ConvTranspose2d
            https://distill.pub/2016/deconv-checkerboard/
            https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/transformer_net.py
    """
    
    def __init__(self, in_channels, out_channels, kernel_size, stride, scale_factor, use_sn=False):
        super(UpsampleConv, self).__init__()
        self.scale_factor = scale_factor
        self.pad  = torch.nn.ReflectionPad2d(kernel_size // 2)        
        self.conv = apply_sn(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride), use_sn)
        
    def forward(self, x):
        x = F.interpolate(x, mode='nearest', scale_factor=self.scale_factor)
        x = self.pad(x)
        x = self.conv(x)
        return x
    

def deconv3x3(in_planes, out_planes, stride=1, dilation=1, use_sn=False):
    return UpsampleConv(in_planes, out_planes, kernel_size=3, stride=1, scale_factor=2, use_sn=use_sn)


def deconv1x1(in_planes, out_planes, stride=1, dilation=1, use_sn=False):
    return UpsampleConv(in_planes, out_planes, kernel_size=1, stride=1, scale_factor=2, use_sn=use_sn)


class ResidualBlock(nn.Module):
    """ Pre-activation Residual Block
            BN, nonlinearity, conv3x3, BN, nonlinearity, conv3x3

    References:
        https://arxiv.org/abs/1512.03385
        http://torch.ch/blog/2016/02/04/resnets.html

        ResBlock
            https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
        WGAN-GP resnet architecture
            https://github.com/igul222/improved_wgan_training/blob/fa66c574a54c4916d27c55441d33753dcc78f6bc/gan_cifar_resnet.py#L159
            Generator: BN, ReLU, conv3x3, Tanh -> out
            PreactivationResblock: https://github.com/igul222/improved_wgan_training/blob/master/gan_64x64.py
        SNGAN/Projection cGAN architecture
            https://github.com/pfnet-research/sngan_projection/blob/master/dis_models/resblocks.py
    """

    def __init__(self, in_channels, out_channels, 
                 resample=None,
                 norm_layer=nn.Identity,
                 nonlinearity=nn.ReLU(),
                 resblk_1st=False,
                 use_sn=False):
        """
            resample
                \in {None, 'up', 'dn'}
            norm_layer
                \in {nn.Identity, nn.BatchNorm2d}
            nonlinearity
                either 
                    nn.ReLU(inplace=True)
                    nn.LeakyReLU(slope=0.2)
            resblk_1st
                if True, no nonlinearity before first `conv_1`
            use_sn
                Apply spectral normalization for each linear/conv layers
        """
        super(ResidualBlock, self).__init__()
        
        self.identity_shortcut = (in_channels == out_channels and resample == None)
        
        if   resample == 'dn':
            residual_conv_resample = conv3x3(in_channels, out_channels, 2, use_sn=use_sn)
            shortcut_conv_resample = conv1x1(in_channels, out_channels, 2, use_sn=use_sn)
        elif resample == 'up':
            residual_conv_resample = deconv3x3(in_channels, out_channels, use_sn=use_sn)
            shortcut_conv_resample = deconv1x1(in_channels, out_channels, use_sn=use_sn)
        else:
            residual_conv_resample = conv3x3(in_channels, out_channels, 1, use_sn=use_sn)
            shortcut_conv_resample = conv1x1(in_channels, out_channels, 1, use_sn=use_sn)

        self.residual_normalization_1 = norm_layer(in_channels)
        self.residual_nonlinearity_1 = nn.Identity() if resblk_1st else nonlinearity
        self.residual_conv_1 = residual_conv_resample
        self.residual_normalization_2 = norm_layer(out_channels)
        self.residual_nonlinearity_2 = nonlinearity
        self.residual_conv_2 = conv3x3(out_channels, out_channels, use_sn=use_sn)
        
        if not self.identity_shortcut:
            self.shortcut_normalization_1 = norm_layer(in_channels)
            self.shortcut_conv_1 = shortcut_conv_resample
            
        
    def forward(self, x):

        identity = x
        
        if self.identity_shortcut:
            s = identity
        else:
            s = self.shortcut_normalization_1(identity)
            s = self.shortcut_conv_1(s)
            
        x = self.residual_normalization_1(x)
        x = self.residual_nonlinearity_1(x)
        x = self.residual_conv_1(x)
        x = self.residual_normalization_2(x)
        x = self.residual_nonlinearity_2(x)
        x = self.residual_conv_2(x)
        
        return x + s
    
    
    
class ConditionalResidualBlock(ResidualBlock):
    """ Residual block w/ categorical conditional BatchNorm2d
    """
    
    def __init__(self, in_channels, out_channels,
                 resample=None,
                 norm_layer=nn.BatchNorm2d,
                 nonlinearity=nn.ReLU(inplace=True),
                 resblk_1st=False,
                 use_sn=False):
        """
            norm_layer
                initialize w/ num_features
        """
        
        super(ConditionalResidualBlock, self).__init__(
            in_channels, out_channels,
            resample = resample,
            norm_layer = norm_layer,
            nonlinearity = nonlinearity,
            resblk_1st = resblk_1st,
            use_sn = use_sn)
        
    def forward(self, x, c):
        
        identity = x
        
        if self.identity_shortcut:
            s = identity
        else:
            s = self.shortcut_normalization_1(identity, c)
            s = self.shortcut_conv_1(s)
            
        x = self.residual_normalization_1(x, c)
        x = self.residual_nonlinearity_1(x)
        x = self.residual_conv_1(x)
        x = self.residual_normalization_2(x, c)
        x = self.residual_nonlinearity_2(x)
        x = self.residual_conv_2(x)
        
        return x + s
        
        
class Generator(nn.Module):
    
    def __init__(self, conv_channels, conv_upsample,
                 resblk_cls = ResidualBlock,
                 norm_layer = nn.BatchNorm2d,
                 dim_z = 128,
                 im_channels = 3):
        """
            conv_channels
                [1024, 1024, 512, 256, 128, 64]
                     c1    c2   c3   c4   c5
            conv_upsample
                4x4 -> 128x128    [True, True, True, True, True]
                4x4 -> 64x64      [True, True, True, True]
                4x4 -> 32x32      [True, True, True]
            im_channnels
                3 for color image
                1 for grayscale image
            num_classes
                if not None, use conditional batchnorm
        """
        super(Generator, self).__init__()
        
        n_convs = len(conv_channels) - 1
        assert(n_convs > 0)
        assert(n_convs == len(conv_upsample))
        
        self.n_convs = n_convs
        self.bottom_width = 4
        self.nonlinearity = nn.ReLU()
        
        self.linear = nn.Linear(dim_z, (self.bottom_width**2) * conv_channels[0])
        
        for i in range(n_convs):
            upsample = conv_upsample[i]
            self.add_module(
                f'residual_block_{i}',
                resblk_cls(conv_channels[i], conv_channels[i+1],
                           resample = "up" if upsample else None,
                           norm_layer = norm_layer,
                           nonlinearity = self.nonlinearity))
        
        self.normalization_final = norm_layer(conv_channels[-1])
        self.conv_final = conv3x3(conv_channels[-1], im_channels)
        self.nonlinearity_final = nn.Tanh()

    def forward(self, x):
        # bottom_width = 4
        # conv_channels = [1024, 1024, 512, 256, 128, 64]
        # conv_upsample = [True, True, True, True, True]
        # im_channnels = 3
        #
        # 128
        x = self.linear(x)
        x = x.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
        # 1024x4x4
        for i in range(self.n_convs):
            x = getattr(self, f'residual_block_{i}')(x)
        # 64x128x128
        x = self.normalization_final(x)
        x = self.nonlinearity(x)
        x = self.conv_final(x)
        x = self.nonlinearity_final(x)
        # 3x128x128
        return x
    
    
class ConditionalGenerator(Generator):
    
    def __init__(self, conv_channels, conv_upsample, num_classes,
                 dim_z = 128,
                 im_channels = 3):
        """
            norm_layer = lambda num_features: ConditionalBatchNorm2d(num_features, num_classes)
        """ 
        resblk_cls = ConditionalResidualBlock
        norm_layer = lambda num_features: ConditionalBatchNorm2d(num_features, num_classes)
        
        super(ConditionalGenerator, self).__init__(conv_channels, conv_upsample,
            resblk_cls = resblk_cls,
            norm_layer = norm_layer,
            dim_z = dim_z,
            im_channels = im_channels)

    def forward(self, x, c):
        """ x    batch_size x im_channels x h x w
            c    batch_size
        """
        c = c.view(-1)
        #
        # 128
        x = self.linear(x)
        x = x.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
        # 1024x4x4
        for i in range(self.n_convs):
            x = getattr(self, f'residual_block_{i}')(x, c)
        # 64x128x128
        x = self.normalization_final(x, c)
        x = self.nonlinearity(x)
        x = self.conv_final(x)
        x = self.nonlinearity_final(x)
        # 3x128x128
        return x

    

class Discriminator(nn.Module):
    
    def __init__(self, conv_channels, conv_dnsample, use_sn=True):
        """
            conv_channels
                [3, 64, 128, 256, 512, 1024, 1024]
                  c1  c2   c3   c4   c5    c6
            conv_dnsample
                128x128 -> 4x4    [True, True, True, True, True, False]
                64x64 -> 4x4      [True, True, True, True]
                32x32 -> 4x4      [True, True, True]
                
        Projection cGAN 
            conv_channels = [3, 64, 128, 256, 512, 1024, 1024]
            conv_dnsample = [True, True, True, True, True, False]
        """
        super(Discriminator, self).__init__()
        
        n_convs = len(conv_channels) - 1
        assert(n_convs > 0)
        assert(n_convs == len(conv_dnsample))
        
        self.nonlinearity = nn.LeakyReLU(0.2)

        self.residual_blocks = nn.Sequential()
        for i in range(n_convs):
            downsample = conv_dnsample[i]
            self.residual_blocks.add_module(f'residual_block_{i}',
                  ResidualBlock(conv_channels[i], conv_channels[i+1],
                                resample = "dn" if downsample else None,
                                norm_layer = nn.Identity,
                                nonlinearity = self.nonlinearity,
                                resblk_1st = True if i == 0 else False,
                                use_sn = use_sn))

        self.linear = apply_sn(nn.Linear(conv_channels[-1], 1), use_sn)

    def forward(self, x):
        # conv_channels = [3, 64, 128, 256, 512, 1024, 1024]
        # conv_dnsample = [True, True, True, True, True]
        #
        # 3x128x128
        x = self.residual_blocks(x)
        # 1024x4x4
        x = self.nonlinearity(x)
        x = torch.sum(x, dim=(2,3))   # (global sum pooling)
        # 1024
        x = self.linear(x)
        # 1
        return x
    
    
class ConditionalDiscriminator(Discriminator):
    
    def __init__(self, conv_channels, conv_dnsample, num_classes, use_sn=True):
        super(ConditionalDiscriminator, self).__init__(conv_channels, conv_dnsample, use_sn=use_sn)
        
        self.c_embed = apply_sn(nn.Embedding(num_classes, conv_channels[-1]), use_sn)
        
    def forward(self, x, c):
        """ x    batch_size x im_channels x h x w
            c    batch_size
        """
        c = c.view(-1)
        # conv_channels = [3, 64, 128, 256, 512, 1024, 1024]
        # conv_dnsample = [True, True, True, True, True]
        #
        # 3x128x128
        x = self.residual_blocks(x)
        # 1024x4x4
        x = self.nonlinearity(x)
        x = torch.sum(x, dim=(2,3))   # (global sum pooling)
        # 1024
        
        # sigmoid^-1(p(real/fake|x,c)) =
        #     log(p_data(x)/p_model(x)) + 
        #     log(p_data(c|x)/p_model(c|x))
        x = self.linear(x) + \
            torch.sum(self.c_embed(c) * x, dim=1, keepdim=True)
        # 1
        return x
        

In [None]:
##############################
## resblock
##############################

blk = ResidualBlock(16, 32, resblk_1st=True)

# print(blk)

x = torch.empty((32, 16, 28, 28))
out = blk(x)

print(x.shape, out.shape)

##############################
## G
##############################

dim_z = 50
conv_channels = [256, 256, 128, 64]
conv_upsample = [True, True, True]
G = Generator(conv_channels, conv_upsample, dim_z = dim_z, im_channels = 1)

# print(G)

x = torch.empty((32, dim_z))
out = G(x)

print(x.shape, out.shape)

##############################
## D
##############################

conv_channels = [3, 64, 128, 256]
conv_dnsample = [True, True, True]
D = Discriminator(conv_channels, conv_dnsample)

# print(D)

x = torch.empty((50, 3, 32, 32))
out = D(x)

print(x.shape, out.shape)


##############################
## conditional D
##############################

num_classes = 10

conv_channels = [3, 64, 128, 256]
conv_dnsample = [True, True, True]
D = ConditionalDiscriminator(conv_channels, conv_dnsample, num_classes)


x = torch.empty((50, 3, 32, 32))
c = torch.empty((50, 1), dtype=torch.long).random_(0, 10)
out = D(x, c)

print(x.shape, out.shape)



##############################
## conditional G
##############################

num_classes = 10

dim_z = 50
conv_channels = [256, 256, 128, 64]
conv_upsample = [True, True, True]
G = ConditionalGenerator(conv_channels, conv_dnsample, num_classes, dim_z=dim_z)


x = torch.empty((50, dim_z))
c = torch.empty((50,), dtype=torch.long).random_(0, 10)
out = G(x, c)

print(x.shape, out.shape)


In [None]:
import os

In [None]:
model_name = 'resgan_jpt'
seed = 0
gpu_id = '2'
image_size = 32
dim_z = 50
batch_size = 32
lr = 0.0002
beta1 = 0.5
n_epochs = 20
figure_root = 'figures/resgan_jpt'
log_interval = 100
use_sn = False
n_workers = 8
batch_size = 32

conditional_G = False
conditional_D = True
num_classes = 10
include_c_in_z= True

In [None]:
os.makedirs(figure_root, exist_ok=True)

torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader = torch.utils.data.DataLoader(
    MNIST(
        root='./data', download=True, train=True, transform=tv_transforms.Compose([
            tv_transforms.Resize(image_size),
            tv_transforms.ToTensor(),
            tv_transforms.Normalize((0.5,), (0.5,)),
        ])),
    batch_size=batch_size, shuffle=True, num_workers=n_workers, pin_memory=True)

In [None]:
conv_channels = [256, 256, 128, 64]
conv_upsample = [True, True, True]

conv_channels = [1, 64, 128, 256]
conv_dnsample = [True, True, True]

if conditional_G:
    G = ConditionalGenerator(conv_channels, conv_upsample, num_classes=10, dim_z=dim_z, im_channels=1).to(device)
else:
    G = Generator(conv_channels, conv_upsample, dim_z=dim_z, im_channels = 1).to(device)
    
if conditional_D:
    D = ConditionalDiscriminator(conv_channels, conv_dnsample, num_classes, use_sn=use_sn).to(device)
else:
    D = Discriminator(conv_channels, conv_dnsample, use_sn=use_sn).to(device)

In [None]:
criterion = nn.BCEWithLogitsLoss()

fixed_z = torch.randn(100, dim_z, device=device)
fixed_c = torch.arange(10).repeat(10).to(device)

if include_c_in_z:
    fixed_z[:, -10:] = F.one_hot(fixed_c, 10)

# label flipping helps with training G!
real_label = 0
fake_label = 1

optimizerD = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
torch.autograd.set_detect_anomaly(True)

for epoch in range(n_epochs):

    for it, (x_real, c_real) in enumerate(train_loader):

        # batch_size for last batch might be different ...
        batch_size = x_real.size(0)
        real_labels = torch.full((batch_size, 1), real_label, device=device)
        fake_labels = torch.full((batch_size, 1), fake_label, device=device)

        ##############################################################
        # Update Discriminator: Maximize E[log(D(x))] + E[log(1 - D(G(z)))]
        ##############################################################

        D.zero_grad()

        # a minibatch of samples from data distribution
        x_real, c_real = x_real.to(device), c_real.to(device)

        y = D(x_real, c_real) if conditional_D else D(x_real)
        loss_D_real = criterion(y, real_labels)
        loss_D_real.backward()

        # a minibatch of samples from the model distribution
        z = torch.randn(batch_size, dim_z, device=device)
        c_fake = torch.empty(batch_size, dtype=torch.long).random_(0, num_classes).to(device)

        if include_c_in_z:
            z[:, -num_classes:] = F.one_hot(c_fake, num_classes)

        x_fake = G(z, c_fake) if conditional_G else G(z)
        # https://github.com/pytorch/examples/issues/116
        # If we do not detach, then, although x_fake is not needed for gradient update of D,
        #   as a consequence of backward pass which clears all the variables in the graph
        #   graph for G will not be available for gradient update of G
        # Also for performance considerations, detaching x_fake will prevent computing 
        #   gradients for parameters in G
        y = D(x_fake.detach(), c_fake) if conditional_D else D(x_fake.detach())
        loss_D_fake = criterion(y, fake_labels)
        loss_D_fake.backward()

        loss_D = loss_D_real + loss_D_fake

        optimizerD.step()

        ##############################################################
        # Update Generator: Minimize E[log(1 - D(G(z)))] => Maximize E[log(D(G(z))))]
        ##############################################################

        G.zero_grad()

        y = D(x_fake, c_fake) if conditional_D else D(x_fake)
        loss_G = criterion(y, real_labels)
        loss_G.backward()

        optimizerG.step()

        ##############################################################
        # write/print
        ##############################################################
        
        loss_D = loss_D.item()
        loss_G = loss_G.item()
        loss_total = loss_D + loss_G
        
        global_step = epoch*len(train_loader)+it
        
        if it % log_interval == log_interval-1:
            print(f'[{epoch+1}/{n_epochs}][{it+1}/{len(train_loader)}]'
                f'loss: {loss_total:.4}\t'
                f'loss_D: {loss_D:.4}\t'
                f'loss_G: {loss_G:.4}')
            x_fake = G(fixed_z, fixed_c) if conditional_G else G(fixed_z)
            tv_utils.save_image(x_fake.detach(),
                os.path.join(figure_root,
                    f'{model_name}_epoch={epoch}_it={it}.png'), nrow=10)