In [1]:
import os
import json
import yaml
import pickle
import random
import datetime
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
from tqdm import tqdm

# ML / DL
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

# Metrics / Plots
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import itertools

# ----------------------------
# CONFIG (Windows raw strings)
# ----------------------------
DATA_DIRS = [
    r"C:\Users\sagni\Downloads\Plastic Detector\archive\dataset-resized\plastic",
    r"C:\Users\sagni\Downloads\Plastic Detector\archive\dataset-resized\trash",
    r"C:\Users\sagni\Downloads\Plastic Detector\archive\dataset-resized\paper",
    r"C:\Users\sagni\Downloads\Plastic Detector\archive\dataset-resized\metal",
    r"C:\Users\sagni\Downloads\Plastic Detector\archive\dataset-resized\glass",
    r"C:\Users\sagni\Downloads\Plastic Detector\archive\dataset-resized\cardboard",
]
# Parent directory containing all class folders (…\dataset-resized)
DATA_ROOT = str(Path(DATA_DIRS[0]).parent)

OUTPUT_DIR = r"C:\Users\sagni\Downloads\Plastic Detector"
MODEL_H5 = str(Path(OUTPUT_DIR) / "model.h5")
CLASS_PKL = str(Path(OUTPUT_DIR) / "class_indices.pkl")
RUN_YAML  = str(Path(OUTPUT_DIR) / "run_config.yaml")
METRICS_JSON = str(Path(OUTPUT_DIR) / "metrics.json")
VAL_PRED_JSON = str(Path(OUTPUT_DIR) / "val_predictions.json")

# Plot / report artifacts
ACC_PNG = str(Path(OUTPUT_DIR) / "accuracy_loss.png")
CM_PNG  = str(Path(OUTPUT_DIR) / "confusion_matrix.png")
CR_CSV  = str(Path(OUTPUT_DIR) / "classification_report.csv")
CM_CSV  = str(Path(OUTPUT_DIR) / "confusion_matrix.csv")

# Training hyperparams
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 12
VAL_SPLIT = 0.2
SEED = 42
LEARNING_RATE = 1e-3
AUGMENT = True  # set False for baseline

# ----------------------------
# Reproducibility
# ----------------------------
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

set_seed(SEED)

# ----------------------------
# Prepare output directory
# ----------------------------
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# ----------------------------
# Verify folders exist
# ----------------------------
expected_classes = sorted([Path(p).name for p in DATA_DIRS])
if not Path(DATA_ROOT).exists():
    raise FileNotFoundError(f"DATA_ROOT not found: {DATA_ROOT}")
for p in DATA_DIRS:
    if not Path(p).exists():
        raise FileNotFoundError(f"Class folder missing: {p}")

print("[INFO] Data root:", DATA_ROOT)
print("[INFO] Classes:", expected_classes)

# ----------------------------
# Data generators (Keras)
# ----------------------------
if AUGMENT:
    train_datagen = ImageDataGenerator(
        preprocessing_function=preprocess_input,
        validation_split=VAL_SPLIT,
        rotation_range=15,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True,
        fill_mode="nearest",
    )
else:
    train_datagen = ImageDataGenerator(
        preprocessing_function=preprocess_input,
        validation_split=VAL_SPLIT,
    )

val_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    validation_split=VAL_SPLIT
)

# Directory structure must be:
# DATA_ROOT/cardboard, glass, metal, paper, plastic, trash
train_gen = train_datagen.flow_from_directory(
    DATA_ROOT,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    classes=expected_classes,     # lock class order to your folders
    class_mode="categorical",
    shuffle=True,
    subset="training",
    seed=SEED
)

val_gen = val_datagen.flow_from_directory(
    DATA_ROOT,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    classes=expected_classes,
    class_mode="categorical",
    shuffle=False,
    subset="validation",
    seed=SEED
)

num_classes = len(train_gen.class_indices)
print("[INFO] Class indices:", train_gen.class_indices)

# ----------------------------
# Build model (EfficientNetB0 TL)
# ----------------------------
device_name = "/GPU:0" if tf.config.list_physical_devices("GPU") else "/CPU:0"
with tf.device(device_name):
    base = EfficientNetB0(include_top=False, input_shape=(*IMG_SIZE, 3), weights="imagenet")
    base.trainable = False  # freeze backbone first

    inputs = layers.Input(shape=(*IMG_SIZE, 3))
    x = base(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)
    model = models.Model(inputs, outputs)

    opt = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
    model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=["accuracy"])
    model.summary()

# ----------------------------
# Callbacks
# ----------------------------
callbacks = [
    EarlyStopping(monitor="val_accuracy", patience=3, restore_best_weights=True),
    ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2, min_lr=1e-6, verbose=1),
    ModelCheckpoint(MODEL_H5, monitor="val_accuracy", save_best_only=True, verbose=1)
]

