In [None]:
import os
import cv2
import numpy as np
from sklearn.cluster import KMeans
from tqdm import tqdm


def segment_image_kmeans(image_path, n_clusters=3):
    # Read image
    img = cv2.imread(image_path)
    if img is None:
        return None
    # Convert to RGB
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # Reshape to (num_pixels, 3)
    pixels = img_rgb.reshape(-1, 3)
    # Apply KMeans
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    labels = kmeans.fit_predict(pixels)
    centers = np.uint8(kmeans.cluster_centers_)
    # Recreate segmented image
    segmented = centers[labels].reshape(img_rgb.shape)
    # Convert back to BGR for saving
    segmented_bgr = cv2.cvtColor(segmented, cv2.COLOR_RGB2BGR)
    return segmented_bgr


def process_dataset_kmeans(source_root, dest_root, n_clusters=3):
    splits = ['train', 'val', 'test']
    for split in splits:
        src_split = os.path.join(source_root, split)
        dst_split = os.path.join(dest_root, split)
        if not os.path.isdir(src_split):
            continue
        # Iterate classes
        for class_name in os.listdir(src_split):
            src_class = os.path.join(src_split, class_name)
            dst_class = os.path.join(dst_split, class_name)
            if not os.path.isdir(src_class):
                continue
            os.makedirs(dst_class, exist_ok=True)
            # Process each image
            for img_file in tqdm(os.listdir(src_class), desc=f"{split}/{class_name}"):
                if not img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    continue
                src_img_path = os.path.join(src_class, img_file)
                segmented = segment_image_kmeans(src_img_path, n_clusters=n_clusters)
                if segmented is None:
                    continue
                # Save segmented image with same filename
                dst_img_path = os.path.join(dst_class, img_file)
                cv2.imwrite(dst_img_path, segmented)
    print("Segmentation complete. Segmented images saved under dest_root.")


if __name__ == "__main__":
    source_root = r"C:\Users\rakti\Downloads\a\Final-project\Datasets\npld_bg_removed"
    dest_root   = r"C:\Users\rakti\Downloads\a\Final-project\Datasets\npld_kmeans"
    process_dataset_kmeans(source_root, dest_root, n_clusters=3)