
# Multi-Model Training & Profiling
Complete binary PlantVillage comparison (MobileNetV2, EfficientNetB0, ResNet50, VGG16, DenseNet121, InceptionV3) with timing and memory stats. Optimized defaults for M3 Pro.


In [2]:

import os
import time
from collections import defaultdict
import math

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers, models

from sklearn.metrics import (
    confusion_matrix,
    classification_report,
    precision_recall_fscore_support,
    roc_auc_score,
)


def get_memory_usage():
    info = {"ram_gb": None, "gpu_current_gb": None, "gpu_peak_gb": None}
    try:
        import psutil
        process = psutil.Process(os.getpid())
        info["ram_gb"] = process.memory_info().rss / (1024 ** 3)
    except Exception:
        pass
    try:
        mem_info = tf.config.experimental.get_memory_info("GPU:0")
        info["gpu_current_gb"] = mem_info["current"] / (1024 ** 3)
        info["gpu_peak_gb"] = mem_info["peak"] / (1024 ** 3)
    except Exception:
        pass
    return info


def reset_gpu_memory_stats():
    try:
        tf.config.experimental.reset_memory_stats("GPU:0")
    except Exception:
        pass


In [3]:

# Global config
BATCH_SIZE = 32
IMAGE_SIZE = (224, 224)
EPOCHS = 10
SEED = 42

os.makedirs("models", exist_ok=True)
os.makedirs("logs", exist_ok=True)
os.makedirs("results", exist_ok=True)

print("Configuring TensorFlow for M3 Pro...")

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"✓ Metal GPU acceleration enabled: {len(gpus)} GPU(s) found")
    except RuntimeError as e:
        print(f"GPU configuration error: {e}")
else:
    print("⚠ No GPU found - training will use CPU (slower)")

# Tune threading for the M3 Pro CPU
tf.config.threading.set_inter_op_parallelism_threads(8)
tf.config.threading.set_intra_op_parallelism_threads(8)

print(f"TensorFlow version: {tf.__version__}")
print(f"Num GPUs Available: {len(tf.config.list_physical_devices('GPU'))}")


Configuring TensorFlow for M3 Pro...
✓ Metal GPU acceleration enabled: 1 GPU(s) found
TensorFlow version: 2.16.2
Num GPUs Available: 1



## 1) Data loading and binary labels


In [4]:

def load_and_split_plant_village(seed=SEED):
    plant_village_data, info = tfds.load(
        "plant_village", with_info=True, as_supervised=True
    )
    if len(plant_village_data) == 1 and "train" in plant_village_data:
        full_train = plant_village_data["train"]
        total_size = sum(1 for _ in full_train)
        train_size = int(0.7 * total_size)
        val_size = int(0.15 * total_size)
        full_train = full_train.shuffle(total_size, seed=seed)
        train_ds = full_train.take(train_size)
        val_ds = full_train.skip(train_size).take(val_size)
        test_ds = full_train.skip(train_size + val_size)
        plant_village_data = {
            "train": train_ds,
            "validation": val_ds,
            "test": test_ds,
        }
    return plant_village_data, info


def make_binary_labels(plant_village_data, info):
    label_names = info.features["label"].names
    binary_lookup = np.array(
        [
            0 if name.split("___", 1)[-1].lower() == "healthy" else 1
            for name in label_names
        ],
        dtype=np.int32,
    )
    binary_lookup_tf = tf.constant(binary_lookup)

    def to_binary_label(image, label):
        label = tf.cast(label, tf.int32)
        binary_label = tf.gather(binary_lookup_tf, label)
        return image, binary_label

    binary_data = {
        split: ds.map(to_binary_label, num_parallel_calls=tf.data.AUTOTUNE)
        for split, ds in plant_village_data.items()
    }

    train_original = plant_village_data["train"]
    counts = defaultdict(lambda: {"healthy": 0, "diseased": 0})
    for _, label in tfds.as_numpy(train_original):
        label_str = info.features["label"].int2str(int(label))
        plant_type, disease_name = label_str.split("___", 1)
        if disease_name.lower() == "healthy":
            counts[plant_type]["healthy"] += 1
        else:
            counts[plant_type]["diseased"] += 1

    total_healthy = sum(v["healthy"] for v in counts.values())
    total_diseased = sum(v["diseased"] for v in counts.values())
    return binary_data, total_healthy, total_diseased



