In [1]:
import os
import json
import yaml
import pickle
import random
import datetime
from pathlib import Path
from typing import Dict, List
import itertools

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import classification_report, confusion_matrix

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

# -----------------------------
# CONFIG (Windows raw strings)
# -----------------------------
TRAIN_DIR = r"C:\Users\sagni\Downloads\MRI Scan\archive\Training"
TEST_DIR  = r"C:\Users\sagni\Downloads\MRI Scan\archive\Testing"

OUTPUT_DIR = r"C:\Users\sagni\Downloads\MRI Scan"
MODEL_H5   = str(Path(OUTPUT_DIR) / "model.h5")
CLASS_PKL  = str(Path(OUTPUT_DIR) / "class_indices.pkl")
RUN_YAML   = str(Path(OUTPUT_DIR) / "run_config.yaml")
METRICS_JSON = str(Path(OUTPUT_DIR) / "metrics.json")

# Optional helpful extras (won't break your requirement)
ACC_PNG  = str(Path(OUTPUT_DIR) / "accuracy_loss.png")
CM_PNG   = str(Path(OUTPUT_DIR) / "confusion_matrix.png")
CR_CSV   = str(Path(OUTPUT_DIR) / "classification_report.csv")
CM_CSV   = str(Path(OUTPUT_DIR) / "confusion_matrix.csv")

# Training params
IMG_SIZE    = (256, 256)   # MRI works well with a bit larger res
BATCH_SIZE  = 16
EPOCHS      = 20
VAL_SPLIT   = 0.1          # taken from Training for early-stopping/monitor
SEED        = 42
BASE_LR     = 1e-3
AUGMENT     = True         # light aug helps generalization

# -----------------------------
# Reproducibility
# -----------------------------
def set_seed(s=SEED):
    random.seed(s)
    np.random.seed(s)
    tf.random.set_seed(s)
set_seed()

# -----------------------------
# Sanity checks
# -----------------------------
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
if not Path(TRAIN_DIR).exists():
    raise FileNotFoundError(f"Training dir not found: {TRAIN_DIR}")
if not Path(TEST_DIR).exists():
    raise FileNotFoundError(f"Testing dir not found:  {TEST_DIR}")

# Determine class order from TRAIN_DIR subfolders
classes = sorted([p.name for p in Path(TRAIN_DIR).iterdir() if p.is_dir()])
if not classes:
    raise RuntimeError(f"No class subfolders found under: {TRAIN_DIR}")
print("[INFO] Classes:", classes)

# -----------------------------
# Data generators
# -----------------------------
if AUGMENT:
    train_datagen = ImageDataGenerator(
        preprocessing_function=preprocess_input,
        validation_split=VAL_SPLIT,
        rotation_range=5,
        width_shift_range=0.05,
        height_shift_range=0.05,
        zoom_range=0.05,
        brightness_range=(0.95, 1.05),
        fill_mode="nearest"
    )
else:
    train_datagen = ImageDataGenerator(
        preprocessing_function=preprocess_input,
        validation_split=VAL_SPLIT
    )

val_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    validation_split=VAL_SPLIT
)

test_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

train_flow = train_datagen.flow_from_directory(
    TRAIN_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    classes=classes,                  # explicit order
    class_mode="categorical",
    shuffle=True,
    subset="training",
    seed=SEED
)
val_flow = val_datagen.flow_from_directory(
    TRAIN_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    classes=classes,
    class_mode="categorical",
    shuffle=False,                    # deterministic for metrics
    subset="validation",
    seed=SEED
)
test_flow = test_datagen.flow_from_directory(
    TEST_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    classes=classes,
    class_mode="categorical",
    shuffle=False                     # IMPORTANT for CM/report
)

num_classes = len(classes)
print("[INFO] Class indices:", train_flow.class_indices)

