In [1]:
import os
import cv2
import numpy as np
from glob import glob
import tensorflow as tf
from matplotlib import pyplot
from sklearn.utils import shuffle
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

In [2]:
#Describing the image height width and color.
IMG_H = 256
IMG_W = 256
IMG_C = 3 #RGB=3, greyscale=1.

w_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

In [3]:
def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.io.decode_png(img)
    img = tf.image.resize_with_crop_or_pad(img, IMG_H, IMG_W)
    img = tf.cast(img, tf.float32)
    img = (img - 127.5) / 127.5
    return img

def tf_dataset(images_path, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices(images_path)
    dataset = dataset.shuffle(buffer_size=10240)
    dataset = dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return dataset

def deconv_block(inputs, num_filters, kernel_size, strides, bn=True):
    X = Conv2DTranspose(
        filters=num_filters,
        kernel_size=kernel_size,
        kernel_initializer=w_init,
        padding="same",
        strides=strides,
        use_bias=False
        )(inputs)
    
    if bn:
        X = BatchNormalization()(X)
        X = LeakyReLU(alpha=0.2)(X)
    return X

def conv_block(inputs, num_filters, kernel_size, padding="same", strides=2, activation=True):
    X = Conv2D(
        filters=num_filters,
        kernel_size=kernel_size,
        kernel_initializer=w_init,
        padding=padding,
        strides=strides,
    )(inputs)
    
    if activation:
        X = LeakyReLU(alpha=0.2)(X)
        X = Dropout(0.3)(X)
        
    return X

def build_generator(latent_dim):
    f = [2**i for i in range(5)][::-1]
    filters = 32
    output_strides = 16
    h_output = IMG_H // output_strides
    w_output = IMG_W // output_strides
    
    noise = Input(shape=(latent_dim,), name="generator_noise_input")
    
    X = Dense(f[0] * filters * h_output * w_output, use_bias=False)(noise)
    X = BatchNormalization()(X)
    X = LeakyReLU(alpha=0.2)(X)
    X = Reshape((h_output, w_output, 16 * filters))(X)
    
    for i in range(1, 5):
        X = deconv_block(X,
            num_filters=f[i] * filters,
            kernel_size=5,
            strides=2,             
            bn=True
        )
    
    X = conv_block(X,
        num_filters=3,
        kernel_size=5,
        strides=1,
        activation=False
    )
    fake_output = Activation("tanh")(X)
    
    return Model(noise, fake_output, name="generator")

def build_discriminator():
    f = [2**i for i in range(4)]
    image_input = Input(shape=(IMG_H, IMG_W, IMG_C))
    X = image_input
    filters = 64
    output_strides = 16
    h_output = IMG_H // output_strides
    w_output = IMG_W // output_strides
     
    for i in range(0,4):
        X = conv_block(X,
            num_filters=f[i] * filters,
            kernel_size=5,
            strides=2
        )
    X = Flatten()(X)
    X = Dense(1)(X)
    
    return Model(image_input, X, name="discriminator")

class GAN(Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]

        for _ in range(2):
            ## Train the discriminator
            random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
            generated_images = self.generator(random_latent_vectors)
            generated_labels = tf.zeros((batch_size, 1))

            with tf.GradientTape() as ftape:
                predictions = self.discriminator(generated_images)
                d1_loss = self.loss_fn(generated_labels, predictions)
            grads = ftape.gradient(d1_loss, self.discriminator.trainable_weights)
            self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

            ## Train the discriminator
            labels = tf.ones((batch_size, 1))

            with tf.GradientTape() as rtape:
                predictions = self.discriminator(real_images)
                d2_loss = self.loss_fn(labels, predictions)
            grads = rtape.gradient(d2_loss, self.discriminator.trainable_weights)
            self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

        ## Train the generator
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        misleading_labels = tf.ones((batch_size, 1))

        with tf.GradientTape() as gtape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = gtape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        return {"d1_loss": d1_loss, "d2_loss": d2_loss, "g_loss": g_loss}

def save_plot(examples, epoch, n):
    examples = (examples + 1) / 2.0
    for i in range(n * n):
        pyplot.subplot(n, n, i+1)
        pyplot.axis("off")
        pyplot.imshow(examples[i])  ## pyplot.imshow(np.squeeze(examples[i], axis=-1))
    filename = f"generated_plot_epoch-{epoch+1}.png"
    pyplot.savefig(filename)
    pyplot.close()

In [5]:
if __name__=="__main__":
    #Hyperparameters
    batch_size=18
    latent_dim=128
    num_epochs=10
    
    #Dataset
    images_path=glob("/home/administrator/satvik/sandbox/molGAN/scripts/GAN/images/enamine/*")
    print("Dataset Size: ", len(images_path))
    
    d_model = build_discriminator()
    g_model = build_generator(latent_dim)
    
    d_model.summary()
    g_model.summary()
    
    gan = GAN(d_model, g_model, latent_dim)
    
    bce_loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing=0.1)
    d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    g_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    gan.compile(d_optimizer, g_optimizer, bce_loss_fn)

    images_dataset = tf_dataset(images_path, batch_size)

    for epoch in range(num_epochs):
        gan.fit(images_dataset, epochs=1)
        g_model.save("g_model-epoch10.h5")
        d_model.save("d_model-epoch10.h5")

        n_samples = 25
        noise = np.random.normal(size=(n_samples, latent_dim))
        examples = g_model.predict(noise)
        save_plot(examples, epoch, int(np.sqrt(n_samples)))

Dataset Size:  50240
Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 256, 256, 3)]     0         
                                                                 
 conv2d_5 (Conv2D)           (None, 128, 128, 64)      4864      
                                                                 
 leaky_re_lu_9 (LeakyReLU)   (None, 128, 128, 64)      0         
                                                                 
 dropout_4 (Dropout)         (None, 128, 128, 64)      0         
                                                                 
 conv2d_6 (Conv2D)           (None, 64, 64, 128)       204928    
                                                                 
 leaky_re_lu_10 (LeakyReLU)  (None, 64, 64, 128)       0         
                                                                 
 dropout_5 (Dropout)         (No

