# QC Defect Classifier
Multi-label defect detection model for photo quality assessment.

**Pipeline:**
1. Upload & extract training data zip
2. Load labels and build data pipeline
3. Train MobileNetV2 (frozen head â†’ fine-tune)
4. Evaluate with per-defect threshold tuning
5. Export models (.h5, SavedModel, TFLite) + config JSON

## Imports & GPU Check

In [None]:
import os
import json
import zipfile
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import (
    classification_report, multilabel_confusion_matrix,
    precision_recall_curve, f1_score,
    precision_score, recall_score
)
import matplotlib.pyplot as plt
import random

print("TF version:", tf.__version__)
print("GPU:", tf.config.list_physical_devices('GPU'))

## Upload & Extract Data
Expected zip structure:
```
sorted/
  labels.csv
  good/
  bad/
  unreadable/
```

In [None]:
from google.colab import files

print("Upload your training data zip (images/ folder + labels.csv):")
uploaded = files.upload()
zip_filename = list(uploaded.keys())[0]
print(f"Uploaded: {zip_filename}")

with zipfile.ZipFile(zip_filename, "r") as z:
    z.extractall("training_data")

# Auto-detect structure
if os.path.isfile("training_data/labels.csv"):
    DATA_DIR = "training_data"
elif os.path.isfile("training_data/sorted/labels.csv"):
    DATA_DIR = "training_data/sorted"
else:
    for root, dirs, _files in os.walk("training_data"):
        for d in dirs:
            print(os.path.join(root, d))
    raise FileNotFoundError("Could not find labels.csv. Check your zip structure.")

print(f"Data directory: {DATA_DIR}")

## Config

In [None]:
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS_HEAD = 10
EPOCHS_FT = 20
SEED = 42
DEFAULT_THRESHOLD = 0.5
TOP_K = 3

IMG_DIR = DATA_DIR
LABEL_FILE = os.path.join(DATA_DIR, "labels.csv")

DEFECT_NAMES = [
    "blur", "glare", "shadow", "angle",
    "cropped", "too_far", "too_close", "low_contrast"
]
NUM_DEFECTS = len(DEFECT_NAMES)

## Load Labels & Train/Val Split

In [None]:
df = pd.read_csv(LABEL_FILE)

image_paths = df["filename"].apply(lambda x: os.path.join(IMG_DIR, x)).values
labels = df[DEFECT_NAMES].values.astype("float32")

print("Defect distribution:")
for i, name in enumerate(DEFECT_NAMES):
    count = int(labels[:, i].sum())
    pct = count / len(labels) * 100
    print(f"  {name:15s} {count:5d} ({pct:.1f}%)")
print(f"  {'TOTAL IMAGES':15s} {len(labels):5d}")

# Train/val split
np.random.seed(SEED)
idx = np.random.permutation(len(image_paths))
split = int(0.8 * len(idx))

train_idx, val_idx = idx[:split], idx[split:]
x_train, y_train = image_paths[train_idx], labels[train_idx]
x_val, y_val = image_paths[val_idx], labels[val_idx]

print(f"\nTrain: {len(x_train)} | Val: {len(x_val)}")

## Data Pipeline

In [None]:
def load_image(path, label):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, (IMG_SIZE, IMG_SIZE))
    img = tf.cast(img, tf.float32)
    return img, label

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.cache().shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.cache().batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

## Preview Training Images

In [None]:
plt.figure(figsize=(14, 8))
for images, batch_labels in train_ds.take(1):
    for i in range(min(12, len(images))):
        ax = plt.subplot(3, 4, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        active = [DEFECT_NAMES[j] for j in range(NUM_DEFECTS) if batch_labels[i][j] > 0.5]
        title = ", ".join(active) if active else "no defects"
        plt.title(title, fontsize=8)
        plt.axis("off")
plt.suptitle("Sample Training Images", fontsize=14)
plt.tight_layout()
plt.savefig("preview_training_samples.png", dpi=100)
plt.show()
print("Saved preview_training_samples.png")

## Augmentation & Model

In [None]:
data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.15),
    layers.RandomZoom(0.1),
    layers.RandomBrightness(0.2),
    layers.RandomContrast(0.2),
], name="data_augmentation")

