
# Multi-Model Training & Profiling
Complete binary PlantVillage comparison (MobileNetV2, EfficientNetB0, ResNet50, VGG16, DenseNet121, InceptionV3) with timing and memory stats. Optimized defaults for M3 Pro.


In [1]:
# ============================================================
# STEP 1 ‚Äî Imports, Global Utilities, Memory Tracking
# ============================================================

import os
import time
import math
import psutil
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# TensorFlow
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers, models

# sklearn metrics for evaluation
from sklearn.metrics import (
    confusion_matrix,
    classification_report,
    precision_recall_fscore_support,
    roc_auc_score,
)

print("\n===== CHECKING ACCELERATORS =====")
print(f"TensorFlow GPU devices: {tf.config.list_physical_devices('GPU')}")


# ============================================================
# MEMORY TRACKING HELPERS
# ============================================================

def get_memory_usage():
    """Return RAM + GPU memory in GB. Safe on Apple Silicon."""
    info = {"ram_gb": None, "gpu_current_gb": None, "gpu_peak_gb": None}

    # RAM usage (always available)
    try:
        process = psutil.Process(os.getpid())
        info["ram_gb"] = process.memory_info().rss / (1024 ** 3)
    except Exception:
        info["ram_gb"] = None

    # GPU memory ‚Äî works on NVIDIA CUDA only
    try:
        mem_info = tf.config.experimental.get_memory_info("GPU:0")
        info["gpu_current_gb"] = mem_info["current"] / (1024 ** 3)
        info["gpu_peak_gb"] = mem_info["peak"] / (1024 ** 3)
    except Exception:
        # Apple Metal + some TF versions: unsupported
        info["gpu_current_gb"] = None
        info["gpu_peak_gb"] = None

    return info


def reset_gpu_memory_stats():
    """Safely reset GPU memory stats (CUDA only)."""
    try:
        tf.config.experimental.reset_memory_stats("GPU:0")
    except Exception:
        pass



===== CHECKING ACCELERATORS =====
‚úì PyTorch running on Apple MPS GPU
TensorFlow GPU devices: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
# ============================================================
# STEP 2 ‚Äî Global Config + TensorFlow Hardware Setup
# ============================================================

# Global config
BATCH_SIZE = 32
IMAGE_SIZE = (224, 224)
EPOCHS = 10
SEED = 42

# Output directories
os.makedirs("models", exist_ok=True)
os.makedirs("logs", exist_ok=True)
os.makedirs("results", exist_ok=True)

print("\n===== STEP 2: TensorFlow Hardware Configuration =====")
print("Configuring TensorFlow runtime...")

# -----------------------------------------------------------
# GPU detection
# -----------------------------------------------------------
gpus = tf.config.list_physical_devices("GPU")

if len(gpus) > 0:
    try:
        # Enable memory growth (prevents OOM issues)
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

        print(f"‚úì TensorFlow detected {len(gpus)} GPU(s)")
        print("‚úì GPU memory growth enabled")
    except Exception as e:
        print(f"‚ö† Error enabling memory growth: {e}")
else:
    print("‚ö† No TensorFlow GPU device found ‚Äî running on CPU (slower)")
    print("  On Apple Silicon, ensure 'tensorflow-macos' + 'tensorflow-metal' are installed.")

# -----------------------------------------------------------
# CPU Thread Optimization for Apple Silicon
# -----------------------------------------------------------
try:
    tf.config.threading.set_inter_op_parallelism_threads(8)
    tf.config.threading.set_intra_op_parallelism_threads(8)
    print("‚úì Optimized TensorFlow threading for Apple Silicon CPU")
except Exception:
    print("‚ö† Threading optimization not supported on this setup")

# -----------------------------------------------------------
# Log TF Version / Hardware
# -----------------------------------------------------------
print(f"TensorFlow version    : {tf.__version__}")
print(f"Detected TF GPUs      : {len(tf.config.list_physical_devices('GPU'))}")
print("=========================================================\n")



