In [1]:
import numpy as np
from PIL import Image
import math
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Reshape, Flatten, BatchNormalization, Activation, UpSampling2D, Conv2D, MaxPooling2D
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.datasets import mnist
import os


Generator Model

In [2]:
def generator_model():
    model = Sequential()
    model.add(Dense(1024, input_dim=100))
    model.add(Activation('tanh'))
    model.add(Dense(128 * 7 * 7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((7, 7, 128)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(64, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(1, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    return model

Discriminator Model

In [3]:
def discriminator_model():
    model = Sequential()
    model.add(Conv2D(64, (5, 5), padding='same', input_shape=(28, 28, 1)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (5, 5)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

Combined GAN Model

In [4]:
def generator_and_discriminator(generator, discriminator):
    model = Sequential()
    model.add(generator)
    discriminator.trainable = False
    model.add(discriminator)
    return model

In [5]:
def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height * shape[0], width * shape[1]), dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = index // width
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = img[:, :, 0]
    return image


In [6]:
def train(BATCH_SIZE, epochs=1):
    (X_train, _), (_, _) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=-1)

    discriminator = discriminator_model()
    generator = generator_model()
    d_on_g = generator_and_discriminator(generator, discriminator)

    d_opt = SGD(learning_rate=0.0005, momentum=0.9, nesterov=True) # using SGD as an optimizer
    g_opt = SGD(learning_rate=0.0005, momentum=0.9, nesterov=True)

    generator.compile(loss='binary_crossentropy', optimizer="SGD")
    d_on_g.compile(loss='binary_crossentropy', optimizer=g_opt)
    discriminator.trainable = True
    discriminator.compile(loss='binary_crossentropy', optimizer=d_opt)

    noise = np.zeros((BATCH_SIZE, 100))

    for epoch in range(epochs):
        print(f"Epoch {epoch}")
        num_batches = int(X_train.shape[0] / BATCH_SIZE)
        print(f"Number of batches: {num_batches}")

        for index in range(num_batches):
            noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
            image_batch = X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
            generated_images = generator.predict(noise, verbose=0)

            if index % 20 == 0:
                image = combine_images(generated_images)
                image = (image * 127.5) + 127.5
                Image.fromarray(image.astype(np.uint8)).save(f"images/{epoch}_{index}.png")

            X = np.concatenate((image_batch, generated_images))
            y = np.concatenate((np.ones(BATCH_SIZE), np.zeros(BATCH_SIZE)))

            d_loss = discriminator.train_on_batch(X, y) # calculating discriminator loss
            print(f"Batch {index} D Loss: {d_loss:.4f}")

            noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
            discriminator.trainable = False
            g_loss = d_on_g.train_on_batch(noise, np.ones(BATCH_SIZE)) # calculating generator loss
            discriminator.trainable = True
            print(f"Batch {index} G Loss: {g_loss:.4f}")

        generator.save_weights("generator.weights.h5") # creating files to save weights
        discriminator.save_weights("discriminator.weights.h5")


Function to generate images

In [7]:
def generate(BATCH_SIZE, nice=False):
    generator = generator_model()
    generator.load_weights("generator.weights.h5")
    generator.compile(loss='binary_crossentropy', optimizer="SGD")

    noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
    generated_images = generator.predict(noise, verbose=1)
    image = combine_images(generated_images)
    image = (image * 127.5) + 127.5
    Image.fromarray(image.astype(np.uint8)).save("images/generated_image.png")
    print("Image saved at images/generated_image.png")


In [8]:
train(BATCH_SIZE=128, epochs=20)

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 0
Number of batches: 468
Batch 0 D Loss: 0.6714
Batch 0 G Loss: 0.6040
Batch 1 D Loss: 0.6654
Batch 1 G Loss: 0.6009
Batch 2 D Loss: 0.6608
Batch 2 G Loss: 0.5982
Batch 3 D Loss: 0.6569
Batch 3 G Loss: 0.5947
Batch 4 D Loss: 0.6521
Batch 4 G Loss: 0.5911
Batch 5 D Loss: 0.6464
Batch 5 G Loss: 0.5859
Batch 6 D Loss: 0.6401
Batch 6 G Loss: 0.5813
Batch 7 D Loss: 0.6339
Batch 7 G Loss: 0.5763
Batch 8 D Loss: 0.6272
Batch 8 G Loss: 0.5717
Batch 9 D Loss: 0.6206
Batch 9 G Loss: 0.5675
Batch 10 D Loss: 0.6146
Batch 10 G Loss: 0.5630
Batch 11 D Loss: 0.6083
Batch 11 G Loss: 0.5585
Batch 12 D Loss: 0.6020
Batch 12 G Loss: 0.5541
Batch 13 D Loss: 0.5961
Batch 13 G Loss: 0.5500
Batch 14 D Loss: 0.5904
Batch 14 G Loss: 0.5455
Batch 15 D Loss: 0.5851
Batch 15 G Loss: 0.5417
Batch 16 D Loss: 0.5800
Batch 16 G Loss: 0.5381
Batch 17 D Loss: 0.5749
Batch 17 G Loss: 0.5343
Batch 18 D Loss: 0.5703
Batch 18 G Loss: 0.5307
Batch 19 D Loss: 0.5660
Batch 19 G Loss: 0.5274
Batch 20 D Loss: 0.5618
Batch

In [9]:
generate(BATCH_SIZE=128)

[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step
Image saved at images/generated_image.png


In [10]:
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from scipy.linalg import sqrtm
import numpy as np
from tensorflow.image import resize

inception_model = InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3))

def preprocess_images_for_inception(images):
    images_resized = resize(images, [299, 299])  
    images_rgb = tf.image.grayscale_to_rgb(images_resized)  
    images_preprocessed = preprocess_input(images_rgb * 255.0) 
    return images_preprocessed


Function to calculate FID score

In [11]:
def calculate_fid(real_images, generated_images):
    
    real_images_preprocessed = preprocess_images_for_inception(real_images)
    generated_images_preprocessed = preprocess_images_for_inception(generated_images)

    act1 = inception_model.predict(real_images_preprocessed)
    act2 = inception_model.predict(generated_images_preprocessed)

    mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)

    diff = np.sum((mu1 - mu2) ** 2)
    
    covmean = sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid


In [12]:
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=-1)

generator = generator_model()
generator.load_weights("generator.weights.h5")

noise = np.random.uniform(-1, 1, (1000, 100))
generated_images = generator.predict(noise)

real_images = X_train[np.random.choice(X_train.shape[0], 1000, replace=False)]

fid_score = calculate_fid(real_images, generated_images)
print(f"FID score: {fid_score:.2f}")


[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 28ms/step
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m86s[0m 3s/step
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m86s[0m 3s/step
FID score: 16.69
