In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 기본 라이브러리
import os
import time
import json
import numpy as np

# 이미지 처리
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import albumentations as A
from albumentations.pytorch import ToTensorV2

# 머신러닝 관련
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    precision_score,
    recall_score,
    f1_score,
    classification_report,
    confusion_matrix,
    roc_auc_score,
    average_precision_score,
    roc_curve,
    auc,
    precision_recall_curve
)

# TensorFlow/Keras 관련
import tensorflow as tf
from tensorflow.keras.applications import (
    MobileNet,
    MobileNetV2,
    MobileNetV3Small,
    MobileNetV3Large
)
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input  # Input 추가
from tensorflow.keras.models import Model

# 기타 유틸리티
from tqdm import tqdm
from itertools import cycle

In [None]:
#데이터 로드 및 확인
def get_data_paths(base_path):
    classes = ['CN7', 'G2', 'BLDC-400-f', 'BLDC-400-b']
    data_paths = {}

    for class_name in classes:
        class_path = os.path.join(base_path, class_name)
        if os.path.exists(class_path):
            # 해당 클래스 디렉토리의 모든 이미지 파일 경로 수집
            image_paths = []
            for file_name in os.listdir(class_path):
                image_path = os.path.join(class_path, file_name)
                if os.path.isfile(image_path):  # 파일인 경우만 추가
                    image_paths.append(image_path)
            data_paths[class_name] = image_paths

    return data_paths

base_path = '/content/drive/MyDrive/augmented_data'
paths = get_data_paths(base_path)

# 각 클래스별 이미지 개수 출력
for class_name, image_paths in paths.items():
    print(f"{class_name}: {len(image_paths)} images")

# 클래스 별 예시 이미지 1개씩 출력
def display_one_per_class(paths):
  num_classes = len(paths)
  plt.figure(figsize=(15, 3))

  for idx, (class_name, image_paths) in enumerate(paths.items(), 1):
       # 첫 번째 이미지만 선택
       img_path = image_paths[0]

       plt.subplot(1, num_classes, idx)
       img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
       plt.imshow(img, cmap='gray')
       plt.title(f'{class_name}')
       plt.axis('off')

  plt.tight_layout()
  plt.show()

# 각 클래스별 1개 이미지 표시
display_one_per_class(paths)

In [None]:
#V1,V2

In [None]:
def plot_roc_curves(y_test, y_pred_proba, class_names):
    """다중 클래스 ROC 커브 플로팅"""
    n_classes = len(class_names)

    # 원-핫 인코딩으로 변환
    y_test_bin = tf.keras.utils.to_categorical(y_test, n_classes)

    # 각 클래스별 ROC 커브 계산
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_pred_proba[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Plot all ROC curves
    plt.figure(figsize=(10, 8))
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green'])

    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=2,
                label=f'{class_names[i]} (AUC = {roc_auc[i]:0.2f})')

    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Multi-class ROC Curves')
    plt.legend(loc="lower right")
    plt.show()

def calculate_map(y_test, y_pred_proba, class_names):
    """클래스별 AP와 mAP 계산"""
    n_classes = len(class_names)
    y_test_bin = tf.keras.utils.to_categorical(y_test, n_classes)

    # 클래스별 AP 계산
    ap_scores = {}
    for i in range(n_classes):
        ap = average_precision_score(y_test_bin[:, i], y_pred_proba[:, i])
        ap_scores[class_names[i]] = ap

    # mAP 계산
    map_score = np.mean(list(ap_scores.values()))

    print("\n=== Average Precision Scores ===")
    for class_name, ap in ap_scores.items():
        print(f"{class_name}: {ap:.4f}")
    print(f"\nMean Average Precision (mAP): {map_score:.4f}")

    return ap_scores, map_score