===== STEP 2: TensorFlow Hardware Configuration =====
Configuring TensorFlow for Apple M-Series / CUDA...
‚úì TensorFlow detected 1 GPU(s)
‚úì GPU memory growth enabled
‚úì Optimized TensorFlow threading for M-Series CPU
TensorFlow version    : 2.16.2
Detected TF GPUs      : 1
PyTorch device chosen : mps



In [3]:
# ============================================================
# STEP 3 ‚Äî Load Dataset + Split + Convert to Binary Labels
# ============================================================

def load_and_split_plant_village(seed=SEED):
    """
    Loads the PlantVillage dataset.
    If only a 'train' split exists, creates train/val/test = 70/15/15.
    Returns:
        plant_village_data (dict of datasets)
        info (tfds metadata)
    """
    print("===== Loading PlantVillage Dataset =====")

    plant_village_data, info = tfds.load(
        "plant_village",
        with_info=True,
        as_supervised=True,
        shuffle_files=True,
    )

    # If TFDS provided only `train`, manually split it.
    if list(plant_village_data.keys()) == ["train"]:
        print("‚ö† Dataset only has 'train' split ‚Üí Creating 70/15/15 splits...")

        full = plant_village_data["train"]
        total_size = info.splits["train"].num_examples

        train_size = int(0.70 * total_size)
        val_size = int(0.15 * total_size)

        full = full.shuffle(total_size, seed=seed)

        train_ds = full.take(train_size)
        val_ds = full.skip(train_size).take(val_size)
        test_ds = full.skip(train_size + val_size)

        plant_village_data = {
            "train": train_ds,
            "validation": val_ds,
            "test": test_ds,
        }

        print(f"‚úì train={train_size}, val={val_size}, test={total_size - train_size - val_size}")

    else:
        print("‚úì TFDS already provides train/validation/test splits")

    print("=========================================\n")
    return plant_village_data, info



def make_binary_labels(plant_village_data, info):
    """
    Converts multi-class PlantVillage labels -> binary (healthy=0, diseased=1).
    Also calculates total counts of healthy vs diseased for augmentation strategy.
    Returns:
        binary_data (dict of datasets)
        total_healthy (int)
        total_diseased (int)
    """

    print("===== Converting to Binary Labels (Healthy=0, Diseased=1) =====")

    # List of all original text labels in TFDS
    label_names = info.features["label"].names

    # Build lookup: per-index ‚Üí 0/1
    binary_lookup = np.array(
        [0 if name.lower().endswith("healthy") else 1 for name in label_names],
        dtype=np.int32
    )
    binary_lookup_tf = tf.constant(binary_lookup)

    def to_binary_label(image, label):
        """
        Map TFDS integer label ‚Üí 0 (healthy) / 1 (diseased)
        using the lookup table.
        """
        label = tf.cast(label, tf.int32)
        binary_label = tf.gather(binary_lookup_tf, label)
        return image, binary_label

    # Apply to each split
    binary_data = {
        split: ds.map(to_binary_label, num_parallel_calls=tf.data.AUTOTUNE)
        for split, ds in plant_village_data.items()
    }

    # -------------------------------------------
    # Compute original class imbalance statistics
    # -------------------------------------------
    print("Counting healthy vs diseased samples in training split...")

    total_healthy = 0
    total_diseased = 0

    for _, label in tfds.as_numpy(plant_village_data["train"]):
        label_str = info.features["label"].int2str(int(label))
        class_name = label_str.split("___", 1)[-1].lower()
        if class_name == "healthy":
            total_healthy += 1
        else:
            total_diseased += 1

    print(f"‚úì Healthy: {total_healthy}")
    print(f"‚úì Diseased: {total_diseased}")
    print(f"‚Üí Imbalance ratio: {total_diseased / max(total_healthy, 1):.2f}:1")
    print("===============================================================\n")

    return binary_data, total_healthy, total_diseased


