In [1]:
# imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequential
import numpy as np
from PIL import Image
import tensorflowjs as tfjs
from split import split_keras_model

2022-10-13 10:52:54.719555: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [17]:
# Parameters in Here
image_size = 224
n_channels = 3
patch_size = 32
n_patches = (image_size // patch_size) ** 2
assert (image_size % patch_size == 0), 'image_size must be divisible by patch_size'
projection_dim = 8
transformer_layers = 2
num_heads = 8
transformer_units = [
    projection_dim * 2,
    projection_dim,
] 
num_classes = 2
mlp_head_units = [2048, 1024]

In [11]:
def plot_models(model, i):
    if i == 0: tf.keras.utils.plot_model(model, 'ViT_architecture.png')

    try: 
        for layer in model.layers:
            try: 
                tf.keras.utils.plot_model(layer, 'model_%i.png'%i)
                i += 1
                plot_models(layer, i)

            except: 
                i += 1
    except: 
        return

In [12]:
def save_models(model, saveModel=False, saveSplit=False, saveTFJS=False):
    if saveSplit or saveTFJS:
        submodel_1, submodel_2 = split_keras_model(model, 4)
        print('Model Successfuly Split')

    if saveModel:
        model.save('/savedModels/ViT/tf/fullModel')
        print('Model Successfuly Saved')


    if saveSplit:
        submodel_1.save('/savedModels/ViT/tf/split/subModel_1/')
        print('Submodel 1 Successfully Saved')
        submodel_2.save('/savedModels/ViT/tf/split/subModel_2/')
        print('Submodel 2 Successfully Saved')

    if saveTFJS:
        try: 
            tfjs.converters.save_keras_model(submodel_2, '/savedModels/ViT/tfjs/subModel_2')
            print('Submodel 2 Successfully Converted and Saved')

        except: 
            print('Submodel Conversion Failed')
            print('Ensure model is split at valid index for TFJS conversion')


In [2]:
# MLP Definition Function
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

In [19]:
# Data Augmentation Block
def data_augmentation_block(image_size):
    model = tf.keras.Sequential(
        [
            layers.Normalization(),
            layers.Resizing(image_size, image_size),
            layers.RandomFlip("horizontal"),
            layers.RandomRotation(factor=0.02),
            layers.RandomZoom(
                height_factor=0.2, width_factor=0.2
            ),
        ],name="data_augmentation",)

    return model

In [4]:
# Patch Creation Block
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

In [6]:
# Patch Encoder Block
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

In [7]:
# MSA Creation 
def create_msa_block(n_heads, projection_dim, key_dim, name, dropout=0.1):
    main_input = keras.Input(shape=(49, projection_dim), name='MSA_Main_Input')

    heads_output = []
    for index in range(n_heads):

        head_name = 'MSA_Head_%i' %index

        # create the queries, keys, and values by passing the embedded patches 
        # through distinct linear projections 
        value_dense = layers.Dense(units=key_dim, name='Value_Dense_%i'%index)(main_input)
        query_dense = layers.Dense(units=key_dim, name='Query_Dense_%i'%index)(main_input)
        key_dense = layers.Dense(units=key_dim, name='Key_Dense_%i'%index)(main_input)

        # perform scaled dot-product attention on the queries, keys, and values
        # using tfjs compatible layers

        dot_layer_1 = layers.Dot(axes=1, name='Dot_Layer_1_%i'%index)([query_dense, key_dense])
        softmax_layer = layers.Softmax(name='Softmax_%i'%index)(dot_layer_1)
        dot_layer_2 = layers.Dot(axes=2, name='Dot_Layer_2_%i'%index)([value_dense, softmax_layer])

        heads_output.append(dot_layer_2)
    
    concatenate_layer = layers.Concatenate(axis=2, name='Concatenate_%i'%index)(heads_output)
    output = layers.Dense(units=key_dim, name='Final_MSA_Linear_%i'%index)(concatenate_layer)

    MSA_block = tf.keras.Model(inputs=main_input, outputs=output, name=name)

    return MSA_block

In [8]:
# Transformer Encoder Creation
def create_transformer_block(input_shape, transformer_layers):
    # input to this block is patch embeddings
    # input_shape is the shape of the patch embeddings. Must be a tuple
    input = tf.keras.Input(shape=input_shape, name='Transformer_Block_Input')

    for _ in range(transformer_layers):
        # MSA Block
        if _ == 0: 
            layer_norm_1 = layers.LayerNormalization(epsilon=1e-6, name='Layer_Norm_1_%i'%_)(input)
            MSA = create_msa_block(num_heads, projection_dim, projection_dim, name='MSA_Block_%i'%_, dropout=0.1)(layer_norm_1)
            residual_connection_1 = layers.Add(name='Residual_Connection_%i'%_)([input, MSA])
        else: 
            layer_norm_1 = layers.LayerNormalization(epsilon=1e-6, name='Layer_Norm_1_%i'%_)(output)
            MSA = create_msa_block(num_heads, projection_dim, projection_dim, name='MSA_Block_%i'%_, dropout=0.1)(layer_norm_1)
            residual_connection_1 = layers.Add(name='Residual_Connection_%i'%_)([output, MSA])

        # MLP Block
        dropout_rate = 0.1        
        layer_norm_2 = layers.LayerNormalization(epsilon=1e-6, name='Layer_Norm_2_%i'%_)(residual_connection_1)
        # create as many sections in transformer_units
        for index, units in enumerate(transformer_units):
            try: 
                mlp = layers.Dense(units, activation=tf.nn.gelu, name='Dense_%i'%index + '_%i'%_)(layer_norm_2)
            except: 
                mlp = layers.Dense(units, activation=tf.nn.gelu, name='Dense_%i'%index + '_%i'%_)(mlp)
                
            mlp = layers.Dropout(dropout_rate, name='Dropout_%i'%index + '_%i'%_)(mlp)

        output = layers.Add(name='Transformer_Encoder_Output_%i'%_)([residual_connection_1, mlp])

    transformer_encoder_block = keras.Model(inputs=input, outputs=output, name='transformer_encoder_block')
    return transformer_encoder_block


In [9]:
# MLP Classifier Creation
def mlp_head():
    model = tf.keras.Sequential([
        layers.LayerNormalization(epsilon=1e-6),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(mlp_head_units[0], activation=tf.nn.gelu),
        layers.Dropout(0.5),
        layers.Dense(num_classes)
        ], name='MLP_head')
    
    return model

In [22]:
# ViT Instantiation
def create_vit_classifier_custom():

    model = tf.keras.Sequential([
        layers.Input(shape=(image_size, image_size, n_channels)),
        data_augmentation_block(image_size), 
        Patches(patch_size),
        PatchEncoder(n_patches, projection_dim),
        create_transformer_block((n_patches, projection_dim), transformer_layers),
        mlp_head()
    ], name='Vision_Transformer')

    return model

In [23]:
model = create_vit_classifier_custom()
model.summary()

# plot_models(model, 0)
save_models(model, saveModel=True, saveSplit=True, saveTFJS=True)


Model: "Vision_Transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 data_augmentation (Sequenti  (None, 224, 224, 3)      7         
 al)                                                             
                                                                 
 patches (Patches)           (None, None, 3072)        0         
                                                                 
 patch_encoder (PatchEncoder  (None, 49, 8)            24976     
 )                                                               
                                                                 
 transformer_encoder_block (  (None, 49, 8)            4704      
 Functional)                                                     
                                                                 
 MLP_head (Sequential)       (None, 2)                 808978    
                                                



































































































































































































































PermissionDeniedError: /savedModels; Read-only file system