def load_data(base_path, num_samples=40):
    transforms = A.Compose([
        A.Resize(224, 224),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        ToTensorV2()
    ])

    class_names = ['CN7', 'G2', 'BLDC-400-f', 'BLDC-400-b']
    images = []
    labels = []

    print("\n데이터 로딩 시작:")
    for class_name in class_names:
        class_path = os.path.join(base_path, class_name)
        if os.path.exists(class_path):
            files = []
            for ext in ['.jpg', '.jpeg', '.png', '.bmp']:
                files.extend([f for f in os.listdir(class_path)
                            if f.lower().endswith(ext)])
            files = sorted(files)[:num_samples]

            print(f"{class_name} 클래스 로딩 중: {len(files)}개 이미지")

            for file in tqdm(files):
                image_path = os.path.join(class_path, file)
                try:
                    image = cv2.imread(image_path)
                    if image is None:
                        print(f"이미지를 읽을 수 없음: {image_path}")
                        continue

                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                    transformed = transforms(image=image)['image']
                    transformed = transformed.numpy().transpose(1, 2, 0)

                    images.append(transformed)
                    labels.append(class_names.index(class_name))

                except Exception as e:
                    print(f"이미지 처리 중 에러 발생 {image_path}: {str(e)}")
                    continue

    return np.array(images), np.array(labels)

def load_test_data(test_path):
    transforms = A.Compose([
        A.Resize(224, 224),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        ToTensorV2()
    ])

    class_names = ['CN7', 'G2', 'BLDC-400-f', 'BLDC-400-b']
    images = []
    labels = []

    print("\n테스트 데이터 로딩:")
    for class_name in class_names:
        class_path = os.path.join(test_path, class_name)
        if os.path.exists(class_path):
            files = []
            for ext in ['.jpg', '.jpeg', '.png', '.bmp']:
                files.extend([f for f in os.listdir(class_path)
                            if f.lower().endswith(ext)])
            files = sorted(files)

            print(f"{class_name} 클래스 발견: {len(files)}개 이미지")

            for file in tqdm(files):
                image_path = os.path.join(class_path, file)
                try:
                    image = cv2.imread(image_path)
                    if image is None:
                        print(f"이미지를 읽을 수 없음: {image_path}")
                        continue

                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                    transformed = transforms(image=image)['image']
                    transformed = transformed.numpy().transpose(1, 2, 0)

                    images.append(transformed)
                    labels.append(class_names.index(class_name))

                except Exception as e:
                    print(f"이미지 처리 중 에러 발생 {image_path}: {str(e)}")
                    continue

    return np.array(images), np.array(labels)

def evaluate_metrics(y_true, y_pred, y_pred_proba, class_names):
    """모델의 성능을 종합적으로 평가하는 함수"""
    metrics = {
        'accuracy': np.mean(y_true == y_pred),
        'macro_precision': precision_score(y_true, y_pred, average='macro'),
        'macro_recall': recall_score(y_true, y_pred, average='macro'),
        'macro_f1': f1_score(y_true, y_pred, average='macro'),
        'weighted_precision': precision_score(y_true, y_pred, average='weighted'),
        'weighted_recall': recall_score(y_true, y_pred, average='weighted'),
        'weighted_f1': f1_score(y_true, y_pred, average='weighted'),
    }

    class_report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
    conf_matrix = confusion_matrix(y_true, y_pred)

    try:
        roc_auc = roc_auc_score(y_true, y_pred_proba, multi_class='ovr', average='macro')
        metrics['roc_auc'] = roc_auc
    except:
        metrics['roc_auc'] = None

    class_accuracy = conf_matrix.diagonal() / conf_matrix.sum(axis=1)

    return {
        'basic_metrics': metrics,
        'class_report': class_report,
        'confusion_matrix': conf_matrix,
        'class_accuracy': dict(zip(class_names, class_accuracy))
    }

class PerformanceThresholdCallback(tf.keras.callbacks.Callback):
    def __init__(self, accuracy_threshold=0.90, loss_threshold=0.3):
        super().__init__()
        self.best_weights = None
        self.best_accuracy = 0
        self.accuracy_threshold = accuracy_threshold
        self.loss_threshold = loss_threshold

    def on_epoch_end(self, epoch, logs={}):
        current_accuracy = logs.get('val_accuracy', 0)
        current_loss = logs.get('val_loss', float('inf'))

        if (current_accuracy >= self.accuracy_threshold and
            current_loss <= self.loss_threshold and
            current_accuracy > self.best_accuracy):
            print(f"\n성능 임계값 달성! (accuracy: {current_accuracy:.4f}, loss: {current_loss:.4f})")
            self.best_accuracy = current_accuracy
            self.best_weights = self.model.get_weights()
            print(f"새로운 best weights 저장됨 (정확도: {current_accuracy:.4f})")

        if current_accuracy >= 0.99 and current_loss < 0.05:
            print(f"\n최종 목표 성능 도달! (accuracy: {current_accuracy:.4f}, loss: {current_loss:.4f})")
            if self.best_weights is not None:
                self.model.set_weights(self.best_weights)
                print("최고 성능 모델의 가중치로 복원됨")
            self.model.stop_training = True