In [4]:
# ============================================================
# STEP 4 ‚Äî Augmentation + Balanced Training Dataset
# ============================================================

def build_augmented_train_ds(binary_data, total_healthy, total_diseased):
    """
    Creates a balanced augmented training dataset:
    - Aggressive augmentation for healthy class
    - Mild augmentation / 50% replacement for diseased class
    - Replicates healthy images to balance dataset
    """

    print("===== Building Augmented Training Dataset =====")

    train_ds = binary_data["train"]

    # ----------------------------------------------
    # Split into healthy vs diseased datasets
    # ----------------------------------------------
    healthy_label = 0
    diseased_label = 1

    healthy_ds = train_ds.filter(lambda _, lbl: tf.equal(lbl, healthy_label))
    diseased_ds = train_ds.filter(lambda _, lbl: tf.equal(lbl, diseased_label))

    print(f"Original counts ‚Üí Healthy={total_healthy}, Diseased={total_diseased}")

    # ============================================================
    # AUGMENTATION FUNCTIONS
    # ============================================================

    def augment_healthy(image, label):
        """Strong augmentation to create variety for minority class."""
        img = tf.image.convert_image_dtype(image, tf.float32)

        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_flip_up_down(img)
        img = tf.image.rot90(img, tf.random.uniform([], 0, 4, dtype=tf.int32))
        img = tf.image.random_saturation(img, 0.8, 1.25)
        img = tf.image.random_hue(img, 0.05)
        img = tf.image.random_brightness(img, 0.12)
        img = tf.image.random_contrast(img, 0.8, 1.25)

        img = tf.clip_by_value(img, 0.0, 1.0)
        return tf.image.convert_image_dtype(img, tf.uint8), label


    def augment_diseased_with_replacement(image, label):
        """Mild augmentations, 50% chance to keep original."""
        img = tf.image.convert_image_dtype(image, tf.float32)

        def apply_aug():
            aug = tf.image.random_flip_left_right(img)
            aug = tf.image.rot90(aug, tf.random.uniform([], 0, 4, dtype=tf.int32))
            aug = tf.image.random_contrast(aug, 0.9, 1.1)
            aug = tf.image.random_brightness(aug, 0.08)
            aug = tf.image.random_hue(aug, 0.03)
            aug = tf.clip_by_value(aug, 0.0, 1.0)
            return tf.image.convert_image_dtype(aug, tf.uint8)

        # 50% chance original / 50% chance augmented
        prob = tf.random.uniform([], 0.0, 1.0)
        chosen = tf.cond(prob > 0.5, apply_aug, lambda: image)
        return chosen, label

    # ============================================================
    # OVERSAMPLING / REPLICATION FOR HEALTHY CLASS
    # ============================================================

    # Compute replication factor
    if total_healthy == 0:
        healthy_multiplier = 1
    else:
        healthy_multiplier = max(1, math.ceil(total_diseased / total_healthy) - 1)

    print(f"Healthy class replication multiplier: {healthy_multiplier}x")

    # Apply augmentations
    augmented_healthy_datasets = [healthy_ds]  # original included

    for _ in range(healthy_multiplier):
        augmented_healthy_datasets.append(
            healthy_ds.map(augment_healthy, num_parallel_calls=tf.data.AUTOTUNE)
        )

    # Concatenate all healthy datasets
    healthy_augmented_ds = augmented_healthy_datasets[0]
    for ds in augmented_healthy_datasets[1:]:
        healthy_augmented_ds = healthy_augmented_ds.concatenate(ds)

    healthy_augmented_ds = healthy_augmented_ds.shuffle(4096)

    # ============================================================
    # AUGMENT DISEASED CLASS
    # ============================================================

    diseased_augmented_ds = diseased_ds.map(
        augment_diseased_with_replacement, num_parallel_calls=tf.data.AUTOTUNE
    )

    # ============================================================
    # COMBINE + SHUFFLE + PREFETCH
    # ============================================================

    augmented_train_ds = (
        healthy_augmented_ds
        .concatenate(diseased_augmented_ds)
        .shuffle(8192)
        .prefetch(tf.data.AUTOTUNE)
    )

    new_healthy = total_healthy * (healthy_multiplier + 1)
    print(f"Augmented training ‚Üí Healthy‚âà{new_healthy}, Diseased‚âà{total_diseased}")
    print("Final ratio ‚âà 1:1")
    print("===============================================================\n")

    # Update dict
    binary_data["train"] = augmented_train_ds
    return binary_data


