# DCGAN Pytorch implementation
via https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

In [1]:
from __future__ import print_function
%matplotlib inline
import argparse
import os
import time
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

## Konstanten

In [2]:
# Root directory for dataset
dataroot = "I:/GAN-Art/datasets/wikiart_25/Expressionism_"

# Number of channels in the training images. For color images this is 3
nc = 3

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

# filter in gen
ngf = 96

# filter in discriminator
ndf = 36

# latent vector dimension
nz = 100

# device
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Binary cross entropy loss
criterion = nn.BCELoss()

# latent vector generation
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# real, fake label convention
real_label = 0.9
fake_label = 0

In [3]:
# Modelle

zoo = {"256px output": (
    nn.Sequential(
        # Input Z (100x1x1)
        nn.ConvTranspose2d(nz, ngf * 32, 4, 1, 0, bias=False),
        nn.BatchNorm2d(ngf * 32),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        # 4x4x(ngf*32)

        nn.ConvTranspose2d(ngf * 32, ngf * 16, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf * 16),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        # 8x8x(ngf*16)

        nn.ConvTranspose2d(ngf * 16, ngf * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf * 8),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        # 16x16x(ngf*8)

        nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf * 4),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        # 32x32x(ngf*4)

        nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf * 2),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        # 64x64x(ngf*2)

        nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        # 128x128x(ngf)

        nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
        nn.Tanh()
        # 256x256x3 Output
    ),
    
    nn.Sequential(
        # Input 256x256x3
        nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf),
        nn.LeakyReLU(0.2, inplace=True),
        # 128x128xndf

        nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 2),
        nn.LeakyReLU(0.2, inplace=True),
        # 64x64x(ndf * 2)

        nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 4),
        nn.LeakyReLU(0.2, inplace=True),
        # 32x32x(ndf * 4)

        nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 8),
        nn.LeakyReLU(0.2, inplace=True),
        # 16x16x(ndf * 8)

        nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 16),
        nn.LeakyReLU(0.2, inplace=True),
        # 8x8x(ndf * 16)

        nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 32),
        nn.LeakyReLU(0.2, inplace=True),
        # 4x4x(ndf * 32)

        nn.Conv2d(ndf * 32, 1, 4, 1, 0, bias=False),
        nn.Sigmoid()
        # 1x1x1
    )),
        "64px output": (
        nn.Sequential(
        # Input Z (100x1x1)
        nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
        nn.BatchNorm2d(ngf * 8),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        # 4x4x(ngf*8)

        nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf * 4),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        # 8x8x(ngf*4)

        nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf * 2),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        # 16x16x(ngf*2)

        nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        # 32x32x(ngf)

        nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
        nn.Tanh()
        # 64x64x3 Output
    ),
    
    nn.Sequential(
        # Input 64x64x3
        nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf),
        nn.LeakyReLU(0.2, inplace=True),
        # 32x32xndf

        nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 2),
        nn.LeakyReLU(0.2, inplace=True),
        # 16x16x(ndf * 2)

        nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 4),
        nn.LeakyReLU(0.2, inplace=True),
        # 8x8x(ndf * 4)

        nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 8),
        nn.LeakyReLU(0.2, inplace=True),
        # 4x4x(ndf * 8)

        nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
        nn.Sigmoid()
        # 1x1x1    
    ))
}

## Data pipeline

In [4]:
def create_dataloader(img_size, batch_size, num_workers, plot_train):
    dataset = dset.ImageFolder(root=dataroot,
                          transform=transforms.Compose([
                              transforms.Resize((img_size, img_size)),
                              transforms.CenterCrop(img_size),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                          ]))


    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                        shuffle=True, num_workers=num_workers)
    
    
    if plot_train:
        # Plot training data
        real_batch = next(iter(dataloader))
        plt.figure(figsize=(20, 20))
        plt.axis("off")
        plt.title("Training Images")
        plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:32], padding=2, normalize=True).cpu(),(1,2,0)))
        
        
    return dataloader
    

## Utils