## 2) Augmentation and balancing


In [5]:

def build_augmented_train_ds(binary_data, total_healthy, total_diseased):
    train_ds = binary_data["train"]
    healthy_label = 0
    diseased_label = 1

    healthy_ds = train_ds.filter(lambda _, lbl: tf.equal(lbl, healthy_label))
    diseased_ds = train_ds.filter(lambda _, lbl: tf.equal(lbl, diseased_label))

    def augment_healthy(image, label):
        image_f = tf.image.convert_image_dtype(image, tf.float32)
        image_f = tf.image.random_flip_left_right(image_f)
        image_f = tf.image.random_flip_up_down(image_f)
        image_f = tf.image.rot90(
            image_f,
            tf.random.uniform([], minval=0, maxval=4, dtype=tf.int32),
        )
        image_f = tf.image.random_saturation(image_f, 0.8, 1.25)
        image_f = tf.image.random_hue(image_f, 0.05)
        image_f = tf.image.random_brightness(image_f, 0.12)
        image_f = tf.image.random_contrast(image_f, 0.8, 1.25)
        image_f = tf.clip_by_value(image_f, 0.0, 1.0)
        image_aug = tf.image.convert_image_dtype(image_f, tf.uint8)
        return image_aug, label

    def augment_diseased_with_replacement(image, label):
        image_f = tf.image.convert_image_dtype(image, tf.float32)

        def augmented():
            aug = tf.image.random_flip_left_right(image_f)
            aug = tf.image.rot90(
                aug, tf.random.uniform([], minval=0, maxval=4, dtype=tf.int32)
            )
            aug = tf.image.random_contrast(aug, 0.9, 1.1)
            aug = tf.image.random_brightness(aug, 0.08)
            aug = tf.image.random_hue(aug, 0.03)
            aug = tf.clip_by_value(aug, 0.0, 1.0)
            return tf.image.convert_image_dtype(aug, tf.uint8)

        def original():
            return image

        choice = tf.random.uniform([], 0.0, 1.0)
        return tf.cond(choice > 0.5, augmented, original), label

    healthy_multiplier = 0
    if total_healthy > 0:
        healthy_multiplier = max(1, math.ceil(total_diseased / total_healthy) - 1)

    augmented_healthy_datasets = [healthy_ds]
    for _ in range(healthy_multiplier):
        augmented_healthy_datasets.append(
            healthy_ds.map(augment_healthy, num_parallel_calls=tf.data.AUTOTUNE)
        )

    healthy_augmented_ds = augmented_healthy_datasets[0]
    for ds in augmented_healthy_datasets[1:]:
        healthy_augmented_ds = healthy_augmented_ds.concatenate(ds)

    healthy_augmented_ds = healthy_augmented_ds.shuffle(4096)

    diseased_augmented_ds = diseased_ds.map(
        augment_diseased_with_replacement, num_parallel_calls=tf.data.AUTOTUNE
    )

    augmented_train_ds = healthy_augmented_ds.concatenate(diseased_augmented_ds)
    augmented_train_ds = augmented_train_ds.shuffle(8192).prefetch(tf.data.AUTOTUNE)

    binary_data["train"] = augmented_train_ds
    return binary_data



## 3) tf.data prep


In [6]:

def prepare_dataset(ds, batch_size=BATCH_SIZE, image_size=IMAGE_SIZE):
    autotune = tf.data.AUTOTUNE

    def preprocess(image, label):
        image = tf.image.resize(image, image_size)
        image = tf.cast(image, tf.float32) / 255.0
        return image, label

    return (
        ds.map(preprocess, num_parallel_calls=autotune)
        .batch(batch_size)
        .prefetch(autotune)
    )


def prepare_all_splits(binary_data):
    train_ds = prepare_dataset(binary_data["train"])
    val_ds = prepare_dataset(binary_data["validation"])
    test_ds = prepare_dataset(binary_data["test"])
    return train_ds, val_ds, test_ds



## 4) Model factory


In [8]:

