In [None]:
import rasterio
import numpy as np
import cv2
from rasterio.features import shapes, rasterize
from shapely.geometry import shape
from scipy import ndimage
import matplotlib.pyplot as plt
from pathlib import Path

# --- Configuration ---
mask_path = "masks.tif"
clean_mask_path = "masks_clean.tif"
boundaries_path = "masks_boundaries.tif"

# Fine-tuned separation parameters for thin boundaries
EROSION_KERNEL_SIZE = 3  
EROSION_ITERATIONS = 2     
DILATION_ITERATIONS = 1      
MIN_FIELD_AREA = 100         
EDGE_THRESHOLD_LOW = 20     
EDGE_THRESHOLD_HIGH = 80    

# Advanced separation settings
USE_WATERSHED = True         
WATERSHED_THRESHOLD = 0.3    

# Boundary visibility settings
BOUNDARY_THICKNESS = 4      
CREATE_BOUNDARY_FILE = True  # Save boundaries as separate file

def detect_thin_boundaries(mask):
    """
    Detect thin boundaries between touching fields
    
    Parameters:
    -----------
    mask : numpy.ndarray
        Binary mask of fields
    
    Returns:
    --------
    numpy.ndarray : Binary mask of thin boundaries
    """
    # Convert to uint8
    mask_uint8 = (mask * 255).astype(np.uint8)
    
    # Apply Gaussian blur to reduce noise
    blurred = cv2.GaussianBlur(mask_uint8, (3, 3), 0)
    
    # Detect edges with Canny (thin edges)
    edges = cv2.Canny(blurred, EDGE_THRESHOLD_LOW, EDGE_THRESHOLD_HIGH)
    
    # Thin the edges further using morphological thinning
    kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
    thinned = cv2.morphologyEx(edges, cv2.MORPH_GRADIENT, kernel)
    
    return thinned > 0

def fine_morphological_separation(mask):
    """
    Separate touching fields with minimal boundary thickness
    
    Parameters:
    -----------
    mask : numpy.ndarray
        Binary mask with touching fields
    
    Returns:
    --------
    numpy.ndarray : Mask with separated fields (thin boundaries)
    """
    # Create small structuring element for fine control
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, 
                                       (EROSION_KERNEL_SIZE, EROSION_KERNEL_SIZE))
    
    # Very light erosion to just separate touching regions
    eroded = cv2.erode(mask.astype(np.uint8), kernel, iterations=EROSION_ITERATIONS)
    
    # Label connected components after erosion
    labeled, num_features = ndimage.label(eroded)
    
    # Create output mask
    separated_mask = np.zeros_like(mask, dtype=np.uint8)
    
    # For each separated component, dilate back slightly
    for label_id in range(1, num_features + 1):
        component = (labeled == label_id).astype(np.uint8)
        
        # Dilate to restore size but stop at original boundaries
        dilated = cv2.dilate(component, kernel, iterations=DILATION_ITERATIONS)
        
        # Ensure we don't exceed original mask boundaries
        dilated = np.minimum(dilated, mask)
        
        separated_mask = np.maximum(separated_mask, dilated)
    
    return separated_mask