# -----------------------------
# Model: EfficientNetB0
# -----------------------------
device = "/GPU:0" if tf.config.list_physical_devices("GPU") else "/CPU:0"
with tf.device(device):
    base = EfficientNetB0(include_top=False, input_shape=(*IMG_SIZE, 3), weights="imagenet")
    base.trainable = False  # freeze for initial training

    inputs = layers.Input(shape=(*IMG_SIZE, 3))
    x = base(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.25)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)
    model = models.Model(inputs, outputs)

    model.compile(
        optimizer=tf.keras.optimizers.Adam(BASE_LR),
        loss="categorical_crossentropy",
        metrics=["accuracy"]
    )

model.summary()

# -----------------------------
# Callbacks
# -----------------------------
callbacks = [
    EarlyStopping(monitor="val_accuracy", patience=4, restore_best_weights=True, verbose=1),
    ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2, min_lr=1e-6, verbose=1),
    ModelCheckpoint(MODEL_H5, monitor="val_accuracy", save_best_only=True, verbose=1)
]

# -----------------------------
# Train
# -----------------------------
history = model.fit(
    train_flow,
    validation_data=val_flow,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

# Save (best already saved by ModelCheckpoint; save current as well)
model.save(MODEL_H5)
print(f"[INFO] Saved model → {MODEL_H5}")

# Save class map (PKL)
with open(CLASS_PKL, "wb") as f:
    pickle.dump(train_flow.class_indices, f)
print(f"[INFO] Saved class_indices → {CLASS_PKL}")

# -----------------------------
# Evaluate on TEST set
# -----------------------------
test_metrics = model.evaluate(test_flow, verbose=1)
test_loss, test_acc = float(test_metrics[0]), float(test_metrics[1])
print(f"[INFO] Test accuracy: {test_acc:.4f}")

# Predict for confusion matrix & report
probs_test = model.predict(test_flow, verbose=1)
y_pred = np.argmax(probs_test, axis=1)
y_true = test_flow.classes
idx_to_class = {v: k for k, v in train_flow.class_indices.items()}
labels_order = [idx_to_class[i] for i in range(len(idx_to_class))]

# Optional helpful artifacts (CSV + plots). Not required by you, but nice to have.
cr = classification_report(y_true, y_pred, target_names=labels_order, output_dict=True, zero_division=0)
pd.DataFrame(cr).to_csv(CR_CSV)
cm = confusion_matrix(y_true, y_pred, labels=list(range(len(labels_order))))
pd.DataFrame(cm, index=labels_order, columns=labels_order).to_csv(CM_CSV)

# Accuracy/Loss plot (optional)
plt.figure(figsize=(10,8))
ax1 = plt.subplot(2,1,1)
ax1.plot(history.history["accuracy"], label="Train Acc")
ax1.plot(history.history["val_accuracy"], label="Val Acc")
ax1.set_title("Accuracy"); ax1.set_xlabel("Epoch"); ax1.set_ylabel("Acc"); ax1.grid(alpha=0.25); ax1.legend()

ax2 = plt.subplot(2,1,2)
ax2.plot(history.history["loss"], label="Train Loss")
ax2.plot(history.history["val_loss"], label="Val Loss")
ax2.set_title("Loss"); ax2.set_xlabel("Epoch"); ax2.set_ylabel("Loss"); ax2.grid(alpha=0.25); ax2.legend()

plt.tight_layout()
plt.savefig(ACC_PNG, dpi=200)
plt.close()

# Confusion matrix heatmap (optional)
cm_norm = cm.astype("float") / cm.sum(axis=1, keepdims=True)
cm_norm = np.nan_to_num(cm_norm)
fig = plt.figure(figsize=(9,7))
ax = plt.gca()
im = ax.imshow(cm_norm, interpolation="nearest", cmap="viridis")
plt.title("Confusion Matrix (Normalized)")
cbar = plt.colorbar(im, fraction=0.046, pad=0.04)
cbar.ax.set_ylabel("Proportion", rotation=90)
ticks = np.arange(len(labels_order))
plt.xticks(ticks, labels_order, rotation=45, ha="right")
plt.yticks(ticks, labels_order)
thresh = cm_norm.max() / 2.0
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    ax.text(j, i, f"{cm[i,j]}\n{cm_norm[i,j]*100:.1f}%",
            ha="center", va="center",
            color="white" if cm_norm[i, j] > thresh else "black",
            fontsize=9)
plt.ylabel("True label"); plt.xlabel("Predicted label")
plt.tight_layout(); plt.savefig(CM_PNG, dpi=220); plt.close()

# -----------------------------
# Save metrics.json (training curves + test metrics)
# -----------------------------
metrics_payload = {
    "timestamp": datetime.datetime.now().isoformat(),
    "device": device,
    "classes": labels_order,
    "params": {
        "img_size": list(IMG_SIZE),
        "batch_size": BATCH_SIZE,
        "epochs_requested": EPOCHS,
        "val_split_from_training": VAL_SPLIT,
        "base_lr": BASE_LR,
        "augment": AUGMENT
    },
    "history": {k: [float(v) for v in vals] for k, vals in history.history.items()},
    "final": {
        "train_accuracy": float(history.history["accuracy"][-1]),
        "train_loss": float(history.history["loss"][-1]),
        "val_accuracy": float(history.history["val_accuracy"][-1]),
        "val_loss": float(history.history["val_loss"][-1]),
        "test_accuracy": test_acc,
        "test_loss": test_loss
    }
}
with open(METRICS_JSON, "w", encoding="utf-8") as f:
    json.dump(metrics_payload, f, indent=2)
print(f"[INFO] Saved metrics.json → {METRICS_JSON}")

# -----------------------------
# Save run_config.yaml
# -----------------------------
run_cfg = {
    "run": {
        "timestamp": datetime.datetime.now().isoformat(),
        "seed": SEED,
        "device": device
    },
    "data": {
        "train_dir": TRAIN_DIR,
        "test_dir": TEST_DIR,
        "classes": labels_order,
        "val_split": VAL_SPLIT
    },
    "model": {
        "architecture": "EfficientNetB0",
        "transfer_learning": True,
        "frozen_base": True,
        "optimizer": "Adam",
        "learning_rate": BASE_LR,
        "epochs": EPOCHS,
        "image_size": list(IMG_SIZE),
        "num_classes": num_classes
    },
    "artifacts": {
        "model_h5": MODEL_H5,
        "class_indices_pkl": CLASS_PKL,
        "metrics_json": METRICS_JSON,
        "classification_report_csv": CR_CSV,
        "confusion_matrix_csv": CM_CSV,
        "accuracy_loss_png": ACC_PNG,
        "confusion_matrix_png": CM_PNG
    }
}
with open(RUN_YAML, "w", encoding="utf-8") as f:
    yaml.safe_dump(run_cfg, f, sort_keys=False, allow_unicode=True)
print(f"[INFO] Saved run_config.yaml → {RUN_YAML}")

print("\n[DONE] All artifacts saved to:", OUTPUT_DIR)


[INFO] Classes: ['glioma', 'meningioma', 'notumor', 'pituitary']
Found 5143 images belonging to 4 classes.
Found 569 images belonging to 4 classes.
Found 1311 images belonging to 4 classes.
[INFO] Class indices: {'glioma': 0, 'meningioma': 1, 'notumor': 2, 'pituitary': 3}


  self._warn_if_super_not_called()


Epoch 1/20
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 427ms/step - accuracy: 0.6961 - loss: 0.7753
Epoch 1: val_accuracy improved from -inf to 0.78910, saving model to C:\Users\sagni\Downloads\MRI Scan\model.h5




[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m170s[0m 471ms/step - accuracy: 0.6964 - loss: 0.7747 - val_accuracy: 0.7891 - val_loss: 0.6029 - learning_rate: 0.0010
Epoch 2/20
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 400ms/step - accuracy: 0.8599 - loss: 0.3875
Epoch 2: val_accuracy improved from 0.78910 to 0.81019, saving model to C:\Users\sagni\Downloads\MRI Scan\model.h5




[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m137s[0m 424ms/step - accuracy: 0.8600 - loss: 0.3875 - val_accuracy: 0.8102 - val_loss: 0.5187 - learning_rate: 0.0010
Epoch 3/20
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 367ms/step - accuracy: 0.8821 - loss: 0.3346
Epoch 3: val_accuracy improved from 0.81019 to 0.84007, saving model to C:\Users\sagni\Downloads\MRI Scan\model.h5




[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m126s[0m 392ms/step - accuracy: 0.8821 - loss: 0.3346 - val_accuracy: 0.8401 - val_loss: 0.4360 - learning_rate: 0.0010
Epoch 4/20
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 375ms/step - accuracy: 0.8957 - loss: 0.2958
Epoch 4: val_accuracy did not improve from 0.84007
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m129s[0m 401ms/step - accuracy: 0.8957 - loss: 0.2958 - val_accuracy: 0.8366 - val_loss: 0.4236 - learning_rate: 0.0010
Epoch 5/20
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 377ms/step - accuracy: 0.9079 - loss: 0.2754
Epoch 5: val_accuracy did not improve from 0.84007
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m129s[0m 401ms/step - accuracy: 0.9079 - loss: 0.2754 - val_accuracy: 0.8366 - val_loss: 0.4609 - learning_rate: 0.0010
Epoch 6/20
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 363ms/step - accuracy: 0.8972 - loss: 0.



[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 387ms/step - accuracy: 0.8972 - loss: 0.2700 - val_accuracy: 0.8418 - val_loss: 0.4275 - learning_rate: 0.0010
Epoch 7/20
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 386ms/step - accuracy: 0.9102 - loss: 0.2627
Epoch 7: val_accuracy improved from 0.84183 to 0.84710, saving model to C:\Users\sagni\Downloads\MRI Scan\model.h5




[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m133s[0m 413ms/step - accuracy: 0.9102 - loss: 0.2627 - val_accuracy: 0.8471 - val_loss: 0.4246 - learning_rate: 5.0000e-04
Epoch 8/20
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.9157 - loss: 0.2448
Epoch 8: val_accuracy improved from 0.84710 to 0.85237, saving model to C:\Users\sagni\Downloads\MRI Scan\model.h5




[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m576s[0m 2s/step - accuracy: 0.9157 - loss: 0.2448 - val_accuracy: 0.8524 - val_loss: 0.3993 - learning_rate: 5.0000e-04
Epoch 9/20
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 361ms/step - accuracy: 0.9156 - loss: 0.2308
Epoch 9: val_accuracy improved from 0.85237 to 0.85940, saving model to C:\Users\sagni\Downloads\MRI Scan\model.h5




[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m124s[0m 386ms/step - accuracy: 0.9156 - loss: 0.2309 - val_accuracy: 0.8594 - val_loss: 0.3860 - learning_rate: 5.0000e-04
Epoch 10/20
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 360ms/step - accuracy: 0.9231 - loss: 0.2294
Epoch 10: val_accuracy did not improve from 0.85940
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m124s[0m 383ms/step - accuracy: 0.9231 - loss: 0.2294 - val_accuracy: 0.8594 - val_loss: 0.3940 - learning_rate: 5.0000e-04
Epoch 11/20
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 354ms/step - accuracy: 0.9140 - loss: 0.2370
Epoch 11: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.

Epoch 11: val_accuracy did not improve from 0.85940
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m122s[0m 378ms/step - accuracy: 0.9140 - loss: 0.2369 - val_accuracy: 0.8366 - val_loss: 0.4199 - learning_rate: 5.0000e-04
Epoch 12/20
[1m322/322



[INFO] Saved model → C:\Users\sagni\Downloads\MRI Scan\model.h5
[INFO] Saved class_indices → C:\Users\sagni\Downloads\MRI Scan\class_indices.pkl
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 232ms/step - accuracy: 0.8402 - loss: 0.3460
[INFO] Test accuracy: 0.8886
[1m82/82[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 266ms/step
[INFO] Saved metrics.json → C:\Users\sagni\Downloads\MRI Scan\metrics.json
[INFO] Saved run_config.yaml → C:\Users\sagni\Downloads\MRI Scan\run_config.yaml

[DONE] All artifacts saved to: C:\Users\sagni\Downloads\MRI Scan
