**WEED EXTRACTION USING SAM AND COLOR FILTERING**

This script is used to extract **individual weed instances** from images generated by the Stable Diffusion pipeline, specifically for weed augmentation.

Unlike crops—which are augmented via inpainting—weeds are extracted as standalone cut-out objects and later pasted into new scenes.

The following block uses Segment Anything Model (SAM) to automatically segment objects in a synthetic weed image and generate a transparent PNG cut-out of the most likely weed instance.

Key steps:
- Loads a synthetic image containing weeds.
- Uses SAM to generate object masks.
- Filters masks by analyzing the average color of each region, selecting those with a green-dominant profile (likely to be plants).
- Among the green masks, selects the largest one as the most reliable weed region.
- Converts the selected mask into an alpha channel and creates a final RGBA PNG with transparency.

In [None]:
import numpy as np
import cv2
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device)
mask_generator = SamAutomaticMaskGenerator(sam)

image_path = "weed.jpg"
image = cv2.imread(image_path)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Masks generation
masks = mask_generator.generate(image_rgb)
print(f"Founded {len(masks)} masks")

# Filter masks to find green plant masks
plant_masks = []
for m in masks:
    seg = m['segmentation']
    area = m['area']

    # Mean color of the mask
    masked_pixels = image_rgb[seg]
    avg_color = masked_pixels.mean(axis=0)

    # Check if the mask is predominantly green
    if avg_color[1] > 70 and avg_color[1] > avg_color[0] * 0.9 and avg_color[1] > avg_color[2] * 0.9:
        plant_masks.append(m)

# Use the largest mask between the green ones
plant_mask = max(plant_masks, key=lambda x: x['area'])['segmentation']
plant_mask = np.logical_not(plant_mask)
alpha = plant_mask.astype(np.uint8) * 255
bgr_masked = cv2.bitwise_and(image, image, mask=plant_mask.astype(np.uint8))
b, g, r = cv2.split(bgr_masked)
plant_png = cv2.merge((b, g, r, alpha))
cv2.imwrite("pianta_ritagliata_sam.png", plant_png)
plant_mask = max(plant_masks, key=lambda x: x['area'])['segmentation']
plant_mask = plant_mask.astype(np.uint8) * 255

The following block processes weed images (in RGBA format) and extracts their binary segmentation masks from the alpha channel.

Key steps:
- Loads each PNG image from segmented_weeds/, expecting a 4-channel RGBA image.
- Extracts the alpha channel and converts it into a binary mask (0 = background, 255 = foreground).
- Saves the binary mask in segmented_weeds_masks/.

This is useful for semantic segmentation training and controlled paste-augmentation workflows.

In [None]:
import os

# Directories
directory_weeds = "segmented_weeds"
output_directory = "segmented_weeds_masks"
os.makedirs(output_directory, exist_ok=True)

def get_weed_and_mask(weed_path):
    # Load the weed image with alpha channel
    weed = cv2.imread(weed_path, cv2.IMREAD_UNCHANGED)
    if weed.shape[2] != 4:
        raise ValueError(f"{weed_path} does not have an alpha channel!")

    # Channel separation
    b, g, r, a = cv2.split(weed)
    mask = (a > 0).astype(np.uint8) * 255  # For saving as PNG (0 or 255)

    return weed, mask

# Process each weed image
for weed_file in os.listdir(directory_weeds):
    weed_path = os.path.join(directory_weeds, weed_file)
    try:
        weed, mask = get_weed_and_mask(weed_path)
        print(f"✅ Processed {weed_file} — shape: {weed.shape}, mask pixels: {np.sum(mask > 0)}")

        # Save mask
        base_name = os.path.splitext(weed_file)[0]
        cv2.imwrite(os.path.join(output_directory, f"{base_name}_mask.png"), mask)

    except ValueError as e:
        print(f"⚠️ {e}")