In [None]:
import sys
from pathlib import Path

current_path = Path.cwd()
root = current_path.parent
sys.path.append(str(root))

import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    tf.config.set_visible_devices(gpus[0], 'GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
print(gpus)

In [2]:
from tensorflow import keras
from tensorflow.keras.layers import Normalization, RandomFlip, RandomRotation, RandomZoom, Dense, Dropout, Embedding

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

In [3]:
from keras.layers import Reshape, Rescaling

data_augmentation = keras.Sequential(
    [
        Normalization(),
        Reshape((28, 28, 1), input_shape=(28, 28)),
        RandomFlip("horizontal"),
        RandomRotation(factor=0.02),
        RandomZoom(height_factor=0.2, width_factor=0.2),
        Rescaling(1./255, input_shape=(28, 28, 1))
    ],
    name="data_augmentation",
)

In [4]:
import numpy as np
from layers.selective_attention import SelectiveAttention

def positional_encoding_2D(length, width, depth):
    depth = depth / 4

    pos_x = np.arange(width)[:, np.newaxis]
    pos_y = np.arange(length)[:, np.newaxis]
    depths = np.arange(depth)[np.newaxis, :] / depth

    angle_rates = 1 / (10000 ** depths)
    angle_rads_x = pos_x * angle_rates
    angle_rads_y = pos_y * angle_rates

    pos_encoding_x = np.concatenate(
        [np.sin(angle_rads_x), np.cos(angle_rads_x)],
        axis=-1
    )

    pos_encoding_x = tf.expand_dims(pos_encoding_x, 1)

    pos_encoding_y = np.concatenate(
        [np.sin(angle_rads_y), np.cos(angle_rads_y)],
        axis=-1
    )

    pos_encoding_y = tf.expand_dims(pos_encoding_y, 0)

    pos_encoding_x = tf.tile(pos_encoding_x, (1, length, 1))
    pos_encoding_y = tf.tile(pos_encoding_y, (width, 1, 1))
    pos_encoding = tf.concat([pos_encoding_x, pos_encoding_y], -1)

    return tf.cast(pos_encoding, dtype=tf.float32)

class PositionalEmbedding2D(tf.keras.layers.Layer):
    def __init__(self, length, width, d_model):
        super().__init__()
        self._length = length
        self._width = width
        self._input_dim = length * width
        self._d_model = d_model
        self._embedding = tf.keras.layers.Embedding(input_dim=self._input_dim, output_dim=self._d_model, mask_zero=True)
        self._pos_encoding = positional_encoding_2D(length=self._length, width=self._width, depth=self._d_model)

    def call(self, x):
        x = self._embedding(x)
        x *= tf.math.sqrt(tf.cast(self._d_model, tf.float32))
        x = x + self._pos_encoding[tf.newaxis, :self._length, :self._width, :]
        return x

class BaseAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__()
        self._sa = SelectiveAttention(**kwargs)
        self._layer_norm = tf.keras.layers.LayerNormalization()
        self._add = tf.keras.layers.Add()

#automatic and triggered by stimulis
class BottomUpAttention(BaseAttention):
    def call(self, x):
        att_output, att_scores = self._sa(
            query=x,
            key=x,
            value=x,
            return_attention_scores=True
        )

        x = self._add([x, att_output])
        x = self._layer_norm(x)
        self._last_att_scores = att_scores
        return x
    
class FeedForward(tf.keras.layers.Layer):
    def __init__(self, d_model, dropout_rate=0.1):
        super().__init__()
        self._seq = tf.keras.Sequential([
            tf.keras.layers.Dense(d_model, activation="relu"),
            tf.keras.layers.Dropout(dropout_rate)
            ])
        self._add = tf.keras.layers.Add()
        self._layer_norm = tf.keras.layers.LayerNormalization()

    def call(self, x):
        x = self._add([x, self._seq(x)])
        x = self._layer_norm(x)
        return x
    
class FinalLayer(tf.keras.layers.Layer):
    def __init__(self, hidden_units, d_model, hidden_rate, num_cat, dropout_rate=0.1):
        super().__init__()
        self._flatten = tf.keras.layers.Flatten()
        self._seq = tf.keras.Sequential([
            #tf.keras.layers.Dropout(hidden_rate),
            #tf.keras.layers.Dense(hidden_units, activation="gelu"),
            tf.keras.layers.Dense(d_model, activation="relu"),
            tf.keras.layers.Dropout(dropout_rate)
            ])
        self._add = tf.keras.layers.Add()
        self._layer_norm = tf.keras.layers.LayerNormalization()
        self._final = tf.keras.layers.Dense(num_cat)

    def call(self, x):
        reduce_axis = range(1, x.shape.rank - 1)
        x = tf.reduce_mean(x, axis=reduce_axis, keepdims=False)
        x = self._add([x, self._seq(x)])
        x = self._layer_norm(x)
        x = self._final(x)
        return x
    
class AttentionLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, dropout_rate=0.1, **kwargs):
        super().__init__()
        self._bottom_up_attention = BottomUpAttention(
            key_dim=d_model,
            dropout=dropout_rate
        )

        self._ffn = FeedForward(d_model=d_model, dropout_rate=dropout_rate)
    
    def call(self, x):
        x = self._bottom_up_attention(x)
        x = self._ffn(x)
        self._last_att_scores = self._bottom_up_attention._last_att_scores
        return x
    
    
