In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.applications import EfficientNetV2B0


In [2]:
# Constants
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 20

In [3]:
# Dataset Path
train_dir = 'dataset1/test1'
test_dir = 'dataset1/test1'

In [4]:
# Data Augmentation
train_datagen = ImageDataGenerator(rescale=1.0/255.0)
test_datagen = ImageDataGenerator(rescale=1.0/255.0)


In [5]:
# Data Generators
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical')

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical')

print(train_generator.class_indices)
assert train_generator.num_classes == 23, f"Expected 23 classes, but got {train_generator.num_classes}"


Found 4002 images belonging to 23 classes.
Found 4002 images belonging to 23 classes.
{'Acne and Rosacea Photos': 0, 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions': 1, 'Atopic Dermatitis Photos': 2, 'Bullous Disease Photos': 3, 'Cellulitis Impetigo and other Bacterial Infections': 4, 'Eczema Photos': 5, 'Exanthems and Drug Eruptions': 6, 'Hair Loss Photos Alopecia and other Hair Diseases': 7, 'Herpes HPV and other STDs Photos': 8, 'Light Diseases and Disorders of Pigmentation': 9, 'Lupus and other Connective Tissue diseases': 10, 'Melanoma Skin Cancer Nevi and Moles': 11, 'Nail Fungus and other Nail Disease': 12, 'Poison Ivy Photos and other Contact Dermatitis': 13, 'Psoriasis pictures Lichen Planus and related diseases': 14, 'Scabies Lyme Disease and other Infestations and Bites': 15, 'Seborrheic Keratoses and other Benign Tumors': 16, 'Systemic Disease': 17, 'Tinea Ringworm Candidiasis and other Fungal Infections': 18, 'Urticaria Hives': 19, 'Vascular Tumors': 2

In [6]:
# Vision Transformer Model
def build_vit_model(img_size, num_classes):
    inputs = layers.Input(shape=(img_size, img_size, 3))
    x = layers.Rescaling(1.0 / 255.0)(inputs)
    x = layers.Conv2D(64, (16, 16), strides=(16, 16), padding='valid')(x)
    x = layers.Reshape((-1, 64))(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.MultiHeadAttention(num_heads=8, key_dim=64)(x, x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs, outputs)
    return model

model = build_vit_model(IMG_SIZE, 23)

In [7]:
# Compile Model
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
              metrics=['accuracy'])

model.summary()

In [8]:
# Callbacks
checkpoint = ModelCheckpoint("model/vit_model.keras", save_best_only=True, monitor='val_loss', mode='min')
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
lr_scheduler = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6, verbose=1)


In [9]:
# Model Training
history = model.fit(train_generator,
                    validation_data=test_generator,
                    epochs=EPOCHS,
                    callbacks=[checkpoint, early_stop, lr_scheduler])


Epoch 1/20


  self._warn_if_super_not_called()


[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 238ms/step - accuracy: 0.0807 - loss: 3.0719 - val_accuracy: 0.0952 - val_loss: 3.0042 - learning_rate: 1.0000e-04
Epoch 2/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 246ms/step - accuracy: 0.0877 - loss: 3.0139 - val_accuracy: 0.1177 - val_loss: 2.9695 - learning_rate: 1.0000e-04
Epoch 3/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 255ms/step - accuracy: 0.1061 - loss: 2.9874 - val_accuracy: 0.1219 - val_loss: 2.9521 - learning_rate: 1.0000e-04
Epoch 4/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 246ms/step - accuracy: 0.1067 - loss: 2.9742 - val_accuracy: 0.1169 - val_loss: 2.9417 - learning_rate: 1.0000e-04
Epoch 5/20
[1m126/126[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 248ms/step - accuracy: 0.1087 - loss: 2.9723 - val_accuracy: 0.1332 - val_loss: 2.9384 - learning_rate: 1.0000e-04
Epoch 6/20
[1m126/126[0m [32m━━━━━━━━━━━━━━