In [1]:
from __future__ import absolute_import, division, print_function

# Import TensorFlow >= 1.10 and enable eager execution
import tensorflow as tf
tf.enable_eager_execution()

import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
from IPython import display

In [2]:
# hyperparameters
n_critic = 5
lr_d = 1e-4
lr_g = 1e-4
img_size = [64,64]
BUFFER_SIZE = 142
BATCH_SIZE = 16

In [3]:
def create_data():
    counter=0
    images = []
    labels=[]
    print('converting images...')
    for img in os.listdir('data2'):
        images.append(np.asarray(PIL.Image.open('data2/'+img).resize((64,64)).convert('RGB')))
        labels.append(img.split('.')[0])
        counter+=1
    images = np.asarray(images).astype('float32')
    display.clear_output(wait=True)
    print('done converting')
  #do something similar for the labels
    return images    

In [4]:
images=create_data()

done converting


In [5]:
images = (images - 127.5) / 127.5

In [6]:
train_dataset = tf.data.Dataset.from_tensor_slices(images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [7]:
class Generator(tf.keras.Model):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = tf.keras.layers.Dense(4*4*512, use_bias=False)
        self.batchnorm1 = tf.keras.layers.BatchNormalization()
        
        self.conv1 = tf.keras.layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)
        self.batchnorm2 = tf.keras.layers.BatchNormalization()
        
        self.conv2 = tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False)
        self.batchnorm3 = tf.keras.layers.BatchNormalization()

        self.conv3 = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)
        self.batchnorm4 = tf.keras.layers.BatchNormalization()

        self.conv5 = tf.keras.layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False)

    def call(self, x, training=True):
        x = self.fc1(x)
        x = self.batchnorm1(x, training=training)
        x = tf.nn.leaky_relu(x)

        x = tf.reshape(x, shape=(-1, 4, 4, 512))

        x = self.conv1(x)
        x = self.batchnorm2(x, training=training)
        x = tf.nn.leaky_relu(x)

        x = self.conv2(x)
        x = self.batchnorm3(x, training=training)
        x = tf.nn.leaky_relu(x)
        
        x = self.conv3(x)
        x = self.batchnorm4(x, training=training)
        x = tf.nn.leaky_relu(x) 


        x = tf.nn.tanh(self.conv5(x))  
        return x

In [8]:
class Critic(tf.keras.Model):
    def __init__(self):
        super(Critic, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')
        
        self.conv2= tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')
        
        self.conv3= tf.keras.layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same')
        
        self.conv4= tf.keras.layers.Conv2D(512, (5, 5), strides=(2, 2), padding='same')
        
        self.dropout = tf.keras.layers.Dropout(0.3)
        
        self.flatten = tf.keras.layers.Flatten()
        
        self.fc1 = tf.keras.layers.Dense(1)

    def call(self, x, training = True):
        x = tf.nn.leaky_relu(self.conv1(x))
        x = self.dropout(x, training=training)
        
        x = tf.nn.leaky_relu(self.conv2(x))
        x = self.dropout(x, training=training)
        
        x = tf.nn.leaky_relu(self.conv3(x))
        x = self.dropout(x, training=training)
        
        x = tf.nn.leaky_relu(self.conv4(x))
        x = self.dropout(x, training=training)

        x = self.flatten(x)
        x = self.fc1(x)
        return x

In [9]:
generator = Generator()
critic = Critic()

In [10]:
generator.call = tf.contrib.eager.defun(generator.call)
critic.call = tf.contrib.eager.defun(critic.call)

In [11]:
def get_critic_loss(real_output, generated_output):
    loss = tf.reduce_mean(real_output) - tf.reduce_mean(generated_output)
    return loss

In [12]:
def get_generator_loss(generated_output):
    return tf.reduce_mean(generated_output)

In [13]:
discriminator_optimizer = tf.train.AdamOptimizer(lr_d)
generator_optimizer = tf.train.AdamOptimizer(lr_g)

# weight clipping for wgan
discriminator_optimizer_clipped = tf.contrib.gan.features.clip_variables(discriminator_optimizer, critic.variables, 0.01)

In [14]:
checkpoint_dir = 'checkpoint-wgan/'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator = generator,
                                 critic = critic)

In [15]:
EPOCHS = 2000
noise_dim = 100
num_examples_to_generate = 16

# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement of the gan.
random_vector_for_generation = tf.random_normal([num_examples_to_generate,
                                                 noise_dim])

In [16]:
def generate_and_save_images(model, epoch):
  # make sure the training parameter is set to False because we
  # don't want to train the batchnorm layer when doing inference.
    predictions = model(random_vector_for_generation, training=False)
    e2=predictions*127.5+127.5
    fig = plt.figure(figsize=(4,4))
        
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(np.uint8(e2[i]))
        plt.axis('off')
        
    plt.savefig('generated/wgan-gp_image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

In [19]:
def train(dataset, epochs, noise_dim):  
    iteration=1
    for epoch in range(epochs):
        start = time.time()
    
        for images in dataset:
            if((iteration) % 6 != 0):
                # generating noise from a uniform distribution
                noise = tf.random_normal([BATCH_SIZE, noise_dim])
      
                with tf.GradientTape() as critic_tape, tf.GradientTape() as tape:
                    generated_images = generator(noise, training=True)
                    
                    real_output = critic(images, training=True)
                    generated_output = critic(generated_images, training=True)
                    critic_loss = get_critic_loss(real_output, generated_output)
                    
#                     # calculating gradient penalty term, still experimental
#                     with tf.GradientTape() as tape:
                    epsilon = tf.random_uniform([], 0, 1)
                    xhat = epsilon*images + (1-epsilon)*generated_images
                    tape.watch(xhat)
                    dhat = critic(xhat, training=True)
                    dhat_grad = tape.gradient(dhat, xhat)
                    slopes = tf.sqrt(tf.reduce_sum(tf.square(dhat2), reduction_indices=[1]))
                    gradient_penalty = 10*tf.reduce_mean((slopes-1.0)**2)
                    critic_loss += gradient_penalty
                    
                gradients_of_critic = critic_tape.gradient(critic_loss, critic.variables)
                discriminator_optimizer_clipped.apply_gradients(zip(gradients_of_critic, critic.variables))
                iteration+=1
            else:
                with tf.GradientTape() as generator_tape:
                    noise = tf.random_normal([BATCH_SIZE, noise_dim])
                    generated_images2 = generator(noise, training=True)
                    gen_loss = get_generator_loss(generated_images2)
                
                gradients_of_generator = generator_tape.gradient(gen_loss, generator.variables)
                generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))
                iteration+=1
      
        if epoch % 1 == 0:
            display.clear_output(wait=True)
            generate_and_save_images(generator,epoch + 1)
    
    # saving (checkpoint) the model every 15 epochs
        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
    
        print ('Time taken for epoch {} is {} sec'.format(epoch + 1,
                                                      time.time()-start))
  # generating after the final epoch
    display.clear_output(wait=True)
    generate_and_save_images(generator, epochs)

In [None]:
train(train_dataset, EPOCHS, noise_dim)