In [5]:
# ============================================================
# STEP 5 ‚Äî TF Dataset Preprocessing + Batched Pipelines
# ============================================================

def prepare_dataset(ds, batch_size=BATCH_SIZE, image_size=IMAGE_SIZE):
    """
    Takes a raw TF dataset (image,label) and applies:
      - resizing to image_size
      - scaling to [0,1]
      - batching
      - prefetching
    Returns:
        A performant tf.data.Dataset for training or evaluation.
    """

    AUTOTUNE = tf.data.AUTOTUNE

    def preprocess(image, label):
        # Resize all images to the chosen model input size (224x224)
        image = tf.image.resize(image, image_size)
        # Normalize to [0, 1] float32
        image = tf.cast(image, tf.float32) / 255.0
        return image, label

    return (
        ds.map(preprocess, num_parallel_calls=AUTOTUNE)
        .batch(batch_size)
        .prefetch(AUTOTUNE)
    )


def prepare_all_splits(binary_data):
    """
    Wraps prepare_dataset() for all three splits.
    Ensures consistent preprocessing for:
        - train_ds
        - val_ds
        - test_ds
    Returns:
        (train_ds, val_ds, test_ds)
    """

    print("===== Preparing TF Datasets (Batched + Prefetched) =====")

    train_ds = prepare_dataset(binary_data["train"])
    val_ds   = prepare_dataset(binary_data["validation"])
    test_ds  = prepare_dataset(binary_data["test"])

    print(f"‚úì Batch size    : {BATCH_SIZE}")
    print(f"‚úì Image size    : {IMAGE_SIZE}")
    print("‚úì train/val/test datasets ready")
    print("=========================================================\n")

    return train_ds, val_ds, test_ds


In [None]:
# ============================================================
# STEP 6 ‚Äî Model Builders (TensorFlow)
# ============================================================

from tensorflow.keras.applications import (
    MobileNetV2,
    EfficientNetB0,
    VGG16,
    ResNet50,
    Xception,
)

TF_MODEL_REGISTRY = {
    "mobilenet_v2": MobileNetV2,
    "efficientnet_b0": EfficientNetB0,
    "vgg16": VGG16,
    "resnet50": ResNet50,
    "xception": Xception,
}


def create_binary_classifier(base_model_fn, input_shape=(224, 224, 3), dropout=0.2):
    base_model = base_model_fn(
        input_shape=input_shape,
        include_top=False,
        weights="imagenet",
    )
    base_model.trainable = False

    inputs = base_model.input
    x = layers.GlobalAveragePooling2D()(base_model.output)
    x = layers.Dropout(dropout)(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)

    model = models.Model(inputs=inputs, outputs=outputs)
    return model


def get_model_builder(model_key):
    key = model_key.lower()
    if key not in TF_MODEL_REGISTRY:
        raise ValueError(f"Unknown TensorFlow model key: {model_key}")

    def builder():
        return create_binary_classifier(TF_MODEL_REGISTRY[key])

    return builder


MODELS_TO_TRAIN = [
    "mobilenet_v2",
    "efficientnet_b0",
]


In [None]:
# ============================================================
# STEP 7 ‚Äî TensorFlow Training Utilities
# ============================================================