base_model = keras.applications.MobileNetV2(
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    include_top=False,
    weights="imagenet"
)
base_model.trainable = False

print(f"Base model layers: {len(base_model.layers)}")
print(f"Base model params: {base_model.count_params():,}")

inputs = keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x = data_augmentation(inputs)
x = keras.applications.mobilenet_v2.preprocess_input(x)
x = base_model(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(128, activation="relu")(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(NUM_DEFECTS, activation="sigmoid")(x)

model = keras.Model(inputs, outputs, name="qc_defect_model")
model.summary()

## Phase 1: Train Head (Base Frozen)

In [None]:
early_stop = keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=3,
    restore_best_weights=True,
)

reduce_lr = keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.5,
    patience=2,
    min_lr=1e-6,
    verbose=1,
)

model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss="binary_crossentropy",
    metrics=[
        keras.metrics.BinaryAccuracy(name="bin_acc"),
        keras.metrics.AUC(name="auc"),
    ],
)

print("Phase 1: Training classification head (base frozen)...")
history1 = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS_HEAD,
    callbacks=[early_stop, reduce_lr],
)

## Phase 2: Fine-Tune Top Layers

In [None]:
base_model.trainable = True
fine_tune_from = 100
for layer in base_model.layers[:fine_tune_from]:
    layer.trainable = False

trainable_count = sum(1 for l in base_model.layers if l.trainable)
print(f"Fine-tuning {trainable_count} of {len(base_model.layers)} base layers")

early_stop_ft = keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=5,
    restore_best_weights=True,
)

reduce_lr_ft = keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.5,
    patience=2,
    min_lr=1e-6,
    verbose=1,
)

model.compile(
    optimizer=keras.optimizers.Adam(1e-4),
    loss="binary_crossentropy",
    metrics=[
        keras.metrics.BinaryAccuracy(name="bin_acc"),
        keras.metrics.AUC(name="auc"),
    ],
)

print("Phase 2: Fine-tuning top layers...")
history2 = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS_FT,
    callbacks=[early_stop_ft, reduce_lr_ft],
)

## Training History Plots

In [None]:
def combine_histories(h1, h2):
    combined = {}
    for key in h1.history:
        combined[key] = h1.history[key] + h2.history[key]
    return combined

history = combine_histories(history1, history2)
phase1_epochs = len(history1.history["loss"])

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(history["bin_acc"], label="Train")
axes[0].plot(history["val_bin_acc"], label="Validation")
axes[0].axvline(x=phase1_epochs - 0.5, color="gray", linestyle="--", alpha=0.5, label="Fine-tune start")
axes[0].set_title("Binary Accuracy")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Accuracy")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history["auc"], label="Train")
axes[1].plot(history["val_auc"], label="Validation")
axes[1].axvline(x=phase1_epochs - 0.5, color="gray", linestyle="--", alpha=0.5, label="Fine-tune start")
axes[1].set_title("AUC")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("AUC")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

axes[2].plot(history["loss"], label="Train")
axes[2].plot(history["val_loss"], label="Validation")
axes[2].axvline(x=phase1_epochs - 0.5, color="gray", linestyle="--", alpha=0.5, label="Fine-tune start")
axes[2].set_title("Loss")
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("Loss")
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.suptitle("Training History", fontsize=14)
plt.tight_layout()
plt.savefig("training_history.png", dpi=100)
plt.show()
print("Saved training_history.png")

## Evaluation: Collect Predictions

In [None]:
print("=" * 55)
print("EVALUATION")
print("=" * 55)

y_true_list = []
y_pred_raw_list = []

for images, batch_labels in val_ds:
    preds = model.predict(images, verbose=0)
    y_true_list.append(batch_labels.numpy())
    y_pred_raw_list.append(preds)