def create_binary_classifier(base_model_fn, input_shape=(224, 224, 3)):
    base_model = base_model_fn(
        input_shape=input_shape, include_top=False, weights="imagenet"
    )
    base_model.trainable = False
    x = layers.GlobalAveragePooling2D()(base_model.output)
    x = layers.Dropout(0.2)(x)
    output = layers.Dense(1, activation="sigmoid")(x)
    model = models.Model(inputs=base_model.input, outputs=output)
    return model


def get_model_builder(model_key):
    key = model_key.lower()
    if key == "mobilenet_v2":
        return lambda: create_binary_classifier(tf.keras.applications.MobileNetV2)
    if key == "efficientnet_b0":
        return lambda: create_binary_classifier(tf.keras.applications.EfficientNetB0)
    if key == "resnet50":
        return lambda: create_binary_classifier(tf.keras.applications.ResNet50)
    if key == "vgg16":
        return lambda: create_binary_classifier(tf.keras.applications.VGG16)
    if key == "densenet121":
        return lambda: create_binary_classifier(tf.keras.applications.DenseNet121)
    if key == "inception_v3":
        return lambda: create_binary_classifier(tf.keras.applications.InceptionV3)
    raise ValueError(f"Unknown model key: {model_key}")


MODELS_TO_TRAIN = [
    # "mobilenet_v2",
    # "efficientnet_b0",
    "resnet50"
    # "vgg16",
    # "densenet121",
    # "inception_v3",
]



## 5) Training helpers


In [9]:

def compile_model(model, lr=1e-3):
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss="binary_crossentropy",
        metrics=["accuracy", tf.keras.metrics.AUC(name="auc")],
    )


def train_single_model(model_key, train_ds, val_ds, epochs=EPOCHS):
    builder = get_model_builder(model_key)
    model = builder()
    compile_model(model)

    num_params = model.count_params()

    timestamp = time.strftime("%Y%m%d-%H%M%S")
    log_dir = os.path.join("logs", f"{model_key}_{timestamp}")
    ckpt_path = os.path.join("models", f"{model_key}_{timestamp}_best.h5")

    callbacks = [
        tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1),
        tf.keras.callbacks.EarlyStopping(
            monitor="val_loss", patience=3, restore_best_weights=True, verbose=1
        ),
        tf.keras.callbacks.ModelCheckpoint(
            ckpt_path, monitor="val_accuracy", save_best_only=True, verbose=1
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss", factor=0.5, patience=2, verbose=1, min_lr=1e-6
        ),
    ]

    reset_gpu_memory_stats()
    mem_before = get_memory_usage()
    t0 = time.time()

    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=callbacks,
        verbose=1,
    )

    t1 = time.time()
    mem_after = get_memory_usage()

    train_time_sec = t1 - t0

    return model, history, ckpt_path, num_params, train_time_sec, mem_before, mem_after



## 6) Evaluation helpers


In [10]:

def evaluate_model(model, test_ds, model_key):
    y_true = []
    y_pred = []
    y_proba = []

    for images, labels in test_ds:
        probs = model.predict(images, verbose=0).flatten()
        preds = (probs > 0.5).astype(int)
        y_proba.extend(probs)
        y_pred.extend(preds)
        y_true.extend(labels.numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_proba = np.array(y_proba)

    acc = (y_true == y_pred).mean()
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="binary"
    )

    try:
        auc = roc_auc_score(y_true, y_proba)
    except ValueError:
        auc = np.nan

    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6, 5))
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=["Healthy", "Diseased"],
        yticklabels=["Healthy", "Diseased"],
    )
    plt.title(f"{model_key} - Confusion Matrix")
    plt.ylabel("True")
    plt.xlabel("Predicted")
    cm_path = os.path.join("models", f"{model_key}_confusion_matrix.png")
    plt.savefig(cm_path, dpi=150, bbox_inches="tight")
    plt.close()

    print(f"{model_key} Test Accuracy: {acc:.4f} ({acc*100:.2f}%)")
    print(f"Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f} | AUC: {auc:.4f}")

    return {
        "test_accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "auc": auc,
        "cm_path": cm_path,
    }



## 7) Run training & collect results


In [None]:

