In [10]:
import os
import cv2
import numpy as np
from sklearn.cluster import KMeans
from tqdm import tqdm


def segment_image_kmeans(image_path, n_clusters=3):
    """
    Perform K-means clustering on the pixels of the image to segment it into n_clusters colors.
    Returns the segmented image.
    """
    # Read image
    img = cv2.imread(image_path)
    if img is None:
        return None
    # Convert to RGB
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # Reshape to (num_pixels, 3)
    pixels = img_rgb.reshape(-1, 3)
    # Apply KMeans
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    labels = kmeans.fit_predict(pixels)
    centers = np.uint8(kmeans.cluster_centers_)
    # Recreate segmented image
    segmented = centers[labels].reshape(img_rgb.shape)
    # Convert back to BGR for saving
    segmented_bgr = cv2.cvtColor(segmented, cv2.COLOR_RGB2BGR)
    return segmented_bgr


def process_dataset_kmeans(source_root, dest_root, n_clusters=3):
    """
    Process all images under source_root/train, /val, /test splits,
    perform K-means segmentation, and save segmented images under dest_root,
    preserving the split and class folder structure.
    """
    splits = ['train', 'val', 'test']
    for split in splits:
        src_split = os.path.join(source_root, split)
        dst_split = os.path.join(dest_root, split)
        if not os.path.isdir(src_split):
            continue
        # Iterate classes
        for class_name in os.listdir(src_split):
            src_class = os.path.join(src_split, class_name)
            dst_class = os.path.join(dst_split, class_name)
            if not os.path.isdir(src_class):
                continue
            os.makedirs(dst_class, exist_ok=True)
            # Process each image
            for img_file in tqdm(os.listdir(src_class), desc=f"{split}/{class_name}"):
                if not img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    continue
                src_img_path = os.path.join(src_class, img_file)
                segmented = segment_image_kmeans(src_img_path, n_clusters=n_clusters)
                if segmented is None:
                    continue
                # Save segmented image with same filename
                dst_img_path = os.path.join(dst_class, img_file)
                cv2.imwrite(dst_img_path, segmented)
    print("Segmentation complete. Segmented images saved under dest_root.")


if __name__ == "__main__":
    source_root = r"C:\Users\rakti\Downloads\a\Final-project\Datasets\npld_bg_removed"
    dest_root   = r"C:\Users\rakti\Downloads\a\Final-project\Datasets\npld_kmeans"
    process_dataset_kmeans(source_root, dest_root, n_clusters=3)

train/1 1 Apple healthy: 100%|██████████| 351/351 [00:09<00:00, 35.24it/s]
train/1 2 Apple Apple scab: 100%|██████████| 352/352 [00:10<00:00, 34.52it/s]
train/1 3 Apple Black rot: 100%|██████████| 347/347 [00:10<00:00, 31.76it/s]
train/1 4 Apple Cedar apple rust: 100%|██████████| 308/308 [00:08<00:00, 34.56it/s]
train/2 1 Tomato healthy: 100%|██████████| 336/336 [00:10<00:00, 32.84it/s]
train/2 2 Tomato Early blight: 100%|██████████| 336/336 [00:10<00:00, 30.74it/s]
train/2 3 Tomato Late blight: 100%|██████████| 324/324 [00:10<00:00, 30.94it/s]
train/2 4 Tomato Bacterial spot: 100%|██████████| 297/297 [00:09<00:00, 31.34it/s]
train/2 5 Tomato Leaf Mold: 100%|██████████| 329/329 [00:09<00:00, 33.30it/s]
train/2 6 Tomato Septoria leaf spot: 100%|██████████| 305/305 [00:10<00:00, 29.82it/s]
train/2 7 Tomato Spider mites Two-spotted spider mite: 100%|██████████| 304/304 [00:11<00:00, 27.58it/s]
train/2 8 Tomato Tomato mosaic virus: 100%|██████████| 313/313 [00:11<00:00, 26.43it/s]
train/2 

Segmentation complete. Segmented images saved under dest_root.





In [9]:
import os
import cv2
import numpy as np
from tqdm import tqdm


