In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequential
import numpy as np
from PIL import Image
import tensorflowjs as tfjs


2022-10-13 11:23:00.265905: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

In [3]:
# Parameters in Here
image_size = 224
n_channels = 3
patch_size = 32
n_patches = (image_size // patch_size) ** 2
assert (image_size % patch_size == 0), 'image_size must be divisible by patch_size'
projection_dim = 8
transformer_layers = 2
num_heads = 8
transformer_units = [
    projection_dim * 2,
    projection_dim,
] 
num_classes = 2
mlp_head_units = [2048, 1024]

In [4]:
np_image = Image.open('../data/image.jpeg')
np_image = np_image.resize((image_size, image_size), Image.Resampling.BILINEAR)
np_image = np.asarray(np_image)
np_image_batch = []
i = 0
while i < 10:
    np_image_batch.append(np_image)
    i += 1

np_image_batch = np.asarray(np_image_batch)
tensor_image_batch = tf.convert_to_tensor(np_image_batch)


In [5]:
data_augmentation = tf.keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)

In [6]:
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

In [7]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

In [8]:
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

In [9]:
def create_msa_block(n_heads, projection_dim, key_dim, dropout=0.1):
    main_input = keras.Input(shape=(49, projection_dim), name='MSA_Main_Input')

    heads_output = []
    for index in range(n_heads):

        head_name = 'MSA_Head_%i' %index

        # create the queries, keys, and values by passing the embedded patches 
        # through distinct linear projections 
        value_dense = layers.Dense(units=key_dim, name='Value_Dense_%i'%index)(main_input)
        query_dense = layers.Dense(units=key_dim, name='Query_Dense_%i'%index)(main_input)
        key_dense = layers.Dense(units=key_dim, name='Key_Dense_%i'%index)(main_input)

        # perform scaled dot-product attention on the queries, keys, and values
        # using tfjs compatible layers

        dot_layer_1 = layers.Dot(axes=1, name='Dot_Layer_1_%i'%index)([query_dense, key_dense])
        softmax_layer = layers.Softmax(name='Softmax_%i'%index)(dot_layer_1)
        dot_layer_2 = layers.Dot(axes=2, name='Dot_Layer_2_%i'%index)([value_dense, softmax_layer])

        heads_output.append(dot_layer_2)
    
    concatenate_layer = layers.Concatenate(axis=2, name='Concatenate_%i'%index)(heads_output)
    output = layers.Dense(units=key_dim, name='Final_MSA_Linear_%i'%index)(concatenate_layer)

    MSA_block = tf.keras.Model(inputs=main_input, outputs=output, name='MSA_Block')

    return MSA_block


In [10]:
def create_transformer_block(input_shape, transformer_layers):
    # input to this block is patch embeddings
    # input_shape is the shape of the patch embeddings. Must be a tuple
    input = tf.keras.Input(shape=input_shape, name='Transformer_Block_Input')

    for _ in range(transformer_layers):
        # MSA Block
        if _ == 0: 
            layer_norm_1 = layers.LayerNormalization(epsilon=1e-6, name='Layer_Norm_1_%i'%_)(input)
            MSA = create_msa_block(num_heads, projection_dim, projection_dim, dropout=0.1)(layer_norm_1)
            residual_connection_1 = layers.Add(name='Residual_Connection_%i'%_)([input, MSA])
        else: 
            layer_norm_1 = layers.LayerNormalization(epsilon=1e-6, name='Layer_Norm_1_%i'%_)(output)
            MSA = create_msa_block(num_heads, projection_dim, projection_dim, dropout=0.1)(layer_norm_1)
            residual_connection_1 = layers.Add(name='Residual_Connection_%i'%_)([output, MSA])

        # MLP Block
        dropout_rate = 0.1        
        layer_norm_2 = layers.LayerNormalization(epsilon=1e-6, name='Layer_Norm_2_%i'%_)(residual_connection_1)
        # create as many sections in transformer_units
        for index, units in enumerate(transformer_units):
            try: 
                mlp = layers.Dense(units, activation=tf.nn.gelu, name='Dense_%i'%index + '_%i'%_)(mlp)
            except: 
                mlp = layers.Dense(units, activation=tf.nn.gelu, name='Dense_%i'%index + '_%i'%_)(layer_norm_2)
                
            mlp = layers.Dropout(dropout_rate, name='Dropout_%i'%index + '_%i'%_)(mlp)

        output = layers.Add(name='Transformer_Encoder_Output_%i'%_)([layer_norm_1, mlp])

    transformer_encoder_block = keras.Model(inputs=input, outputs=output, name='transformer_encoder_block')
    return transformer_encoder_block