def compile_tf_model(model, lr=1e-3):
    model.compile(
        optimizer=tf.keras.optimizers.Adam(lr),
        loss="binary_crossentropy",
        metrics=[
            tf.keras.metrics.BinaryAccuracy(name="accuracy"),
            tf.keras.metrics.AUC(name="auc"),
            tf.keras.metrics.Precision(name="precision"),
            tf.keras.metrics.Recall(name="recall"),
        ],
    )


def train_tf_model(model, train_ds, val_ds, ckpt_path, log_dir, epochs=10):
    compile_tf_model(model)

    cb = [
        tf.keras.callbacks.TensorBoard(log_dir=log_dir),
        tf.keras.callbacks.ModelCheckpoint(
            ckpt_path, monitor="val_auc", mode="max", save_best_only=True
        ),
        tf.keras.callbacks.EarlyStopping(
            monitor="val_auc", mode="max", patience=3, restore_best_weights=True
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss", factor=0.5, patience=2, min_lr=1e-6
        ),
    ]

    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=cb,
        verbose=1,
    )
    return model, history


def train_single_model(model_key, train_ds, val_ds, epochs=EPOCHS):
    print(f"\n===== TRAINING MODEL: {model_key} =====")

    builder = get_model_builder(model_key)
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    ckpt_path = f"models/{model_key}_{timestamp}.keras"
    log_dir = f"logs/{model_key}_{timestamp}"

    reset_gpu_memory_stats()
    mem_before = get_memory_usage()
    t0 = time.time()

    model = builder()
    num_params = model.count_params()

    model, history = train_tf_model(
        model,
        train_ds,
        val_ds,
        ckpt_path=ckpt_path,
        log_dir=log_dir,
        epochs=epochs,
    )

    train_time = time.time() - t0
    mem_after = get_memory_usage()

    print(f"‚úì Model trained: {model_key}")
    print(f"‚Üí Params: {num_params:,}")
    print(f"‚Üí Training time: {train_time:.2f} sec ({train_time/60:.2f} min)")

    return model, history, ckpt_path, num_params, train_time, mem_before, mem_after


In [None]:
# ============================================================
# STEP 8 ‚Äî Evaluation & End-to-End Orchestration (TensorFlow)
# ============================================================

def predict_tf(model, test_ds):
    y_true, y_pred, y_proba = [], [], []

    for images, labels in test_ds:
        probs = model.predict(images, verbose=0).flatten()
        preds = (probs > 0.5).astype(int)

        y_true.extend(labels.numpy())
        y_pred.extend(preds)
        y_proba.extend(probs)

    return np.array(y_true), np.array(y_pred), np.array(y_proba)


def compute_metrics(y_true, y_pred, y_proba):
    acc = (y_true == y_pred).mean()

    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="binary", zero_division=0
    )

    try:
        auc = roc_auc_score(y_true, y_proba)
    except ValueError:
        auc = np.nan

    return {
        "test_accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "auc": auc,
        "report": classification_report(
            y_true,
            y_pred,
            target_names=["Healthy", "Diseased"],
            zero_division=0,
        ),
    }


def save_confusion_matrix(y_true, y_pred, model_key):
    cm = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(6, 5))
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=["Healthy", "Diseased"],
        yticklabels=["Healthy", "Diseased"],
    )
    plt.title(f"{model_key} ‚Äî Confusion Matrix")
    plt.ylabel("True Label")
    plt.xlabel("Predicted Label")

    out_path = f"models/{model_key}_cm.png"
    plt.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close()

    return out_path


def evaluate_model(model, test_ds, model_key):
    print(f"‚Üí Evaluating: {model_key}")

    y_true, y_pred, y_proba = predict_tf(model, test_ds)
    metrics = compute_metrics(y_true, y_pred, y_proba)
    metrics["cm_path"] = save_confusion_matrix(y_true, y_pred, model_key)

    print(
        f"{model_key} ‚Üí ACC={metrics['test_accuracy']:.4f}, "
        f"PREC={metrics['precision']:.4f}, "
        f"REC={metrics['recall']:.4f}, "
        f"F1={metrics['f1']:.4f}, "
        f"AUC={metrics['auc']:.4f}"
    )

    return metrics


