In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, cohen_kappa_score, confusion_matrix
import os

In [None]:
SUBSET_DIR = '../data/subset/'
MODELS_DIR = '../models/'
RESULTS_DIR = '../results/'
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 25 

print("--- Step 3: Creating Balanced Data Subset ---")
df = pd.read_csv(CSV_PATH)
balanced_df = df.groupby('diagnosis').apply(
    lambda x: x.sample(SAMPLES_PER_CLASS, random_state=42) if len(x) >= SAMPLES_PER_CLASS else x
).reset_index(drop=True)

balanced_df = pd.read_csv('../data/balanced_labels.csv')
train_df, val_df = train_test_split(
    balanced_df, test_size=0.2, random_state=42, stratify=balanced_df['diagnosis']
)

# Create a separate test set from the validation data for final evaluation
val_df, test_df = train_test_split(
    val_df, test_size=0.5, random_state=42, stratify=val_df['diagnosis']
)

# --- Data Generators ---
train_datagen = ImageDataGenerator(
    rescale=1./255., rotation_range=20, width_shift_range=0.1,
    height_shift_range=0.1, shear_range=0.1, zoom_range=0.1,
    horizontal_flip=True, fill_mode='nearest'
)
val_datagen = ImageDataGenerator(rescale=1./255.) # Used for validation and testing

train_generator = train_datagen.flow_from_dataframe(...)
validation_generator = val_datagen.flow_from_dataframe(...)
# Important: Create a test generator that does NOT shuffle
test_generator = val_datagen.flow_from_dataframe(
    dataframe=test_df,
    directory=SUBSET_DIR,
    x_col='id_code',
    y_col='diagnosis',
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=1, # Process one image at a time for evaluation
    class_mode='categorical',
    shuffle=False # DO NOT SHUFFLE
)

# =============================================================================
# 4. BUILD AND TRAIN THE FINAL MODEL (EfficientNetB0)
# =============================================================================
base_model = EfficientNetB0(input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top=False, weights='imagenet')
base_model.trainable = False
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu')(x)
predictions = Dense(5, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

print("\n--- Training Final Model (EfficientNetB0) ---")
history = model.fit(
    train_generator,
    epochs=EPOCHS,
    validation_data=validation_generator
)
model.save(os.path.join(MODELS_DIR, 'final_model.h5'))
print(f"Final model saved to {os.path.join(MODELS_DIR, 'final_model.h5')}")


# =============================================================================
# 5. EVALUATE THE FINAL MODEL ON THE TEST SET
# =============================================================================
print("\n--- Evaluating Final Model on Test Set ---")
# Load the best performing model
final_model = load_model(os.path.join(MODELS_DIR, 'final_model.h5'))

# Get predictions
predictions = final_model.predict(test_generator, steps=len(test_df))
y_pred = np.argmax(predictions, axis=1)
y_true = test_generator.classes
class_labels = list(test_generator.class_indices.keys())

# --- Calculate and Print Metrics ---
kappa = cohen_kappa_score(y_true, y_pred)
report = classification_report(y_true, y_pred, target_names=class_labels)
print(f"Cohen's Kappa Score: {kappa:.4f}")
print("\nClassification Report:\n", report)

# --- Generate and Save Confusion Matrix ---
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_labels, yticklabels=class_labels)
plt.title("Confusion Matrix - Final Model")
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.savefig(os.path.join(RESULTS_DIR, 'final_model_confusion_matrix.png'))
print(f"Confusion matrix saved to {os.path.join(RESULTS_DIR, 'final_model_confusion_matrix.png')}")
plt.show()