In [11]:
transformer_block = create_transformer_block((n_patches, projection_dim), 2)
transformer_block.summary()

Model: "transformer_encoder_block"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 Transformer_Block_Input (Input  [(None, 49, 8)]     0           []                               
 Layer)                                                                                           
                                                                                                  
 Layer_Norm_1_0 (LayerNormaliza  (None, 49, 8)       16          ['Transformer_Block_Input[0][0]']
 tion)                                                                                            
                                                                                                  
 MSA_Block (Functional)         (None, 49, 8)        2248        ['Layer_Norm_1_0[0][0]']         
                                                                          

In [12]:
msa = create_msa_block(2, 8, 4, dropout=0.1)
dummmy_projection = tf.ones([2, 49, 8])
msa.summary()
print(msa(dummmy_projection))

Model: "MSA_Block"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 MSA_Main_Input (InputLayer)    [(None, 49, 8)]      0           []                               
                                                                                                  
 Query_Dense_0 (Dense)          (None, 49, 4)        36          ['MSA_Main_Input[0][0]']         
                                                                                                  
 Key_Dense_0 (Dense)            (None, 49, 4)        36          ['MSA_Main_Input[0][0]']         
                                                                                                  
 Query_Dense_1 (Dense)          (None, 49, 4)        36          ['MSA_Main_Input[0][0]']         
                                                                                          

In [13]:
def mlp_classifier():
    model = tf.keras.Sequential()
    model.add(tf.keras.Input())
    model.add(layers.LayerNormalization(epsilon=1e-6))
    model.add(layers.Flatten())
    x = layers.Dropout(0.5)
    model.add(x)
    layers.Dense(mlp_head_units[0], activation=tf.nn.gelu)(x)
    layers.Dropout(0.5)(x)
    model.add(layers.Dense(num_classes))
    
    return model

In [14]:
# Custom
def create_vit_classifier_custom():

    model = tf.keras.Sequential([
        layers.Input(shape=(image_size, image_size, n_channels)),
        data_augmentation, 
        Patches(patch_size),
        PatchEncoder(n_patches, projection_dim),
        create_transformer_block((n_patches, projection_dim), transformer_layers),
        layers.LayerNormalization(epsilon=1e-6),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(mlp_head_units[0], activation=tf.nn.gelu),
        layers.Dropout(0.5),
        layers.Dense(num_classes)
    ])

    # # Augment data.
    # augmented = data_augmentation(inputs)
    # # Create patches.
    # patches = Patches(patch_size)(augmented)
    # # Encode patches.
    # encoded_patches = PatchEncoder(n_patches, projection_dim)(patches)

    # # Create multiple layers of the Transformer block.
    # transformer_block = create_transformer_block((n_patches, projection_dim), 2)(encoded_patches)
    
    # # Create a [batch_size, projection_dim] tensor.
    # representation = layers.LayerNormalization(epsilon=1e-6)(transformer_block)
    # representation = layers.Flatten()(representation)
    # representation = layers.Dropout(0.5)(representation)
    # # Add MLP.
    # features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # # Classify outputs.
    # logits = layers.Dense(num_classes)(features)
    # # Create the Keras model.
    # model = tf.keras.Model(inputs=inputs, outputs=logits)

    return model

