In [1]:
# ============================
# PlantDocX — Safe-Mode Train & Export (1 epoch, fixed class_names)
# ============================
import os, sys, json, pickle, math, random, warnings
warnings.filterwarnings("ignore")

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import classification_report, accuracy_score
from collections import Counter

try:
    import yaml
    HAVE_YAML = True
except Exception:
    HAVE_YAML = False

# Force CPU to avoid CUDA runtime issues; comment if your GPU is good
try:
    tf.config.set_visible_devices([], "GPU")
except Exception:
    pass

DATA_ROOT   = r"C:\Users\sagni\Downloads\PlantDocX\archive\PlantVillage"
OUTPUT_DIR  = r"C:\Users\sagni\Downloads\PlantDocX"
os.makedirs(OUTPUT_DIR, exist_ok=True)

SEED = 42
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

def has_image_children(path):
    if not os.path.isdir(path): return False
    for d in os.listdir(path):
        p = os.path.join(path, d)
        if os.path.isdir(p):
            for dd in os.listdir(p):
                if dd.lower().endswith((".jpg",".jpeg",".png",".bmp",".webp")):
                    return True
    return False

candidates = [
    DATA_ROOT,
    os.path.join(DATA_ROOT, "color"),
    os.path.join(DATA_ROOT, "segmented"),
    os.path.join(DATA_ROOT, "grayscale"),
]
DATA_DIR = None
for c in candidates:
    if has_image_children(c):
        DATA_DIR = c; break
if DATA_DIR is None:
    raise RuntimeError(f"Could not locate class folders under {DATA_ROOT}")

print(f"[INFO] Using dataset directory: {DATA_DIR}")

IMG_SIZE  = (224, 224)
BASE_BATCH = 32
EPOCHS    = 1
VAL_SPLIT = 0.2

def make_datasets(batch_size):
    # Create raw datasets first, capture class_names now (before transforms)
    tr_raw = tf.keras.utils.image_dataset_from_directory(
        DATA_DIR, labels="inferred", label_mode="int", color_mode="rgb",
        batch_size=batch_size, image_size=IMG_SIZE, shuffle=True, seed=SEED,
        validation_split=VAL_SPLIT, subset="training",
    )
    class_names = tr_raw.class_names  # <-- capture here

    va_raw = tf.keras.utils.image_dataset_from_directory(
        DATA_DIR, labels="inferred", label_mode="int", color_mode="rgb",
        batch_size=batch_size, image_size=IMG_SIZE, shuffle=False, seed=SEED,
        validation_split=VAL_SPLIT, subset="validation",
    )

    # Robust to corrupt files
    try:
        tr = tr_raw.ignore_errors()
        va = va_raw.ignore_errors()
    except Exception:
        try:
            tr = tr_raw.apply(tf.data.experimental.ignore_errors())
            va = va_raw.apply(tf.data.experimental.ignore_errors())
        except Exception:
            tr, va = tr_raw, va_raw

    AUTOTUNE = tf.data.AUTOTUNE
    tr = tr.prefetch(AUTOTUNE)
    va = va.prefetch(AUTOTUNE)
    return tr, va, class_names

train_ds, val_ds, class_names = make_datasets(BASE_BATCH)
num_classes = len(class_names)
print(f"[INFO] Classes ({num_classes}): {class_names[:10]}{'...' if num_classes>10 else ''}")

def gather_labels(ds):
    labs = []
    for _, y in ds:
        labs.extend(y.numpy().tolist())
    return np.array(labs, dtype=int)

y_train_all = gather_labels(train_ds)
counts = Counter(y_train_all.tolist())
total = int(sum(counts.values()))
class_weight = {cls_idx: (total/(num_classes*cnt)) for cls_idx, cnt in counts.items() if cnt > 0}
print("[INFO] Class counts (train):", dict(counts))
print("[INFO] Class weights:", {int(k): float(v) for k, v in class_weight.items()})

data_augment = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.05),
    layers.RandomZoom(0.1),
    layers.RandomContrast(0.1),
], name="augment")
preprocess = layers.Rescaling(1./255)