def run_all_models(models=MODELS_TO_TRAIN, epochs=EPOCHS):
    print("\n==============================")
    print("===== MODELS TO TRAIN =====")
    for i, m in enumerate(models, start=1):
        print(f"{i}. {m}")
    print("==============================\n")

    # --------------------------------------------------------
    # Load + preprocess dataset
    # --------------------------------------------------------
    plant_village_data, info = load_and_split_plant_village()
    binary_data, total_healthy, total_diseased = make_binary_labels(plant_village_data, info)
    binary_data = build_augmented_train_ds(binary_data, total_healthy, total_diseased)
    train_ds, val_ds, test_ds = prepare_all_splits(binary_data)

    all_rows = []
    histories = {}

    def flatten(prefix, mem):
        return {
            f"{prefix}_ram_gb": mem["ram_gb"],
            f"{prefix}_gpu_current_gb": mem["gpu_current_gb"],
            f"{prefix}_gpu_peak_gb": mem["gpu_peak_gb"],
        }

    # --------------------------------------------------------
    # Loop through models
    # --------------------------------------------------------
    for model_key in models:
        print("\n====================================================")
        print(f"üöÄ Starting training for: {model_key}")
        print("====================================================")

        model, history, ckpt_path, params, train_time, mem_before, mem_after = train_single_model(
            model_key,
            train_ds,
            val_ds,
            epochs,
        )

        print(f"üîç Evaluating model: {model_key}")

        metrics = evaluate_model(
            model,
            test_ds=test_ds,
            model_key=model_key,
        )

        row = {
            "model": model_key,
            "params": params,
            "train_time_sec": train_time,
            "train_time_min": train_time / 60,
            "checkpoint_path": ckpt_path,
            "classification_report": metrics.pop("report"),
        }
        row.update(flatten("mem_before", mem_before))
        row.update(flatten("mem_after", mem_after))
        row.update(metrics)

        all_rows.append(row)
        histories[model_key] = history

        print(f"‚úî Finished model: {model_key}")

    df = pd.DataFrame(all_rows)
    df.to_csv("results/model_comparison.csv", index=False)

    print("\n===== ALL MODELS COMPLETED SUCCESSFULLY =====")
    print("üìÅ Results saved to results/model_comparison.csv\n")

    return df, histories


: 

In [None]:
df, histories = run_all_models()



===== MODELS TO TRAIN =====
1. alexnet

===== Loading PlantVillage Dataset =====


2025-11-27 15:34:05.618687: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3 Pro
2025-11-27 15:34:05.618724: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 18.00 GB
2025-11-27 15:34:05.618730: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 6.66 GB
2025-11-27 15:34:05.618763: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-11-27 15:34:05.618774: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


‚ö† Dataset only has 'train' split ‚Üí Creating 70/15/15 splits...
‚úì train=38012, val=8145, test=8146

===== Converting to Binary Labels (Healthy=0, Diseased=1) =====
Counting healthy vs diseased samples in training split...


2025-11-27 15:34:13.113485: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


‚úì Healthy: 10465
‚úì Diseased: 27547
‚Üí Imbalance ratio: 2.63:1

===== Building Augmented Training Dataset =====
Original counts ‚Üí Healthy=10465, Diseased=27547
Healthy class replication multiplier: 2x
Augmented training ‚Üí Healthy‚âà31395, Diseased‚âà27547
Final ratio ‚âà 1:1

===== Preparing TF Datasets (Batched + Prefetched) =====
‚úì Batch size    : 32
‚úì Image size    : (224, 224)
‚úì train/val/test datasets ready

‚Üí Extracting raw TF dataset for PyTorch pipeline...


