In [1]:
import random
import math
import pandas as pd
import numpy as np
# for visualizations
import matplotlib.pyplot as plt
import seaborn as sns
# for image related operations
import PIL
# for warnings
import warnings
warnings.filterwarnings("ignore")

In [3]:
import tensorflow as tf

# importing data
data_path = "../input/alzheimer-augmented/AG_Alzheimer_s Dataset"
batch_size = 64 # can use 32 or 128 as well

# tf.keras.preprocessing.image_dataset_from_directory generates a tf.data.Dataset from image files in a directory
data = tf.keras.preprocessing.image_dataset_from_directory(data_path, label_mode = None, image_size = (64,64), batch_size = batch_size)

In [4]:
type(data)

In [6]:
plt.figure(figsize=(15,15)) # (15, 15) is the size of each image
for images in data.take(1):
    for i in range(10): # display 10 images
        ax = plt.subplot(6, 5, i + 1) # 6 rows, 5 columns for the i+1th image (i starts from 0 hence 1 is added)
        
        # dataset needs to be first converted to numpy array to be displayed. unit8 has range from 0 to 255 which fits perfectly for our image data, hence this is used
        ax.imshow(images[i].numpy().astype("uint8")) 
        ax.axis("off")

In [7]:
data = data.map(lambda x: x / 255.0)
data

In [8]:
from PIL import Image
import tensorflow  as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Reshape, UpSampling2D, Conv2D, BatchNormalization
from tensorflow.keras.layers import LeakyReLU, Dropout, ZeroPadding2D, Flatten, Activation
from tensorflow.keras.optimizers import Adam

In [9]:
latent_dim = 100

# building a generator
generator = Sequential()
generator.add(Dense(4*4*256, activation="relu", input_dim=latent_dim))
generator.add(Reshape((4,4,256)))
generator.add(UpSampling2D())
generator.add(Conv2D(256,kernel_size=3,padding="same"))
generator.add(BatchNormalization(momentum=0.8))
generator.add(Activation("relu"))
generator.add(UpSampling2D())
generator.add(Conv2D(256,kernel_size=3,padding="same"))
generator.add(BatchNormalization(momentum=0.8))
generator.add(Activation("relu"))
generator.add(UpSampling2D())
generator.add(Conv2D(256,kernel_size=3,padding="same"))#
generator.add(BatchNormalization(momentum=0.8))
generator.add(Activation("relu"))
generator.add(UpSampling2D())
generator.add(Conv2D(128,kernel_size=3,padding="same"))
generator.add(BatchNormalization(momentum=0.8))
generator.add(Activation("relu"))
generator.add(Conv2D(3,kernel_size=3,padding="same"))
generator.add(Activation("tanh")) 

generator.summary()

In [10]:
# creating a random noise and output from generator
noise = tf.random.normal([1, latent_dim]) # 1 image of latent_dim size = 100
Generated_image = generator(noise, training=False) # generate image from the random noise

# plotting the image output of generator without training 
plt.imshow(Generated_image[0, :, :, 0])
plt.axis("off")


In [11]:
# building the discriminator
discriminator = Sequential()
discriminator.add(Conv2D(32, kernel_size=3, strides=2, input_shape=(64,64,3), padding="same"))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
discriminator.add(ZeroPadding2D(padding=((0,1),(0,1))))
discriminator.add(BatchNormalization(momentum=0.8))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25)) # can experiment by removing Dropout layer. I got better performance with it hence using it
discriminator.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
discriminator.add(BatchNormalization(momentum=0.8))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
discriminator.add(BatchNormalization(momentum=0.8))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Conv2D(512, kernel_size=3, strides=1, padding="same"))
discriminator.add(BatchNormalization(momentum=0.8))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation="sigmoid"))

discriminator.summary()

In [12]:
class GAN(tf.keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self.d_loss_metric = tf.keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = tf.keras.metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def train_step(self, real_images):
        # sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        noise = tf.random.normal(shape=(batch_size, self.latent_dim))
        # decode them to fake images
        generated_images = self.generator(noise)
        
        # combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)
        # assemble labels discriminating real from fake images
        labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)
        # add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))
        # train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

        # sample random points in the latent space
        noise = tf.random.normal(shape=(batch_size, self.latent_dim))

        # assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # train the generator (note that we should *not* update the weights of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(noise))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
   
        # update metrics
        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {"d_loss": self.d_loss_metric.result(), "g_loss": self.g_loss_metric.result()}

In [13]:
# defining the number of epochs
epochs = 100

# the optimizers for generator and discriminator
discriminator_opt = tf.keras.optimizers.Adamax(1.5e-4,0.5)
generator_opt = tf.keras.optimizers.Adamax(1.5e-4,0.5)

# to compute cross entropy loss
loss_fn = tf.keras.losses.BinaryCrossentropy()

# defining GAN model
model = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)

# compiling GAN model
model.compile(d_optimizer=discriminator_opt, g_optimizer=generator_opt, loss_fn=loss_fn)

# fitting the GAN
history = model.fit(data, epochs=epochs)

In [14]:
num_img=10

# a function to generate and save images
def Image_Generator():
    Generated_images = []
    noise = tf.random.normal([num_img, latent_dim]) 
    generated_image = generator(noise)
    generated_image *= 255 
    generated_image = generated_image.numpy()
    for i in range(num_img):
            img = tf.keras.preprocessing.image.array_to_img(generated_image[i])
            Generated_images.append(img)
            img.save("image{:02d}.png".format(i)) 
    return 

# generating images
Images = Image_Generator()

In [15]:
Generated_path = "./"
Images_generated = tf.keras.preprocessing.image_dataset_from_directory(Generated_path, label_mode = None)

In [17]:
plt.figure(figsize=(15,15))
for images in Images_generated.take(1):
    for i in range(10):
        ax = plt.subplot(5, 6, i + 1)
        ax.imshow(images[i].numpy().astype("uint8"))
        ax.axis("off")