In [None]:
from vit_keras import vit
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.models import Sequential


# 1) Create your base ViT model (pretrained)
base_model = vit.vit_b32(
    image_size=img_size,
    pretrained=True,
    include_top=False,
    pretrained_top=False
)

# 2) Freeze the entire base model first
for layer in base_model.layers:
    layer.trainable = False

# 3) Create your classifier head
model = Sequential()
model.add(base_model)
model.add(Dense(1024, activation='gelu'))
model.add(Dropout(0.3))
model.add(Dense(len(emotion_labels), activation='softmax'))

# 4) Compile with base_model frozen - only the Dense head trains
model.compile(
    optimizer=tf.keras.optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-5),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# 5) Callbacks
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=2, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2),
    tf.keras.callbacks.ModelCheckpoint(filepath='ViT_pretrained_head_only.keras',
                                       monitor='val_loss',
                                       save_best_only=True)
]

# 6) Train just the Dense head for a few epochs
history_head = model.fit(
    train_generator,
    epochs=5,
    validation_data=validation_generator,
    callbacks=callbacks
)

In [None]:
# 7) Unfreeze some of the last layers of the base model
for layer in base_model.layers[-8:]:
    layer.trainable = True


# 8) Re-compile after unfreezing
model.compile(
    optimizer=tf.keras.optimizers.AdamW(learning_rate=1e-5, weight_decay=1e-5),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

callbacks_finetune = [
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2),
    tf.keras.callbacks.ModelCheckpoint(filepath='ViT_pretrained_finetuned.keras',
                                       monitor='val_loss',
                                       save_best_only=True)
]

# 10) Fine-tune the model
history_finetune = model.fit(
    train_generator,
    epochs=15,
    validation_data=validation_generator,
    callbacks=callbacks_finetune
)