# Waste Classification with EfficientNetB0

This notebook builds an end-to-end image classification pipeline for multi-class waste sorting using EfficientNetB0. It covers data acquisition from Google Drive, preprocessing, training with regularization and optimization callbacks, and model evaluation/testing with multiple metrics.


## Workflow Overview

1. **Data acquisition** – Mount Google Drive (or download) and set dataset paths.
2. **Data preprocessing** – Build TensorFlow datasets with augmentation, batching, caching, and normalization.
3. **Model creation** – Fine-tune EfficientNetB0 with dropout, L2 regularization, Adam optimizer, and learning-rate scheduling. Save the best model as `.h5`.
4. **Model evaluation** – Report loss, accuracy, precision, recall, F1, confusion matrix, and classification report on validation/test sets.
5. **Model testing** – Run single-image inference utilities for manual inspection.


In [None]:
import os
from pathlib import Path
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, regularizers
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.preprocessing import image_dataset_from_directory

from sklearn.metrics import classification_report, confusion_matrix, f1_score, precision_score, recall_score



In [None]:
# Global config
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

plt.style.use("seaborn-v0_8")
sns.set_context("talk", font_scale=0.9)

IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 1e-3
DROPOUT_RATE = 0.2
L2_REG = 1e-4
AUTOTUNE = tf.data.AUTOTUNE



## Data Acquisition

The dataset lives on Google Drive with class-specific folders:

- `Cardboard`, `Food Organics`, `Glass`, `Metal`, `Miscellaneous Trash`, `Paper`, `Plastic`, `Textile Trash`, `Vegetation`

Each folder should contain JPEG/PNG images. Update the `DRIVE_DATASET_PATH` below to point to the root directory that holds these class folders (or to a parent directory containing `train`, `val`, and `test` splits).


In [None]:
# Mount Google Drive when running in Google Colab (safe to skip elsewhere)
try:
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive", force_remount=True)
    DEFAULT_DATASET_PATH = "/content/drive/MyDrive/datasets/waste"
    DEFAULT_MODEL_DIR = "/content/drive/MyDrive/experiments/waste"
except ModuleNotFoundError:
    print("Running outside Google Colab. Skipping Drive mount.")
    DEFAULT_DATASET_PATH = "/path/to/local/waste-dataset"
    DEFAULT_MODEL_DIR = "./models"

DATASET_ROOT = Path(os.getenv("WASTE_DATASET_PATH", DEFAULT_DATASET_PATH))
MODEL_DIR = Path(os.getenv("WASTE_MODEL_DIR", DEFAULT_MODEL_DIR))
MODEL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_PATH = MODEL_DIR / "efficientnetb0_waste_classifier.h5"
DATASET_ROOT


In [None]:
if not DATASET_ROOT.exists():
    raise FileNotFoundError(
        f"Dataset path {DATASET_ROOT} was not found. Set WASTE_DATASET_PATH env variable or update DEFAULT_DATASET_PATH."
    )

class_names = sorted([entry.name for entry in DATASET_ROOT.iterdir() if entry.is_dir()])
print(f"Found {len(class_names)} classes:")
for name in class_names:
    num_files = len(list((DATASET_ROOT / name).glob("*.jpg"))) + len(list((DATASET_ROOT / name).glob("*.png")))
    print(f" - {name}: ~{num_files} images (jpg/png)")


## Data Processing

We create TensorFlow datasets directly from the directory tree, applying:

- stratified split (70% train, 15% validation, 15% test) using a deterministic seed
- on-the-fly augmentation: random flip, rotation, color jitter
- batching, caching, shuffling, and prefetching to keep GPUs busy
- EfficientNetB0-specific normalization (built-in preprocessing layer)



In [None]:
VAL_TEST_SPLIT = 0.3  # 70% train, 30% temp (which we split into val/test)

base_kwargs = dict(
    directory=DATASET_ROOT,
    labels="inferred",
    label_mode="categorical",
    batch_size=BATCH_SIZE,
    image_size=IMG_SIZE,
    interpolation="bilinear",
    shuffle=True,
    seed=SEED,
    color_mode="rgb",
)

train_ds = image_dataset_from_directory(
    **base_kwargs,
    validation_split=VAL_TEST_SPLIT,
    subset="training",
)
val_test_ds = image_dataset_from_directory(
    **base_kwargs,
    validation_split=VAL_TEST_SPLIT,
    subset="validation",
)

