In [1]:
import numpy as np
from matplotlib import pyplot as plt
import cv2
import os
from tensorpack import dataflow
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from IPython.display import clear_output

%matplotlib inline

Failed to import tensorflow.


In [2]:
USE_CUDA = torch.cuda.is_available()

FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor

if USE_CUDA:
    print('Using CUDA')
else:
    print('Using CPU')

Using CUDA


In [3]:
DATA_PATH = './data'
if not os.path.exists(DATA_PATH):
    print('Can\'t find DATA_PATH:', DATA_PATH) 
    quit(1)

In [4]:
DATASET_SIZE = len(os.listdir(DATA_PATH))
if DATASET_SIZE == 0:
    print('No dataset found')
    quit(1)

In [5]:
def path_iter():
    filenames = os.listdir(DATA_PATH)
    
    while True:
        random.shuffle(filenames)
        
        for fn in filenames:
            yield os.path.join(DATA_PATH, fn)

In [6]:
IMAGE_SIDE = 64

In [7]:
def prepare_image(path):
    image = cv2.imread(path)
    
    if image is None:
        return None
    
    image = cv2.resize(image, (IMAGE_SIDE, IMAGE_SIDE))
    
    image = np.swapaxes(image, 2, 0)
    
    image = ((image / 255.) - .5) * 2
        
    return image

In [8]:
image_iter = dataflow.MapData(path_iter(), func=prepare_image)

In [9]:
PRIOR_DIM = 100

ngf = 64
nc = 3

In [10]:
class G(nn.Module):
    def __init__(self):
        super(G, self).__init__()
        
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(PRIOR_DIM, ngf * 8, 4, 1, 0, bias=False),
            nn.ReLU(True),
            #nn.BatchNorm2d(ngf * 8),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.ReLU(True),
            #nn.BatchNorm2d(ngf * 4),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.ReLU(True),
            #nn.BatchNorm2d(ngf * 2),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.ReLU(True),
            #nn.BatchNorm2d(ngf),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )
        
        if USE_CUDA:
            self.cuda()
            
    def forward(self, x):
        x = torch.reshape(x, (-1, PRIOR_DIM, 1, 1))
        
        x = self.main(x)
        
        return x


In [11]:
ndf = 64
CLIP = 0.01

In [12]:
class D(nn.Module):
    def __init__(self):
        super(D, self).__init__()
        
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            #nn.BatchNorm2d(ndf),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            #nn.BatchNorm2d(ndf * 2),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            #nn.BatchNorm2d(ndf * 4),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            #nn.BatchNorm2d(ndf * 8),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
        if USE_CUDA:
            self.cuda()
        
    def forward(self, x):
        x = self.main(x)
        
        return x.view((-1, 1))
    
    def clip(self):
        self.main[0].weight.data.clamp_(min=-CLIP, max=CLIP)
        self.main[2].weight.data.clamp_(min=-CLIP, max=CLIP)
        self.main[4].weight.data.clamp_(min=-CLIP, max=CLIP)
        self.main[6].weight.data.clamp_(min=-CLIP, max=CLIP)
        self.main[8].weight.data.clamp_(min=-CLIP, max=CLIP)

In [13]:
INIT_STD = 0.02

In [14]:
def init_weights(m):
    if type(m) in [nn.Conv2d, nn.Linear, nn.ConvTranspose2d]:
        torch.nn.init.normal_(m.weight.data, mean=0, std=INIT_STD)
        if m.bias is not None:
            torch.nn.init.normal_(m.bias.data, mean=0, std=INIT_STD)
    elif type(m) in [nn.BatchNorm2d, nn.LeakyReLU, nn.ReLU, nn.Sequential, nn.Tanh]:
        return
    else:
        print('Couldn\'t init wieghts of layer with type:', type(m))


In [15]:
generator = G()
discriminator = D()

generator.apply(init_weights)
discriminator.apply(init_weights);