def remove_background_grabcut(image_path, iter_count=5):
    """
    Removes the background from an image using OpenCV's GrabCut algorithm.
    Returns an RGBA image (foreground intact, background transparent).
    """
    img = cv2.imread(image_path)
    if img is None:
        return None
    # Initialize mask, bgdModel, fgdModel
    mask = np.zeros(img.shape[:2], np.uint8)
    bgdModel = np.zeros((1, 65), np.float64)
    fgdModel = np.zeros((1, 65), np.float64)
    # Define initial rectangle (slightly inset)
    h, w = img.shape[:2]
    rect = (10, 10, w - 20, h - 20)
    # Apply GrabCut
    cv2.grabCut(img, mask, rect, bgdModel, fgdModel, iterCount=iter_count, mode=cv2.GC_INIT_WITH_RECT)
    # Prepare mask: 0,2 => background; 1,3 => foreground
    mask2 = np.where((mask == cv2.GC_BGD) | (mask == cv2.GC_PR_BGD), 0, 1).astype('uint8')
    # Extract foreground
    fg = img * mask2[:, :, np.newaxis]
    # Create transparent background
    b, g, r = cv2.split(fg)
    alpha = mask2 * 255
    rgba = cv2.merge([b, g, r, alpha])
    return rgba


def process_remove_background(source_root, dest_root):
    """
    Traverse the dataset directory structured as source_root/{train,val,test}/{class_name}/
    Remove background from each image and save as PNG with transparency in dest_root,
    preserving the same split and class folder structure.
    """
    splits = ['train', 'val', 'test']
    for split in splits:
        src_split = os.path.join(source_root, split)
        dst_split = os.path.join(dest_root, split)
        if not os.path.isdir(src_split):
            continue
        for class_name in os.listdir(src_split):
            src_class = os.path.join(src_split, class_name)
            dst_class = os.path.join(dst_split, class_name)
            if not os.path.isdir(src_class):
                continue
            os.makedirs(dst_class, exist_ok=True)
            for img_file in tqdm(os.listdir(src_class), desc=f"Removing bg {split}/{class_name}"):
                if not img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    continue
                src_img = os.path.join(src_class, img_file)
                rgba = remove_background_grabcut(src_img)
                if rgba is None:
                    continue
                # Save with same filename but as PNG to keep alpha channel
                name, _ = os.path.splitext(img_file)
                dst_img = os.path.join(dst_class, f"{name}.png")
                cv2.imwrite(dst_img, rgba)
    print("Background removal complete. Images saved under dest_root.")


if __name__ == '__main__':
    source_root = r"C:\Users\rakti\Downloads\a\Final-project\Datasets\New Plant Disease Dataset ttv"
    dest_root   = r"C:\Users\rakti\Downloads\a\Final-project\Datasets\npld_bg_removed"
    process_remove_background(source_root, dest_root)

Removing bg train/1 1 Apple healthy: 100%|██████████| 351/351 [01:20<00:00,  4.35it/s]
Removing bg train/1 2 Apple Apple scab: 100%|██████████| 352/352 [01:37<00:00,  3.61it/s]
Removing bg train/1 3 Apple Black rot: 100%|██████████| 347/347 [01:01<00:00,  5.60it/s]
Removing bg train/1 4 Apple Cedar apple rust: 100%|██████████| 308/308 [01:32<00:00,  3.35it/s]
Removing bg train/2 1 Tomato healthy: 100%|██████████| 336/336 [01:24<00:00,  3.99it/s]
Removing bg train/2 2 Tomato Early blight: 100%|██████████| 336/336 [01:36<00:00,  3.50it/s]
Removing bg train/2 3 Tomato Late blight: 100%|██████████| 324/324 [02:02<00:00,  2.64it/s]
Removing bg train/2 4 Tomato Bacterial spot: 100%|██████████| 297/297 [01:04<00:00,  4.59it/s]
Removing bg train/2 5 Tomato Leaf Mold: 100%|██████████| 329/329 [01:36<00:00,  3.40it/s]
Removing bg train/2 6 Tomato Septoria leaf spot: 100%|██████████| 305/305 [01:17<00:00,  3.92it/s]
Removing bg train/2 7 Tomato Spider mites Two-spotted spider mite: 100%|█████████

Background removal complete. Images saved under dest_root.



