
# Preprocessing — Static Augmented Dataset Builder
Build a balanced, augmented PlantVillage binary dataset (0=healthy, 1=diseased) and save RGB PNGs + CSV labels for TensorFlow and PyTorch.


In [6]:
import os
import math
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds

# Reproducibility
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Resize target
IMAGE_SIZE = (224, 224)
DATASET_NAME = "plant_village"

# Directory structure
PROJECT_ROOT = Path(".").resolve()
PREPROCESSED_DIR = PROJECT_ROOT / "preprocessed"
PREPROCESSED_DIR.mkdir(parents=True, exist_ok=True)

for split in ["train", "val", "test"]:
    (PREPROCESSED_DIR / split).mkdir(parents=True, exist_ok=True)

AUTOTUNE = tf.data.AUTOTUNE

print("TensorFlow GPU:", tf.config.list_physical_devices("GPU"))


TensorFlow GPU: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


## 1) Load & binary labels

In [7]:
def load_and_split_plant_village(seed=SEED):
    data, info = tfds.load(DATASET_NAME, with_info=True, as_supervised=True)
    
    if len(data) == 1 and "train" in data:
        full = data["train"]
        total_size = sum(1 for _ in full)
        train_size = int(0.7 * total_size)
        val_size = int(0.15 * total_size)

        full = full.shuffle(total_size, seed=seed, reshuffle_each_iteration=False)
        train_ds = full.take(train_size)
        val_ds = full.skip(train_size).take(val_size)
        test_ds = full.skip(train_size + val_size)

        data = {"train": train_ds, "val": val_ds, "test": test_ds}

    return data, info


plant_data, info = load_and_split_plant_village()

print(info)
def make_binary_labels(dataset_splits, info):
    label_names = info.features["label"].names
    
    binary_lookup = np.array(
        [0 if name.split("___",1)[-1].lower()=="healthy" else 1
         for name in label_names],
        dtype=np.int32
    )
    lookup_tf = tf.constant(binary_lookup)

    def to_binary(image, label):
        return image, tf.gather(lookup_tf, tf.cast(label, tf.int32))

    binary = {
        split: ds.map(to_binary, num_parallel_calls=AUTOTUNE)
        for split, ds in dataset_splits.items()
    }

    # Count original healthy/diseased
    counts = defaultdict(lambda: {"healthy": 0, "diseased": 0})
    for _, label in tfds.as_numpy(dataset_splits["train"]):
        s = info.features["label"].int2str(int(label))
        plant, disease = s.split("___", 1)
        if disease.lower() == "healthy":
            counts[plant]["healthy"] += 1
        else:
            counts[plant]["diseased"] += 1

    total_healthy = sum(v["healthy"] for v in counts.values())
    total_diseased = sum(v["diseased"] for v in counts.values())

    print(f"ORIGINAL TRAIN — healthy={total_healthy}, diseased={total_diseased}")
    return binary, total_healthy, total_diseased


binary_data, total_healthy, total_diseased = make_binary_labels(plant_data, info)


2025-11-27 19:35:53.793500: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


tfds.core.DatasetInfo(
    name='plant_village',
    full_name='plant_village/1.0.2',
    description="""
    The PlantVillage dataset consists of 54303 healthy and unhealthy leaf images
    divided into 38 categories by species and disease.
    
    NOTE: The original dataset is not available from the original source
    (plantvillage.org), therefore we get the unaugmented dataset from a paper that
    used that dataset and republished it. Moreover, we dropped images with
    Background_without_leaves label, because these were not present in the original
    dataset.
    
    Original paper URL: https://arxiv.org/abs/1511.08060 Dataset URL:
    https://data.mendeley.com/datasets/tywbtsjrjv/1
    """,
    homepage='https://arxiv.org/abs/1511.08060',
    data_dir='/Users/pratyaksh/tensorflow_datasets/plant_village/1.0.2',
    file_format=tfrecord,
    download_size=827.82 MiB,
    dataset_size=815.37 MiB,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=ui

2025-11-27 19:36:02.579860: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [8]:
def augment_healthy(image):
    image = tf.image.convert_image_dtype(image, tf.float32)

    # spatial
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    k = tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)
    image = tf.image.rot90(image, k)

    # color jitter strong
    image = tf.image.random_saturation(image, 0.7, 1.4)
    image = tf.image.random_hue(image, 0.05)
    image = tf.image.random_brightness(image, 0.15)
    image = tf.image.random_contrast(image, 0.8, 1.3)

    image = tf.clip_by_value(image, 0.0, 1.0)
    return tf.image.convert_image_dtype(image, tf.uint8)


