In [None]:
# This code MUST be run in a Jupyter Notebook environment 
# to generate the final trained model file (final_cnn_model.h5).

import os
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50 # The PRE-TRAINED CNN MODEL
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
import matplotlib.pyplot as plt

# --- 1. Configuration ---
# Data paths rely on running the 01_data_preprocessing.ipynb first.
BASE_DIR = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
PROCESSED_DATA_DIR = os.path.join(BASE_DIR, 'data', 'processed')
MODELS_DIR = os.path.join(BASE_DIR, 'models')

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_CLASSES = 3  # <--- CRITICAL: Normal, Ischemic Stroke, Hemorrhagic Stroke
EPOCHS_PHASE_1 = 10 # Train only the new layers
EPOCHS_PHASE_2 = 10 # Fine-tune the whole model
TARGET_MODEL_PATH = os.path.join(MODELS_DIR, 'final_cnn_model.h5')

# --- 2. Data Loading with Augmentation ---
# Use augmentation to prevent overfitting and improve generalization
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    validation_split=0.2 # Use 20% of data for validation
)

try:
    train_generator = train_datagen.flow_from_directory(
        PROCESSED_DATA_DIR,
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        subset='training',
        seed=42
    )
    validation_generator = train_datagen.flow_from_directory(
        PROCESSED_DATA_DIR,
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        subset='validation',
        seed=42
    )
    print(f"Detected Classes: {train_generator.class_indices}")

except Exception as e:
    print("\n--- FATAL: Data directory not found. Please run 01_data_preprocessing.ipynb first. ---")
    print(f"Error: {e}")
    # Stop execution if data is missing
    raise

# --- 3. Build Transfer Learning Model (ResNet50 CNN) ---

print("\n--- Building ResNet50 Transfer Learning Model ---")

# Load ResNet50 pre-trained on ImageNet, excluding the top classification layer
base_model = ResNet50(
    weights='imagenet', # Loads the pre-trained weights
    include_top=False,  # Excludes the default 1000-class output head
    input_shape=(IMG_SIZE, IMG_SIZE, 3)
)

# Attach a new classification head for stroke detection (3 classes)
x = base_model.output
x = GlobalAveragePooling2D()(x) # Reduces feature maps to a vector
x = Dense(512, activation='relu')(x) # Added more capacity to the head
x = Dropout(0.5)(x) # Regularization to prevent overfitting on medical images
predictions = Dense(NUM_CLASSES, activation='softmax', name='final_output')(x)

# Create the final model
model = Model(inputs=base_model.input, outputs=predictions)

# --- 4. Training Phase 1: Train only the new top layers ---
print("\n--- Starting Training Phase 1: Training Classification Head (Fast Learning) ---")

# Freeze the base CNN layers (Standard practice: train the specialized layers first)
for layer in base_model.layers:
    layer.trainable = False

model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

history_1 = model.fit(
    train_generator,
    epochs=EPOCHS_PHASE_1,
    validation_data=validation_generator
)

# --- 5. Training Phase 2: Fine-Tuning for High Accuracy (>= 95%) ---
print("\n--- Starting Training Phase 2: Fine-Tuning Top Layers (Slow Learning) ---")

# Unfreeze the last convolutional block for fine-tuning
fine_tune_at = 140 # Unfreeze layers starting from this index (roughly the last 30 layers of ResNet50)
for layer in base_model.layers[fine_tune_at:]:
    layer.trainable = True

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), # Use very low learning rate
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

history_2 = model.fit(
    train_generator,
    epochs=EPOCHS_PHASE_2,
    validation_data=validation_generator
)

# --- 6. Final Saving ---

# Ensure the models directory exists
os.makedirs(MODELS_DIR, exist_ok=True)
model.save(TARGET_MODEL_PATH)
print(f"\nModel training complete and SAVED to: {TARGET_MODEL_PATH}")