def build_model(img_size=IMG_SIZE, n_classes=num_classes):
    inputs = keras.Input(shape=img_size + (3,))
    x = data_augment(inputs)
    x = preprocess(x)
    try:
        backbone = keras.applications.EfficientNetB0(include_top=False, weights="imagenet", input_tensor=x)
        x = backbone.output
        print("[INFO] Using EfficientNetB0(weights='imagenet').")
    except Exception as e:
        print("[WARN] EfficientNetB0(weights='imagenet') unavailable; falling back to MobileNetV2(no weights).", e)
        backbone = keras.applications.MobileNetV2(include_top=False, weights=None, input_tensor=x)
        x = backbone.output
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.25)(x)
    outputs = layers.Dense(n_classes, activation="softmax")(x)
    model = keras.Model(inputs, outputs, name="plantdocx_classifier")
    model.compile(optimizer=keras.optimizers.Adam(1e-3),
                  loss="sparse_categorical_crossentropy",
                  metrics=["accuracy"])
    return model

model = build_model()
model.summary(line_length=120)

def train_with_auto_batch(model, base_batch, max_retries=2):
    batch = base_batch
    attempts = 0
    global train_ds, val_ds
    while True:
        try:
            callbacks = [
                keras.callbacks.ModelCheckpoint(
                    filepath=os.path.join(OUTPUT_DIR, "cls_model_best.keras"),
                    monitor="val_accuracy", save_best_only=True, save_weights_only=False
                )
            ]
            history = model.fit(
                train_ds, validation_data=val_ds,
                epochs=EPOCHS,
                class_weight=class_weight if len(class_weight)>0 else None,
                verbose=1, callbacks=callbacks
            )
            return history
        except tf.errors.ResourceExhaustedError:
            attempts += 1
            if attempts > max_retries or batch <= 4:
                print("[ERROR] OOM even after retries; giving up.")
                raise
            batch //= 2
            print(f"[WARN] OOM encountered. Retrying with smaller batch_size={batch} ...")
            train_ds, val_ds, _ = make_datasets(batch)

history = train_with_auto_batch(model, BASE_BATCH)

best_model_path = os.path.join(OUTPUT_DIR, "cls_model_best.keras")
if os.path.exists(best_model_path):
    try:
        model = tf.keras.models.load_model(best_model_path, safe_mode=False)
        print(f"[INFO] Loaded best model from {best_model_path}")
    except Exception as e:
        print("[WARN] Could not reload best .keras model, using current weights.", e)

# Evaluate on validation
y_true, y_pred = [], []
for xb, yb in val_ds:
    probs = model.predict(xb, verbose=0)
    yhat = np.argmax(probs, axis=1)
    y_true.extend(yb.numpy().tolist())
    y_pred.extend(yhat.tolist())
y_true = np.array(y_true, dtype=int)
y_pred = np.array(y_pred, dtype=int)

acc = float(accuracy_score(y_true, y_pred))
print(f"[INFO] Validation accuracy: {acc:.4f}")

# Robust classification report (align labels to full set)
all_labels = np.arange(num_classes, dtype=int)
report = classification_report(
    y_true, y_pred,
    labels=all_labels,
    target_names=class_names,
    digits=4,
    zero_division=0
)
with open(os.path.join(OUTPUT_DIR, "classification_report.txt"), "w", encoding="utf-8") as f:
    f.write(report)
print("[INFO] Wrote classification_report.txt")

# Save models
h5_path = os.path.join(OUTPUT_DIR, "cls_model.h5")
try:
    model.save(h5_path)
    print(f"[INFO] Saved H5 -> {h5_path}")
except Exception as e:
    print("[WARN] Could not save H5 model:", e)

keras_path = os.path.join(OUTPUT_DIR, "cls_model.keras")
try:
    model.save(keras_path)
    print(f"[INFO] Saved native Keras -> {keras_path}")
except Exception as e:
    print("[WARN] Could not save .keras model:", e)

# preprocessor.pkl
preproc = {
    "image_size": IMG_SIZE,
    "rescale": 1/255.0,
    "class_names": class_names,
    "label_map": {name: i for i, name in enumerate(class_names)},
    "augment": {"flip":"horizontal","rotation":0.05,"zoom":0.1,"contrast":0.1},
    "seed": SEED,
    "dataset_dir_used": DATA_DIR
}
with open(os.path.join(OUTPUT_DIR, "preprocessor.pkl"), "wb") as f:
    pickle.dump(preproc, f)