# ----------------------------
# Train
# ----------------------------
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

# Optional quick fine-tune (uncomment to enable):
# base.trainable = True
# for layer in base.layers[:-40]:
#     layer.trainable = False
# model.compile(optimizer=tf.keras.optimizers.Adam(1e-4),
#               loss="categorical_crossentropy",
#               metrics=["accuracy"])
# history_ft = model.fit(train_gen, validation_data=val_gen, epochs=4, callbacks=callbacks, verbose=1)

# Ensure best model saved; also save current model graph
model.save(MODEL_H5)
print(f"[INFO] Saved model: {MODEL_H5}")

# ----------------------------
# Save class indices (PKL)
# ----------------------------
with open(CLASS_PKL, "wb") as f:
    pickle.dump(train_gen.class_indices, f)
print(f"[INFO] Saved class indices: {CLASS_PKL}")

# ----------------------------
# Compute and save metrics JSON
# ----------------------------
final_train_acc = float(history.history["accuracy"][-1])
final_train_loss = float(history.history["loss"][-1])
final_val_acc = float(history.history["val_accuracy"][-1])
final_val_loss = float(history.history["val_loss"][-1])

metrics_payload = {
    "timestamp": datetime.datetime.now().isoformat(),
    "device": device_name,
    "epochs_run": len(history.history["loss"]),
    "final": {
        "train_accuracy": final_train_acc,
        "train_loss": final_train_loss,
        "val_accuracy": final_val_acc,
        "val_loss": final_val_loss
    },
    "history": {k: [float(x) for x in v] for k, v in history.history.items()}
}

with open(METRICS_JSON, "w", encoding="utf-8") as f:
    json.dump(metrics_payload, f, indent=2)
print(f"[INFO] Saved metrics: {METRICS_JSON}")

# ----------------------------
# Validation predictions JSON
# ----------------------------
idx_to_class = {v: k for k, v in train_gen.class_indices.items()}

val_gen.reset()
all_probs = model.predict(val_gen, verbose=1)
top1_idx = np.argmax(all_probs, axis=1)
top1_conf = np.max(all_probs, axis=1)

val_records = []
for rel_path, pred_i, conf in zip(val_gen.filenames, top1_idx, top1_conf):
    val_records.append({
        "file": rel_path.replace("\\", "/"),
        "pred_class": idx_to_class[int(pred_i)],
        "confidence": float(conf)
    })

with open(VAL_PRED_JSON, "w", encoding="utf-8") as f:
    json.dump(val_records, f, indent=2)
print(f"[INFO] Saved validation predictions: {VAL_PRED_JSON}")

# ----------------------------
# PLOTS: Accuracy/Loss & Confusion Matrix
# ----------------------------
# 1) Accuracy & Loss curves (combined + separate)
plt.figure(figsize=(9, 6))
plt.plot(history.history["accuracy"], label="Train Acc")
plt.plot(history.history["val_accuracy"], label="Val Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training vs Validation Accuracy")
plt.legend(loc="lower right")
plt.grid(alpha=0.25)
plt.tight_layout()
plt.savefig(ACC_PNG.replace(".png", "_acc.png"), dpi=200)
plt.close()