val_test_cardinality = tf.data.experimental.cardinality(val_test_ds)
val_batches = val_test_cardinality // 2

def split_validation_and_test(dataset, val_batches):
    val_dataset = dataset.take(val_batches)
    test_dataset = dataset.skip(val_batches)
    return val_dataset, test_dataset

val_ds, test_ds = split_validation_and_test(val_test_ds, val_batches)

print(f"Train batches: {tf.data.experimental.cardinality(train_ds)}")
print(f"Validation batches: {tf.data.experimental.cardinality(val_ds)}")
print(f"Test batches: {tf.data.experimental.cardinality(test_ds)}")


In [None]:
data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
    layers.RandomContrast(0.1),
], name="data_augmentation")

preprocess = keras.applications.efficientnet.preprocess_input

def prepare(ds, training=False):
    ds = ds.cache()
    if training:
        ds = ds.shuffle(1000, seed=SEED)
    ds = ds.map(lambda x, y: (preprocess(x), y), num_parallel_calls=AUTOTUNE)
    return ds.prefetch(AUTOTUNE)

train_ds_ready = prepare(train_ds, training=True)
val_ds_ready = prepare(val_ds)
test_ds_ready = prepare(test_ds)



In [None]:
def visualize_batch(dataset, class_names, title="Training batch"):
    images, labels = next(iter(dataset.unbatch().batch(16)))
    plt.figure(figsize=(12, 8))
    for idx in range(min(16, images.shape[0])):
        plt.subplot(4, 4, idx + 1)
        plt.imshow((images[idx].numpy() + 1) / 2)  # revert preprocessing approx
        label_idx = np.argmax(labels[idx].numpy())
        plt.title(class_names[label_idx], fontsize=8)
        plt.axis("off")
    plt.suptitle(title)
    plt.tight_layout()

visualize_batch(train_ds.take(1), class_names)



## Model Creation – EfficientNetB0 Backbone

We fine-tune EfficientNetB0 pretrained on ImageNet. The network includes:

- data augmentation and EfficientNet preprocessing
- global average pooling + dropout (`0.2`)
- dense classifier with L2 regularization to reduce overfitting
- Adam optimizer (`lr=1e-3`), categorical cross-entropy loss
- callbacks: early stopping, model checkpoint, reduce-on-plateau


In [None]:
def build_model(num_classes: int) -> keras.Model:
    inputs = layers.Input(shape=IMG_SIZE + (3,), name="input_image")
    x = data_augmentation(inputs)
    x = preprocess(x)

    base_model = EfficientNetB0(include_top=False, input_tensor=x, weights="imagenet")
    base_model.trainable = False  # freeze for warm-up

    x = layers.GlobalAveragePooling2D(name="avg_pool")(base_model.output)
    x = layers.Dropout(DROPOUT_RATE, name="dropout")
    outputs = layers.Dense(
        num_classes,
        activation="softmax",
        kernel_regularizer=regularizers.l2(L2_REG),
        name="classifier",
    )(x)

    model = keras.Model(inputs=inputs, outputs=outputs, name="EfficientNetB0_waste")

    optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
    model.compile(
        optimizer=optimizer,
        loss="categorical_crossentropy",
        metrics=[
            keras.metrics.CategoricalAccuracy(name="accuracy"),
            keras.metrics.Precision(name="precision"),
            keras.metrics.Recall(name="recall"),
            keras.metrics.TopKCategoricalAccuracy(k=3, name="top3_acc"),
        ],
    )
    return model

model = build_model(len(class_names))
model.summary()



In [None]:
early_stop = keras.callbacks.EarlyStopping(
    monitor="val_accuracy",
    patience=5,
    mode="max",
    restore_best_weights=True,
    verbose=1,
)
reduce_lr = keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.2,
    patience=3,
    min_lr=1e-6,
    verbose=1,
)
checkpoint = keras.callbacks.ModelCheckpoint(
    MODEL_PATH,
    monitor="val_accuracy",
    mode="max",
    save_best_only=True,
    save_weights_only=False,
    verbose=1,
)

history = model.fit(
    train_ds_ready,
    epochs=EPOCHS,
    validation_data=val_ds_ready,
    callbacks=[early_stop, reduce_lr, checkpoint],
)