class MobileNetExperiment:
    def __init__(self, version, num_classes=4, input_shape=(224, 224, 3)):
        self.version = version
        self.num_classes = num_classes
        self.input_shape = input_shape
        self.model = self._create_model()
        self.history = None
        self.training_time = None
        self.best_val_accuracy = 0
        self.best_weights = None

    def save_model(self, save_path):
        """모델을 저장하는 함수"""
        os.makedirs(save_path, exist_ok=True)

        model_path = os.path.join(save_path, f'mobilenet_{self.version}_model.h5')
        self.model.save(model_path)

        metadata = {
            'version': self.version,
            'best_val_accuracy': self.best_val_accuracy,
            'training_time': self.training_time,
            'history': self.history.history if self.history else None
        }

        metadata_path = os.path.join(save_path, f'mobilenet_{self.version}_metadata.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f)

        print(f"모델이 저장되었습니다: {model_path}")
        print(f"메타데이터가 저장되었습니다: {metadata_path}")

    @classmethod
    def load_model(cls, load_path, version):
        """저장된 모델을 불러오는 함수"""
        instance = cls(version)

        model_path = os.path.join(load_path, f'mobilenet_{version}_model.h5')
        metadata_path = os.path.join(load_path, f'mobilenet_{version}_metadata.json')

        if os.path.exists(model_path):
            instance.model = tf.keras.models.load_model(model_path)
            print(f"모델을 불러왔습니다: {model_path}")
        else:
            raise FileNotFoundError(f"모델 파일을 찾을 수 없습니다: {model_path}")

        if os.path.exists(metadata_path):
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)

            instance.best_val_accuracy = metadata['best_val_accuracy']
            instance.training_time = metadata['training_time']
            if metadata['history']:
                instance.history = type('History', (), {'history': metadata['history']})

            print(f"메타데이터를 불러왔습니다: {metadata_path}")

        return instance

    def _create_model(self):
        if self.version == 'v1':
            base_model = MobileNet(
                weights='imagenet',
                include_top=False,
                input_shape=self.input_shape
            )
        elif self.version == 'v2':
            base_model = MobileNetV2(
                weights='imagenet',
                include_top=False,
                input_shape=self.input_shape
            )
        elif self.version == 'v3-small':
            base_model = MobileNetV3Small(
                weights='imagenet',
                include_top=False,
                input_shape=self.input_shape
            )
        elif self.version == 'v3-large':
            base_model = MobileNetV3Large(
                weights='imagenet',
                include_top=False,
                input_shape=self.input_shape
            )

        base_model.trainable = False

        x = base_model.output
        x = GlobalAveragePooling2D()(x)
        predictions = Dense(self.num_classes, activation='softmax')(x)

        model = Model(inputs=base_model.input, outputs=predictions)
        model.compile(
            optimizer='adam',
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )

        return model

    def train(self, X_train, y_train, X_val, y_val, epochs=10):
        print(f"\n=== MobileNet {self.version} 학습 시작 ===")

        start_time = time.time()

        early_stop = tf.keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            min_delta=0.001,
            patience=6,
            mode='max',
            verbose=1,
            restore_best_weights=False
        )

        performance_threshold = PerformanceThresholdCallback(
            accuracy_threshold=1,  # 100% 정확도
            loss_threshold=0.05    # 0.05 이하 손실
        )

        self.history = self.model.fit(
            X_train, y_train,
            validation_data=(X_val, y_val),
            epochs=epochs,
            batch_size=32,
            callbacks=[early_stop, performance_threshold],
            verbose=1
        )

        if performance_threshold.best_weights is not None:
            print("\n임계값을 넘는 best weights가 발견되어 복원합니다.")
            self.model.set_weights(performance_threshold.best_weights)
            self.best_val_accuracy = performance_threshold.best_accuracy
        else:
            print("\n임계값을 넘는 모델이 없어 현재 weights를 유지합니다.")
            self.best_val_accuracy = max(self.history.history['val_accuracy'])

        self.training_time = time.time() - start_time
        print(f"학습 시간: {self.training_time:.2f}초")

    def evaluate(self, X_test, y_test, class_names):
        inference_start_time = time.time()

        # 예측 수행
        y_pred_proba = self.model.predict(X_test, verbose=0)
        y_pred = np.argmax(y_pred_proba, axis=1)

        inference_time = time.time() - inference_start_time

        # 종합적인 성능 평가
        evaluation_results = evaluate_metrics(y_test, y_pred, y_pred_proba, class_names)

        # ROC 커브 그리기
        plot_roc_curves(y_test, y_pred_proba, class_names)

        # MAP 계산
        ap_scores, map_score = calculate_map(y_test, y_pred_proba, class_names)

        # 결과 출력
        print("\n=== 모델 성능 평가 결과 ===")
        print(f"\n1. 기본 메트릭:")
        for metric, value in evaluation_results['basic_metrics'].items():
            if value is not None:
                print(f"- {metric}: {value:.4f}")

        print("\n2. 클래스별 성능:")
        for class_name in class_names:
            metrics = evaluation_results['class_report'][class_name]
            print(f"\n{class_name}:")
            print(f"- Precision: {metrics['precision']:.4f}")
            print(f"- Recall: {metrics['recall']:.4f}")
            print(f"- F1-score: {metrics['f1-score']:.4f}")
            print(f"- Support: {metrics['support']}")

        print(f"\n3. 처리 시간:")
        print(f"- 전체 추론 시간: {inference_time:.2f}초")
        print(f"- 이미지당 평균 추론 시간: {(inference_time/len(X_test))*1000:.2f}ms")

        self._plot_training_history()
        self._plot_confusion_matrix(evaluation_results['confusion_matrix'], class_names)

        return {
            'metrics': evaluation_results['basic_metrics'],
            'class_report': evaluation_results['class_report'],
            'confusion_matrix': evaluation_results['confusion_matrix'],
            'class_accuracy': evaluation_results['class_accuracy'],
            'training_time': self.training_time,
            'inference_time': inference_time,
            'inference_time_per_image': inference_time/len(X_test),
            'best_val_accuracy': self.best_val_accuracy,
            'best_val_loss': min(self.history.history['val_loss']),
            'ap_scores': ap_scores,
            'map_score': map_score
        }

    def _plot_training_history(self):
        plt.figure(figsize=(12, 4))

        plt.subplot(1, 2, 1)
        plt.plot(self.history.history['accuracy'], label='train')
        plt.plot(self.history.history['val_accuracy'], label='validation')
        plt.title(f'MobileNet {self.version} Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(self.history.history['loss'], label='train')
        plt.plot(self.history.history['val_loss'], label='validation')
        plt.title(f'MobileNet {self.version} Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.tight_layout()
        plt.show()

    def _plot_confusion_matrix(self, cm, class_names):
        plt.figure(figsize=(8, 6))
        sns.heatmap(
            cm,
            annot=True,
            fmt='d',
            cmap='Blues',
            xticklabels=class_names,
            yticklabels=class_names
        )
        plt.title(f'MobileNet {self.version} Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.show()

def run_single_experiment(version, X_train, y_train, X_val, y_val, X_test, y_test, class_names):
    experiment = MobileNetExperiment(version)
    experiment.train(X_train, y_train, X_val, y_val)
    results = experiment.evaluate(X_test, y_test, class_names)

    print(f"\n=== MobileNet {version} 실험 결과 ===")
    print(f"최고 검증 정확도: {results['best_val_accuracy']:.4f}")
    print(f"최고 검증 loss: {results['best_val_loss']:.4f}")
    print(f"학습 소요 시간: {results['training_time']:.2f}초")
    print(f"전체 추론 시간: {results['inference_time']:.2f}초")
    print(f"이미지당 평균 추론 시간: {results['inference_time_per_image']*1000:.2f}ms")
    print(f"테스트 정확도: {results['metrics']['accuracy']:.4f}")

    return results

if __name__ == "__main__":
    # 1. 경로 설정
    base_path = '/content/drive/MyDrive/augmented_data'
    test_path = '/content/drive/MyDrive/test_data2'
    model_save_path = '/content/drive/MyDrive/saved_models'  # 모델 저장 경로

    # 2. 데이터 로드
    print("=== 데이터 로드 시작 ===")
    train_images, train_labels = load_data(base_path, num_samples=102)
    test_images, test_labels = load_test_data(test_path)

    # 3. 학습/검증 데이터 분할
    X_train, X_val, y_train, y_val = train_test_split(
        train_images,
        train_labels,
        test_size=0.2,
        random_state=42,
        stratify=train_labels
    )

    # 4. 모델 학습
    version = 'v2'
    experiment = MobileNetExperiment(version)
    experiment.train(X_train, y_train, X_val, y_val)

    # 5. 모델 저장
    experiment.save_model(model_save_path)

    #6. 모델 불러오기 및 평가
    loaded_experiment = MobileNetExperiment.load_model(model_save_path, version)
    class_names = ['CN7', 'G2', 'BLDC-400-f', 'BLDC-400-b']
    results = loaded_experiment.evaluate(test_images, test_labels, class_names)

In [None]:
#V3(small,large) V1,V2 랑 레이어 달라서 코드 float 함수 바꿔서 넣어야함

In [None]:
def plot_roc_curves(y_test, y_pred_proba, class_names):
    """다중 클래스 ROC 커브 플로팅"""
    n_classes = len(class_names)
    y_test_bin = tf.keras.utils.to_categorical(y_test, n_classes)

    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_pred_proba[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    plt.figure(figsize=(10, 8))
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green'])

    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=2,
                label=f'{class_names[i]} (AUC = {roc_auc[i]:0.2f})')

    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Multi-class ROC Curves')
    plt.legend(loc="lower right")
    plt.show()

def calculate_map(y_test, y_pred_proba, class_names):
    """클래스별 AP와 mAP 계산"""
    n_classes = len(class_names)
    y_test_bin = tf.keras.utils.to_categorical(y_test, n_classes)

    ap_scores = {}
    for i in range(n_classes):
        ap = average_precision_score(y_test_bin[:, i], y_pred_proba[:, i])
        ap_scores[class_names[i]] = ap

    map_score = np.mean(list(ap_scores.values()))

    print("\n=== Average Precision Scores ===")
    for class_name, ap in ap_scores.items():
        print(f"{class_name}: {ap:.4f}")
    print(f"\nMean Average Precision (mAP): {map_score:.4f}")

    return ap_scores, map_score

def load_data(base_path, num_samples=40):
    transforms = A.Compose([
        A.Resize(224, 224),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        ToTensorV2()
    ])

    class_names = ['CN7', 'G2', 'BLDC-400-f', 'BLDC-400-b']
    images = []
    labels = []

    print("\n데이터 로딩 시작:")
    for class_name in class_names:
        class_path = os.path.join(base_path, class_name)
        if os.path.exists(class_path):
            files = []
            for ext in ['.jpg', '.jpeg', '.png', '.bmp']:
                files.extend([f for f in os.listdir(class_path)
                            if f.lower().endswith(ext)])
            files = sorted(files)[:num_samples]

            print(f"{class_name} 클래스 로딩 중: {len(files)}개 이미지")

            for file in tqdm(files):
                image_path = os.path.join(class_path, file)
                try:
                    image = cv2.imread(image_path)
                    if image is None:
                        print(f"이미지를 읽을 수 없음: {image_path}")
                        continue

                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                    transformed = transforms(image=image)['image']
                    transformed = transformed.numpy().transpose(1, 2, 0)

                    images.append(transformed)
                    labels.append(class_names.index(class_name))

                except Exception as e:
                    print(f"이미지 처리 중 에러 발생 {image_path}: {str(e)}")
                    continue

    return np.array(images), np.array(labels)

def load_test_data(test_path):
    transforms = A.Compose([
        A.Resize(224, 224),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        ToTensorV2()
    ])

    class_names = ['CN7', 'G2', 'BLDC-400-f', 'BLDC-400-b']
    images = []
    labels = []

    print("\n테스트 데이터 로딩:")
    for class_name in class_names:
        class_path = os.path.join(test_path, class_name)
        if os.path.exists(class_path):
            files = []
            for ext in ['.jpg', '.jpeg', '.png', '.bmp']:
                files.extend([f for f in os.listdir(class_path)
                            if f.lower().endswith(ext)])
            files = sorted(files)

            print(f"{class_name} 클래스 발견: {len(files)}개 이미지")

            for file in tqdm(files):
                image_path = os.path.join(class_path, file)
                try:
                    image = cv2.imread(image_path)
                    if image is None:
                        print(f"이미지를 읽을 수 없음: {image_path}")
                        continue

                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                    transformed = transforms(image=image)['image']
                    transformed = transformed.numpy().transpose(1, 2, 0)

                    images.append(transformed)
                    labels.append(class_names.index(class_name))

                except Exception as e:
                    print(f"이미지 처리 중 에러 발생 {image_path}: {str(e)}")
                    continue

    return np.array(images), np.array(labels)

def evaluate_metrics(y_true, y_pred, y_pred_proba, class_names):
    """모델의 성능을 종합적으로 평가하는 함수"""
    metrics = {
        'accuracy': np.mean(y_true == y_pred),
        'macro_precision': precision_score(y_true, y_pred, average='macro'),
        'macro_recall': recall_score(y_true, y_pred, average='macro'),
        'macro_f1': f1_score(y_true, y_pred, average='macro'),
        'weighted_precision': precision_score(y_true, y_pred, average='weighted'),
        'weighted_recall': recall_score(y_true, y_pred, average='weighted'),
        'weighted_f1': f1_score(y_true, y_pred, average='weighted'),
    }

    class_report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
    conf_matrix = confusion_matrix(y_true, y_pred)

    try:
        roc_auc = roc_auc_score(y_true, y_pred_proba, multi_class='ovr', average='macro')
        metrics['roc_auc'] = roc_auc
    except:
        metrics['roc_auc'] = None

    class_accuracy = conf_matrix.diagonal() / conf_matrix.sum(axis=1)

    return {
        'basic_metrics': metrics,
        'class_report': class_report,
        'confusion_matrix': conf_matrix,
        'class_accuracy': dict(zip(class_names, class_accuracy))
    }

class PerformanceThresholdCallback(tf.keras.callbacks.Callback):
    def __init__(self, accuracy_threshold=0.90, loss_threshold=0.3):
        super().__init__()
        self.best_weights = None
        self.best_accuracy = 0
        self.accuracy_threshold = accuracy_threshold
        self.loss_threshold = loss_threshold

    def on_epoch_end(self, epoch, logs={}):
        current_accuracy = logs.get('val_accuracy', 0)
        current_loss = logs.get('val_loss', float('inf'))

        if (current_accuracy >= self.accuracy_threshold and
            current_loss <= self.loss_threshold and
            current_accuracy > self.best_accuracy):
            print(f"\n성능 임계값 달성! (accuracy: {current_accuracy:.4f}, loss: {current_loss:.4f})")
            self.best_accuracy = current_accuracy
            self.best_weights = self.model.get_weights()
            print(f"새로운 best weights 저장됨 (정확도: {current_accuracy:.4f})")

        if current_accuracy >= 0.99 and current_loss < 0.05:
            print(f"\n최종 목표 성능 도달! (accuracy: {current_accuracy:.4f}, loss: {current_loss:.4f})")
            if self.best_weights is not None:
                self.model.set_weights(self.best_weights)
                print("최고 성능 모델의 가중치로 복원됨")
            self.model.stop_training = True

class MobileNetExperiment:
    def __init__(self, version, num_classes=4, input_shape=(224, 224, 3)):
        self.version = version
        self.num_classes = num_classes
        self.input_shape = input_shape
        self.model = self._create_model()
        self.history = None
        self.training_time = None
        self.best_val_accuracy = 0
        self.best_weights = None

    def _create_model(self):
        inputs = Input(shape=self.input_shape)

        if self.version == 'v1':
            base_model = MobileNet(
                weights='imagenet',
                include_top=False,
                input_shape=self.input_shape
            )
            x = base_model(inputs)
            x = GlobalAveragePooling2D()(x)

        elif self.version == 'v2':
            base_model = MobileNetV2(
                weights='imagenet',
                include_top=False,
                input_shape=self.input_shape
            )
            x = base_model(inputs)
            x = GlobalAveragePooling2D()(x)

        elif self.version == 'v3-small':
            base_model = MobileNetV3Small(
                weights='imagenet',
                include_top=False,
                input_shape=self.input_shape,
                pooling='avg'
            )
            x = base_model(inputs)

        elif self.version == 'v3-large':
            base_model = MobileNetV3Large(
                weights='imagenet',
                include_top=False,
                input_shape=self.input_shape,
                pooling='avg'
            )
            x = base_model(inputs)

        base_model.trainable = False
        outputs = Dense(self.num_classes, activation='softmax')(x)

        model = Model(inputs=inputs, outputs=outputs)
        model.compile(
            optimizer='adam',
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )

        return model

    def save_model(self, save_path):
        os.makedirs(save_path, exist_ok=True)
        model_path = os.path.join(save_path, f'mobilenet_{self.version}_model.h5')
        self.model.save(model_path)

        metadata = {
            'version': self.version,
            'best_val_accuracy': self.best_val_accuracy,
            'training_time': self.training_time,
            'history': self.history.history if self.history else None
        }

        metadata_path = os.path.join(save_path, f'mobilenet_{self.version}_metadata.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f)

        print(f"모델이 저장되었습니다: {model_path}")
        print(f"메타데이터가 저장되었습니다: {metadata_path}")

    @classmethod
    def load_model(cls, load_path, version):
        instance = cls(version)
        model_path = os.path.join(load_path, f'mobilenet_{version}_model.h5')
        metadata_path = os.path.join(load_path, f'mobilenet_{version}_metadata.json')

        if os.path.exists(model_path):
            instance.model = tf.keras.models.load_model(model_path)
            print(f"모델을 불러왔습니다: {model_path}")
        else:
            raise FileNotFoundError(f"모델 파일을 찾을 수 없습니다: {model_path}")

        if os.path.exists(metadata_path):
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)

            instance.best_val_accuracy = metadata['best_val_accuracy']
            instance.training_time = metadata['training_time']
            if metadata['history']:
                instance.history = type('History', (), {'history': metadata['history']})

            print(f"메타데이터를 불러왔습니다: {metadata_path}")

        return instance

    def train(self, X_train, y_train, X_val, y_val, epochs=10):
        print(f"\n=== MobileNet {self.version} 학습 시작 ===")
        start_time = time.time()

        early_stop = tf.keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            min_delta=0.001,
            patience=6,
            mode='max',
            verbose=1,
            restore_best_weights=False
        )

        performance_threshold = PerformanceThresholdCallback(
            accuracy_threshold=1,  # 100% 정확도
            loss_threshold=0.05    # 0.05 이하 손실
        )

        self.history = self.model.fit(
            X_train, y_train,
            validation_data=(X_val, y_val),
            epochs=epochs,
            batch_size=32,
            callbacks=[early_stop, performance_threshold],
            verbose=1
        )

        if performance_threshold.best_weights is not None:
            print("\n임계값을 넘는 best weights가 발견되어 복원합니다.")
            self.model.set_weights(performance_threshold.best_weights)
            self.best_val_accuracy = performance_threshold.best_accuracy
        else:
            print("\n임계값을 넘는 모델이 없어 현재 weights를 유지합니다.")
            self.best_val_accuracy = max(self.history.history['val_accuracy'])

        self.training_time = time.time() - start_time
        print(f"학습 시간: {self.training_time:.2f}초")

    def evaluate(self, X_test, y_test, class_names):
        inference_start_time = time.time()

        # 예측 수행
        y_pred_proba = self.model.predict(X_test, verbose=0)
        y_pred = np.argmax(y_pred_proba, axis=1)

        inference_time = time.time() - inference_start_time

        # 종합적인 성능 평가
        evaluation_results = evaluate_metrics(y_test, y_pred, y_pred_proba, class_names)

        # ROC 커브 그리기
        plot_roc_curves(y_test, y_pred_proba, class_names)

        # MAP 계산
        ap_scores, map_score = calculate_map(y_test, y_pred_proba, class_names)

        # 결과 출력
        print("\n=== 모델 성능 평가 결과 ===")
        print(f"\n1. 기본 메트릭:")
        for metric, value in evaluation_results['basic_metrics'].items():
            if value is not None:
                print(f"- {metric}: {value:.4f}")

        print("\n2. 클래스별 성능:")
        for class_name in class_names:
            metrics = evaluation_results['class_report'][class_name]
            print(f"\n{class_name}:")
            print(f"- Precision: {metrics['precision']:.4f}")
            print(f"- Recall: {metrics['recall']:.4f}")
            print(f"- F1-score: {metrics['f1-score']:.4f}")
            print(f"- Support: {metrics['support']}")

        print(f"\n3. 처리 시간:")
        print(f"- 전체 추론 시간: {inference_time:.2f}초")
        print(f"- 이미지당 평균 추론 시간: {(inference_time/len(X_test))*1000:.2f}ms")

        self._plot_training_history()
        self._plot_confusion_matrix(evaluation_results['confusion_matrix'], class_names)

        return {
            'metrics': evaluation_results['basic_metrics'],
            'class_report': evaluation_results['class_report'],
            'confusion_matrix': evaluation_results['confusion_matrix'],
            'class_accuracy': evaluation_results['class_accuracy'],
            'training_time': self.training_time,
            'inference_time': inference_time,
            'inference_time_per_image': inference_time/len(X_test),
            'best_val_accuracy': self.best_val_accuracy,
            'best_val_loss': min(self.history.history['val_loss']),
            'ap_scores': ap_scores,
            'map_score': map_score
        }

    def _plot_training_history(self):
        plt.figure(figsize=(12, 4))

        plt.subplot(1, 2, 1)
        plt.plot(self.history.history['accuracy'], label='train')
        plt.plot(self.history.history['val_accuracy'], label='validation')
        plt.title(f'MobileNet {self.version} Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(self.history.history['loss'], label='train')
        plt.plot(self.history.history['val_loss'], label='validation')
        plt.title(f'MobileNet {self.version} Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.tight_layout()
        plt.show()

    def _plot_confusion_matrix(self, cm, class_names):
        plt.figure(figsize=(8, 6))
        sns.heatmap(
            cm,
            annot=True,
            fmt='d',
            cmap='Blues',
            xticklabels=class_names,
            yticklabels=class_names
        )
        plt.title(f'MobileNet {self.version} Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.show()

if __name__ == "__main__":
    # 1. 경로 설정
    base_path = '/content/drive/MyDrive/augmented_data'
    test_path = '/content/drive/MyDrive/test_data2'
    model_save_path = '/content/drive/MyDrive/saved_models'  # 모델 저장 경로

    # 2. 데이터 로드
    print("=== 데이터 로드 시작 ===")
    train_images, train_labels = load_data(base_path, num_samples=102)
    test_images, test_labels = load_test_data(test_path)

    # 데이터 타입 확인 및 변환
    train_images = train_images.astype('float32')
    test_images = test_images.astype('float32')

    # 3. 학습/검증 데이터 분할
    X_train, X_val, y_train, y_val = train_test_split(
        train_images,
        train_labels,
        test_size=0.2,
        random_state=42,
        stratify=train_labels
    )

    # 4. 모델 생성 및 학습
    version = 'v2'
    experiment = MobileNetExperiment(version)

    try:
        # 5. 저장된 모델 불러오기
        loaded_experiment = MobileNetExperiment.load_model(model_save_path, version)
        print("저장된 모델을 성공적으로 불러왔습니다.")
    except Exception as e:
        print(f"모델 로드 중 에러 발생: {str(e)}")
        print("새로운 모델을 학습합니다.")
        experiment.train(X_train, y_train, X_val, y_val)
        experiment.save_model(model_save_path)
        loaded_experiment = experiment

    # 6. 모델 평가
    class_names = ['CN7', 'G2', 'BLDC-400-f', 'BLDC-400-b']
    results = loaded_experiment.evaluate(test_images, test_labels, class_names)