
# Preprocessing ‚Äî Static Augmented Dataset Builder
Build a balanced, augmented PlantVillage binary dataset (0=healthy, 1=diseased) and save RGB PNGs + CSV labels for TensorFlow and PyTorch.


In [None]:
# ======================================================================
#   PLANTVILLAGE PREPROCESSING (JPEG) WITH RESUME + TIMING + A+B OPT
# ======================================================================

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import time
from pathlib import Path
from PIL import Image
import csv, math
import re
from datetime import datetime

# =========================== SETTINGS ================================
BATCH_SIZE = 750
IMAGE_SIZE = (224, 224)
SEED = 42

PROFILE_MODE = False   # profile only 1 batch if True

tf.random.set_seed(SEED)
np.random.seed(SEED)

OUTPUT_DIR = Path("preprocessed")
IMG_DIR = OUTPUT_DIR / "images"
OUTPUT_DIR.mkdir(exist_ok=True)
IMG_DIR.mkdir(exist_ok=True)

CSV_PATH = OUTPUT_DIR / "labels.csv"

# ======================================================================
# RESUME SUPPORT
# ======================================================================
existing_jpegs = []
for f in IMG_DIR.glob("*.jpg"):
    nums = re.findall(r"(\d+)", f.name)
    if nums:
        existing_jpegs.append((int(nums[0]), f.name))

if existing_jpegs:
    existing_jpegs.sort()
    last_num, last_file = existing_jpegs[-1]
    filename_id = last_num + 1
    resume_mode = True
    csv_mode = "a"
    print(f"üîÅ RESUMING from image {filename_id:06d}.jpg (last was {last_file})")
else:
    filename_id = 1
    resume_mode = False
    csv_mode = "w"
    print("üÜï STARTING NEW PREPROCESSING RUN ‚Äî no existing numbered JPEGs found")

# ======================================================================
# LOAD DATASET
# ======================================================================
print("Loading PlantVillage...")
data, info = tfds.load("plant_village", as_supervised=True, with_info=True)
full_dataset = data["train"]

# binary mapping
label_names = info.features["label"].names
binary_lookup = np.array(
    [0 if name.split("___",1)[-1].lower()=="healthy" else 1 for name in label_names],
    dtype=np.int32
)
binary_lookup_tf = tf.constant(binary_lookup)

def to_binary(img, lbl):
    return img, tf.gather(binary_lookup_tf, lbl)

full_dataset = full_dataset.map(to_binary, num_parallel_calls=tf.data.AUTOTUNE)

# ======================================================================
# MANUAL 80/10/10 SPLIT
# ======================================================================
total_count = sum(1 for _ in tfds.as_numpy(full_dataset))

val_size  = int(total_count * 0.10)
test_size = int(total_count * 0.10)
train_size = total_count - val_size - test_size

full_dataset = full_dataset.shuffle(total_count, seed=SEED, reshuffle_each_iteration=False)

val_raw   = full_dataset.take(val_size)
test_raw  = full_dataset.skip(val_size).take(test_size)
train_raw = full_dataset.skip(val_size + test_size)

train_raw = train_raw.prefetch(tf.data.AUTOTUNE)

# ======================================================================
# COUNT HEALTHY/DISEASED IN TRAIN
# ======================================================================
H = 0
D = 0
for _, lbl in tfds.as_numpy(train_raw):
    if lbl == 0: H += 1
    else:        D += 1

healthy_mult = math.ceil(D / H)
print(f"Train Healthy={H}, Diseased={D}, multiplier={healthy_mult}")

# ======================================================================
# AUGMENT (A+B OPTIMIZED)
# ======================================================================
def aug_healthy(img):
    x = tf.image.convert_image_dtype(img, tf.float32)
    x = tf.image.random_flip_left_right(x)
    x = tf.image.random_flip_up_down(x)
    x = tf.image.rot90(x, tf.random.uniform([], 0, 4, dtype=tf.int32))
    x = tf.image.random_saturation(x, 0.8, 1.25)

    # FAST hue (A+B)
    hue_delta = tf.random.uniform([], -0.05, 0.05)
    x = tf.image.adjust_hue(x, hue_delta)

    x = tf.image.random_brightness(x, 0.12)
    x = tf.image.random_contrast(x, 0.8, 1.25)
    return tf.image.convert_image_dtype(tf.clip_by_value(x, 0, 1), tf.uint8)

def aug_diseased(img):
    x = tf.image.convert_image_dtype(img, tf.float32)
    x = tf.image.random_flip_left_right(x)
    x = tf.image.rot90(x, tf.random.uniform([], 0, 4, dtype=tf.int32))
    x = tf.image.random_contrast(x, 0.9, 1.1)

    hue_delta = tf.random.uniform([], -0.03, 0.03)
    x = tf.image.adjust_hue(x, hue_delta)

    x = tf.image.random_brightness(x, 0.08)
    return tf.image.convert_image_dtype(tf.clip_by_value(x, 0, 1), tf.uint8)

# ======================================================================
# BATCH RESIZE
# ======================================================================
def resize_batch(batch_uint8):
    batch_f = tf.image.convert_image_dtype(batch_uint8, tf.float32)
    batch_r = tf.image.resize(batch_f, IMAGE_SIZE)
    batch_r = tf.image.convert_image_dtype(batch_r, tf.uint8)
    return batch_r.numpy()

# ======================================================================
# JPEG SAVE
# ======================================================================
def save_jpeg(arr, filename):
    Image.fromarray(arr).save(IMG_DIR / filename, format="JPEG", quality=95)

# ======================================================================
# CSV HEADER
# ======================================================================
with open(CSV_PATH, csv_mode, newline="") as f:
    writer = csv.writer(f)
    if not resume_mode:
        writer.writerow(["filename", "label", "split"])