In [15]:
model = create_vit_classifier_custom()
model.summary()
dummy_data = tf.ones((10, 224, 224, 3))
print(model(dummy_data))

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 data_augmentation (Sequenti  (None, 224, 224, 3)      7         
 al)                                                             
                                                                 
 patches (Patches)           (None, None, 3072)        0         
                                                                 
 patch_encoder (PatchEncoder  (None, 49, 8)            24976     
 )                                                               
                                                                 
 transformer_encoder_block (  (None, 49, 8)            2856      
 Functional)                                                     
                                                                 
 layer_normalization (LayerN  (None, 49, 8)            16        
 ormalization)                                          

In [16]:
def split_keras_model(model, index):

    layer_input_1 = keras.layers.Input(model.layers[0].input_shape[1:])
    layer_input_2 = keras.layers.Input(model.layers[index].input_shape[1:])
    x = layer_input_1
    y = layer_input_2

    for layer in model.layers[0:index]:
        x = layer(x)

    for layer in model.layers[index:]:
        y = layer(y)

    model1 = keras.Model(inputs=layer_input_1, outputs=x)
    model2 = keras.Model(inputs=layer_input_2, outputs=y)
    
    print('-- -- -- -- -- FULL MODEL -- -- -- -- --')
    model.summary()

    print('-- -- -- - FEATURE EXTRACTION - -- -- --')
    model1.summary()

    print('-- -- -- -- - DENSE LAYERS - -- -- -- --')
    model2.summary()

    return (model1, model2)

In [17]:
submodel_1, submodel_2 = split_keras_model(model, 3)
model.save('../savedModels/ViT/tf/fullModel')
submodel_1.save('../savedModels/ViT/tf/split/subModel_1')
submodel_2.save('../savedModels/ViT/tf/split/subModel_2')

-- -- -- -- -- FULL MODEL -- -- -- -- --
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 data_augmentation (Sequenti  (None, 224, 224, 3)      7         
 al)                                                             
                                                                 
 patches (Patches)           (None, None, 3072)        0         
                                                                 
 patch_encoder (PatchEncoder  (None, 49, 8)            24976     
 )                                                               
                                                                 
 transformer_encoder_block (  (None, 49, 8)            2856      
 Functional)                                                     
                                                                 
 layer_normalization (LayerN  (None, 49, 8)            16        
 ormalization) 



































































































































































































































INFO:tensorflow:Assets written to: ../savedModels/ViT/tf/fullModel/assets


INFO:tensorflow:Assets written to: ../savedModels/ViT/tf/fullModel/assets






































































































































































































































INFO:tensorflow:Assets written to: ../savedModels/ViT/tf/split/subModel_1/assets


INFO:tensorflow:Assets written to: ../savedModels/ViT/tf/split/subModel_1/assets






INFO:tensorflow:Assets written to: ../savedModels/ViT/tf/split/subModel_2/assets


INFO:tensorflow:Assets written to: ../savedModels/ViT/tf/split/subModel_2/assets


In [19]:
try: 
    tfjs.converters.save_keras_model(model, '../savedModels/ViT/tfjs/fullModel', )
except:
    try:
        tfjs.converters.save_keras_model(submodel_1, '../savedModels/ViT/tfjs/subModel_1', )
    except: 
        try:
            tfjs.converters.save_keras_model(submodel_2, '../savedModels/ViT/tfjs/subModel_2', )
        except: 
            print('All Roads Fail')

model.summary()
submodel_1.summary()
submodel_2.summary()
        













Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 data_augmentation (Sequenti  (None, 224, 224, 3)      7         
 al)                                                             
                                                                 
 patches (Patches)           (None, None, 3072)        0         
                                                                 
 patch_encoder (PatchEncoder  (None, 49, 8)            24976     
 )                                                               
                                                                 
 transformer_encoder_block (  (None, 49, 8)            2856      
 Functional)                                                     
                                                                 
 layer_normalization (LayerN  (None, 49, 8)            16        
 ormalization)                                          