In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import pandas as pd
from models import *

In [4]:
DATASET_NAME = 'MNIST'
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

num_classes = 10

train_images = np.repeat(np.expand_dims(x_train, -1), 3, -1)/255
val_images = np.repeat(np.expand_dims(x_test, -1), 3, -1)/255

train_labels = pd.get_dummies(y_train).values.astype(int)
val_labels = pd.get_dummies(y_test).values.astype(int)

In [5]:
train_images.shape, train_labels.shape

((60000, 28, 28, 3), (60000, 10))

In [6]:
imm = train_images[0]
input_shape = imm.shape

In [7]:
# Configuration for the VIT

resize_to = 224                                                     # To reshape the image before sending it to ViT
img_size = imm.shape[1] if resize_to is None else resize_to         # Final shape of the image

include_top = True                                                  # If True, last layer has a softmax

patch_size = 16                                                     # Size of the patch. 16 recommended
num_patches = (img_size // patch_size) ** 2

projection_dim = 512                                                # Embedding size
transformer_layers = 12                                             # Transformer blocks
num_heads = 4                                                       # Multi-headed attention blocks

mlp_head_units = 1024                                               

In [8]:
MODEL_NAME = 'VIT'
VIT = VisionTransformer(input_shape=input_shape, patch_size=patch_size, num_patches=num_patches, 
                        projection_dim=projection_dim, transformer_layers=transformer_layers, 
                        num_heads=num_heads, mlp_head_units=mlp_head_units, include_top=include_top, num_classes=num_classes,
                        resize_to=resize_to)

In [9]:
VIT.summary();

Model: "model"
________________________________________________________________________________________________________________________
 Layer (type)                                         Output Shape                                    Param #           
 input_1 (InputLayer)                                 [(None, 28, 28, 3)]                             0                 
                                                                                                                        
 resizing (Resizing)                                  (None, 224, 224, 3)                             0                 
                                                                                                                        
 patches (Patches)                                    (None, 14, 14, 512)                             393728            
                                                                                                                        
 reshape (Reshape

In [10]:
logs = 'logs'

In [11]:
from datetime import datetime
from callbacks import *
from tqdm.keras import TqdmCallback

In [13]:
initial_lr = 0.0002

num_epochs = 10
patience = 5

VIT.compile(num_classes=num_classes, learning_rate=initial_lr, weight_decay=0)

log_dir = f'{logs}/{datetime.now().strftime("%Y%m%d-%H%M%S")}/{DATASET_NAME}/{MODEL_NAME}/'
history = VIT.fit(train_images, train_labels,
    epochs=num_epochs,
    validation_data=(val_images, val_labels),
    verbose=0,
    callbacks=[
        # DuplicatedModelCheck(model_log_dir=log_dir),
               SimpleLogger(log_dir=log_dir),               
               EarlyStopping(restore_best_weights=True, patience=patience),
               ModelCheckpoint(log_dir, monitor=f"val_metrics/accuracy", save_best_only=True, save_weights_only=True),
               PositionEmbeddingLogger(log_dir=log_dir),
               TqdmCallback()
              ],
)