# Starting creating our own model architecture

In [1]:
import pandas as pd
import tensorflow as tf
from CancerClassification.utils.utility import read_yaml, load_hyperparameters
from CancerClassification.constants import *
from CancerClassification.config.configuration import configManager

from CancerClassification.components.data_preparation import DataPreparation


In [2]:
config = read_yaml(CONFIG_FILE_PATH)
params = read_yaml(PARAMS_FILE_PATH)

In [3]:
params

ConfigBox({'IMAGE_SIZE': 512, 'NUM_CHANNELS': 3, 'PATCH_SIZE': 64, 'BATCH_SIZE': 32, 'LEARNING_RATE': '1e-4', 'EPOCHS': 30, 'NUM_CLASSES': 22, 'NUM_LAYERS': 12, 'HIDDEN_DIM': 512, 'MLP_DIM': 3072, 'NUM_HEADS': 12, 'DROPOUT_RATE': 0.1})

In [4]:
print("Initialising config manager...")
c = configManager()
dpc = c.get_data_preparation_config()
dp = DataPreparation(dpc)

print("Running DataPreparation...")
train, valid, test, class_names = dp.run()
print(class_names)

# Print a single batch shape
for i, j in train:
    print(i.shape, j.shape)
    break

Initialising config manager...
Running DataPreparation...
['brain_glioma', 'brain_menin', 'brain_tumor', 'breast_benign', 'breast_malignant', 'cervix_dyk', 'cervix_koc', 'cervix_mep', 'cervix_pab', 'cervix_sfi', 'kidney_normal', 'kidney_tumor', 'colon_aca', 'colon_bnt', 'lung_aca', 'lung_bnt', 'lung_scc', 'lymph_cll', 'lymph_fl', 'lymph_mcl', 'oral_normal', 'oral_scc']
(32, 64, 12288) (32, 22)


In [5]:
from tensorflow.keras import layers # type: ignore 
from tensorflow.keras.models import Model # type: ignore 
from tensorflow.keras import callbacks # type: ignore 

In [6]:
class ClassToken(layers.Layer):
    def __init__(self):
        super().__init__()

    def build(self, input_shape):
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            initial_value = w_init(shape=(1, 1, input_shape[-1]), dtype=tf.float32), 
            trainable = True
        )

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        hidden_dim = self.w.shape[-1]

        #reshape
        cls = tf.broadcast_to(self.w, [batch_size, 1, hidden_dim])
        #change data type
        cls = tf.cast(cls, dtype=inputs.dtype)
        return cls 

In [7]:
def mlp(x, cf):
    x = layers.Dense(cf['MLP_DIM'], activation='gelu')(x)
    x = layers.Dropout(cf['DROPOUT_RATE'])(x)
    x = layers.Dense(cf['HIDDEN_DIM'])(x)
    x = layers.Dropout(cf['DROPOUT_RATE'])(x)
    return x

In [8]:
def transformer_encoder(x, cf):
    skip_1 = x
    x = layers.LayerNormalization()(x)
    x = layers.MultiHeadAttention(num_heads=cf['NUM_HEADS'], key_dim=cf['HIDDEN_DIM'])(x,x)
    x = layers.Add()([x, skip_1])
    
    skip_2 = x
    x = layers.LayerNormalization()(x)
    x = mlp(x, cf)
    x = layers.Add()([x, skip_2])
    
    return x

In [9]:
def ViT(cf):
    input_shape = (cf['NUM_PATCHES'], cf['PATCH_SIZE']*cf['PATCH_SIZE']*cf['NUM_CHANNELS'])
    inputs = layers.Input(input_shape) #(None, 256, 3072)
    
    #patch + Position embedding
    patch_embed = layers.Dense(cf['HIDDEN_DIM'])(inputs) #(None, 256, 768)
    
    positions = tf.range(start=0, limit=cf['NUM_PATCHES'], delta=1)
    pos_emb = layers.Embedding(input_dim=cf['NUM_PATCHES'], output_dim=cf['HIDDEN_DIM'])(positions) #(256, 768)
    
    embed = patch_embed + pos_emb #(None, 256, 768)
    
    token = ClassToken()(embed)
    x = layers.Concatenate(axis=1)([token, embed]) #(None, 257, 768)
    
    for _ in range(cf['NUM_LAYERS']):
        x = transformer_encoder(x, cf)
        
    x = layers.LayerNormalization()(x)
    x = x[:, 0, :]
    x = layers.Dense(cf['NUM_CLASSES'], activation='softmax')(x)
    
    model = Model(inputs, x)
    return model

In [10]:
hp = load_hyperparameters(
    path=PARAMS_FILE_PATH,
    s3_bucket=dp.config.s3_bucket,
    data_folder=dp.config.class_structure
)

In [11]:
hp

{'IMAGE_SIZE': 512,
 'NUM_CHANNELS': 3,
 'PATCH_SIZE': 64,
 'BATCH_SIZE': 32,
 'LEARNING_RATE': '1e-4',
 'EPOCHS': 30,
 'NUM_CLASSES': 22,
 'NUM_LAYERS': 12,
 'HIDDEN_DIM': 512,
 'MLP_DIM': 3072,
 'NUM_HEADS': 12,
 'DROPOUT_RATE': 0.1,
 'NUM_PATCHES': 64,
 'FLAT_PATCHES_SHAPE': (64, 12288),
 'CLASS_NAMES': ['brain_glioma',
  'brain_menin',
  'brain_tumor',
  'breast_benign',
  'breast_malignant',
  'cervix_dyk',
  'cervix_koc',
  'cervix_mep',
  'cervix_pab',
  'cervix_sfi',
  'kidney_normal',
  'kidney_tumor',
  'colon_aca',
  'colon_bnt',
  'lung_aca',
  'lung_bnt',
  'lung_scc',
  'lymph_cll',
  'lymph_fl',
  'lymph_mcl',
  'oral_normal',
  'oral_scc']}

In [12]:
model = ViT(hp)
model.summary()