plant_village_data, info = load_and_split_plant_village()
binary_data, total_healthy, total_diseased = make_binary_labels(plant_village_data, info)
binary_data = build_augmented_train_ds(binary_data, total_healthy, total_diseased)
train_ds, val_ds, test_ds = prepare_all_splits(binary_data)

all_results = []
histories = {}

for model_key in MODELS_TO_TRAIN:
    model, history, ckpt_path, num_params, train_time_sec, mem_before, mem_after = (
        train_single_model(model_key, train_ds, val_ds, epochs=EPOCHS)
    )

    metrics = evaluate_model(model, test_ds, model_key)

    def mem_dict_to_flat(prefix, d):
        return {
            f"{prefix}_ram_gb": d.get("ram_gb"),
            f"{prefix}_gpu_current_gb": d.get("gpu_current_gb"),
            f"{prefix}_gpu_peak_gb": d.get("gpu_peak_gb"),
        }

    row = {
        "model": model_key,
        "num_params": num_params,
        "train_time_sec": train_time_sec,
        "train_time_min": train_time_sec / 60.0,
        "checkpoint_path": ckpt_path,
    }
    row.update(mem_dict_to_flat("mem_before", mem_before))
    row.update(mem_dict_to_flat("mem_after", mem_after))
    row.update(metrics)

    all_results.append(row)
    histories[model_key] = history

results_df = pd.DataFrame(all_results)
csv_path = os.path.join("results", "model_comparison.csv")
results_df.to_csv(csv_path, index=False)
print(f"Model comparison table saved to: {csv_path}")
print(results_df[["model", "test_accuracy", "f1", "auc", "train_time_min"]])


2025-11-27 13:35:08.921419: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3 Pro
2025-11-27 13:35:08.921453: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 18.00 GB
2025-11-27 13:35:08.921456: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 6.66 GB
2025-11-27 13:35:08.921474: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-11-27 13:35:08.921486: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2025-11-27 13:35:12.529265: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-11-27 13:35:19.734561: W

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 0us/step
Epoch 1/10


2025-11-27 13:35:25.707189: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
2025-11-27 13:35:37.395942: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] ShuffleDatasetV3:76: Filling up shuffle buffer (this may take a while): 6490 of 8192
2025-11-27 13:35:37.788232: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:480] Shuffle buffer filled.


    856/Unknown [1m159s[0m 167ms/step - accuracy: 0.9630 - auc: 0.1083 - loss: 0.0712


## 8) Comparison plots


In [None]:

def plot_bar_metric(df, metric, ylabel, title, filename):
    plt.figure(figsize=(8, 5))
    x = np.arange(len(df))
    plt.bar(x, df[metric])
    plt.xticks(x, df["model"], rotation=30, ha="right")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.tight_layout()
    out_path = os.path.join("results", filename)
    plt.savefig(out_path, dpi=150)
    plt.close()
    print(f"Saved: {out_path}")

results_df_sorted = results_df.sort_values("test_accuracy", ascending=False)

plot_bar_metric(
    results_df_sorted,
    "test_accuracy",
    "Accuracy",
    "Test Accuracy by Model",
    "accuracy_by_model.png",
)

plot_bar_metric(
    results_df_sorted,
    "f1",
    "F1-score",
    "F1-score by Model",
    "f1_by_model.png",
)

plot_bar_metric(
    results_df_sorted,
    "auc",
    "AUC",
    "ROC-AUC by Model",
    "auc_by_model.png",
)

plot_bar_metric(
    results_df_sorted,
    "train_time_min",
    "Minutes",
    "Training Time (min) by Model",
    "train_time_by_model.png",
)

if results_df_sorted["mem_after_ram_gb"].notna().any():
    plot_bar_metric(
        results_df_sorted,
        "mem_after_ram_gb",
        "GB",
        "RAM Usage After Training by Model",
        "ram_usage_by_model.png",
    )

if results_df_sorted["mem_after_gpu_peak_gb"].notna().any():
    plot_bar_metric(
        results_df_sorted,
        "mem_after_gpu_peak_gb",
        "GB",
        "GPU Peak Memory by Model",
        "gpu_peak_by_model.png",
    )

print("All training, evaluation, and comparison complete.")
print("Check models/ for checkpoints, results/ for CSV & plots, logs/ for TensorBoard.")