Couldn't init wieghts of layer with type: <class '__main__.G'>
Couldn't init wieghts of layer with type: <class 'torch.nn.modules.activation.Sigmoid'>
Couldn't init wieghts of layer with type: <class '__main__.D'>


In [16]:
lr = 0.00005

In [17]:
G_optim = optim.RMSprop(generator.parameters(), lr = lr)
D_optim = optim.RMSprop(discriminator.parameters(), lr = lr)

In [18]:
def prior():
    return np.random.multivariate_normal(np.zeros(PRIOR_DIM), np.identity(PRIOR_DIM))

In [19]:
EPOCH_COUNT = 200000
BATCH_SIZE = 32
BATCH_DIVISOR = 1
DISCRIMINATOR_LEARNING_REPEATS = 5

MINIBATCH_SIZE = BATCH_SIZE // BATCH_DIVISOR
EPOCH_LEN = int(DATASET_SIZE / (BATCH_SIZE * (DISCRIMINATOR_LEARNING_REPEATS + 2)))

In [20]:
def get_data_minibatch():
    i = 0
    ret = []
    for image in image_iter:
        ret.append(image)
        
        i += 1
        if i == MINIBATCH_SIZE:
            break
    
    return np.stack(ret, axis=0)

In [21]:
def save(fn=None):
    if fn is None:
        fn = 'checkpoint'

    torch.save({
        'gen': generator.state_dict(),
        'dis': discriminator.state_dict(),
        'g_opt': G_optim.state_dict(),
        'd_opt': D_optim.state_dict()
    }, fn)

In [22]:
def load(fn):
    if fn is None:
        fn = 'checkpoint'
    state = torch.load(fn)

    generator.load_state_dict(state['gen'])
    discriminator.load_state_dict(state['dis'])
    G_optim.load_state_dict(state['g_opt'])
    D_optim.load_state_dict(state['d_opt'])

In [23]:
if not os.path.exists('generated'):
    os.mkdir('generated')

load('checkpoint')

for epoch_num in range(EPOCH_COUNT):
    print('===============================')
    print('Epoch', epoch_num, 'started!')
    print('===============================')

    print('Saving...')
    save('checkpoint')
    
    for t in range(EPOCH_LEN):
        clear_output(wait=True)

        print('t:', t)
        
        generator.eval()
        discriminator.train()
        
        for k in range(DISCRIMINATOR_LEARNING_REPEATS):
            D_optim.zero_grad()

            for sample_num in range(BATCH_DIVISOR):
                data = FloatTensor(get_data_minibatch())
                p = FloatTensor(np.stack([prior() for i in range(MINIBATCH_SIZE)], axis=0))

                D_of_x = discriminator(data)
                D_of_G_of_z = discriminator(generator(p))

                loss = D_of_x - D_of_G_of_z
                loss = torch.mean(loss, dim = 0)
                loss = -loss
                loss = loss / BATCH_DIVISOR

                loss.backward()

                print('D loss:', loss.detach().cpu().numpy()[0])
                
            D_optim.step()
            
            discriminator.clip()
            
        #=========================================
        
        discriminator.eval()
        generator.train()
        
        G_optim.zero_grad()

        for sample_num in range(BATCH_DIVISOR):
            p = FloatTensor(np.stack([prior() for i in range(MINIBATCH_SIZE)], axis=0))

            D_of_G_of_z = discriminator(generator(p))

            loss = -D_of_G_of_z
            loss = torch.mean(loss, dim = 0)
            loss = loss / BATCH_DIVISOR

            loss.backward()

            print('G loss:', loss.detach().cpu().numpy()[0])

        G_optim.step()
        
        generator.eval()        
        image = (generator(FloatTensor(np.expand_dims(prior(), 0))).detach().cpu().numpy()[0] + 1) / 2
        image = (np.swapaxes(image, 0, 2) * 255).astype(np.uint8)
        cv2.imwrite(os.path.join('generated', str(t) + '.jpg'), image)

KeyboardInterrupt: 