# ======================================================================
# SKIP COMPLETED BATCHES
# ======================================================================
processed_images = len(existing_jpegs)
est_per_batch = int((H/train_size)*healthy_mult*BATCH_SIZE + (D/train_size)*BATCH_SIZE)
skip_batches = processed_images // max(1, est_per_batch)
print(f"‚è≠ Skipping first {skip_batches} batches")

# ======================================================================
# TRAIN LOOP WITH TIMING + A+B OPT
# ======================================================================
batch_idx = 0

with open(CSV_PATH, "a", newline="") as f:
    writer = csv.writer(f)

    for images, labels in train_raw.batch(BATCH_SIZE):

        if batch_idx < skip_batches:
            batch_idx += 1
            continue

        # START batch timer
        batch_start = time.time()

        images_np = images.numpy()
        labels_np = labels.numpy()

        # ================== PROFILING MODE ==================
        if PROFILE_MODE:
            t0 = time.time()

            h_idx = np.where(labels_np == 0)[0]
            d_idx = np.where(labels_np == 1)[0]

            t_split = time.time()

            # aug speed
            aug_times = []
            for img in images_np[h_idx][:20]:
                a0 = time.time()
                _ = aug_healthy(img)
                aug_times.append(time.time() - a0)

            # resize speed
            r0 = time.time()
            _ = resize_batch(images_np[:50])
            r1 = time.time()

            # jpeg save speed
            s0 = time.time()
            save_jpeg(images_np[0], "profiling_test.jpg")
            s1 = time.time()

            print("\nüîç ===== PROFILING THIS BATCH =====")
            print(f"Split time:       {(t_split - t0):.5f}s")
            print(f"Avg augment time: {np.mean(aug_times):.5f}s")
            print(f"Batch resize:     {(r1 - r0):.5f}s")
            print(f"JPEG save time:   {(s1 - s0):.5f}s")
            print("========================================\n")

            print("‚õî Profiling ON ‚Äî stopping after one batch.")
            break

        # ================== REAL PIPELINE ==================
        h_idx = np.where(labels_np == 0)[0]
        d_idx = np.where(labels_np == 1)[0]

        # ---- HEALTHY (replicated) ----
        for rep in range(healthy_mult):
            outs = []
            for img in images_np[h_idx]:
                out = aug_healthy(img) if rep > 0 else img
                outs.append(out)

            outs = resize_batch(np.stack(outs))

            for out in outs:
                fname = f"{filename_id:06d}.jpg"
                save_jpeg(out, fname)
                writer.writerow([fname, 0, "train"])
                filename_id += 1

        # ---- DISEASED ----
        outs = []
        for img in images_np[d_idx]:
            out = aug_diseased(img) if np.random.rand() > 0.5 else img
            outs.append(out)

        outs = resize_batch(np.stack(outs))

        for out in outs:
            fname = f"{filename_id:06d}.jpg"
            save_jpeg(out, fname)
            writer.writerow([fname, 1, "train"])
            filename_id += 1

        # END batch timer
        batch_end = time.time()
        elapsed = batch_end - batch_start
        now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        print(f"‚úÖ Batch {batch_idx} completed in {elapsed:.2f}s at {now}")

        batch_idx += 1

print("\nTRAIN DONE (profiling or resume).")


üîÅ RESUMING from image 045104.jpg (last was 045103.jpg)
Loading PlantVillage...


2025-11-28 15:29:16.259687: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3 Pro
2025-11-28 15:29:16.259709: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 18.00 GB
2025-11-28 15:29:16.259712: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 6.66 GB
2025-11-28 15:29:16.259727: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-11-28 15:29:16.259737: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2025-11-28 15:29:20.180028: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-11-28 15:29:27.669710: W

Train Healthy=12052, Diseased=31391, multiplier=3
‚è≠ Skipping first 38 batches
‚úÖ Batch 38 completed in 20.75s at 2025-11-28 15:29:52
‚úÖ Batch 39 completed in 35.41s at 2025-11-28 15:30:27
‚úÖ Batch 40 completed in 50.06s at 2025-11-28 15:31:17
‚úÖ Batch 41 completed in 67.66s at 2025-11-28 15:32:25
‚úÖ Batch 42 completed in 81.95s at 2025-11-28 15:33:47


In [None]:
import pandas as pd
df = pd.read_csv(CSV_PATH)

print("===== BASIC INFO =====")
print(df.head())
print(f"\nTotal images: {len(df)}")

# Unique splits
print("\nCounts per split:")
print(df['split'].value_counts())

# Unique labels
print("\nCounts per label:")
label_map = {0: "healthy", 1: "diseased"}
print(df['label'].map(label_map).value_counts())

# Split + label combination
print("\nCounts per split per label:")
print(df.groupby(['split', 'label']).size())

# Summary table
summary = df.groupby(['split', 'label']).size().unstack(fill_value=0)
summary.columns = ["healthy", "diseased"]
summary["total"] = summary.sum(axis=1)

print("\n===== SUMMARY TABLE =====")
print(summary)


===== BASIC INFO =====
     filename  label  split
0  000001.jpg      0  train
1  000002.jpg      0  train
2  000003.jpg      0  train
3  000004.jpg      0  train
4  000005.jpg      0  train

Total images: 23228

Counts per split:
split
train    23228
Name: count, dtype: int64

Counts per label:
label
healthy     12795
diseased    10433
Name: count, dtype: int64

Counts per split per label:
split  label
train  0        12795
       1        10433
dtype: int64

===== SUMMARY TABLE =====
       healthy  diseased  total
split                          
train    12795     10433  23228