2025-11-27 15:36:07.785005: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] ShuffleDatasetV3:8: Filling up shuffle buffer (this may take a while): 45558 of 54303
2025-11-27 15:36:10.574776: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:480] Shuffle buffer filled.
2025-11-27 15:38:38.380581: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [None]:
# ============================================================
# STEP 9 ‚Äî Visualization: Model Comparison Plots
# ============================================================

def plot_bar_metric(df, metric, ylabel, title, filename, sort=True):
    """
    Creates a bar chart for any metric.
    Automatically skips missing or empty metrics.
    """
    if metric not in df.columns:
        print(f"‚ö† Metric '{metric}' not found. Skipping plot.")
        return

    if df[metric].dropna().empty:
        print(f"‚ö† Metric '{metric}' has no values. Skipping plot.")
        return

    # Sort models for readability
    if sort:
        df = df.sort_values(metric, ascending=False)

    plt.figure(figsize=(10, 6))
    x = np.arange(len(df))
    plt.bar(x, df[metric], color="#4F81BD")

    plt.xticks(x, df["model"], rotation=35, ha="right")
    plt.ylabel(ylabel, fontsize=12)
    plt.title(title, fontsize=15, fontweight="bold")
    plt.grid(axis="y", linestyle="--", alpha=0.4)

    plt.tight_layout()
    out_path = os.path.join("results", filename)
    plt.savefig(out_path, dpi=150)
    plt.close()

    print(f"‚úì Saved plot: {out_path}")


def generate_all_plots(results_df):
    print("\n===== Generating Comparison Plots =====")

    df_sorted = results_df.sort_values("test_accuracy", ascending=False)

    # ------------------------------------------------------------
    # CORE PERFORMANCE METRICS
    # ------------------------------------------------------------
    plot_bar_metric(
        df_sorted,
        metric="test_accuracy",
        ylabel="Accuracy",
        title="Model Comparison ‚Äî Test Accuracy",
        filename="accuracy_by_model.png",
    )

    plot_bar_metric(
        df_sorted,
        metric="f1",
        ylabel="F1 Score",
        title="Model Comparison ‚Äî F1 Score",
        filename="f1_by_model.png",
    )

    plot_bar_metric(
        df_sorted,
        metric="auc",
        ylabel="AUC Score",
        title="Model Comparison ‚Äî ROC-AUC",
        filename="auc_by_model.png",
    )

    # ------------------------------------------------------------
    # TRAINING TIME
    # ------------------------------------------------------------
    plot_bar_metric(
        df_sorted,
        metric="train_time_min",
        ylabel="Minutes",
        title="Training Time by Model (Minutes)",
        filename="training_time_by_model.png",
    )

    # ------------------------------------------------------------
    # MEMORY USAGE
    # ------------------------------------------------------------
    if "mem_after_ram_gb" in df_sorted.columns:
        plot_bar_metric(
            df_sorted,
            metric="mem_after_ram_gb",
            ylabel="RAM (GB)",
            title="RAM Usage After Training",
            filename="ram_usage_after_training.png",
            sort=False,
        )

    if "mem_after_gpu_peak_gb" in df_sorted.columns:
        plot_bar_metric(
            df_sorted,
            metric="mem_after_gpu_peak_gb",
            ylabel="Peak GPU Memory (GB)",
            title="GPU Peak Memory by Model",
            filename="gpu_peak_memory_by_model.png",
            sort=False,
        )

    print("\n===== ALL PLOTS GENERATED SUCCESSFULLY =====")
    print("üìÅ Check 'results/' folder for all graphs.")
    print("üìÅ Confusion matrices are inside 'models/' directory.")
    print("üìÅ Comparison table saved at: results/model_comparison.csv")


Saved: results/accuracy_by_model.png
Saved: results/f1_by_model.png
Saved: results/auc_by_model.png
Saved: results/train_time_by_model.png
Saved: results/ram_usage_by_model.png
Saved: results/gpu_peak_by_model.png
All training, evaluation, and comparison complete.
Check models/ for checkpoints, results/ for CSV & plots, logs/ for TensorBoard.
