In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import Adamax, AdamW
import os
import numpy as np
import cv2
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tensorflow.keras.applications import VGG16

# 얼굴 인식 모델 로드
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

def is_valid_image(face_image):
    aspect_ratio = float(face_image.shape[1]) / face_image.shape[0]
    if aspect_ratio < 0.4 or aspect_ratio > 2.5:
        return False
    mean_color = np.mean(face_image)
    if mean_color < 30 or mean_color > 225:
        return False
    if face_image.shape[0] < 100 or face_image.shape[1] < 100:
        return False
    return True

def load_images_and_labels(folder_path, label_mapping, img_size=(128, 128)):
    images = []
    labels = []
    class_counts = {class_name: 0 for class_name in label_mapping.keys()}

    for subfolder_name in os.listdir(folder_path):
        subfolder_path = os.path.join(folder_path, subfolder_name)

        if os.path.isdir(subfolder_path) and subfolder_name in label_mapping:
            label = label_mapping[subfolder_name]

            for file_name in os.listdir(subfolder_path):
                file_path = os.path.join(subfolder_path, file_name)

                if file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    try:
                        image = cv2.imread(file_path)
                        if image is None:
                            print(f"Warning: Image at {file_path} could not be loaded.")
                            continue

                        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
                        faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5)

                        for (x, y, w, h) in faces:
                            face_image = image[y:y+h, x:x+w]
                            face_image = cv2.resize(face_image, img_size)

                            if is_valid_image(face_image):
                                images.append(face_image / 255.0)  
                                labels.append(label)
                                class_counts[subfolder_name] += 1
                            break

                    except Exception as e:
                        print(f"Error loading image {file_path}: {e}")

    print(f"Total images loaded: {len(images)}")
    print(f"Total labels collected: {len(labels)}")
    
    unique_classes = np.unique(labels)
    print(f"Total unique classes: {len(unique_classes)}")
    
    for class_name, count in class_counts.items():
        print(f"Class '{class_name}' has {count} images.")

    return np.array(images), np.array(labels), class_counts

def visualize_images(images, labels, label_mapping, num_images_per_class=5):
    plt.figure(figsize=(15, 10))
    for i in range(len(label_mapping)):
        class_images = images[labels == i]
        if len(class_images) > 0:
            random_indices = np.random.choice(len(class_images), size=min(num_images_per_class, len(class_images)), replace=False)
            for j in range(len(random_indices)):
                plt.subplot(len(label_mapping), num_images_per_class, i * num_images_per_class + j + 1)
                plt.imshow(class_images[random_indices[j]])
                plt.axis('off')
                plt.title(list(label_mapping.keys())[i])
    plt.tight_layout()
    plt.show()

def build_model(input_shape, num_classes):
    base_model = VGG16(weights='imagenet', include_top=False, input_shape=input_shape)

    for layer in base_model.layers[:-4]:  
        layer.trainable = False

    inputs = layers.Input(shape=input_shape)
    x = base_model(inputs)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.BatchNormalization()(x)

    x = layers.Dense(128, activation='elu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)  
    x = layers.Dense(64, activation='elu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)  
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs, outputs)
    return model

def detect_faces(images):
    detected_images = []
    detected_labels = []
    
    for i, image in enumerate(images):
        image_uint8 = (image * 255).astype(np.uint8)  
        gray = cv2.cvtColor(image_uint8, cv2.COLOR_BGR2GRAY)
        faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5)
        if len(faces) > 0:
            detected_images.append(image)
            detected_labels.append(i)  
    return np.array(detected_images), np.array(detected_labels)

if __name__ == "__main__":
    folder_path = "/Users/withmocha/Desktop/DATA/BOAZ/미니 프로젝트 2/data/face data/"
    
    label_mapping = {
        "spring": 0,
        "summer": 1,
        "fall": 2,
        "winter": 3
    }

    input_shape = (128, 128, 3)
    num_classes = len(label_mapping)

    images, labels, class_counts = load_images_and_labels(folder_path=folder_path, label_mapping=label_mapping, img_size=(128, 128))

    encoder = OneHotEncoder(sparse_output=False)
    labels_onehot = encoder.fit_transform(labels.reshape(-1, 1))

    print("Counts of validated images per class:")
    for class_name, count in class_counts.items():
        print(f"Class '{class_name}': {count} images")

    visualize_images(images, labels, label_mapping, num_images_per_class=5)

    detected_images, detected_labels = detect_faces(images)

    detected_class_counts = {class_name: 0 for class_name in label_mapping.keys()}
    
    detected_labels_as_classes = []

    for idx in detected_labels:
        detected_label = labels[idx]
        detected_labels_as_classes.append(detected_label)
       
    for detected_label in detected_labels_as_classes:
        class_name = list(label_mapping.keys())[detected_label]  
        detected_class_counts[class_name] += 1  

    print("Counts of detected faces per class:")
    for class_name, count in detected_class_counts.items():
        print(f"Class '{class_name}': {count} images")  

    for class_name in detected_class_counts.keys():
        if detected_class_counts[class_name] > 0:
            class_images = detected_images[np.array(detected_labels_as_classes) == label_mapping[class_name]]
            visualize_images(class_images, np.full(len(class_images), label_mapping[class_name]), label_mapping, num_images_per_class=5)

    train_images, val_images, train_labels, val_labels = train_test_split(
        detected_images, detected_labels_as_classes, test_size=0.1, stratify=detected_labels_as_classes, shuffle=True
    )

    train_labels = np.array(train_labels)
    val_labels = np.array(val_labels)

    train_labels_onehot = encoder.fit_transform(train_labels.reshape(-1, 1))
    val_labels_onehot = encoder.transform(val_labels.reshape(-1, 1))

    model = build_model(input_shape=input_shape, num_classes=num_classes)

    optimizer = AdamW(learning_rate=0.00005)
    model.compile(
        optimizer=optimizer,
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True)
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6)

    history = model.fit(
        train_images,
        train_labels_onehot,  
        validation_data=(val_images, val_labels_onehot),  
        epochs=100,
        callbacks=[early_stopping, reduce_lr]
    )

    val_predictions = model.predict(val_images)
    val_predictions_classes = np.argmax(val_predictions, axis=-1)

    val_labels_classes = np.argmax(val_labels_onehot, axis=1)  
    print(classification_report(val_labels_classes, val_predictions_classes, target_names=list(label_mapping.keys())))

    cm = confusion_matrix(val_labels_classes, val_predictions_classes)
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(num_classes)
    plt.xticks(tick_marks, list(label_mapping.keys()), rotation=45)
    plt.yticks(tick_marks, list(label_mapping.keys()))
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()


KeyboardInterrupt: 