In [5]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [6]:
def print_losses(G_losses, D_losses):
    plt.figure(figsize=(10,5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses,label="G")
    plt.plot(D_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()
    
        

In [7]:
def show_slider(img_list):
    #%%capture
    fig = plt.figure(figsize=(8,8))
    plt.axis("off")
    ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]

    ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
    plt.show()
    HTML(ani.to_jshtml())

In [8]:
def inference(foldername, num_samples=64, imgs_per_files = 16, check_path=None, 
              architecture=None, save_path="I:/GAN-Art/Jupyter/samples/", model=None):
    
    gen = None
    samples = None
    num_files = int(num_samples / imgs_per_files)
    noise = torch.randn(num_samples, nz, 1, 1, device=device)
    folder_path = save_path + foldername + "/"
    if not os.path.exists(folder_path): os.mkdir(folder_path) 
    
    if check_path is not None and architecture is not None:
        state = torch.load(check_path, map_location="cpu")
        gen, _ = create_architectures(architecture, False)
        gen.load_state_dict(state["netG"])
        samples = gen(noise)

    if model is not None:
        samples = model(noise)
       
    # tuple of len 4
    split = torch.split(samples, [imgs_per_files] * num_files, dim=0)
    
    # read files
    onlyfiles = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
    
    # strip strgs
    lst = []
    for x in onlyfiles:
        lst.append(os.path.splitext(x)[0])
        
    lst = list(map(int, lst))
    high = max(lst) if lst else 0
    
    print("Saving generated images to...\n")
    j=0
    for i in range(high, high+num_files):
        filename = folder_path + ("%s"%(i+1))+".png"
        open(filename, "ab")
        print(filename)
        
        vutils.save_image(split[j], filename, nrow=4, padding=0, normalize=True)
        
        j += 1
        

        
      
    

## Generator/Discriminator

In [9]:
class Generator(nn.Module):
    def __init__(self, architecture):
        super(Generator, self).__init__()
        
        self.main = architecture
        
    def forward(self, input):
        return self.main(input)
    
    
class Discriminator(nn.Module):
    def __init__(self, architecture):
        super(Discriminator, self).__init__()
        
        self.main = architecture
        
    def forward(self, input):
        return self.main(input)

In [10]:
def create_architectures(architecture, print_summary):
    # create architectures
    netG = Generator(architecture[0]).to(device)
    netD = Discriminator(architecture[1]).to(device)

    # Initialize weights in layers
    netG.apply(weights_init)
    netD.apply(weights_init)

    # print summary
    if print_summary:
        print(netG)
        print(netD)
        
    return netG, netD

## Model

In [11]:
def model(architecture, batch_size, image_size, num_epochs, lr, beta1, save_path=None, load_path=None, 
          save_state=False, save_epochs = 30, print_summary = False, num_workers=4, 
          plot_train=False, print_eval=True, genimgs_iter = 500):
    
    # Variables 
    state = {
        "epoch": 0,
        "netD" : None,
        "netG" : None,
        "optimizerD" : None,
        "optimizerG" : None,
        "img_list" : [],
        "G_losses" : [],
        "D_losses" : [] 
    }
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0
    epochs_trained = 0
    epochs_loaded = 0
    elapsed = 0
    
    # Dataloader
    dataloader = create_dataloader(image_size, batch_size, num_workers, plot_train)
    
    # create gen, discr
    netG, netD = create_architectures(architecture, print_summary)

    # Adam optimizers
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
    
    # check, if checkpointed model should be loaded up
    if load_path is not None:
        state = torch.load(load_path, map_location="cpu")
        
        epochs_trained = state["epochs_trained"]
        epochs_loaded = state["epochs_trained"]
        #img_list = state["img_list"]
        G_losses = state["G_losses"] 
        D_losses = state["D_losses"]
        
        netD.load_state_dict(state["netD"]) 
        netG.load_state_dict(state["netG"])
        optimizerD.load_state_dict(state["optimizerD"])
        optimizerG.load_state_dict(state["optimizerG"])

        
        

    if load_path is not None:
        print("Continue training from model %s, with %d epochs already trained. Training for additional %d epochs.\n" 
             % (load_path, epochs_loaded, num_epochs))
    else:
        print("Training a new model for %d epochs.\n" % (num_epochs))
    
    for epoch in range(num_epochs):
        tic = time.time()
        for i, data in enumerate(dataloader, 0):
            #######################################
            # Train Discriminator
            # max log(G(x)) + log(1-D(G(z)))
            #######################################

            ## Erster Durchlauf: real batch
            netD.zero_grad()
            real_batch = data[0].to(device)     # [batchsize, 3, H, W]
            b_size = real_batch.size(0)
            labels = torch.full((b_size, ), real_label, device=device)

            # classify real batch
            out = netD(real_batch).view(-1)

            # calculate loss
            errD_real = criterion(out, labels)

            # Calculate gradients
            errD_real.backward()
            D_x = out.mean().item()

            ## Zweiter Durchlauf: fake data

            # generate latent samples
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            fake_batch = netG(noise)
            labels.fill_(fake_label)

            # classify fake batch
            out = netD(fake_batch.detach()).view(-1)

            # calculate loss
            errD_fake = criterion(out, labels)

            # calculate gradients
            errD_fake.backward()
            D_G_z1 = out.mean().item()

            # add gradients 
            errD = errD_real + errD_fake

            # update discriminator weights (optimizerD <-> vb mit architektur in discr)
            optimizerD.step()



            #######################################
            # Train Generator
            # max log(D(G(z))) 
            #######################################

            netG.zero_grad()
            labels.fill_(real_label)

            # generate new latent sample (same as above)
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            fake_batch = netG(noise)

            out = netD(fake_batch).view(-1)

            # calculate generator loss
            errG = criterion(out, labels)

            # calculate gradients
            errG.backward()
            D_G_z2 = out.mean().item()

            # update generator weights (optimizerG <-> vb mit architektur in gen)
            optimizerG.step()



            #######################################
            # Evaluation statistics
            #######################################

            if i % 10 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f %s'
                      % (epochs_trained, epochs_loaded+num_epochs, i, len(dataloader),
                         errD.item(), errG.item(), D_x, D_G_z1, D_G_z2, 
                         ("EpochTime: %.2fmin, TimeLeft: %.2fmin" %(elapsed, (num_epochs-epoch)*elapsed) 
                          if (epoch is not 0 and i == 0) else "")))

            if i % 200 == 0:
                # Save Losses for plotting later
                G_losses.append(errG.item())
                D_losses.append(errD.item())

            # Check how the generator is doing by saving G's output on fixed_noise
            if (iters % genimgs_iter == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()

                imgrid = vutils.make_grid(fake[random.sample(range(64), 10)], padding=2, normalize=True, nrow=10)
                img_list.append(imgrid)
                plt.figure(figsize=(20,20))
                plt.imshow(np.transpose(imgrid, (1, 2, 0)))
                plt.show()
                
                
            iters += 1
            
        # serialize data every few epochs
        if (save_state == True and (epoch % save_epochs == 0) and save_path is not None and epoch is not 0):
                
            state["epochs_trained"] = epochs_trained
            state["netD"] = netD.state_dict()
            state["netG"] = netG.state_dict()
            state["optimizerD"] = optimizerD.state_dict()
            state["optimizerG"] = optimizerG.state_dict()
            #state["img_list"] = img_list
            state["G_losses"] = G_losses
            state["D_losses"] = D_losses  

            torch.save(state, save_path)
            print("Model saved after %d iters.\n" % (iters))    
            
        # epoch timing
        epochs_trained += 1
        toc = time.time()
        elapsed = (toc-tic)/60
        

    
    if print_eval:
        print_losses(G_losses, D_losses)
        show_slider(img_list)
    
    
    
    
    
    
    

## Training

### Expressionism

In [None]:
# expressionism 240+ epochs

model(architecture = zoo["256px output"], batch_size=32, image_size=256, num_epochs=100, lr=0.0002, beta1=0.5,
     save_path = "I:/GAN-Art/Jupyter/checkpoints/256x_300eps_expressionism.pt", 
     load_path = "I:/GAN-Art/Jupyter/checkpoints/256x_300eps_expressionism.pt",
     save_state=True, save_epochs = 20, genimgs_iter = 400)