class Transformer(tf.keras.Model):
    def __init__(self,
                 *,
                 num_layers,
                 d_model,
                 length,
                 width,
                 hidden_units,
                 num_cat,
                 hidden_rate=0.5,
                 dropout_rate=0.1 
                ):
        super().__init__()
        self._num_layers = num_layers
        self._hidden_units = hidden_units
        self._num_cat = num_cat
        self._pos_embedding  = PositionalEmbedding2D(length, width, d_model)
        self._attention_layer = [
            AttentionLayer(
                d_model=d_model,
                dropout_rate=dropout_rate
            )
            for _ in range(num_layers)
        ]
        self._dropout = tf.keras.layers.Dropout(dropout_rate)
        self._final_layer = FinalLayer(hidden_units=hidden_units, d_model=d_model, 
                                       hidden_rate=hidden_rate, num_cat=num_cat, dropout_rate=dropout_rate)
        self._last_att_score = None
    
    def call(self, x):
        #x = tf.expand_dims(x, -1)
        x = self._pos_embedding(x)
        x = self._dropout(x)

        for i in range(self._num_layers):
            x = self._attention_layer[i](x)

        self._last_att_score = self._attention_layer[-1]._last_att_scores

        logits = self._final_layer(x)
        return logits


In [5]:
num_layers = 12
d_model = 16
length = 28
width = 28
hidden_units=1024
num_cat = 10
hidden_rate = 0.5
dropout_rate = 0.1

model = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    length=length,
    width=width,
    hidden_units=hidden_units,
    num_cat=num_cat,
    hidden_rate=hidden_rate,
    dropout_rate=dropout_rate
)

optimizer = tf.keras.optimizers.Adam(learning_rate=0.005, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)

history = model.compile(
    optimizer=optimizer,
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['accuracy'],
    run_eagerly = True
)

In [6]:
class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, patience=5):
        self.patience=patience
        self.best_accuracy = 0
        self.stagnation_counter = 0

    def on_epooch_end(self, epoch, logs=None):
        current_accuracy = logs.get("accuracy")

        if current_accuracy > self.best_accuracy:
            self.best_accuracy = current_accuracy
            self.stagnation_counter = 0
        else:
            self.stagnation_counter += 1

        if self.stagnation_counter >= self.patience:
            current_lr = tf.keras.backend.get_value(self.model.optimizer.lr)
            tf.keras.backend.set_value(self.model.optimizer.lr, current_lr * 1.5)
            self.stagnation_counter = 0

checkpoint_filepath="./tmp/vsa12.weights.h5"

callback = [

    tf.keras.callbacks.EarlyStopping(
        monitor="val_accuracy",
        patience=10,
        restore_best_weights=True
    ),

    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_accuracy',
        factor=0.5,
        patience=5,
        min_lr=1e-6
    ),
    
    keras.callbacks.ModelCheckpoint(
    checkpoint_filepath,
    monitor="val_accuracy",
    save_best_only=True,
    save_weights_only=True,
    ),

    CustomCallback(patience=5)
    
]

In [None]:

batch_size = 256
num_epochs = 50

history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks= callback
    )

model.load_weights(checkpoint_filepath)
_, accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy:{round(accuracy * 100, 2)}%")

In [None]:
def plot_history(history, item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_"+item], label="val_"+item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()

plot_history(history, "loss")
plot_history(history, "accuracy")