In [1]:
import keras
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Dense, Dropout, Input, Flatten, Activation, Reshape, BatchNormalization
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model, load_model
from keras.optimizers import Adam
from keras.layers.advanced_activations import LeakyReLU
from PIL import Image
import os
from tqdm import tqdm

Using TensorFlow backend.


In [10]:
class Gan:
    def __init__(self,w,h,channels):
        self.img_w = w
        self.img_h = h
        self.random_noise_dimension = 100
        self.channels = channels
        self.image_shape = (self.img_w, self.img_h, self.channels)
        
        optimizer = Adam()
        
        self.dis = self.discriminator()
        self.dis.compile(loss='binary_crossentropy', optimizer = optimizer, metrics = ['accuracy'])
        self.gen = self.generator()
        
        random_input = Input(shape=(self.random_noise_dimension,))
        gen_img = self.gen(random_input)
        
        self.dis.trainable = False
        result = self.dis(gen_img)
        
        self.c = Model(random_input, result)
        self.c.compile(loss='binary_crossentropy', optimizer=optimizer)
    
    def load_data(self,data_dir):
        training_data = []
        filenames = os.listdir(data_dir)
        for file in tqdm(filenames):
            path = os.path.join(data_dir, file)
            image = Image.open(path)
            image = image.resize((self.img_w, self.img_h))
            img_array = np.asarray(image)
            training_data.append(img_array)
            
        training_data = np.reshape(training_data, (-1, self.img_w, self.img_h, self.channels))
        return training_data
    
    def generator(self):
        model = Sequential()
        model.add(Dense(256 * 4 * 4, activation = 'relu', input_dim = self.random_noise_dimension))
        model.add(Reshape((4,4,256)))
        
        model.add(UpSampling2D())
        model.add(Conv2D(256, kernel_size  = 3, padding = 'same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        
        model.add(UpSampling2D())
        model.add(Conv2D(256, kernel_size  = 3, padding = 'same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size  = 3, padding = 'same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size  = 3, padding = 'same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        
        model.add(Conv2D(self.channels, kernel_size = 3, padding = 'same'))
        model.add(Activation('tanh'))
        model.summary()
        
        input = Input(shape = (self.random_noise_dimension,))
        gen_images = model(input)
        
        return Model(input, gen_images)
    
    def discriminator(self):
        model = Sequential()
        model.add(Conv2D(32, kernel_size = 3, strides = 1, input_shape = self.image_shape, padding = 'same'))
        model.add(LeakyReLU(alpha=0.2))
        
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size = 3, strides = 2, padding = 'same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        
        model.add(Dropout(0.25))
        model.add(Conv2D(128, kernel_size = 3, strides = 2, padding = 'same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        
        model.add(Dropout(0.25))
        model.add(Conv2D(256, kernel_size = 3, strides = 1, padding = 'same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        
        model.add(Dropout(0.25))
        model.add(Conv2D(512, kernel_size = 3, strides = 1, padding = 'same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(1, activation = 'sigmoid'))
        
        input_img = Input(shape = self.image_shape)
        output = model(input_img)
        
        return Model(input_img, output)
        
    
    def train(self, data_dir, epochs, batch_size):
        training_data = self.load_data(data_dir)
        training_data = training_data / 127.5 - 1
        
        labels_real_img = np.ones((batch_size, 1))
        labels_fake_img = np.zeros((batch_size, 1))
        
        for epoch in range(epochs):
            indexes = np.random.randint(0,training_data.shape[0], batch_size)
            real_img = training_data[indexes]
            
            random_noise = np.random.normal(0,1,(batch_size, self.random_noise_dimension))
            gen_img = self.gen.predict(random_noise)
            
            dis_loss_real = self.dis.train_on_batch(real_img, labels_real_img)
            dis_loss_fake = self.dis.train_on_batch(gen_img, labels_fake_img)
            
            dis_loss = 0.5 * np.add(dis_loss_real, dis_loss_fake)
            
            gen_loss = self.c.train_on_batch(random_noise, labels_fake_img)
            print("Dis loss {} : acc : {}, Gen loss {}".format(dis_loss[0], 100 * dis_loss[1], gen_loss))
            
            if epoch % 10 == 0:
                self.save_images(epoch)
    
    def save_images(self):
        rows, cols = 5,5
        noise = np.random.normal(0,1,(rows * cols, self.random_noise_dimension))
        gen_images = self.gen.predict(noise)
        figure, axis = plt.subplots(rows, cols)
        count = 0
        for row in range(rows):
            for col in range(cols):
                axis[row,col].imshow(gen_images[count, :])
                axis[row,col].axis('off')
                count += 1
        figure.savefig("gen_images/gen_{}.png".format(epoch))
        plt.close()
    
    def test_data(self):
        pass

In [11]:
faceGan = Gan(64,64,3)
faceGan.train('images', 100, 32)

Model: "sequential_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_8 (Dense)              (None, 4096)              413696    
_________________________________________________________________
reshape_3 (Reshape)          (None, 4, 4, 256)         0         
_________________________________________________________________
up_sampling2d_9 (UpSampling2 (None, 8, 8, 256)         0         
_________________________________________________________________
conv2d_31 (Conv2D)           (None, 8, 8, 256)         590080    
_________________________________________________________________
batch_normalization_25 (Batc (None, 8, 8, 256)         1024      
_________________________________________________________________
activation_11 (Activation)   (None, 8, 8, 256)         0         
_________________________________________________________________
up_sampling2d_10 (UpSampling (None, 16, 16, 256)      

KeyboardInterrupt: 