<a href="https://colab.research.google.com/github/wileyw/DeepLearningDemos/blob/master/SinGAN/SinGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SinGAN

[Official SinGAN Repository](https://github.com/tamarott/SinGAN)

In this notebook, we will implement and create a SinGAN homework assignment for other's to learn how to implement SinGAN as well.

In [0]:
%cd /content/
!git clone https://github.com/mswang12/SinGAN.git
%cd /content/SinGAN/
!git checkout experimental
# Explore Input images here:
%cd /content/SinGAN/Input/Images/
!ls

In [0]:
import cv2
import glob
from google.colab.patches import cv2_imshow

print('original image')
original_img_path = '/content/SinGAN/Input/Images/trees3.png'
img = cv2.imread(original_img_path)
cv2_imshow(img)

# Let's train SinGAN here

## Notes
1. mode: "rand" vs "rec" - rand generates noise on the fly, it uses Z_opt only for size. rec uses recorded Z_opt without changed it
2. z_opt is unique noise, it's for monitoring the training results with fixed noise
3. Things to figure out, sample around the images for cropping
4. draw_concat() creates a new image with the inputs (noise + previous image, previous image)

In [0]:
# Let's pull out the important functions we want to reimplement
def train_single_scale2(netD,netG,reals,Gs,Zs,in_s,NoiseAmp,opt,centers=None):
    print('placeholder')
    real = reals[len(Gs)]
    opt.nzx = real.shape[2]#+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]#+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size-1)*(opt.num_layer-1))*opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2]+(opt.ker_size-1)*(opt.num_layer)
        opt.nzy = real.shape[3]+(opt.ker_size-1)*(opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy],device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,milestones=[1600],gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,milestones=[1600],gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []


    # NOTE: Train for only 100 epochs to speed things up
    #for epoch in range(int(opt.niter / 2)):
    for epoch in range(100):
        if (Gs == []) & (opt.mode != 'SR_train'):
            z_opt = functions.generate_noise([1,opt.nzx,opt.nzy], device=opt.device)
            z_opt = m_noise(z_opt.expand(1,3,opt.nzx,opt.nzy))
            noise_ = functions.generate_noise([1,opt.nzx,opt.nzy], device=opt.device)
            noise_ = m_noise(noise_.expand(1,3,opt.nzx,opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy], device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(int(opt.Dsteps)):
            # train with real
            netD.zero_grad()

            output = netD(real).to(opt.device)
            #D_real_map = output.detach()
            errD_real = -output.mean()#-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j==0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1,opt.nc_z,opt.nzx,opt.nzy], 0, device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1,opt.nc_z,opt.nzx,opt.nzy], 0, device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rand',m_noise,m_image,opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rec',m_noise,m_image,opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init*RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rand',m_noise,m_image,opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev,centers)
                plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp*noise_+prev

            fake = netG(noise.detach(),prev)
            output = netD(fake.detach())
            # NOTE: netD outputs a tensor. The Discriminator is fully convolution and does not depend on the size of the image.
            # An image is real or fake depending on the mean of the output tensor.
            # Maybe we can talk about this in our Blog post?
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha!=0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)
                Z_opt = opt.noise_amp*z_opt+z_prev
                rec_loss = alpha*loss(netG(Z_opt.detach(),z_prev),real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach()+rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter-1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter-1):
            plt.imsave('%s/fake_sample.png' %  (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1)
            plt.imsave('%s/G(z_opt).png'    % (opt.outf),  functions.convert_image_np(netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1)
            #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)


            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG,netD,z_opt,opt)
    return z_opt,in_s,netG    

In [0]:
# Define the Networks here
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class ConvBlock(nn.Sequential):
  def __init__(self, in_channel, out_channel, ker_size, padd, stride):
    super(ConvBlock, self).__init__()
    # NOTE: Is there a reason why BatchNorm2d is before and not after LeakReLU?
    self.add_module('conv', nn.Conv2d(in_channel, out_channel, kernel_size=ker_size, stride=stride, padding=padd)),
    self.add_module('norm', nn.BatchNorm2d(out_channel)),
    self.add_module('LeakyRelu', nn.LeakyReLU(0.2, inplace=True))

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    m.weight.data.normal_(0.0, 0.02)
  elif classname.find('Norm') != -1:
    # TODO: Verify that this code initializes to zero mean and unit variance.
    # The normal_(1.0, 0.01) is confusing because it looks like unit mean and 0.01 variance?
    # https://forums.fast.ai/t/how-is-batch-norm-initialized/39818
    m.weight.data.normal_(1.0, 0.01)
    m.bias.data.fill_(0)

class WDiscriminator2(nn.Module):
    def __init__(self, opt):
        super(WDiscriminator2, self).__init__()
        self.is_cuda = torch.cuda.is_available()
        N = int(opt.nfc)
        self.head = ConvBlock(opt.nc_im,N,opt.ker_size,opt.padd_size,1)
        self.body = nn.Sequential()
        for i in range(opt.num_layer-2):
            N = int(opt.nfc/pow(2,(i+1)))
            block = ConvBlock(max(2*N,opt.min_nfc),max(N,opt.min_nfc),opt.ker_size,opt.padd_size,1)
            self.body.add_module('block%d'%(i+1),block)
        self.tail = nn.Conv2d(max(N,opt.min_nfc),1,kernel_size=opt.ker_size,stride=1,padding=opt.padd_size)

    def forward(self,x):
        x = self.head(x)
        x = self.body(x)
        x = self.tail(x)
        return x

class GeneratorConcatSkip2CleanAdd2(nn.Module):
    def __init__(self, opt):
        super(GeneratorConcatSkip2CleanAdd2, self).__init__()
        self.is_cuda = torch.cuda.is_available()
        N = opt.nfc
        self.head = ConvBlock(opt.nc_im,N,opt.ker_size,opt.padd_size,1) #GenConvTransBlock(opt.nc_z,N,opt.ker_size,opt.padd_size,opt.stride)
        self.body = nn.Sequential()
        for i in range(opt.num_layer-2):
            N = int(opt.nfc/pow(2,(i+1)))
            block = ConvBlock(max(2*N,opt.min_nfc),max(N,opt.min_nfc),opt.ker_size,opt.padd_size,1)
            self.body.add_module('block%d'%(i+1),block)
        self.tail = nn.Sequential(
            nn.Conv2d(max(N,opt.min_nfc),opt.nc_im,kernel_size=opt.ker_size,stride =1,padding=opt.padd_size),
            nn.Tanh()
        )
    def forward(self,x,y):
        x = self.head(x)
        x = self.body(x)
        x = self.tail(x)
        # NOTE: Are they downsampling/upsampling here?
        ind = int((y.shape[2]-x.shape[2])/2)
        y = y[:,:,ind:(y.shape[2]-ind),ind:(y.shape[3]-ind)]
        return x+y

class DummyOpt:
  def __init__(self):
    self.nfc = 32
    self.nc_im = 3
    self.ker_size = 3
    self.padd_size = 1
    self.num_layer = 5
    self.min_nfc = 3
opt_example = DummyOpt()
D_example = WDiscriminator2(opt_example)
G_example = GeneratorConcatSkip2CleanAdd2(opt_example)

print(D_example)
print(G_example)

In [0]:
%cd /content/SinGAN
!git checkout experimental
![ -d TrainedModels ] && rm -r TrainedModels
import torch
import torch.nn as nn
import sys
import os

# Import help functions from SinGAN
import config
from config import get_arguments
from SinGAN.manipulate import *
from SinGAN.training import *
import SinGAN.functions as functions

print('Implement SinGAN here...')

# Replace the specific functions we want to reimplement
SinGAN.training.train_single_scale = train_single_scale2
SinGAN.models.WDiscriminator = WDiscriminator2
SinGAN.models.GeneratorConcatSkip2CleanAdd = GeneratorConcatSkip2CleanAdd2

del sys.argv[:]
sys.argv.append('main_train.py')

parser = get_arguments()
parser.add_argument('--input_dir', help='input image dir', default='Input/Images')
parser.add_argument('--input_name', help='input image name', default='birds.png')
parser.add_argument('--mode', help='task to be done', default='train')

opt = parser.parse_args()
opt = functions.post_config(opt)

Gs = []
Zs = []
reals = []
NoiseAmp = []
dir2save = functions.generate_dir2save(opt)

if os.path.exists(dir2save):
  print('trained model already exists: {}'.format(dir2save))
else:
  try:
    os.makedirs(dir2save)
  except OSError:
    pass
  real = functions.read_image(opt)
  functions.adjust_scales2image(real, opt)
  train(opt, Gs, Zs, reals, NoiseAmp)
  SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)

# Evaluation: Let's generate some SinGAN images and look at the results

In [0]:
!python3 random_samples.py --input_name birds.png --mode random_samples_arbitrary_sizes --scale_h 1 --scale_v 1
!ls
!ls -l Output/RandomSamples/birds
!ls -l Output/RandomSamples/birds/gen_start_scale=0

In [0]:
import cv2
import glob
from google.colab.patches import cv2_imshow

print('original image')
original_img_path = 'Input/Images/birds.png'
img = cv2.imread(original_img_path)
cv2_imshow(img)

# Get generated images
img_paths = glob.glob('Output/RandomSamples/birds/gen_start_scale=0/*.png')

print('random sample')
img = cv2.imread(img_paths[0])
cv2_imshow(img)

print('random sample')
img = cv2.imread(img_paths[1])
cv2_imshow(img)

print('random sample')
img = cv2.imread(img_paths[2])
cv2_imshow(img)