In [None]:
import tensorflow as tf
import numpy as np
import cv2
import matplotlib.pyplot as plt
from enum import Enum
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.utils.class_weight import compute_class_weight

# Define skin tone classes
class SkinTone(Enum):
    BLACK = 0
    BROWN = 1
    WHITE = 2

# Load dataset
def load_dataset(dataset_path, batch_size=32, img_size=(224, 224)):
    """
    Load the dataset from the given path.
    """
    datagen = ImageDataGenerator(
        rescale=1./255,
        validation_split=0.2,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True
    )

    train_generator = datagen.flow_from_directory(
        dataset_path,
        target_size=img_size,
        batch_size=batch_size,
        class_mode='sparse',
        subset='training'
    )

    val_generator = datagen.flow_from_directory(
        dataset_path,
        target_size=img_size,
        batch_size=batch_size,
        class_mode='sparse',
        subset='validation'
    )

    print("Class names:", train_generator.class_indices)
    return train_generator, val_generator

# Build the model
def build_model(num_classes=len(SkinTone)):
    """
    Build a skin tone classification model using EfficientNetB0 as the base model.
    """
    base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    base_model.trainable = False  # Freeze the base model initially

    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
    x = Dropout(0.5)(x)
    predictions = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs=base_model.input, outputs=predictions)
    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate=5e-4, decay_steps=1000, alpha=1e-6)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

# Train the model
def train_model(model, train_generator, val_generator, epochs=30):
    """
    Train the model with early stopping and learning rate scheduling.
    """
    early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-6)
    
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=np.unique(train_generator.classes),
        y=train_generator.classes
    )
    class_weights = dict(enumerate(class_weights))
    
    history = model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=epochs,
        class_weight=class_weights,
        callbacks=[early_stopping, reduce_lr]
    )
    
    # Unfreeze some layers and fine-tune
    base_model = model.layers[0]
    base_model.trainable = True
    for layer in base_model.layers[:100]:
        layer.trainable = False
    
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
    history_finetune = model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=10,
        class_weight=class_weights,
        callbacks=[early_stopping]
    )
    
    return history, history_finetune

# Save the model
def save_model(model, model_path="skin_tone_model.h5"):
    model.save(model_path)
    print(f"Model saved at {model_path}")

# Load a pre-trained model
def load_trained_model(model_path="skin_tone_model.h5"):
    model = tf.keras.models.load_model(model_path)
    print(f"Model loaded from {model_path}")
    return model

# Skin segmentation using HSV color space
def segment_skin(image):
    """
    Segment the skin region using HSV color space.
    """
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    lower_skin = np.array([0, 48, 80], dtype=np.uint8)
    upper_skin = np.array([20, 255, 255], dtype=np.uint8)
    mask = cv2.inRange(hsv, lower_skin, upper_skin)
    skin = cv2.bitwise_and(image, image, mask=mask)
    return skin

# Preprocess an image for model input
def preprocess_image(image_path, img_size=(224, 224)):
    """
    Preprocess an image for model input.
    """
    try:
        image = cv2.imread(image_path)
        if image is None:
            raise FileNotFoundError(f"Unable to load image at {image_path}")
        skin = segment_skin(image)
        skin = cv2.resize(skin, img_size) / 255.0
        skin = np.expand_dims(skin, axis=0)
        return skin
    except Exception as e:
        print(f"Error preprocessing image: {e}")
        return None

# Predict skin tone
def predict_skin_tone(image_path, model):
    """
    Predict the skin tone of the person in the image.
    """
    image = preprocess_image(image_path)
    if image is None:
        print("Failed to preprocess image.")
        return
    prediction = model.predict(image)
    predicted_class = np.argmax(prediction)
    predicted_skin_tone = SkinTone(predicted_class).name
    print(f"Predicted Skin Tone: {predicted_skin_tone}")
    plt.imshow(cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.title(f"Predicted: {predicted_skin_tone}")
    plt.show()

# Main function
if __name__ == "__main__":
    dataset_path = "datasets/skin-tone/"
    train_generator, val_generator = load_dataset(dataset_path)
    model = build_model()
    train_model(model, train_generator, val_generator, 30)
    save_model(model)
    model = load_trained_model()
    predict_skin_tone("test_brown2.jpg", model)

# if __name__ == "__main__":
#     model = load_trained_model()
#     predict_skin_tone("test_white.jpg", model)


Found 1200 images belonging to 3 classes.
Found 300 images belonging to 3 classes.
Class names: {'black': 0, 'brown': 1, 'white': 2}


  self._warn_if_super_not_called()


Epoch 1/30
[1m38/38[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 473ms/step - accuracy: 0.3070 - loss: 10.7888 - val_accuracy: 0.3333 - val_loss: 6.1853 - learning_rate: 4.9822e-04
Epoch 2/30
[1m38/38[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 416ms/step - accuracy: 0.3515 - loss: 5.3363 - val_accuracy: 0.3333 - val_loss: 3.4093 - learning_rate: 4.9291e-04
Epoch 3/30
[1m38/38[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 436ms/step - accuracy: 0.2950 - loss: 3.0898 - val_accuracy: 0.3333 - val_loss: 2.2770 - learning_rate: 4.8414e-04
Epoch 4/30
[1m38/38[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 453ms/step - accuracy: 0.3351 - loss: 2.1490 - val_accuracy: 0.3333 - val_loss: 1.7813 - learning_rate: 4.7203e-04
Epoch 5/30
[1m38/38[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 445ms/step - accuracy: 0.3174 - loss: 1.7083 - val_accuracy: 0.3333 - val_loss: 1.5199 - learning_rate: 4.5677e-04
Epoch 6/30
[1m38/38[0m [32m━━━━━━━━━━━━━━