def skeleton_based_separation(mask):
    """
    Use morphological skeleton to create thin separation lines
    
    Parameters:
    -----------
    mask : numpy.ndarray
        Binary mask with touching fields
    
    Returns:
    --------
    numpy.ndarray : Mask with separated fields
    """
    # Distance transform to find centers of fields
    dist = cv2.distanceTransform(mask.astype(np.uint8), cv2.DIST_L2, 3)
    
    # Normalize distance
    dist_norm = cv2.normalize(dist, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    
    # Find local maxima (field centers) with configurable threshold
    # Lower threshold = more aggressive separation
    _, peaks = cv2.threshold(dist_norm, WATERSHED_THRESHOLD * dist_norm.max(), 255, cv2.THRESH_BINARY)
    peaks = peaks.astype(np.uint8)
    
    # Label the peaks
    _, markers = cv2.connectedComponents(peaks)
    
    # Small dilation of markers
    kernel = np.ones((2, 2), np.uint8)
    markers = cv2.dilate(markers.astype(np.uint8), kernel, iterations=1)
    markers = markers.astype(np.int32)
    
    # Mark background
    markers[mask == 0] = 0
    
    # Apply watershed
    mask_3ch = cv2.cvtColor(mask.astype(np.uint8) * 255, cv2.COLOR_GRAY2BGR)
    markers = cv2.watershed(mask_3ch, markers)
    
    # Create separated mask (watershed boundaries are -1)
    # Keep -1 as 0 (boundary), everything else as 1 (field)
    separated = (markers > 0).astype(np.uint8)
    
    return separated

def adaptive_separation(mask):
    """
    Adaptive approach that creates minimal boundary thickness
    
    Parameters:
    -----------
    mask : numpy.ndarray
        Binary mask with touching fields
    
    Returns:
    --------
    numpy.ndarray : Mask with thin boundaries
    """
    if USE_WATERSHED:
        # Use watershed method (better for complex touching patterns)
        print(f"  → Using watershed separation (threshold={WATERSHED_THRESHOLD})")
        separated = skeleton_based_separation(mask)
    else:
        # Use morphological method (thinner boundaries)
        print(f"  → Using morphological separation (kernel={EROSION_KERNEL_SIZE}, iter={EROSION_ITERATIONS})")
        separated = fine_morphological_separation(mask)
    
    return separated

def remove_small_fields(mask, min_area):
    """
    Remove small disconnected regions
    
    Parameters:
    -----------
    mask : numpy.ndarray
        Binary mask
    min_area : int
        Minimum area in pixels
    
    Returns:
    --------
    numpy.ndarray : Filtered mask
    """
    labeled, num_features = ndimage.label(mask)
    cleaned_mask = np.zeros_like(mask)
    
    removed_count = 0
    for label_id in range(1, num_features + 1):
        component = (labeled == label_id)
        area = np.sum(component)
        if area >= min_area:
            cleaned_mask[component] = 1
        else:
            removed_count += 1
    
    if removed_count > 0:
        print(f"  → Removed {removed_count} small fields (< {min_area} pixels)")
    
    return cleaned_mask

def count_fields(mask):
    """
    Count number of separate fields
    
    Parameters:
    -----------
    mask : numpy.ndarray
        Binary mask
    
    Returns:
    --------
    int : Number of separate fields
    """
    labeled, num_features = ndimage.label(mask)
    return num_features

def extract_field_polygons(mask, transform):
    """
    Extract individual field polygons from separated mask
    
    Parameters:
    -----------
    mask : numpy.ndarray
        Binary mask with separated fields
    transform : rasterio.Affine
        Geospatial transform
    
    Returns:
    --------
    list : List of (polygon, field_id, area) tuples
    """
    # Label each field
    labeled, num_fields = ndimage.label(mask)
    
    polygons = []
    for field_id in range(1, num_fields + 1):
        field_mask = (labeled == field_id).astype(np.uint8)
        field_area = np.sum(field_mask)
        
        # Extract polygon for this field
        for geom, val in shapes(field_mask, mask=field_mask, transform=transform):
            if val > 0:
                poly = shape(geom)
                polygons.append((poly, field_id, field_area))
    
    return polygons

def extract_visible_boundaries(separated_mask, thickness=3):
    """
    Extract boundaries between fields and make them visible
    
    Parameters:
    -----------
    separated_mask : numpy.ndarray
        Binary mask with separated fields
    thickness : int
        Thickness of boundary lines in pixels
    
    Returns:
    --------
    numpy.ndarray : Boundary mask with visible lines
    """
    # Label each field
    labeled, num_fields = ndimage.label(separated_mask)
    
    # Find edges of each field
    boundaries = np.zeros_like(separated_mask, dtype=np.uint8)
    
    # Create kernel for boundary detection
    kernel = np.ones((3, 3), np.uint8)
    
    for field_id in range(1, num_fields + 1):
        field_mask = (labeled == field_id).astype(np.uint8)
        
        # Find edges using morphological gradient
        edges = cv2.morphologyEx(field_mask, cv2.MORPH_GRADIENT, kernel)
        
        # Add to boundaries
        boundaries = np.maximum(boundaries, edges)
    
    # Make boundaries thicker and more visible
    if thickness > 1:
        thick_kernel = np.ones((thickness, thickness), np.uint8)
        boundaries = cv2.dilate(boundaries, thick_kernel, iterations=1)
    
    return boundaries

def create_boundary_polygons(separated_mask, transform):
    """
    Create polygon lines for field boundaries
    
    Parameters:
    -----------
    separated_mask : numpy.ndarray
        Binary mask with separated fields
    transform : rasterio.Affine
        Geospatial transform
    
    Returns:
    --------
    list : List of boundary line geometries
    """
    from shapely.geometry import LineString
    from shapely.ops import unary_union
    
    # Label fields
    labeled, num_fields = ndimage.label(separated_mask)
    
    # Extract all field polygons
    field_polygons = []
    for field_id in range(1, num_fields + 1):
        field_mask = (labeled == field_id).astype(np.uint8)
        
        for geom, val in shapes(field_mask, mask=field_mask, transform=transform):
            if val > 0:
                poly = shape(geom)
                field_polygons.append(poly)
    
    # Extract boundary lines from polygon edges
    boundary_lines = []
    for poly in field_polygons:
        # Get exterior boundary
        boundary_lines.append(LineString(poly.exterior.coords))
        
        # Get interior boundaries (holes)
        for interior in poly.interiors:
            boundary_lines.append(LineString(interior.coords))
    
    return boundary_lines, field_polygons

def visualize_with_boundaries(original, separated, boundaries, save_path=None):
    """
    Visualize fields with clear visible boundaries
    
    Parameters:
    -----------
    original : numpy.ndarray
        Original combined mask
    separated : numpy.ndarray
        Separated fields mask
    boundaries : numpy.ndarray
        Extracted boundary lines
    save_path : str, optional
        Path to save figure
    """
    # Create labeled versions
    labeled_sep, _ = ndimage.label(separated)
    
    fig, axs = plt.subplots(2, 2, figsize=(16, 16))
    
    # 1. Original mask
    num_original = count_fields(original)
    axs[0, 0].imshow(original, cmap="gray", interpolation='nearest')
    axs[0, 0].set_title(f"Original Mask\n({num_original} fields detected)", 
                        fontsize=14, fontweight='bold')
    axs[0, 0].axis("off")
    
    # 2. Separated fields with boundaries overlay
    num_separated = count_fields(separated)
    # Create RGB image
    separated_rgb = np.stack([separated, separated, separated], axis=-1).astype(float)
    # Overlay boundaries in red
    separated_rgb[boundaries > 0] = [1.0, 0, 0]
    
    axs[0, 1].imshow(separated_rgb, interpolation='nearest')
    axs[0, 1].set_title(f"Separated Fields with Boundaries\n({num_separated} fields, boundaries in RED)", 
                        fontsize=14, fontweight='bold')
    axs[0, 1].axis("off")
    
    # 3. Color-coded fields
    axs[1, 0].imshow(labeled_sep, cmap="nipy_spectral", interpolation='nearest')
    axs[1, 0].set_title(f"Color-Coded Fields\n(Each color = unique field)", 
                        fontsize=14, fontweight='bold')
    axs[1, 0].axis("off")
    
    # 4. Boundaries only (black background, white lines)
    boundary_display = np.zeros_like(separated, dtype=np.uint8)
    boundary_display[boundaries > 0] = 255
    axs[1, 1].imshow(boundary_display, cmap="gray", interpolation='nearest')
    axs[1, 1].set_title(f"Field Boundaries Only\n({BOUNDARY_THICKNESS} pixel thickness)", 
                        fontsize=14, fontweight='bold')
    axs[1, 1].axis("off")
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Visualization saved to: {save_path}")
    
    plt.show()

def main():
    """Main processing pipeline"""
    
    # Check input file
    if not Path(mask_path).exists():
        raise FileNotFoundError(f"Mask file not found: {mask_path}")
    
    print(f"Reading mask from: {mask_path}")
    
    # Read mask
    with rasterio.open(mask_path) as src:
        mask_data = src.read()
        profile = src.profile.copy()
        transform = src.transform
        
        print(f"Mask shape: {mask_data.shape}")
    
    # Combine layers if multi-band
    if mask_data.ndim == 3:
        mask_combined = np.max(mask_data, axis=0) > 0
    else:
        mask_combined = mask_data > 0
    
    mask_combined = mask_combined.astype(np.uint8)
    
    original_fields = count_fields(mask_combined)
    print(f"\nOriginal fields detected: {original_fields}")
    print(f"Original pixels: {np.sum(mask_combined)}")
    
    # Apply adaptive separation
    print("\nApplying thin-boundary separation...")
    separated_mask = adaptive_separation(mask_combined)
    
    # Remove small artifacts
    print("\nCleaning small artifacts...")
    separated_mask = remove_small_fields(separated_mask, MIN_FIELD_AREA)
    
    final_fields = count_fields(separated_mask)
    print(f"\nFinal fields detected: {final_fields}")
    print(f"Final pixels: {np.sum(separated_mask)}")
    print(f"Boundary pixels removed: {np.sum(mask_combined) - np.sum(separated_mask)}")
    
    # Extract polygons for each field
    print("\nExtracting field polygons and boundaries...")
    field_polygons = extract_field_polygons(separated_mask, transform)
    print(f"Total polygons extracted: {len(field_polygons)}")
    
    # Extract visible boundaries
    print(f"\nExtracting visible boundaries (thickness: {BOUNDARY_THICKNESS} pixels)...")
    boundaries = extract_visible_boundaries(separated_mask, thickness=BOUNDARY_THICKNESS)
    boundary_pixels = np.sum(boundaries > 0)
    print(f"Boundary pixels: {boundary_pixels}")
    
    # Extract boundary lines and polygons
    boundary_lines, field_poly_list = create_boundary_polygons(separated_mask, transform)
    print(f"Boundary lines extracted: {len(boundary_lines)}")
    
    if field_polygons:
        areas = [area for _, _, area in field_polygons]
        print(f"\nField Statistics:")
        print(f"  Field size range: {min(areas)} - {max(areas)} pixels")
        print(f"  Average field size: {np.mean(areas):.1f} pixels")
        print(f"  Total fields: {len(areas)}")
    
    # Update profile
    profile.update({
        'count': 1,
        'dtype': 'uint8',
        'compress': 'lzw',
        'nodata': 0
    })
    
    # Save separated mask
    print(f"\nSaving separated mask to: {clean_mask_path}")
    with rasterio.open(clean_mask_path, "w", **profile) as dst:
        dst.write(separated_mask, 1)
    
    print(f"✓ Separated mask saved successfully!")
    
    # Save boundaries as separate file
    if CREATE_BOUNDARY_FILE:
        print(f"\nSaving boundaries to: {boundaries_path}")
        with rasterio.open(boundaries_path, "w", **profile) as dst:
            dst.write(boundaries, 1)
        print(f"✓ Boundaries saved successfully!")
    
    # Visualize results
    print("\nGenerating visualization...")
    visualize_with_boundaries(
        mask_combined,
        separated_mask,
        boundaries,
        save_path="field_separation_comparison.png"
    )
    
    return separated_mask, boundaries, field_polygons, boundary_lines

if __name__ == "__main__":
    try:
        separated_mask, boundaries, field_polygons, boundary_lines = main()
        
        print("\nProcessing Complete!")
        print(f"\nOutput Files Created:")
        print(f"  1. {clean_mask_path} - Separated field masks")
        if CREATE_BOUNDARY_FILE:
            print(f"  2. {boundaries_path} - Field boundaries ({BOUNDARY_THICKNESS}px thick)")
        print(f"  3. field_separation_comparison.png - Visual comparison")
        
    except Exception as e:
        print(f"\nError: {e}")
        import traceback
        traceback.print_exc()

In [None]:
import rasterio
from rasterio.features import shapes
import numpy as np
from shapely.geometry import shape
import geopandas as gpd
import matplotlib.pyplot as plt
from rasterio.plot import show

# --- Paths ---
mask_path = "masks_boundaries.tif"
original_image_path = "tiff_testing/tile_28672_16384.tif"

# --- Read mask ---
with rasterio.open(mask_path) as src:
    mask_data = src.read(1)  # read single layer
    transform = src.transform
    crs = src.crs

# --- Convert mask to polygons ---
mask_binary = (mask_data > 0).astype(np.uint8)
polygons = [shape(geom) for geom, val in shapes(mask_binary, mask=mask_binary, transform=transform)]

# --- Create GeoDataFrame ---
gdf = gpd.GeoDataFrame(geometry=polygons, crs=crs)

# --- Plot mask boundaries over original image ---
with rasterio.open(original_image_path) as src_img:
    fig, ax = plt.subplots(figsize=(10,10))
    show(src_img, ax=ax)  # show the original raster
    gdf.boundary.plot(ax=ax, edgecolor='red', linewidth=1)  # overlay polygons
    plt.title("Mask Boundaries Overlay")
    plt.axis("off")
    plt.show()


In [None]:
import rasterio
from rasterio.features import shapes
import numpy as np
from shapely.geometry import shape
import geopandas as gpd
import matplotlib.pyplot as plt

# --- Paths ---
mask_path = "masks_clean.tif"
original_image_path = "tiff_testing/tile_28672_16384.tif"

# --- Read cleaned mask ---
with rasterio.open(mask_path) as src:
    mask_data = src.read(1)
    transform = src.transform
    crs = src.crs

# --- Extract all polygons ---
mask_binary = (mask_data > 0).astype(np.uint8)
polygons = [shape(geom) for geom, val in shapes(mask_binary, mask=mask_binary, transform=transform)]

n_polygons = len(polygons)
print(f"Total polygons found: {n_polygons}")

# --- Read original image ---
with rasterio.open(original_image_path) as src_img:
    img_data = src_img.read()
    img_transform = src_img.transform
    img_extent = [img_transform[2], img_transform[2] + img_transform[0] * src_img.width,
                  img_transform[5] + img_transform[4] * src_img.height, img_transform[5]]

# --- Prepare image ---
if img_data.shape[0] >= 3:
    rgb = np.transpose(img_data[:3], (1, 2, 0))
    if rgb.max() > 255:
        rgb = (rgb / rgb.max() * 255).astype(np.uint8)
    display_img = rgb
    cmap_to_use = None
else:
    display_img = img_data[0]
    cmap_to_use = 'gray'

# --- Create grid of subplots ---
n_cols = 4
n_rows = int(np.ceil(n_polygons / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 4*n_rows))
axes = axes.flatten() if n_polygons > 1 else [axes]

for idx, polygon in enumerate(polygons):
    ax = axes[idx]
    
    # Show original image
    if cmap_to_use:
        ax.imshow(display_img, cmap=cmap_to_use, extent=img_extent)
    else:
        ax.imshow(display_img, extent=img_extent)
    
    # Plot this polygon
    gdf_single = gpd.GeoDataFrame(geometry=[polygon], crs=crs)
    gdf_single.boundary.plot(ax=ax, edgecolor='red', linewidth=1.5)
    
    ax.set_title(f"Polygon {idx+1}", fontsize=10)
    ax.axis("off")

# Hide unused subplots
for i in range(n_polygons, len(axes)):
    axes[i].axis("off")

plt.tight_layout()
plt.show()

In [31]:
## New Way

In [None]:
import rasterio
import numpy as np
import cv2
from rasterio.features import shapes
from shapely.geometry import shape, mapping
from scipy import ndimage
import matplotlib.pyplot as plt
from pathlib import Path
import json

# --- Configuration ---
mask_path = "masks.tif"
clean_mask_path = "masks_clean.tif"
boundaries_path = "masks_boundaries.tif"
polygons_geojson = "polygons.geojson"

MIN_FIELD_AREA = 100         
BOUNDARY_THICKNESS = 4      
CREATE_BOUNDARY_FILE = True
SAVE_GEOJSON = True

def load_sam_mask(mask_path):
    """
    Load SAM mask and convert to instance mask
    SAM can output masks as:
    1. Multi-band (each band = one object)
    2. Single band with unique IDs per object
    """
    with rasterio.open(mask_path) as src:
        mask_data = src.read()
        profile = src.profile.copy()
        transform = src.transform
        crs = src.crs
        
    print(f"\n{'='*60}")
    print(f"MASK ANALYSIS")
    print(f"{'='*60}")
    print(f"Shape: {mask_data.shape}")
    print(f"Dtype: {mask_data.dtype}")
    print(f"Bands: {mask_data.shape[0] if mask_data.ndim == 3 else 1}")
    
    # Handle multi-band masks (SAM format: each band = one object)
    if mask_data.ndim == 3 and mask_data.shape[0] > 1:
        print(f"\n✓ Multi-band mask detected: {mask_data.shape[0]} bands")
        print(f"→ Each band likely represents ONE separate object")
        
        # Convert bands to instance mask
        instance_mask = np.zeros(mask_data.shape[1:], dtype=np.uint16)
        
        print(f"\nAnalyzing bands:")
        valid_bands = 0
        for band_idx in range(mask_data.shape[0]):
            band = mask_data[band_idx]
            nonzero = np.count_nonzero(band > 0)
            
            if nonzero > 0:
                valid_bands += 1
                # Assign unique ID to this band's pixels
                band_mask = band > 0
                # Avoid overlaps (shouldn't happen with SAM but just in case)
                instance_mask[band_mask & (instance_mask == 0)] = band_idx + 1
                
                if band_idx < 10:  # Show first 10
                    print(f"  Band {band_idx+1}: {nonzero:,} pixels")
        
        if mask_data.shape[0] > 10:
            print(f"  ... and {mask_data.shape[0] - 10} more bands")
        
        print(f"\n✓ Converted {valid_bands} bands to instance mask")
        
    else:
        # Single band - check if it has unique IDs
        instance_mask = mask_data.squeeze()
        unique_vals = np.unique(instance_mask)
        unique_vals = unique_vals[unique_vals > 0]
        
        if len(unique_vals) > 1:
            print(f"\n✓ Single-band instance mask detected")
            print(f"→ {len(unique_vals)} unique object IDs found")
        else:
            print(f"\n⚠ Binary mask detected (all objects = 1)")
            print(f"→ Will separate using connected components")
            # Label connected components
            instance_mask, num = ndimage.label(instance_mask > 0)
            print(f"→ Found {num} connected regions")
    
    # Summary
    num_instances = len(np.unique(instance_mask)) - 1
    total_pixels = np.sum(instance_mask > 0)
    
    print(f"\n{'='*60}")
    print(f"RESULT: {num_instances} objects, {total_pixels:,} pixels")
    print(f"{'='*60}")
    
    return instance_mask, profile, transform, crs

def remove_small_objects(instance_mask, min_area):
    """Remove small objects and reassign IDs"""
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    cleaned_mask = np.zeros_like(instance_mask, dtype=np.uint16)
    new_id = 1
    removed = 0
    
    for inst_id in unique_ids:
        obj_mask = (instance_mask == inst_id)
        area = np.sum(obj_mask)
        
        if area >= min_area:
            cleaned_mask[obj_mask] = new_id
            new_id += 1
        else:
            removed += 1
    
    kept = new_id - 1
    
    print(f"\nCleaning small objects:")
    print(f"  Removed: {removed} objects (< {min_area} px)")
    print(f"  Kept: {kept} objects")
    
    return cleaned_mask

def extract_polygons_per_object(instance_mask, transform):
    """Extract ONE polygon per object"""
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    polygons = []
    
    print(f"\nExtracting polygons:")
    
    for inst_id in unique_ids:
        # Create binary mask for this object
        obj_mask = (instance_mask == inst_id).astype(np.uint8)
        area = np.sum(obj_mask)
        
        # Extract polygon
        for geom, val in shapes(obj_mask, mask=obj_mask, transform=transform):
            if val > 0:
                poly = shape(geom)
                polygons.append({
                    'polygon': poly,
                    'id': int(inst_id),
                    'area': int(area)
                })
                break
    
    print(f"  ✓ Extracted {len(polygons)} polygons")
    
    # Show stats
    if polygons:
        areas = [p['area'] for p in polygons]
        print(f"\nPolygon Statistics:")
        print(f"  Count: {len(polygons)}")
        print(f"  Area range: {min(areas):,} - {max(areas):,} pixels")
        print(f"  Average: {np.mean(areas):,.0f} pixels")
    
    return polygons

def save_geojson(polygons, output_path, crs):
    """Save polygons to GeoJSON"""
    features = []
    
    for poly_data in polygons:
        feature = {
            "type": "Feature",
            "properties": {
                "id": poly_data['id'],
                "area_pixels": poly_data['area']
            },
            "geometry": mapping(poly_data['polygon'])
        }
        features.append(feature)
    
    geojson = {
        "type": "FeatureCollection",
        "crs": {
            "type": "name",
            "properties": {"name": str(crs) if crs else "EPSG:4326"}
        },
        "features": features
    }
    
    with open(output_path, 'w') as f:
        json.dump(geojson, f, indent=2)
    
    print(f"\n✓ GeoJSON saved: {output_path}")
    print(f"  → {len(features)} polygons")

def extract_boundaries(instance_mask, thickness=4):
    """Extract boundary for each object"""
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    boundaries = np.zeros_like(instance_mask, dtype=np.uint8)
    kernel = np.ones((3, 3), np.uint8)
    
    for inst_id in unique_ids:
        obj_mask = (instance_mask == inst_id).astype(np.uint8)
        edges = cv2.morphologyEx(obj_mask, cv2.MORPH_GRADIENT, kernel)
        boundaries = np.maximum(boundaries, edges)
    
    # Thicken boundaries
    if thickness > 1:
        thick_kernel = np.ones((thickness, thickness), np.uint8)
        boundaries = cv2.dilate(boundaries, thick_kernel, iterations=1)
    
    return boundaries

def visualize_results(instance_mask, boundaries, polygons, save_path):
    """Create visualization"""
    binary_mask = (instance_mask > 0).astype(np.uint8)
    num_objects = len(np.unique(instance_mask)) - 1
    
    fig, axs = plt.subplots(2, 2, figsize=(18, 18))
    
    # 1. Instance mask (color-coded)
    axs[0, 0].imshow(instance_mask, cmap="nipy_spectral", interpolation='nearest')
    axs[0, 0].set_title(f"Instance Mask\n({num_objects} objects, each color = unique object)", 
                        fontsize=14, fontweight='bold')
    axs[0, 0].axis("off")
    
    # 2. With boundaries overlay
    rgb = np.stack([binary_mask, binary_mask, binary_mask], axis=-1).astype(float)
    rgb[boundaries > 0] = [1.0, 0, 0]
    
    axs[0, 1].imshow(rgb, interpolation='nearest')
    axs[0, 1].set_title(f"Objects with Boundaries\n({len(polygons)} polygons, boundaries in RED)", 
                        fontsize=14, fontweight='bold', color='darkgreen')
    axs[0, 1].axis("off")
    
    # 3. Binary mask
    axs[1, 0].imshow(binary_mask, cmap="gray", interpolation='nearest')
    axs[1, 0].set_title(f"Binary Mask\n(All objects combined)", 
                        fontsize=14, fontweight='bold')
    axs[1, 0].axis("off")
    
    # 4. Boundaries only
    boundary_display = np.zeros_like(instance_mask, dtype=np.uint8)
    boundary_display[boundaries > 0] = 255
    
    axs[1, 1].imshow(boundary_display, cmap="gray", interpolation='nearest')
    axs[1, 1].set_title(f"Boundaries Only\n({BOUNDARY_THICKNESS}px thickness)", 
                        fontsize=14, fontweight='bold')
    axs[1, 1].axis("off")
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\n✓ Visualization saved: {save_path}")
    plt.show()

def main():
    """Main pipeline"""
    
    if not Path(mask_path).exists():
        raise FileNotFoundError(f"Mask file not found: {mask_path}")
    
    print(f"\n{'='*60}")
    print(f"SAM MASK → INDIVIDUAL POLYGONS")
    print(f"{'='*60}")
    print(f"\nInput: {mask_path}")
    
    # Load SAM mask
    instance_mask, profile, transform, crs = load_sam_mask(mask_path)
    
    # Remove small objects
    if MIN_FIELD_AREA > 0:
        instance_mask = remove_small_objects(instance_mask, MIN_FIELD_AREA)
    
    # Extract polygons (one per object)
    polygons = extract_polygons_per_object(instance_mask, transform)
    
    # Extract boundaries
    print(f"\nExtracting boundaries ({BOUNDARY_THICKNESS}px)...")
    boundaries = extract_boundaries(instance_mask, BOUNDARY_THICKNESS)
    print(f"  ✓ Boundary pixels: {np.sum(boundaries > 0):,}")
    
    # Save instance mask
    profile.update({
        'count': 1,
        'dtype': 'uint16',
        'compress': 'lzw',
        'nodata': 0
    })
    
    print(f"\nSaving files...")
    with rasterio.open(clean_mask_path, "w", **profile) as dst:
        dst.write(instance_mask.astype(np.uint16), 1)
    print(f"  ✓ Instance mask: {clean_mask_path}")
    
    # Save boundaries
    if CREATE_BOUNDARY_FILE:
        profile_boundary = profile.copy()
        profile_boundary['dtype'] = 'uint8'
        
        with rasterio.open(boundaries_path, "w", **profile_boundary) as dst:
            dst.write(boundaries, 1)
        print(f"  ✓ Boundaries: {boundaries_path}")
    
    # Save GeoJSON
    if SAVE_GEOJSON:
        save_geojson(polygons, polygons_geojson, crs)
    
    # Visualize
    print(f"\nGenerating visualization...")
    visualize_results(instance_mask, boundaries, polygons, "polygons_visualization.png")
    
    print(f"\n{'='*60}")
    print(f"COMPLETE!")
    print(f"{'='*60}")
    print(f"\nOutput files:")
    print(f"  1. {clean_mask_path} - Instance mask (unique ID per object)")
    print(f"  2. {boundaries_path} - Boundaries")
    print(f"  3. {polygons_geojson} - Polygons (GeoJSON) ← USE THIS!")
    print(f"  4. polygons_visualization.png - Visualization")
    print(f"\n✓ Each object has its own polygon!")
    
    return instance_mask, boundaries, polygons

if __name__ == "__main__":
    try:
        instance_mask, boundaries, polygons = main()
    except Exception as e:
        print(f"\nERROR: {e}")
        import traceback
        traceback.print_exc()

In [None]:
import rasterio
import numpy as np
import geopandas as gpd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# --- Configuration ---
polygons_geojson = "polygons.geojson"
original_image_path = "tiff_testing/tile_28672_16384.tif"

GRID_COLS = 4  # Number of columns in the grid
POLYGON_COLOR = 'red'  # Boundary color
LINE_WIDTH = 2  # Boundary thickness
SHOW_POLYGON_ID = True  # Show polygon ID in title

def load_polygons(geojson_path):
    """Load polygons from GeoJSON file"""
    print(f"Loading polygons from: {geojson_path}")
    gdf = gpd.read_file(geojson_path)
    print(f"✓ Loaded {len(gdf)} polygons")
    
    # Show some info
    if 'area_pixels' in gdf.columns:
        print(f"\nPolygon Statistics:")
        print(f"  Area range: {gdf['area_pixels'].min():,} - {gdf['area_pixels'].max():,} pixels")
        print(f"  Average area: {gdf['area_pixels'].mean():,.0f} pixels")
    
    return gdf

def load_original_image(image_path):
    """Load original raster image"""
    print(f"\nLoading original image: {image_path}")
    
    with rasterio.open(image_path) as src:
        img_data = src.read()
        img_transform = src.transform
        img_crs = src.crs
        
        # Calculate extent for plotting
        img_extent = [
            img_transform[2],  # left
            img_transform[2] + img_transform[0] * src.width,  # right
            img_transform[5] + img_transform[4] * src.height,  # bottom
            img_transform[5]  # top
        ]
    
    print(f"✓ Image shape: {img_data.shape}")
    print(f"✓ Image CRS: {img_crs}")
    
    # Prepare RGB or grayscale display
    if img_data.shape[0] >= 3:
        # RGB image
        rgb = np.transpose(img_data[:3], (1, 2, 0))
        
        # Normalize to 0-255 if needed
        if rgb.max() > 255:
            rgb = (rgb / rgb.max() * 255).astype(np.uint8)
        
        display_img = rgb
        cmap = None
    else:
        # Grayscale
        display_img = img_data[0]
        cmap = 'gray'
    
    return display_img, img_extent, cmap

def visualize_polygons_grid(gdf, display_img, img_extent, cmap=None, 
                           n_cols=4, figsize_per_plot=(4, 4)):
    """
    Visualize each polygon separately in a grid
    
    Parameters:
    -----------
    gdf : GeoDataFrame
        Polygons to visualize
    display_img : numpy.ndarray
        Image to display as background
    img_extent : list
        [left, right, bottom, top] extent for image
    cmap : str, optional
        Colormap for grayscale images
    n_cols : int
        Number of columns in grid
    figsize_per_plot : tuple
        Size of each subplot
    """
    n_polygons = len(gdf)
    n_rows = int(np.ceil(n_polygons / n_cols))
    
    fig_width = figsize_per_plot[0] * n_cols
    fig_height = figsize_per_plot[1] * n_rows
    
    print(f"\nCreating visualization grid:")
    print(f"  Grid: {n_rows} rows × {n_cols} cols")
    print(f"  Figure size: {fig_width} × {fig_height}")
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height))
    axes = axes.flatten() if n_polygons > 1 else [axes]
    
    # Plot each polygon
    for idx, (_, row) in enumerate(gdf.iterrows()):
        ax = axes[idx]
        
        # Display original image
        if cmap:
            ax.imshow(display_img, cmap=cmap, extent=img_extent)
        else:
            ax.imshow(display_img, extent=img_extent)
        
        # Plot this polygon
        gdf_single = gpd.GeoDataFrame(geometry=[row.geometry], crs=gdf.crs)
        gdf_single.boundary.plot(ax=ax, edgecolor=POLYGON_COLOR, linewidth=LINE_WIDTH)
        
        # Set title
        if SHOW_POLYGON_ID and 'id' in gdf.columns:
            title = f"Polygon {idx+1} (ID: {row['id']})"
        else:
            title = f"Polygon {idx+1}"
        
        # Add area to title if available
        if 'area_pixels' in gdf.columns:
            title += f"\n{row['area_pixels']:,} px"
        
        ax.set_title(title, fontsize=10, fontweight='bold')
        ax.axis("off")
    
    # Hide unused subplots
    for i in range(n_polygons, len(axes)):
        axes[i].axis("off")
    
    plt.tight_layout()
    
    # Save figure
    output_path = "polygons_grid_visualization.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\n✓ Saved visualization: {output_path}")
    
    plt.show()

