# **Imports**

In [None]:
import os
import shutil
import json
from contextlib import redirect_stdout

import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf

In [None]:
gpu_devices = tf.config.list_physical_devices('GPU')
print(gpu_devices)
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

# **Dataset**

In [None]:
dataset_path = "./speech_commands_v2_spectrograms/"
output_path = "./save_files/"
save_folder = "vit"

In [None]:
train_dir = os.path.join(dataset_path, "train")
test_dir = os.path.join(dataset_path, "test")
val_dir = os.path.join(dataset_path, "val")

In [None]:
label_names = np.array([x for x in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir,x))])
label_names.sort()
num_labels = len(label_names)

print(num_labels, "labels:\n", label_names)

In [None]:
IMG_SIZE = (128, 101)
BATCH_SIZE = 32

input_shape = IMG_SIZE + (1,)

train_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    shuffle=True,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    image_size=IMG_SIZE
)

val_ds = tf.keras.utils.image_dataset_from_directory(
    val_dir,
    shuffle=True,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    image_size=IMG_SIZE
)

test_ds = tf.keras.utils.image_dataset_from_directory(
    test_dir,
    shuffle=False,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    image_size=IMG_SIZE
)

In [None]:
rescale = tf.keras.layers.Rescaling(scale=1./255)
def rescale_ds(x, y):
    return rescale(x), y

train_ds = train_ds.map(rescale_ds)
test_ds = test_ds.map(rescale_ds)
val_ds = val_ds.map(rescale_ds)

In [None]:
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.prefetch(tf.data.AUTOTUNE)
# test_ds = test_ds.prefetch(tf.data.AUTOTUNE)

# **Save Helpers**

In [None]:
model_save_loc =  os.path.join(output_path, save_folder)
print("Model save location:", model_save_loc)
os.makedirs(model_save_loc, exist_ok=True)

In [None]:
def get_prev_save_file_name(model_save_loc):
    prev_save_file=""

    save_files = os.listdir(model_save_loc)
    save_files = [k for k in save_files if (k[-2:]=="h5" and k[-7:-3]!="best")]
    if len(save_files)>=1:    
        save_files.sort(reverse=True)
        prev_save_file = save_files[0]

    return prev_save_file


def get_prev_best_save_file_name(model_save_loc):
    prev_best_file = ""

    save_files = os.listdir(model_save_loc)
    best_save_files = [k for k in save_files if k[-7:]=="best.h5"]
    if len(best_save_files)>=1:    
        best_save_files.sort(reverse=True)
        prev_best_file = best_save_files[0]
    
    return prev_best_file



class CustomModelCheckPoint(tf.keras.callbacks.Callback):
    def __init__(self, model_save_loc, prev_save_file="", prev_best_file="", prev_best_acc=0, model_name="epoch", **kargs):
        super(CustomModelCheckPoint,self).__init__(**kargs)
        self.model_save_loc = model_save_loc
        self.model_name = model_name
        self.prev_save_file = prev_save_file
        self.prev_best_file = prev_best_file
        self.prev_best_acc = prev_best_acc

    def on_epoch_end(self, epoch, logs={}):
        # acc = logs.get("accuracy")
        val_acc = logs.get("val_accuracy")

        filename =  f"{self.model_name}_{(epoch+1):03d}-{val_acc:.3f}.h5"
        self.model.save_weights(os.path.join(self.model_save_loc, filename)) # save the model
        
        # remove previous epoch save files
        if self.prev_save_file:
            delete_filename = os.path.join(self.model_save_loc, self.prev_save_file)
            open(delete_filename, 'w').close() # overwrite and make the file blank
            os.remove(delete_filename)
        self.prev_save_file = filename

        # save best model till now
        if val_acc > self.prev_best_acc:           
            best_filename = filename[:-3]+"_best.h5"
            shutil.copy(os.path.join(self.model_save_loc, filename), os.path.join(self.model_save_loc, best_filename))
           
            if self.prev_best_file:
                delete_filename = os.path.join(self.model_save_loc, self.prev_best_file)
                open(delete_filename, 'w').close() # overwrite and make the file blank
                os.remove(delete_filename)
           
            self.prev_best_acc = val_acc
            self.prev_best_file = best_filename