In [None]:
def plot_history(history_obj):
    hist = pd.DataFrame(history_obj.history)
    metrics = ["loss", "accuracy", "precision", "recall", "top3_acc"]
    fig, axes = plt.subplots(len(metrics), 1, figsize=(8, 20))
    for idx, metric in enumerate(metrics):
        axes[idx].plot(hist[metric], label=f"train_{metric}")
        axes[idx].plot(hist[f"val_{metric}"], label=f"val_{metric}")
        axes[idx].set_title(metric.capitalize())
        axes[idx].legend()
    plt.tight_layout()

plot_history(history)



In [None]:
# Optional fine-tuning: unfreeze top layers for a few more epochs (comment/uncomment as needed)
FINE_TUNE_AT = 200  # unfreeze last 200 layers

base_model = model.get_layer("efficientnetb0")
if base_model:
    base_model.trainable = True
    for layer in base_model.layers[:-FINE_TUNE_AT]:
        layer.trainable = False

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE * 0.1),
        loss="categorical_crossentropy",
        metrics=model.metrics,
    )

    fine_tune_history = model.fit(
        train_ds_ready,
        epochs=EPOCHS,
        validation_data=val_ds_ready,
        callbacks=[early_stop, reduce_lr, checkpoint],
    )
else:
    print("EfficientNetB0 layer not found (was model rebuilt?). Skipping fine-tuning.")



## Model Testing & Evaluation

We reload the best `.h5` checkpoint and evaluate on the reserved validation/test splits. Metrics reported:

- categorical accuracy & loss (from Keras evaluation)
- precision, recall, F1 score (micro-averaged)
- per-class precision/recall/F1 via `classification_report`
- confusion matrix heatmap for detailed error analysis


In [None]:
best_model = keras.models.load_model(MODEL_PATH)
val_metrics = best_model.evaluate(val_ds_ready, verbose=0)
test_metrics = best_model.evaluate(test_ds_ready, verbose=0)

metric_names = best_model.metrics_names
val_results = dict(zip(metric_names, val_metrics))
test_results = dict(zip(metric_names, test_metrics))

print("Validation metrics:")
for k, v in val_results.items():
    print(f" - {k}: {v:.4f}")

print("\nTest metrics:")
for k, v in test_results.items():
    print(f" - {k}: {v:.4f}")


In [None]:
def collect_predictions(model, dataset):
    y_true = []
    y_pred = []
    for batch_images, batch_labels in dataset:
        preds = model.predict(batch_images, verbose=0)
        y_true.append(batch_labels.numpy())
        y_pred.append(preds)
    y_true = np.vstack(y_true)
    y_pred = np.vstack(y_pred)
    return y_true, y_pred

y_true, y_pred = collect_predictions(best_model, test_ds_ready)
y_true_labels = np.argmax(y_true, axis=1)
y_pred_labels = np.argmax(y_pred, axis=1)

precision_micro = precision_score(y_true_labels, y_pred_labels, average="micro")
recall_micro = recall_score(y_true_labels, y_pred_labels, average="micro")
f1_micro = f1_score(y_true_labels, y_pred_labels, average="micro")

print(f"Micro Precision: {precision_micro:.4f}")
print(f"Micro Recall: {recall_micro:.4f}")
print(f"Micro F1: {f1_micro:.4f}")

print("\nClassification Report:")
print(classification_report(y_true_labels, y_pred_labels, target_names=class_names))



In [None]:
cm = confusion_matrix(y_true_labels, y_pred_labels)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix – Test Set")
plt.tight_layout()
plt.show()



### Single Image Testing Utility


In [None]:
def predict_single_image(model, image_path: str):
    img = keras.utils.load_img(image_path, target_size=IMG_SIZE)
    img_array = keras.utils.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = preprocess(img_array)
    preds = model.predict(img_array)
    top_idx = np.argmax(preds[0])
    confidence = preds[0][top_idx]
    plt.imshow(img)
    plt.axis("off")
    plt.title(f"Pred: {class_names[top_idx]} ({confidence:.2%})")
    plt.show()

# Example usage (update path to any image inside the dataset)
# predict_single_image(best_model, DATASET_ROOT / "Plastic" / "example.jpg")



## Next Steps

- Experiment with different EfficientNet variants (B1/B2) or longer fine-tuning.
- Add automated hyperparameter sweeps (Keras Tuner / Optuna) to optimize dropout/L2.
- Deploy the exported `.h5` via TensorFlow Serving, FastAPI, or TF Lite for mobile robotics bins.