y_true = np.concatenate(y_true_list, axis=0)
y_pred_raw = np.concatenate(y_pred_raw_list, axis=0)

## Per-Defect Threshold Tuning

In [None]:
def find_best_threshold(y_true_col, y_pred_col, default=DEFAULT_THRESHOLD):
    """Find threshold that maximizes F1 for a single defect column."""
    if y_true_col.sum() == 0:
        return default
    precision, recall, thresholds = precision_recall_curve(y_true_col, y_pred_col)
    precision = precision[:-1]
    recall = recall[:-1]
    with np.errstate(divide="ignore", invalid="ignore"):
        f1_scores = np.where(
            (precision + recall) > 0,
            2 * (precision * recall) / (precision + recall),
            0.0,
        )
    best_idx = np.argmax(f1_scores)
    return float(thresholds[best_idx])

per_defect_thresholds = {}
print("Per-defect optimal thresholds (maximizing F1):")
print("-" * 55)
for i, name in enumerate(DEFECT_NAMES):
    best_t = find_best_threshold(y_true[:, i], y_pred_raw[:, i])
    per_defect_thresholds[name] = round(best_t, 3)
    positives = int(y_true[:, i].sum())
    print(f"  {name:15s}  threshold={best_t:.3f}  (positives={positives})")

threshold_array = np.array([per_defect_thresholds[n] for n in DEFECT_NAMES])
y_pred_tuned = (y_pred_raw >= threshold_array).astype(int)