# **Model**

## **Layers**

In [None]:
def generate_patches(inputs, patch_size, patch_overlap=0, hidden_size=None):
    patch_stride = patch_size - patch_overlap
    if hidden_size is None:
        hidden_size = patch_stride * patch_stride

    patches = tf.keras.layers.Conv2D(
        filters=hidden_size, 
        kernel_size=patch_size, 
        strides=patch_stride, 
        padding='valid',
        name='embedding'
    )(inputs)
    
    _, w, h, _ = patches.shape

    # seq_len = (inputs.shape[1] // patch_size) * (inputs.shape[2] // patch_size)
    seq_len = w*h
    x = tf.reshape(patches, [-1, seq_len, hidden_size])
    return x

In [None]:
# taken from https://github.com/tensorflow/models/blob/master/official/vision/modeling/backbones/vit.py
@tf.keras.utils.register_keras_serializable()
class TokenLayer(tf.keras.layers.Layer):
    """A simple layer to wrap token parameters."""

    def build(self, inputs_shape):
        self.cls = self.add_weight(
            'cls', (1, 1, inputs_shape[-1]), initializer='zeros')

    def call(self, inputs):
        cls = tf.cast(self.cls, inputs.dtype)
        cls = cls + tf.zeros_like(inputs[:, 0:1])  # A hacky way to tile.
        x = tf.concat([cls, inputs], axis=1)
        return x
    
    def get_config(self):
        config = super().get_config()
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

In [None]:
@tf.keras.utils.register_keras_serializable()
class AddPositionEmbs(tf.keras.layers.Layer):
    """Adds (optionally learned) positional embeddings to the inputs."""

    def build(self, inputs_shape):
        pos_emb_shape = (1, inputs_shape[1], inputs_shape[2])
        self.pos_embedding = self.add_weight(
            'pos_embedding', 
            pos_emb_shape, 
            initializer=tf.keras.initializers.RandomNormal(stddev=0.02)
        )

    def call(self, inputs, inputs_positions=None):
        # inputs.shape is (batch_size, seq_len, emb_dim).
        pos_embedding = tf.cast(self.pos_embedding, inputs.dtype)

        return inputs + pos_embedding
    
    def get_config(self):
        config = super().get_config()
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

In [None]:
def mlp_block(inputs, mlp_dim, dropout_rate, activation=tf.nn.gelu):
    x = tf.keras.layers.Dense(units=mlp_dim, activation=activation)(inputs)
    if dropout_rate>0:
        x = tf.keras.layers.Dropout(rate=dropout_rate)(x)
    x = tf.keras.layers.Dense(units=inputs.shape[-1], activation=activation)(x)
    if dropout_rate>0:
        x = tf.keras.layers.Dropout(rate=dropout_rate)(x)

    return x

In [None]:
def Encoder1Dblock(inputs, num_heads, mlp_dim, dropout_rate, attention_dropout_rate):
    x = tf.keras.layers.LayerNormalization(dtype=inputs.dtype)(inputs)
    x = tf.keras.layers.MultiHeadAttention(
        num_heads=num_heads, 
        key_dim=inputs.shape[-1], 
        dropout=attention_dropout_rate
    )(x, x) # self attention multi-head
    x = tf.keras.layers.Add()([x, inputs]) # 1st residual part 

    y = tf.keras.layers.LayerNormalization(dtype=x.dtype)(x)
    y = mlp_block(y, mlp_dim, dropout_rate)
    y_1 = tf.keras.layers.Add()([y, x]) #2nd residual part 
    return y_1

