In [None]:
import os
import cv2
import numpy as np
from tqdm import tqdm


def remove_background_grabcut(image_path, iter_count=5):
    img = cv2.imread(image_path)
    if img is None:
        return None
    # Initialize mask, bgdModel, fgdModel
    mask = np.zeros(img.shape[:2], np.uint8)
    bgdModel = np.zeros((1, 65), np.float64)
    fgdModel = np.zeros((1, 65), np.float64)
    # Define initial rectangle (slightly inset)
    h, w = img.shape[:2]
    rect = (10, 10, w - 20, h - 20)
    # Apply GrabCut
    cv2.grabCut(img, mask, rect, bgdModel, fgdModel, iterCount=iter_count, mode=cv2.GC_INIT_WITH_RECT)
    # Prepare mask: 0,2 => background; 1,3 => foreground
    mask2 = np.where((mask == cv2.GC_BGD) | (mask == cv2.GC_PR_BGD), 0, 1).astype('uint8')
    # Extract foreground
    fg = img * mask2[:, :, np.newaxis]
    # Create transparent background
    b, g, r = cv2.split(fg)
    alpha = mask2 * 255
    rgba = cv2.merge([b, g, r, alpha])
    return rgba


def process_remove_background(source_root, dest_root):
    splits = ['train', 'val', 'test']
    for split in splits:
        src_split = os.path.join(source_root, split)
        dst_split = os.path.join(dest_root, split)
        if not os.path.isdir(src_split):
            continue
        for class_name in os.listdir(src_split):
            src_class = os.path.join(src_split, class_name)
            dst_class = os.path.join(dst_split, class_name)
            if not os.path.isdir(src_class):
                continue
            os.makedirs(dst_class, exist_ok=True)
            for img_file in tqdm(os.listdir(src_class), desc=f"Removing bg {split}/{class_name}"):
                if not img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    continue
                src_img = os.path.join(src_class, img_file)
                rgba = remove_background_grabcut(src_img)
                if rgba is None:
                    continue
                # Save with same filename but as PNG to keep alpha channel
                name, _ = os.path.splitext(img_file)
                dst_img = os.path.join(dst_class, f"{name}.png")
                cv2.imwrite(dst_img, rgba)
    print("Background removal complete. Images saved under dest_root.")


if __name__ == '__main__':
    source_root = r"C:\Users\rakti\Downloads\a\Final-project\Datasets\New Plant Disease Dataset ttv"
    dest_root   = r"C:\Users\rakti\Downloads\a\Final-project\Datasets\npld_bg_removed"
    process_remove_background(source_root, dest_root)