In [None]:
# dermai_skin_disease_classifier_final.py

# 1. Import Required Libraries
import numpy as np
import pandas as pd
import os, cv2, matplotlib.pyplot as plt, seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
from sklearn.utils.class_weight import compute_class_weight
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, GlobalAveragePooling2D, BatchNormalization
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

# -------------------------------
# 2. Load Dataset
# -------------------------------
IMAGE_DIRS = ['HAM10000_images_part_1', 'HAM10000_images_part_2']  # both folders
CSV_PATH = 'HAM10000_metadata.csv'

metadata = pd.read_csv(CSV_PATH)
print("[INFO] Class distribution:\n", metadata['dx'].value_counts())

# -------------------------------
# 3. Preprocess Images
# -------------------------------
images, labels = [], []

for _, row in metadata.iterrows():
    img_found = False
    for folder in IMAGE_DIRS:
        img_path = os.path.join(folder, row['image_id'] + '.jpg')
        if os.path.exists(img_path):
            img = cv2.imread(img_path)
            if img is not None:
                img = cv2.resize(img, (224, 224))
                images.append(img)
                labels.append(row['dx'])
                img_found = True
                break
    if not img_found:
        print(f"[WARNING] Image {row['image_id']} not found!")

images = np.array(images) / 255.0
labels = np.array(labels)

# -------------------------------
# 4. Encode Labels
# -------------------------------
le = LabelEncoder()
labels_encoded = le.fit_transform(labels)
labels_categorical = to_categorical(labels_encoded, num_classes=len(le.classes_))

# -------------------------------
# 5. Train-Test Split
# -------------------------------
X_train, X_test, y_train, y_test = train_test_split(
    images, labels_categorical, test_size=0.2, stratify=labels_categorical, random_state=42
)

# -------------------------------
# 6. Data Augmentation
# -------------------------------
datagen = ImageDataGenerator(
    rotation_range=20, width_shift_range=0.1, height_shift_range=0.1,
    horizontal_flip=True, zoom_range=0.1
)
datagen.fit(X_train)

# -------------------------------
# 7. Class Weights (handle imbalance)
# -------------------------------
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(labels_encoded),
    y=labels_encoded
)
class_weights = dict(enumerate(class_weights))
print("[INFO] Class Weights:", class_weights)

# -------------------------------
# 8. Custom CNN Model
# -------------------------------
cnn_model = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(224,224,3)),
    BatchNormalization(),
    MaxPooling2D(2,2),

    Conv2D(64, (3,3), activation='relu'),
    BatchNormalization(),
    MaxPooling2D(2,2),

    Conv2D(128, (3,3), activation='relu'),
    BatchNormalization(),
    MaxPooling2D(2,2),

    Flatten(),
    Dense(256, activation='relu'),
    Dropout(0.5),
    Dense(len(le.classes_), activation='softmax')
])

cnn_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# -------------------------------
# 9. Transfer Learning ResNet50
# -------------------------------
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224,224,3))
for layer in base_model.layers:
    layer.trainable = False  # Freeze

x = GlobalAveragePooling2D()(base_model.output)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
output = Dense(len(le.classes_), activation='softmax')(x)

resnet_model = Model(inputs=base_model.input, outputs=output)
resnet_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# -------------------------------
# 10. Callbacks
# -------------------------------
callbacks = [
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ModelCheckpoint('best_dermai_model.h5', save_best_only=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1)
]

# -------------------------------
# 11. Train CNN
# -------------------------------
print("[INFO] Training Custom CNN...")
history_cnn = cnn_model.fit(
    datagen.flow(X_train, y_train, batch_size=32),
    validation_data=(X_test, y_test),
    epochs=20, callbacks=callbacks,
    class_weight=class_weights
)

# -------------------------------
# 12. Train ResNet50
# -------------------------------
print("[INFO] Training ResNet50...")
history_resnet = resnet_model.fit(
    datagen.flow(X_train, y_train, batch_size=32),
    validation_data=(X_test, y_test),
    epochs=10, callbacks=callbacks,
    class_weight=class_weights
)

# -------------------------------
# 13. Evaluation Function
# -------------------------------
def evaluate_model(model, X_test, y_test, model_name):
    y_pred_probs = model.predict(X_test)
    y_pred = np.argmax(y_pred_probs, axis=1)
    y_true = np.argmax(y_test, axis=1)

    print(f"\nClassification Report ({model_name}):\n",
          classification_report(y_true, y_pred, target_names=le.classes_))

    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=le.classes_, yticklabels=le.classes_)
    plt.title(f'Confusion Matrix - {model_name}')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

    auc_score = roc_auc_score(y_test, y_pred_probs, multi_class='ovr')
    print(f"{model_name} ROC AUC Score: {auc_score:.4f}")

evaluate_model(cnn_model, X_test, y_test, "Custom CNN")
evaluate_model(resnet_model, X_test, y_test, "ResNet50")

# -------------------------------
# 14. Failure Case Analysis
# -------------------------------
y_pred = np.argmax(resnet_model.predict(X_test), axis=1)
y_true = np.argmax(y_test, axis=1)

misclassified_idx = np.where(y_pred != y_true)[0][:9]
plt.figure(figsize=(12, 12))
for i, idx in enumerate(misclassified_idx):
    plt.subplot(3, 3, i+1)
    plt.imshow(X_test[idx])
    plt.title(f"True: {le.classes_[y_true[idx]]}\nPred: {le.classes_[y_pred[idx]]}")
    plt.axis('off')
plt.suptitle("Failure Case Analysis (ResNet50)")
plt.show()

# -------------------------------
# 15. Save Models
# -------------------------------
cnn_model.save("dermai_cnn_model_final.h5")
resnet_model.save("dermai_resnet_model_final.h5")
print("[INFO] Models saved successfully!")