In [None]:
def Encoder(inputs, num_layers, mlp_dim, num_heads, dropout_rate, attention_dropout_rate):
    x = AddPositionEmbs(name='posembed_input')(inputs)
    
    if dropout_rate>0:
        x = tf.keras.layers.Dropout(rate=dropout_rate)(x)

    for _ in range(num_layers):
        x = Encoder1Dblock(x, num_heads, mlp_dim, dropout_rate, attention_dropout_rate)

    encoded = tf.keras.layers.LayerNormalization(name='encoder_norm')(x)
    return encoded

## **Create Model**

In [None]:
def vision_transformer(
        input_shape,
        classes,
        patch_size=16, 
        patch_overlap=8, 
        hidden_size=64, 
        num_transformer_layers=5,
        num_heads=10,
        mlp_dim=128,
        dropout_rate=0.4, 
        attention_dropout_rate=0.01
    ):

    inputs = tf.keras.layers.Input(shape=input_shape)
    x = inputs
 

    # Create patches.
    x = generate_patches(
        x,
        patch_size, 
        patch_overlap,
        hidden_size
    )

    # Add CLS token
    x = TokenLayer(name='cls')(x)

    # Transformer encoder blocks
    x = Encoder(
        x,
        num_transformer_layers, 
        mlp_dim, 
        num_heads, 
        dropout_rate, 
        attention_dropout_rate,
    )

    # take only the CLS token output
    x = x[:, 0]

    predictions = tf.keras.layers.Dense(classes, name='predictions', activation='softmax')(x)

    # final model
    model = tf.keras.Model(inputs=inputs, outputs=predictions)   
    
    return model

In [None]:
model = vision_transformer(
    input_shape, 
    num_labels, 
    patch_size=16, 
    patch_overlap=8, 
    hidden_size=64, 
    num_transformer_layers=12,
    num_heads=12,
    mlp_dim=256,
    dropout_rate=0.2, 
    attention_dropout_rate=0.2
)

In [None]:
LEARNING_RATE = 0.001

# compile model
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(
    optimizer=optimizer,
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'],
)    

In [None]:
model.summary()

In [None]:
if not os.path.isfile(os.path.join(model_save_loc, "model_summary.txt")):
    with open(os.path.join(model_save_loc, "model_summary.txt"), 'w') as f:
        with redirect_stdout(f):
            model.summary()
    print("Saved model summary ...")

In [None]:
if not os.path.isfile(os.path.join(model_save_loc, "model.json")):
    model_json = model.to_json()
    with open(os.path.join(model_save_loc, "model.json"), "w") as json_file:
        json_file.write(model_json)
    print("Saved model json ...")

# **Train**

In [None]:
prev_save_file = get_prev_save_file_name(model_save_loc)
prev_best_file = get_prev_best_save_file_name(model_save_loc)
prev_epoch = 0
prev_best_acc = 0

if prev_save_file:
    print("Last best save file: ", prev_best_file)
    prev_best_acc = float(prev_best_file[-13:-8])
    print("Last best acc: ", prev_best_acc)

    print("Last save file: ", prev_save_file)
    prev_epoch = int(prev_save_file[-12:-9])
    print("prev epoch:", prev_epoch)

    print("Loading weights...")
    load_status = model.load_weights(os.path.join(model_save_loc,prev_save_file))
    # load_status.assert_consumed()

In [None]:
custom_checkpoint_callback = CustomModelCheckPoint(
                                    model_save_loc,
                                    prev_save_file=prev_save_file,
                                    prev_best_file=prev_best_file,
                                    prev_best_acc=prev_best_acc,
                                    model_name="vic"
                                )

In [None]:
csv_logger_callback = tf.keras.callbacks.CSVLogger(
        os.path.join(model_save_loc, "logs.csv"), separator=',', append=True
    )

In [None]:
num_epochs = 100
early_stopping_patience = 10


history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=num_epochs, 
    initial_epoch=prev_epoch,
    callbacks=[
        custom_checkpoint_callback,
        csv_logger_callback,
        tf.keras.callbacks.EarlyStopping( monitor='val_loss', patience=early_stopping_patience, verbose=1)
    ],
)

# **Test**

In [None]:
results = model.evaluate(test_ds, verbose=1, return_dict=True)
print(results)