def visualize_polygons_overlay(gdf, display_img, img_extent, cmap=None):
    """
    Visualize all polygons overlaid on the original image
    """
    fig, ax = plt.subplots(figsize=(12, 12))
    
    # Display original image
    if cmap:
        ax.imshow(display_img, cmap=cmap, extent=img_extent)
    else:
        ax.imshow(display_img, extent=img_extent)
    
    # Plot all polygons
    gdf.boundary.plot(ax=ax, edgecolor=POLYGON_COLOR, linewidth=LINE_WIDTH)
    
    ax.set_title(f"All {len(gdf)} Polygons Overlay", fontsize=16, fontweight='bold')
    ax.axis("off")
    
    plt.tight_layout()
    
    # Save figure
    output_path = "polygons_overlay_visualization.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\n✓ Saved overlay visualization: {output_path}")
    
    plt.show()

def main():
    """Main visualization pipeline"""
    
    print("=" * 60)
    print("VISUALIZE POLYGONS FROM GEOJSON")
    print("=" * 60)
    
    # Load polygons
    gdf = load_polygons(polygons_geojson)
    
    # Load original image
    display_img, img_extent, cmap = load_original_image(original_image_path)
    
    # Visualize each polygon separately in grid
    print("\n" + "=" * 60)
    print("Creating grid visualization (one polygon per subplot)...")
    print("=" * 60)
    visualize_polygons_grid(gdf, display_img, img_extent, cmap, n_cols=GRID_COLS)
    
    # Visualize all polygons overlaid
    print("\n" + "=" * 60)
    print("Creating overlay visualization (all polygons)...")
    print("=" * 60)
    visualize_polygons_overlay(gdf, display_img, img_extent, cmap)
    
    print("\n" + "=" * 60)
    print("VISUALIZATION COMPLETE!")
    print("=" * 60)
    print("\nOutput files:")
    print("  1. polygons_grid_visualization.png - Each polygon separate")
    print("  2. polygons_overlay_visualization.png - All polygons overlaid")

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"\nERROR: {e}")
        import traceback
        traceback.print_exc()