print(f"[INFO] Saved preprocessor.pkl -> {os.path.join(OUTPUT_DIR, 'preprocessor.pkl')}")

# model_config.yaml (or JSON fallback)
config = {
    "project": "PlantDocX",
    "dataset_dir_used": DATA_DIR,
    "img_height": IMG_SIZE[0],
    "img_width": IMG_SIZE[1],
    "num_classes": num_classes,
    "batch_size": BASE_BATCH,
    "epochs": EPOCHS,
    "optimizer": "adam",
    "loss": "sparse_categorical_crossentropy",
    "metrics": ["accuracy"],
    "class_names": class_names,
}
yaml_path = os.path.join(OUTPUT_DIR, "model_config.yaml")
try:
    if HAVE_YAML:
        with open(yaml_path, "w", encoding="utf-8") as f:
            yaml.safe_dump(config, f, sort_keys=False, allow_unicode=True)
    else:
        with open(yaml_path, "w", encoding="utf-8") as f:
            f.write(json.dumps(config, indent=2))
    print(f"[INFO] Saved model_config.yaml -> {yaml_path}")
except Exception as e:
    print("[WARN] Could not write YAML cleanly; writing JSON instead.", e)
    with open(os.path.join(OUTPUT_DIR, "model_config.json"), "w", encoding="utf-8") as f:
        json.dump(config, f, indent=2)

# metrics.json
metrics = {
    "val_accuracy": acc,
    "num_classes": num_classes,
    "class_counts_train": {int(k): int(v) for k, v in counts.items()},
    "class_weights": {int(k): float(v) for k, v in class_weight.items()},
}
with open(os.path.join(OUTPUT_DIR, "metrics.json"), "w", encoding="utf-8") as f:
    json.dump(metrics, f, indent=2)
print(f"[INFO] Saved metrics.json -> {os.path.join(OUTPUT_DIR, 'metrics.json')}")

print("\n[DONE] All artifacts saved in:", OUTPUT_DIR)


[INFO] Using dataset directory: C:\Users\sagni\Downloads\PlantDocX\archive\PlantVillage
Found 20638 files belonging to 15 classes.
Using 16511 files for training.
Found 20638 files belonging to 15 classes.
Using 4127 files for validation.
[INFO] Classes (15): ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot']...
[INFO] Class counts (train): {14: 1229, 5: 1686, 2: 811, 7: 1568, 3: 812, 9: 1379, 12: 2574, 8: 767, 10: 1344, 11: 1143, 1: 1176, 6: 809, 13: 295, 0: 797, 4: 121}
[INFO] Class weights: {14: 0.8956333062110117, 5: 0.652866745749308, 2: 1.3572544184134814, 7: 0.7019982993197279, 3: 1.3555829228243022, 9: 0.7982112642011119, 12: 0.42763532763532763, 8: 1.4351151673185572, 10: 0.8189980158730159, 11: 0.963021289005541, 1: 0.9359977324263039, 6: 1.3606098063452823, 13: 3.7312994350282485, 

[1m516/516[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1077s[0m 2s/step - accuracy: 0.8036 - loss: 0.6381 - val_accuracy: 0.0000e+00 - val_loss: 8.4210
[INFO] Loaded best model from C:\Users\sagni\Downloads\PlantDocX\cls_model_best.keras




[INFO] Validation accuracy: 0.0000
[INFO] Wrote classification_report.txt
[WARN] Could not save H5 model: cannot pickle 'module' object
[INFO] Saved native Keras -> C:\Users\sagni\Downloads\PlantDocX\cls_model.keras
[INFO] Saved preprocessor.pkl -> C:\Users\sagni\Downloads\PlantDocX\preprocessor.pkl
[INFO] Saved model_config.yaml -> C:\Users\sagni\Downloads\PlantDocX\model_config.yaml
[INFO] Saved metrics.json -> C:\Users\sagni\Downloads\PlantDocX\metrics.json

[DONE] All artifacts saved in: C:\Users\sagni\Downloads\PlantDocX
