In [1]:
import tensorflow as tf

In [3]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [7]:
import numpy as np
from tensorflow.keras import layers,models

Train Directory 

In [2]:
train_dir = '/Users/uvaishnav/osteoscarcoma_evaluation_project/artifacts/train'

Defining fixed Variables

In [11]:
img_height = 224
img_width = 224
batch_size = 32
num_classes =3

Load Images

In [6]:
datagen = ImageDataGenerator(
    rescale = 1/255.0
)

train_generator = datagen.flow_from_directory(
    train_dir,
    target_size = (img_height,img_width),
    batch_size = batch_size,
    class_mode = 'categorical',
    shuffle = True
)

Found 800 images belonging to 3 classes.


Adaptive Normalization

Instead of using BatchNormalization, which normalizes each feature map based on the statistics computed over the entire batch, it is prefered to to use adaptive normalization technique similar to SPADE. 
his allows adaptive Normalization based on class labels.(meaning it can adapt normalization parameters based on the semantic content of the input)

In [14]:
class SPADE(layers.Layer):
    def __init__(self):                  ## This defines a custom Keras layer named SPADE. It inherits from the layers.Layer class, which is the base class for all Keras layers. 
        super(SPADE,self).__init__()
    
    def build(self,input_shape):
        input_shape = input_shape[0].as_list()
        print(input_shape)
        self.gama_fc = layers.Dense(input_shape[-1],use_bias=False)     # Scale Parameter
        self.beta_fc = layers.Dense(input_shape[-1],use_bias=False)     # Shift Parameter

    def call(self,inputs):
        x, cls_info = inputs
        gama = self.gama_fc(cls_info)
        beta = self.beta_fc(cls_info)

        return gama[:,None,None] * x + beta[:,None,None]

Define Generator

In [23]:
def build_generator(latent_dim=100,num_classes=num_classes):
    generator_input = layers.Input(shape=(latent_dim,))
    labels_input = layers.Input(shape=(num_classes,))
    combined_input = layers.concatenate([generator_input,labels_input])

    x = layers.Dense(7*7*512, use_bias=False)(combined_input)
    x = layers.Reshape((7,7,512))(x)
    x = SPADE()([x,labels_input])
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2DTranspose(512, 5, strides=2, padding='same', use_bias=False)(x)
    x = SPADE()([x,labels_input])
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2DTranspose(256, 3, strides=2, padding='same', use_bias=False)(x)
    x = SPADE()([x,labels_input])
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2DTranspose(256, 3, strides=2, padding='same', use_bias=False)(x)
    x = SPADE()([x,labels_input])
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2DTranspose(64, 3, strides=2, padding='same', use_bias=False)(x)
    x = SPADE()([x,labels_input])
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2DTranspose(32, 4, strides=2, padding='same', use_bias=False)(x)
    x = SPADE()([x,labels_input])
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2DTranspose(3,7,padding='same',activation='tanh',use_bias=False)(x)



    generator = models.Model([generator_input, labels_input],x)
    return generator




Gettig generator Summary (test)

In [24]:
test_gen = build_generator()
test_gen.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_11 (InputLayer)       [(None, 100)]                0         []                            
                                                                                                  
 input_12 (InputLayer)       [(None, 3)]                  0         []                            
                                                                                                  
 concatenate_3 (Concatenate  (None, 103)                  0         ['input_11[0][0]',            
 )                                                                   'input_12[0][0]']            
                                                                                                  
 dense_4 (Dense)             (None, 25088)                2584064   ['concatenate_3[0][0]'] 