plt.figure(figsize=(9, 6))
plt.plot(history.history["loss"], label="Train Loss")
plt.plot(history.history["val_loss"], label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend(loc="upper right")
plt.grid(alpha=0.25)
plt.tight_layout()
plt.savefig(ACC_PNG.replace(".png", "_loss.png"), dpi=200)
plt.close()

# Combined canvas
fig = plt.figure(figsize=(10, 8))
ax1 = fig.add_subplot(2, 1, 1)
ax1.plot(history.history["accuracy"], label="Train Acc")
ax1.plot(history.history["val_accuracy"], label="Val Acc")
ax1.set_xlabel("Epoch"); ax1.set_ylabel("Accuracy"); ax1.set_title("Accuracy")
ax1.grid(alpha=0.25); ax1.legend(loc="lower right")

ax2 = fig.add_subplot(2, 1, 2)
ax2.plot(history.history["loss"], label="Train Loss")
ax2.plot(history.history["val_loss"], label="Val Loss")
ax2.set_xlabel("Epoch"); ax2.set_ylabel("Loss"); ax2.set_title("Loss")
ax2.grid(alpha=0.25); ax2.legend(loc="upper right")

fig.tight_layout()
fig.savefig(ACC_PNG, dpi=200)
plt.close(fig)

print(f"[INFO] Saved accuracy/loss plots: {ACC_PNG} (+ separate _acc/_loss PNGs)")

# 2) Confusion Matrix (counts + normalized heatmap) + reports
y_true = val_gen.classes                            # true indices (order aligned to class_indices)
labels_order = [idx_to_class[i] for i in range(num_classes)]
cm = confusion_matrix(y_true, top1_idx, labels=list(range(num_classes)))
cm_norm = cm.astype("float") / cm.sum(axis=1, keepdims=True)
cm_norm = np.nan_to_num(cm_norm)

# Save raw counts and classification report
pd.DataFrame(cm, index=labels_order, columns=labels_order).to_csv(CM_CSV, index=True)
pd.DataFrame(
    classification_report(y_true, top1_idx, target_names=labels_order, output_dict=True)
).to_csv(CR_CSV)
print(f"[INFO] Saved classification report CSV: {CR_CSV}")
print(f"[INFO] Saved confusion matrix CSV: {CM_CSV}")

# Plot normalized heatmap with annotations (counts + %)
fig = plt.figure(figsize=(9, 7))
ax = plt.gca()
im = ax.imshow(cm_norm, interpolation="nearest", cmap="viridis")
plt.title("Confusion Matrix (Normalized)")
cbar = plt.colorbar(im, fraction=0.046, pad=0.04)
cbar.ax.set_ylabel("Proportion", rotation=90)
tick_marks = np.arange(len(labels_order))
plt.xticks(tick_marks, labels_order, rotation=45, ha="right")
plt.yticks(tick_marks, labels_order)

thresh = cm_norm.max() / 2.0
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    count = cm[i, j]
    perc = cm_norm[i, j] * 100.0
    txt = f"{count}\n{perc:.1f}%"
    ax.text(j, i, txt,
            ha="center", va="center",
            color="white" if cm_norm[i, j] > thresh else "black",
            fontsize=9)

plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.tight_layout()
plt.savefig(CM_PNG, dpi=220)
plt.close(fig)

print(f"[INFO] Saved confusion matrix heatmap: {CM_PNG}")

# ----------------------------
# Save YAML run config (last)
# ----------------------------
run_cfg = {
    "run": {
        "timestamp": datetime.datetime.now().isoformat(),
        "seed": SEED,
        "device": device_name
    },
    "data": {
        "data_root": DATA_ROOT,
        "class_dirs": DATA_DIRS,
        "classes": expected_classes,
        "val_split": VAL_SPLIT,
        "image_size": list(IMG_SIZE),
        "batch_size": BATCH_SIZE,
        "augment": AUGMENT
    },
    "model": {
        "architecture": "EfficientNetB0",
        "transfer_learning": True,
        "frozen_base": True,
        "optimizer": "Adam",
        "learning_rate": LEARNING_RATE,
        "epochs": EPOCHS,
        "num_classes": num_classes
    },
    "artifacts": {
        "model_h5": MODEL_H5,
        "class_indices_pkl": CLASS_PKL,
        "metrics_json": METRICS_JSON,
        "val_predictions_json": VAL_PRED_JSON,
        "accuracy_loss_png": ACC_PNG,
        "confusion_matrix_png": CM_PNG,
        "classification_report_csv": CR_CSV,
        "confusion_matrix_csv": CM_CSV
    }
}

with open(RUN_YAML, "w", encoding="utf-8") as f:
    yaml.safe_dump(run_cfg, f, sort_keys=False, allow_unicode=True)
print(f"[INFO] Saved run config: {RUN_YAML}")

print("\n[DONE] All artifacts saved to:", OUTPUT_DIR)


[INFO] Data root: C:\Users\sagni\Downloads\Plastic Detector\archive\dataset-resized
[INFO] Classes: ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
Found 2024 images belonging to 6 classes.
Found 503 images belonging to 6 classes.
[INFO] Class indices: {'cardboard': 0, 'glass': 1, 'metal': 2, 'paper': 3, 'plastic': 4, 'trash': 5}


  self._warn_if_super_not_called()


Epoch 1/12
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 593ms/step - accuracy: 0.5149 - loss: 1.3329
Epoch 1: val_accuracy improved from -inf to 0.75149, saving model to C:\Users\sagni\Downloads\Plastic Detector\model.h5




[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 767ms/step - accuracy: 0.5171 - loss: 1.3281 - val_accuracy: 0.7515 - val_loss: 0.7417 - learning_rate: 0.0010
Epoch 2/12
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 578ms/step - accuracy: 0.8090 - loss: 0.5856
Epoch 2: val_accuracy improved from 0.75149 to 0.79523, saving model to C:\Users\sagni\Downloads\Plastic Detector\model.h5




[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 668ms/step - accuracy: 0.8091 - loss: 0.5853 - val_accuracy: 0.7952 - val_loss: 0.5951 - learning_rate: 0.0010
Epoch 3/12
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 593ms/step - accuracy: 0.8416 - loss: 0.4920
Epoch 3: val_accuracy improved from 0.79523 to 0.80318, saving model to C:\Users\sagni\Downloads\Plastic Detector\model.h5




[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 680ms/step - accuracy: 0.8417 - loss: 0.4915 - val_accuracy: 0.8032 - val_loss: 0.5392 - learning_rate: 0.0010
Epoch 4/12
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 580ms/step - accuracy: 0.8663 - loss: 0.3918
Epoch 4: val_accuracy improved from 0.80318 to 0.81909, saving model to C:\Users\sagni\Downloads\Plastic Detector\model.h5




[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 676ms/step - accuracy: 0.8663 - loss: 0.3918 - val_accuracy: 0.8191 - val_loss: 0.5196 - learning_rate: 0.0010
Epoch 5/12
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.8781 - loss: 0.3590
Epoch 5: val_accuracy improved from 0.81909 to 0.82306, saving model to C:\Users\sagni\Downloads\Plastic Detector\model.h5




[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m108s[0m 2s/step - accuracy: 0.8781 - loss: 0.3590 - val_accuracy: 0.8231 - val_loss: 0.4912 - learning_rate: 0.0010
Epoch 6/12
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 584ms/step - accuracy: 0.8807 - loss: 0.3371
Epoch 6: val_accuracy improved from 0.82306 to 0.83499, saving model to C:\Users\sagni\Downloads\Plastic Detector\model.h5




[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 673ms/step - accuracy: 0.8807 - loss: 0.3372 - val_accuracy: 0.8350 - val_loss: 0.4766 - learning_rate: 0.0010
Epoch 7/12
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 606ms/step - accuracy: 0.9061 - loss: 0.2780
Epoch 7: val_accuracy did not improve from 0.83499
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 688ms/step - accuracy: 0.9059 - loss: 0.2785 - val_accuracy: 0.8330 - val_loss: 0.4677 - learning_rate: 0.0010
Epoch 8/12
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 619ms/step - accuracy: 0.9108 - loss: 0.2917
Epoch 8: val_accuracy did not improve from 0.83499
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 707ms/step - accuracy: 0.9108 - loss: 0.2917 - val_accuracy: 0.8270 - val_loss: 0.4720 - learning_rate: 0.0010
Epoch 9/12
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 577ms/step - accuracy: 0.9094 - loss: 0.2835
Epoch 9: v



[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 666ms/step - accuracy: 0.9094 - loss: 0.2834 - val_accuracy: 0.8410 - val_loss: 0.4447 - learning_rate: 0.0010
Epoch 10/12
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 763ms/step - accuracy: 0.9117 - loss: 0.2687
Epoch 10: val_accuracy improved from 0.84095 to 0.86084, saving model to C:\Users\sagni\Downloads\Plastic Detector\model.h5




[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 852ms/step - accuracy: 0.9118 - loss: 0.2684 - val_accuracy: 0.8608 - val_loss: 0.4439 - learning_rate: 0.0010
Epoch 11/12
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 590ms/step - accuracy: 0.9145 - loss: 0.2503
Epoch 11: val_accuracy did not improve from 0.86084
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 673ms/step - accuracy: 0.9146 - loss: 0.2503 - val_accuracy: 0.8569 - val_loss: 0.4304 - learning_rate: 0.0010
Epoch 12/12
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 620ms/step - accuracy: 0.9209 - loss: 0.2471
Epoch 12: val_accuracy did not improve from 0.86084
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 703ms/step - accuracy: 0.9210 - loss: 0.2469 - val_accuracy: 0.8509 - val_loss: 0.4264 - learning_rate: 0.0010




[INFO] Saved model: C:\Users\sagni\Downloads\Plastic Detector\model.h5
[INFO] Saved class indices: C:\Users\sagni\Downloads\Plastic Detector\class_indices.pkl
[INFO] Saved metrics: C:\Users\sagni\Downloads\Plastic Detector\metrics.json
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 633ms/step
[INFO] Saved validation predictions: C:\Users\sagni\Downloads\Plastic Detector\val_predictions.json
[INFO] Saved accuracy/loss plots: C:\Users\sagni\Downloads\Plastic Detector\accuracy_loss.png (+ separate _acc/_loss PNGs)
[INFO] Saved classification report CSV: C:\Users\sagni\Downloads\Plastic Detector\classification_report.csv
[INFO] Saved confusion matrix CSV: C:\Users\sagni\Downloads\Plastic Detector\confusion_matrix.csv
[INFO] Saved confusion matrix heatmap: C:\Users\sagni\Downloads\Plastic Detector\confusion_matrix.png
[INFO] Saved run config: C:\Users\sagni\Downloads\Plastic Detector\run_config.yaml

[DONE] All artifacts saved to: C:\Users\sagni\Downloads\Plastic Detector