## Precision-Recall Curves

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
for i, name in enumerate(DEFECT_NAMES):
    ax = axes[i // 4, i % 4]
    if y_true[:, i].sum() > 0:
        precision, recall, thresholds = precision_recall_curve(y_true[:, i], y_pred_raw[:, i])
        ax.plot(recall, precision, linewidth=2)
        best_t = per_defect_thresholds[name]
        best_pred = (y_pred_raw[:, i] >= best_t).astype(int)
        bp = precision_score(y_true[:, i], best_pred, zero_division=0)
        br = recall_score(y_true[:, i], best_pred, zero_division=0)
        ax.plot(br, bp, "ro", markersize=8, label=f"t={best_t:.2f}")
        ax.legend(fontsize=9)
    else:
        ax.text(0.5, 0.5, "No positives", ha="center", va="center", transform=ax.transAxes)
    ax.set_title(name, fontsize=11, fontweight="bold")
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_xlim([0, 1.05])
    ax.set_ylim([0, 1.05])
    ax.grid(True, alpha=0.3)

plt.suptitle("Precision-Recall Curves (red dot = tuned threshold)", fontsize=14)
plt.tight_layout()
plt.savefig("precision_recall_curves.png", dpi=100)
plt.show()
print("Saved precision_recall_curves.png")

## Classification Report

In [None]:
print("Classification Report (per-defect tuned thresholds):")
print("-" * 55)
print(classification_report(
    y_true.astype(int),
    y_pred_tuned,
    target_names=DEFECT_NAMES,
    zero_division=0,
))

y_pred_global = (y_pred_raw >= DEFAULT_THRESHOLD).astype(int)
f1_global = f1_score(y_true.astype(int), y_pred_global, average="macro", zero_division=0)
f1_tuned = f1_score(y_true.astype(int), y_pred_tuned, average="macro", zero_division=0)
print(f"Macro F1 with global threshold (0.5): {f1_global:.4f}")
print(f"Macro F1 with tuned thresholds:       {f1_tuned:.4f}")
print(f"Improvement:                           {(f1_tuned - f1_global):+.4f}")

## Per-Defect Confusion Matrices

In [None]:
mcm = multilabel_confusion_matrix(y_true.astype(int), y_pred_tuned)
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
for i, (cm, name) in enumerate(zip(mcm, DEFECT_NAMES)):
    ax = axes[i // 4, i % 4]
    im = ax.imshow(cm, cmap="Blues", interpolation="nearest")
    t = per_defect_thresholds[name]
    ax.set_title(f"{name} (t={t:.2f})", fontsize=11, fontweight="bold")
    ax.set_xlabel("Predicted")
    ax.set_ylabel("Actual")
    ax.set_xticks([0, 1])
    ax.set_xticklabels(["No", "Yes"])
    ax.set_yticks([0, 1])
    ax.set_yticklabels(["No", "Yes"])
    for row in range(2):
        for col in range(2):
            ax.text(col, row, str(cm[row, col]),
                    ha="center", va="center", fontsize=14,
                    color="white" if cm[row, col] > cm.max() / 2 else "black")

plt.suptitle("Per-Defect Confusion Matrices (tuned thresholds)", fontsize=14)
plt.tight_layout()
plt.savefig("confusion_matrices.png", dpi=100)
plt.show()
print("Saved confusion_matrices.png")

## Sample Predictions

In [None]:
plt.figure(figsize=(16, 10))
shown = 0
for images, batch_labels in val_ds:
    preds = model.predict(images, verbose=0)
    for i in range(len(images)):
        if shown >= 16:
            break
        ax = plt.subplot(4, 4, shown + 1)
        plt.imshow(images[i].numpy().astype("uint8"))

        true_defects = [DEFECT_NAMES[j] for j in range(NUM_DEFECTS) if batch_labels[i][j] > 0.5]
        pred_defects = [
            DEFECT_NAMES[j] for j in range(NUM_DEFECTS)
            if preds[i][j] >= per_defect_thresholds[DEFECT_NAMES[j]]
        ]

        true_str = ", ".join(true_defects) if true_defects else "none"
        pred_str = ", ".join(pred_defects) if pred_defects else "none"
        match = set(true_defects) == set(pred_defects)

        plt.title(f"T: {true_str}\nP: {pred_str}", fontsize=7,
                  color="green" if match else "red")
        plt.axis("off")
        shown += 1
    if shown >= 16:
        break

plt.suptitle("Sample Predictions (green=exact match, red=mismatch)", fontsize=13)
plt.tight_layout()
plt.savefig("sample_predictions.png", dpi=100)
plt.show()
print("Saved sample_predictions.png")

## Save Models

In [None]:
h5_path = "qc_defect_model.h5"
model.save(h5_path)
h5_size = os.path.getsize(h5_path) / (1024 * 1024)
print(f"Saved Keras model: {h5_path} ({h5_size:.1f} MB)")

saved_model_dir = "qc_defect_model"
model.save(saved_model_dir)
print(f"Saved TF SavedModel: {saved_model_dir}/")

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

tflite_path = "qc_defect_model.tflite"
with open(tflite_path, "wb") as f:
    f.write(tflite_model)

tflite_size = os.path.getsize(tflite_path) / (1024 * 1024)
print(f"Saved TFLite model: {tflite_path} ({tflite_size:.1f} MB)")
print(f"Size reduction: {h5_size / tflite_size:.1f}x smaller than .h5")

## Save Thresholds & Config

In [None]:
model_config = {
    "defect_names": DEFECT_NAMES,
    "per_defect_thresholds": per_defect_thresholds,
    "default_threshold": DEFAULT_THRESHOLD,
    "top_k": TOP_K,
    "img_size": IMG_SIZE,
    "severity_weights": {
        "blur": 0.9,
        "glare": 0.7,
        "shadow": 0.5,
        "angle": 0.8,
        "cropped": 1.0,
        "too_far": 0.8,
        "too_close": 0.6,
        "low_contrast": 0.4,
    },
}

config_path = "qc_model_config.json"
with open(config_path, "w") as f:
    json.dump(model_config, f, indent=2)
print(f"Saved model config: {config_path}")

## Download Models & Plots

In [None]:
print("Downloading models, config, and plots...")
files.download(h5_path)
files.download(tflite_path)
files.download(config_path)
files.download("training_history.png")
files.download("precision_recall_curves.png")
files.download("confusion_matrices.png")
files.download("sample_predictions.png")
files.download("preview_training_samples.png")

## Production Inference Function

In [None]:
def predict_photo_quality(image_path, model, config):
    """
    Production-ready inference for a single photo.

    Returns a dict structured for Laravel / API / mobile consumption:
    {
        "filename": "photo_001.jpg",
        "overall_quality": "fail",       # "pass" or "fail"
        "severity_score": 0.82,          # 0.0 (clean) to 1.0 (worst)
        "defects_flagged": [             # defects that crossed their threshold
            {"name": "blur", "confidence": 0.91, "threshold": 0.42},
            {"name": "angle", "confidence": 0.73, "threshold": 0.38}
        ],
        "top_k_risks": [                 # top K defects by confidence, even if below threshold
            {"name": "blur", "confidence": 0.91},
            {"name": "angle", "confidence": 0.73},
            {"name": "glare", "confidence": 0.28}
        ],
        "all_scores": {                  # raw sigmoid outputs
            "blur": 0.91, "glare": 0.28
        }
    }
    """
    thresholds = config["per_defect_thresholds"]
    defect_names = config["defect_names"]
    severity_weights = config["severity_weights"]
    top_k = config.get("top_k", TOP_K)
    img_size = config.get("img_size", IMG_SIZE)

    img = keras.utils.load_img(image_path, target_size=(img_size, img_size))
    img_array = np.expand_dims(keras.utils.img_to_array(img), axis=0)
    raw_preds = model.predict(img_array, verbose=0)[0]

    all_scores = {name: round(float(raw_preds[i]), 4) for i, name in enumerate(defect_names)}

    defects_flagged = []
    for i, name in enumerate(defect_names):
        t = thresholds.get(name, DEFAULT_THRESHOLD)
        if raw_preds[i] >= t:
            defects_flagged.append({
                "name": name,
                "confidence": round(float(raw_preds[i]), 4),
                "threshold": t,
            })
    defects_flagged.sort(key=lambda d: d["confidence"], reverse=True)

    sorted_defects = sorted(
        [(name, float(raw_preds[i])) for i, name in enumerate(defect_names)],
        key=lambda x: x[1],
        reverse=True,
    )
    top_k_risks = [
        {"name": name, "confidence": round(conf, 4)}
        for name, conf in sorted_defects[:top_k]
    ]

    weighted_sum = sum(raw_preds[i] * severity_weights.get(name, 0.5)
                       for i, name in enumerate(defect_names))
    max_possible = sum(severity_weights.get(name, 0.5) for name in defect_names)
    severity_score = round(float(weighted_sum / max_possible), 4)

    return {
        "filename": os.path.basename(image_path),
        "overall_quality": "fail" if defects_flagged else "pass",
        "severity_score": severity_score,
        "defects_flagged": defects_flagged,
        "top_k_risks": top_k_risks,
        "all_scores": all_scores,
    }

## Quick Inference Test

In [None]:
test_indices = random.sample(range(len(x_val)), min(5, len(x_val)))
print("Production inference test:")
print("=" * 55)

for idx in test_indices:
    fpath = x_val[idx]
    result = predict_photo_quality(fpath, model, model_config)

    print(f"\n  {result['filename']}")
    print(f"    Quality:  {result['overall_quality'].upper()}")
    print(f"    Severity: {result['severity_score']:.2f}")

    if result["defects_flagged"]:
        print(f"    Flagged:")
        for d in result["defects_flagged"]:
            print(f"      - {d['name']:15s} {d['confidence']:.3f}  (t={d['threshold']:.3f})")
    else:
        print(f"    Flagged:  none")

    print(f"    Top-{TOP_K} risks:")
    for d in result["top_k_risks"]:
        print(f"      - {d['name']:15s} {d['confidence']:.3f}")

print("\n" + "=" * 55)
print("Example JSON response (for Laravel API):")
print("=" * 55)
sample_result = predict_photo_quality(x_val[test_indices[0]], model, model_config)
print(json.dumps(sample_result, indent=2))

print("\nDone.")