In [4]:
import tensorflow as tf
import numpy as np

In [5]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [6]:

from tensorflow.keras import layers,models,optimizers

Train Directory 

In [7]:
train_dir = '/Users/uvaishnav/osteoscarcoma_evaluation_project/artifacts/train'

Defining fixed Variables

In [8]:
img_height = 224
img_width = 224
batch_size = 32
num_classes = 3
latent_dim = 100

Load Images

In [9]:
datagen = ImageDataGenerator(
    rescale = 1/255.0
)

train_generator = datagen.flow_from_directory(
    train_dir,
    target_size = (img_height,img_width),
    batch_size = batch_size,
    class_mode = 'categorical',
    shuffle = True
)

Found 800 images belonging to 3 classes.


Combining One-hot labels with noise and images

In [10]:
def combine_noise_labels_gen(generator_input, labels_input):
    return layers.concatenate([generator_input,labels_input])

def combine_img_labels_critic(critic_input, labels_input,input_shape):
    label_resized = layers.Reshape((1, 1, num_classes))(labels_input)                  # making label size consistent to image size
    label_resized = tf.image.resize(label_resized, (input_shape[0], input_shape[1]))

    return layers.Concatenate(axis=-1)([critic_input, label_resized])
    

Define Generator

In [11]:
def build_generator(latent_dim=latent_dim,num_classes=num_classes):
    generator_input = layers.Input(shape=(latent_dim,))
    labels_input = layers.Input(shape=(num_classes,))
    combined_input = combine_noise_labels_gen(generator_input, labels_input)

    x = layers.Dense(7*7*512, use_bias=False)(combined_input)
    x = layers.Reshape((7,7,512))(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2DTranspose(512, 5, strides=2, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2DTranspose(256, 3, strides=2, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2DTranspose(256, 3, strides=2, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2DTranspose(64, 3, strides=2, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2DTranspose(32, 4, strides=2, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2DTranspose(3,7,padding='same',activation='tanh',use_bias=False)(x)



    generator = models.Model([generator_input, labels_input],x)
    return generator




Gettig generator Summary (test)

In [12]:
test_gen = build_generator()
test_gen.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 100)]                0         []                            
                                                                                                  
 input_2 (InputLayer)        [(None, 3)]                  0         []                            
                                                                                                  
 concatenate (Concatenate)   (None, 103)                  0         ['input_1[0][0]',             
                                                                     'input_2[0][0]']             
                                                                                                  
 dense (Dense)               (None, 25088)                2584064   ['concatenate[0][0]']     

Define Descriminator

In [13]:
def build_critic(input_shape=(224,224,3),num_classes=num_classes):
    critic_input = layers.Input(shape=input_shape)
    labels_input = layers.Input(shape=(num_classes,))

    combined_input = combine_img_labels_critic(critic_input, labels_input, input_shape)

    x = layers.Conv2D(64,3,strides=2,padding='same')(combined_input)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dropout(0.5)(x)

    x = layers.Conv2D(128,3,strides=2,padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dropout(0.5)(x)

    x = layers.Conv2D(256,3,strides=2,padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dropout(0.5)(x)

    x = layers.Conv2D(512,3,strides=2,padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dropout(0.5)(x)

    x = layers.Conv2D(512,3,strides=2,padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dropout(0.5)(x)

    x = layers.Flatten()(x)

    x = layers.Dense(120,activation='relu')(x)
    x = layers.Dense(84,activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(1)(x)

    critic = models.Model([critic_input, labels_input], x)

    return critic
    

Getting discriminator Summary (Test)

In [14]:
disc_test = build_critic()
disc_test.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_4 (InputLayer)        [(None, 3)]                  0         []                            
                                                                                                  
 reshape_1 (Reshape)         (None, 1, 1, 3)              0         ['input_4[0][0]']             
                                                                                                  
 input_3 (InputLayer)        [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 tf.image.resize (TFOpLambd  (None, 224, 224, 3)          0         ['reshape_1[0][0]']           
 a)                                                                                         

To display Images to check Generator Performance

In [15]:
import matplotlib.pyplot as plt

In [25]:
class_names = train_generator.class_indices.keys()
print(class_names)

dict_keys(['Non-Tumor', 'Non-Viable-Tumor', 'Viable'])


In [26]:
# for label,class_name in enumerate(class_names):
#     print(label, class_name)
#     lable_one_hot = tf.one_hot([label],num_classes)
#     print(lable_one_hot)

In [None]:
def generate_and_display_images(generator, epoch, latent_dim = latent_dim,class_names=class_names):
    plt.figure(figsize=(10,10))
    for class_label, class_name in enumerate(class_names):

        # Create a one-hot encoded label for the current class
        label = tf.one_hot([class_label],num_classes)

        # Generate a fake image using random noise and the one-hot encoded label
        noise = tf.random.normal(shape=(1,latent_dim))
        generate_image = generator([noise,label],training=False)

        # Plot the generated image
        plt.subplot(1, num_classes, class_label+1)
        # Generated image pixels will be in range[-1,1] but imshow in matplotlib expects them between[0,1] so,
        plt.imshow((generate_image[0] * 0.5 + 0.5).numpy())
        plt.title(class_name)
        plt.axis("off")
    
    plt.subtitle("Generated Images at Epoch {}".format(epoch))
    plt.show()
        

Train Variables

In [None]:
c_lambda = 10
num_epochs = 20
learning_rate = 0.00002
critic_train_steps = 5

Wasserstein Losses

In [None]:
def compute_generator_loss(critic_fake_score):
    # critic_fake_score : the critic's scores of the fake images
    gen_loss = -tf.reduce_mean(critic_fake_score)
    return gen_loss

def compute_critic_loss(critic_fake_score,critic_real_score,gradient_penality,c_lambda=c_lambda):
    # critic_fake_score : the critic's scores of the fake images
    # critic_real_score : the critic's scores of the real images
    # gradient_penality : the unweighted gradient penalty
    # c_lambda: the current weight of the gradient penalty (hyperparameter usually initialized to 10)

    critic_loss = tf.reduce_mean(critic_fake_score) - tf.reduce_mean(critic_real_score) + c_lambda * gradient_penality
    return critic_loss

    

Gradient penality

1) Compute Gradient with respect to images
2) Compute Gradient Penality given the gradient

-  The gradient is computed by first creating a mixed image.
-  This is done by Weighing fake and real images using a parameter (epsilon) and adding them together, we will get an intermediate mixed image
-  Get the critics output on the mixed image (crictic score)
-  Compute the gradient of critic's score on mixed images.

In [None]:
def get_gradient(critic, real, fake, labels_one_hot, epsilon):
    # critic : The critic model
    # real : batch of real images
    # fake : batch of fake images
    # epsilon : a vector of the uniformly random proportions of real/fake per mixed image

    with tf.GradientTape() as tape:
        mixed_images = (real * epsilon) + (fake * (1 - epsilon))
        tape.watch(mixed_images)
        mixed_score = critic([mixed_images, labels_one_hot])
        # print(mixed_score)
        
    gradient_mixed = tape.gradient(mixed_score, mixed_images)
    # print(gradient_mixed)
    return gradient_mixed

In [None]:
def compute_gradient_penalty(gradient):
    # gradient : gradient of the critic score with respect to mixed images

    gradient = tf.reshape(gradient,(tf.shape(gradient)[0],-1))
    gradient_norm = tf.norm(gradient,ord=2,axis=1)

    penalty = tf.reduce_mean((gradient_norm - 1)**2)

    return penalty

Define Optimizer

In [None]:
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=learning_rate)

In [None]:
# count=0
# for i in range(50):
#     count = count+1
#     real_batch_imgs,labels = train_generator.next()
#     for img in real_batch_imgs:
#         print(img.shape)

# print(count)

In [None]:
# imgs,real_labels_batch = train_generator.next()
# real_labels_onehot = tf.one_hot(tf.argmax(real_labels_batch, axis=1), num_classes)

In [None]:
generator = build_generator()
critic = build_critic()

In [None]:
# noise = tf.random.normal((batch_size,latent_dim))
# imgs = generator([noise, real_labels_onehot])
# import cv2
# print(imgs.shape)

Training

In [27]:
current_step = 0
display_step = 12
generator_losses = []
critic_losses = []

for epoch in range(num_epochs):
    for batch in range(len(train_generator)):
        real_images_batch, labels = train_generator.next()
        one_hot_labels = tf.one_hot(tf.argmax(labels, axis=1), num_classes)

        mean_iter_critic_loss = 0
        for step in range(critic_train_steps):
            ## Upadate Critic 
            with tf.GradientTape() as critic_tape:
                fake_noise = tf.random.normal(shape=(batch_size,latent_dim))
                generated_images = generator([fake_noise, one_hot_labels])

                critic_fake_predictions = critic([generated_images, one_hot_labels],training=True)
                critic_real_predictions = critic([real_images_batch, one_hot_labels],training=True)

                epsilon = tf.random.uniform((batch_size,1,1,1))
                gradient_mix = get_gradient(critic=critic, real=real_images_batch, fake=generated_images, labels_one_hot=one_hot_labels, epsilon=epsilon)
                gradient_penalty = compute_gradient_penalty(gradient_mix)
                critic_loss = compute_critic_loss(critic_fake_score=critic_fake_predictions, critic_real_score=critic_real_predictions,gradient_penality=gradient_penalty,c_lambda=c_lambda)
            
            # Keep track of the average critic loss in this batch
            mean_iter_critic_loss += critic_loss / critic_train_steps

            # Update gradients
            critic_gradients = critic_tape.gradient(critic_loss, critic.trainable_variables)
            optimizer.apply_gradients(zip(critic_gradients, critic.trainable_variables))
        
        critic_losses.append(mean_iter_critic_loss.numpy())

        # Update Generator
        with tf.GradientTape() as generator_tape:
            fake_noise_2 = tf.random.normal(shape=(batch_size,latent_dim))
            generated_images_2 = generator([fake_noise_2, one_hot_labels],training=True)
            critic_fake_predictions_2 = critic([generated_images_2, one_hot_labels],training=True)
            generator_loss = compute_generator_loss(critic_fake_score=critic_fake_predictions_2)
        
        # Update Weights
        generator_gradients = generator_tape.gradient(generator_loss, generator.trainable_variables)
        optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))

        generator_losses.append(generator_loss.numpy())

        if current_step % display_step == 0 and current_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            crit_mean = sum(critic_losses[-display_step:]) / display_step
            print(f"Epoch {epoch}, step {current_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
    
        current_step += 1
    
    if((epoch + 1) % 10 == 0):
        generate_and_display_images(generator=generator,epoch=epoch)




        


        
        