def augment_diseased(image):
    image = tf.image.convert_image_dtype(image, tf.float32)

    # spatial mild
    image = tf.image.random_flip_left_right(image)
    k = tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)
    image = tf.image.rot90(image, k)

    # color jitter mild
    image = tf.image.random_contrast(image, 0.9, 1.1)
    image = tf.image.random_brightness(image, 0.08)
    image = tf.image.random_hue(image, 0.03)

    image = tf.clip_by_value(image, 0.0, 1.0)
    return tf.image.convert_image_dtype(image, tf.uint8)


def augment_diseased_with_replacement(image, label):
    """50% keep original, 50% mildly augment."""
    do_aug = tf.random.uniform([]) > 0.5
    aug = augment_diseased(image)
    return tf.cond(do_aug, lambda: (aug, label), lambda: (image, label))


In [None]:
def build_balanced_train_dataset(train_ds, total_healthy, total_diseased):
    healthy_images = []
    diseased_images = []

    # Step 1: Extract images into arrays
    for img, lbl in tfds.as_numpy(train_ds):
        if lbl == 0:
            healthy_images.append(img)
        else:
            diseased_images.append(img)

    healthy_count = len(healthy_images)
    diseased_count = len(diseased_images)

    print(f"Loaded into memory: healthy={healthy_count}, diseased={diseased_count}")

    # Step 2: Augment healthy to match diseased count
    multiplier = math.ceil(diseased_count / healthy_count)
    print(f"Healthy multiplier: {multiplier}x")

    augmented_healthy = []
    for _ in range(multiplier):
        for img in healthy_images:
            augmented_healthy.append(augment_healthy(img))

    # Now clip to exactly diseased_count
    augmented_healthy = augmented_healthy[:diseased_count]

    # Step 3: Diseased (mild augmentation w/ replacement)
    augmented_diseased = []
    for img in diseased_images:
        aug, _ = augment_diseased_with_replacement(img, tf.constant(1, tf.int32))
        augmented_diseased.append(aug.numpy())

    assert len(augmented_healthy) == len(augmented_diseased)

    print(f"FINAL balanced count: {len(augmented_healthy)} healthy, {len(augmented_diseased)} diseased")

    # Combine
    final_images = augmented_healthy + augmented_diseased
    final_labels = [0] * len(augmented_healthy) + [1] * len(augmented_diseased)

    # Shuffle
    idx = np.arange(len(final_images))
    np.random.shuffle(idx)

    final_images = [final_images[i] for i in idx]
    final_labels = [final_labels[i] for i in idx]

    return final_images, final_labels


final_train_images, final_train_labels = build_balanced_train_dataset(
    binary_data["train"], total_healthy, total_diseased
)


2025-11-27 19:36:09.524777: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Loaded into memory: healthy=10646, diseased=27366
Healthy multiplier: 3x


In [None]:
def save_image_split(images, labels, split_name):
    split_dir = PREPROCESSED_DIR / split_name
    split_dir.mkdir(exist_ok=True)

    rows = []
    for i, (img, label) in enumerate(zip(images, labels)):
        fname = f"{split_name}_{i:06d}.png"
        path = split_dir / fname
        encoded = tf.io.encode_png(img)
        tf.io.write_file(str(path), encoded)
        rows.append({"filename": fname, "label": int(label)})

    df = pd.DataFrame(rows)
    df.to_csv(PREPROCESSED_DIR / f"{split_name}_labels.csv", index=False)

    print(f"Saved {len(rows)} → {split_name}")


# Save TRAIN
save_image_split(final_train_images, final_train_labels, "train")


# Save VAL and TEST directly (no augmentation)
def save_split_from_tfds(ds, split_name):
    rows = []
    split_dir = PREPROCESSED_DIR / split_name

    for i, (img, label) in enumerate(tfds.as_numpy(ds)):
        fname = f"{split_name}_{i:06d}.png"
        path = split_dir / fname
        encoded = tf.io.encode_png(img)
        tf.io.write_file(str(path), encoded)
        rows.append({"filename": fname, "label": int(label)})

    df = pd.DataFrame(rows)
    df.to_csv(PREPROCESSED_DIR / f"{split_name}_labels.csv", index=False)
    print(f"Saved {len(rows)} → {split_name}")


save_split_from_tfds(binary_data["val"], "val")
save_split_from_tfds(binary_data["test"], "test")


In [None]:
def summarize(pre_dir=PREPROCESSED_DIR):
    out = []
    for split in ["train", "val", "test"]:
        csv = pre_dir / f"{split}_labels.csv"
        df = pd.read_csv(csv)
        healthy = (df["label"] == 0).sum()
        diseased = (df["label"] == 1).sum()
        out.append({
            "split": split,
            "healthy": healthy,
            "diseased": diseased,
            "total": len(df)
        })
    return pd.DataFrame(out)

summary_df = summarize()
summary_df
