In [None]:
import rasterio
import numpy as np
import cv2
from rasterio.features import shapes
from shapely.geometry import shape, mapping
from scipy import ndimage
import matplotlib.pyplot as plt
from pathlib import Path
import json

# --- Configuration ---
mask_path = "outputs/tile_24576_24576_masks.tif"
clean_mask_path = "masks_clean.tif"
boundaries_path = "masks_boundaries.tif"
polygons_geojson = "polygons.geojson"

# --- Watershed Separation ---
USE_WATERSHED_SEPARATION = False
WATERSHED_THRESHOLD = 0.25

# --- Morphological Processing ---
USE_MORPHOLOGY = True
MORPH_CLOSE_KERNEL = 3
MORPH_CLOSE_ITERATIONS = 1

# DILATION STRATEGY: Expand objects to close gaps
MORPH_DILATE_KERNEL = 5      # Expand by 5px (closes gaps before merging!)
MORPH_DILATE_ITERATIONS = 1

# EROSION: Shrink back after merging (MUST MATCH DILATION)
MORPH_ERODE_KERNEL = 5       # Shrink by 5px (after merging)
MORPH_ERODE_ITERATIONS = 1

# --- SMART MERGING (Two-Pass) ---
USE_SMART_MERGING = True
MERGE_THRESHOLD = 30000
MERGE_SEARCH_DISTANCE = 8

# --- Filtering Parameters (Applied AFTER merging) ---
MIN_AREA = 5000
MAX_AREA = 20000000

# Shape quality filters
MIN_COMPACTNESS = 0.005
MIN_SOLIDITY = 0.4
MAX_ASPECT_RATIO = 150.0
MIN_CONVEXITY = 0.4
MIN_EXTENT = 0.1

BOUNDARY_THICKNESS = 5
CREATE_BOUNDARY_FILE = True
SAVE_GEOJSON = True

def load_sam_mask(mask_path):
    """Load SAM mask and convert to instance mask"""
    with rasterio.open(mask_path) as src:
        mask_data = src.read()
        profile = src.profile.copy()
        transform = src.transform
        crs = src.crs
        
    print(f"\n{'='*60}")
    print(f"MASK ANALYSIS")
    print(f"{'='*60}")
    print(f"Shape: {mask_data.shape}")
    print(f"Dtype: {mask_data.dtype}")
    print(f"Bands: {mask_data.shape[0] if mask_data.ndim == 3 else 1}")
    
    if mask_data.ndim == 3 and mask_data.shape[0] > 1:
        print(f"\n‚úì Multi-band mask: {mask_data.shape[0]} bands")
        instance_mask = np.zeros(mask_data.shape[1:], dtype=np.uint16)
        print(f"\nAnalyzing bands:")
        valid_bands = 0
        for band_idx in range(mask_data.shape[0]):
            band = mask_data[band_idx]
            nonzero = np.count_nonzero(band > 0)
            if nonzero > 0:
                valid_bands += 1
                band_mask = band > 0
                instance_mask[band_mask & (instance_mask == 0)] = band_idx + 1
                if band_idx < 10:
                    print(f"  Band {band_idx+1}: {nonzero:,} pixels")
        if mask_data.shape[0] > 10:
            print(f"  ... and {mask_data.shape[0] - 10} more bands")
        print(f"\n‚úì Converted {valid_bands} bands")
    else:
        instance_mask = mask_data.squeeze()
        unique_vals = np.unique(instance_mask)
        unique_vals = unique_vals[unique_vals > 0]
        if len(unique_vals) > 1:
            print(f"\n‚úì Single-band instance mask: {len(unique_vals)} objects")
        else:
            print(f"\n‚ö† Binary mask - separating connected regions")
            instance_mask, num = ndimage.label(instance_mask > 0)
            print(f"‚Üí Found {num} connected regions")
    
    num_instances = len(np.unique(instance_mask)) - 1
    total_pixels = np.sum(instance_mask > 0)
    
    print(f"\n{'='*60}")
    print(f"RESULT: {num_instances} objects, {total_pixels:,} pixels")
    print(f"{'='*60}")
    
    return instance_mask, profile, transform, crs

def apply_watershed_separation(instance_mask):
    """Apply watershed algorithm to separate touching objects"""
    if not USE_WATERSHED_SEPARATION:
        return instance_mask
    
    print(f"\n{'='*60}")
    print(f"WATERSHED SEPARATION")
    print(f"{'='*60}")
    
    binary_mask = (instance_mask > 0).astype(np.uint8)
    dist = cv2.distanceTransform(binary_mask, cv2.DIST_L2, 3)
    dist_norm = cv2.normalize(dist, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    _, peaks = cv2.threshold(dist_norm, WATERSHED_THRESHOLD * dist_norm.max(), 255, cv2.THRESH_BINARY)
    peaks = peaks.astype(np.uint8)
    _, markers = cv2.connectedComponents(peaks)
    kernel = np.ones((2, 2), np.uint8)
    markers = cv2.dilate(markers.astype(np.uint8), kernel, iterations=1)
    markers = markers.astype(np.int32)
    markers[binary_mask == 0] = 0
    binary_3ch = cv2.cvtColor(binary_mask * 255, cv2.COLOR_GRAY2BGR)
    markers = cv2.watershed(binary_3ch, markers)
    separated_mask = np.where(markers > 0, markers, 0).astype(np.uint16)
    
    before_count = len(np.unique(instance_mask)) - 1
    after_count = len(np.unique(separated_mask)) - 1
    
    print(f"‚úì Threshold: {WATERSHED_THRESHOLD}")
    print(f"  Objects before: {before_count}")
    print(f"  Objects after: {after_count}")
    
    return separated_mask

def apply_pre_merge_morphology(instance_mask):
    """Pre-merge: Fill holes and dilate to close gaps"""
    if not USE_MORPHOLOGY:
        return instance_mask
    
    print(f"\n{'='*60}")
    print(f"MORPHOLOGICAL PROCESSING (Pre-Merge)")
    print(f"{'='*60}")
    
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    improved_mask = np.zeros_like(instance_mask, dtype=np.uint16)
    
    for inst_id in unique_ids:
        obj_mask = (instance_mask == inst_id).astype(np.uint8)
        
        # Closing: Fill holes
        if MORPH_CLOSE_KERNEL > 0 and MORPH_CLOSE_ITERATIONS > 0:
            close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, 
                                                     (MORPH_CLOSE_KERNEL, MORPH_CLOSE_KERNEL))
            obj_mask = cv2.morphologyEx(obj_mask, cv2.MORPH_CLOSE, close_kernel, 
                                       iterations=MORPH_CLOSE_ITERATIONS)
        
        # Dilation: Expand to close gaps
        if MORPH_DILATE_KERNEL > 0 and MORPH_DILATE_ITERATIONS > 0:
            dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, 
                                                      (MORPH_DILATE_KERNEL, MORPH_DILATE_KERNEL))
            obj_mask = cv2.dilate(obj_mask, dilate_kernel, iterations=MORPH_DILATE_ITERATIONS)
        
        improved_mask[obj_mask > 0] = inst_id
    
    print(f"‚úì Applied pre-merge operations:")
    if MORPH_CLOSE_KERNEL > 0:
        print(f"  ‚Ä¢ Closing: kernel={MORPH_CLOSE_KERNEL}, iterations={MORPH_CLOSE_ITERATIONS}")
    if MORPH_DILATE_KERNEL > 0:
        print(f"  ‚Ä¢ Dilation: kernel={MORPH_DILATE_KERNEL}, iterations={MORPH_DILATE_ITERATIONS}")
        print(f"    ‚Üí Objects expanded by ~{MORPH_DILATE_KERNEL}px (closes gaps!)")
    
    before_pixels = np.sum(instance_mask > 0)
    after_pixels = np.sum(improved_mask > 0)
    change = after_pixels - before_pixels
    change_pct = (change / before_pixels * 100) if before_pixels > 0 else 0
    
    print(f"\nüìä Pixel changes:")
    print(f"  Before: {before_pixels:,} pixels")
    print(f"  After: {after_pixels:,} pixels")
    print(f"  Change: {change:+,} pixels ({change_pct:+.1f}%)")
    
    return improved_mask

def apply_post_merge_morphology(instance_mask):
    """Post-merge: Erode to restore original size"""
    if not USE_MORPHOLOGY:
        return instance_mask
    
    if MORPH_ERODE_KERNEL == 0:
        print(f"\n‚ö† Post-merge erosion disabled (MORPH_ERODE_KERNEL = 0)")
        return instance_mask
    
    print(f"\n{'='*60}")
    print(f"MORPHOLOGICAL PROCESSING (Post-Merge)")
    print(f"{'='*60}")
    
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    improved_mask = np.zeros_like(instance_mask, dtype=np.uint16)
    
    for inst_id in unique_ids:
        obj_mask = (instance_mask == inst_id).astype(np.uint8)
        
        # Erosion: Shrink back to original size
        if MORPH_ERODE_KERNEL > 0 and MORPH_ERODE_ITERATIONS > 0:
            erode_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, 
                                                     (MORPH_ERODE_KERNEL, MORPH_ERODE_KERNEL))
            obj_mask = cv2.erode(obj_mask, erode_kernel, iterations=MORPH_ERODE_ITERATIONS)
        
        improved_mask[obj_mask > 0] = inst_id
    
    print(f"‚úì Applied post-merge operations:")
    print(f"  ‚Ä¢ Erosion: kernel={MORPH_ERODE_KERNEL}, iterations={MORPH_ERODE_ITERATIONS}")
    print(f"    ‚Üí Objects shrunk back by ~{MORPH_ERODE_KERNEL}px")
    
    before_pixels = np.sum(instance_mask > 0)
    after_pixels = np.sum(improved_mask > 0)
    change = after_pixels - before_pixels
    change_pct = (change / before_pixels * 100) if before_pixels > 0 else 0
    
    print(f"\nüìä Pixel changes:")
    print(f"  Before: {before_pixels:,} pixels")
    print(f"  After: {after_pixels:,} pixels")
    print(f"  Change: {change:+,} pixels ({change_pct:+.1f}%)")
    
    return improved_mask

def merge_small_objects(instance_mask):
    """Two-pass merging: small‚Üísmall, then small‚Üílarge"""
    if not USE_SMART_MERGING:
        return instance_mask
    
    print(f"\n{'='*60}")
    print(f"SMART MERGING (Two-Pass Strategy)")
    print(f"{'='*60}")
    print(f"Merge threshold: {MERGE_THRESHOLD:,} pixels")
    print(f"Search distance: {MERGE_SEARCH_DISTANCE}px")
    
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    # Calculate initial areas
    object_areas = {}
    for obj_id in unique_ids:
        area = np.sum(instance_mask == obj_id)
        object_areas[obj_id] = area
    
    # Classify objects
    small_objects = set([obj_id for obj_id, area in object_areas.items() if area < MERGE_THRESHOLD])
    large_objects = set([obj_id for obj_id, area in object_areas.items() if area >= MERGE_THRESHOLD])
    
    print(f"\nüìä Initial classification:")
    print(f"  Large objects (‚â•{MERGE_THRESHOLD:,}px): {len(large_objects)}")
    print(f"  Small objects (<{MERGE_THRESHOLD:,}px): {len(small_objects)}")
    
    if len(small_objects) == 0:
        print(f"\n‚úì No small objects to merge!")
        return instance_mask
    
    search_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, 
                                              (MERGE_SEARCH_DISTANCE, MERGE_SEARCH_DISTANCE))
    
    # PASS 1: Small ‚Üí Small
    print(f"\n{'‚îÄ'*60}")
    print(f"PASS 1: Merging Small ‚Üí Small")
    print(f"{'‚îÄ'*60}")
    
    merged_mask = instance_mask.copy()
    merge_map = {}
    small_to_small_merges = 0
    
    sorted_small = sorted(small_objects, key=lambda x: object_areas[x], reverse=True)
    
    for small_id in sorted_small:
        if small_id in merge_map:
            continue
        
        small_mask = (merged_mask == small_id).astype(np.uint8)
        dilated = cv2.dilate(small_mask, search_kernel, iterations=1)
        neighbor_region = dilated > 0
        neighbor_ids = np.unique(merged_mask[neighbor_region])
        neighbor_ids = neighbor_ids[(neighbor_ids > 0) & (neighbor_ids != small_id)]
        small_neighbors = [nid for nid in neighbor_ids 
                          if nid in small_objects and nid not in merge_map]
        
        if len(small_neighbors) > 0:
            for neighbor_id in small_neighbors:
                neighbor_mask = merged_mask == neighbor_id
                merged_mask[neighbor_mask] = small_id
                merge_map[neighbor_id] = small_id
                small_to_small_merges += 1
    
    print(f"‚úì Pass 1 complete: {small_to_small_merges} merges")
    
    # Recalculate areas
    object_areas_pass1 = {}
    remaining_ids = np.unique(merged_mask)
    remaining_ids = remaining_ids[remaining_ids > 0]
    
    for obj_id in remaining_ids:
        area = np.sum(merged_mask == obj_id)
        object_areas_pass1[obj_id] = area
    
    small_objects_pass1 = set([obj_id for obj_id, area in object_areas_pass1.items() 
                               if area < MERGE_THRESHOLD])
    large_objects_pass1 = set([obj_id for obj_id, area in object_areas_pass1.items() 
                               if area >= MERGE_THRESHOLD])
    
    promoted = len(large_objects_pass1) - len(large_objects)
    
    print(f"\nüìä After Pass 1:")
    print(f"  Large objects: {len(large_objects_pass1)} (+{promoted} promoted)")
    print(f"  Small objects: {len(small_objects_pass1)} (remaining)")
    
    # PASS 2: Small ‚Üí Large
    print(f"\n{'‚îÄ'*60}")
    print(f"PASS 2: Merging Small ‚Üí Large")
    print(f"{'‚îÄ'*60}")
    
    small_to_large_merges = 0
    orphan_count = 0
    
    for small_id in small_objects_pass1:
        small_mask = (merged_mask == small_id).astype(np.uint8)
        dilated = cv2.dilate(small_mask, search_kernel, iterations=1)
        neighbor_region = dilated > 0
        neighbor_ids = np.unique(merged_mask[neighbor_region])
        neighbor_ids = neighbor_ids[(neighbor_ids > 0) & (neighbor_ids != small_id)]
        large_neighbors = [nid for nid in neighbor_ids if nid in large_objects_pass1]
        
        if len(large_neighbors) > 0:
            neighbor_areas = [(nid, object_areas_pass1[nid]) for nid in large_neighbors]
            target_id = max(neighbor_areas, key=lambda x: x[1])[0]
            merged_mask[small_mask > 0] = target_id
            small_to_large_merges += 1
        else:
            orphan_count += 1
    
    print(f"‚úì Pass 2 complete: {small_to_large_merges} merges, {orphan_count} orphans")
    
    before_count = len(unique_ids)
    after_count = len(np.unique(merged_mask)) - 1
    
    print(f"\n{'='*60}")
    print(f"MERGING SUMMARY")
    print(f"{'='*60}")
    print(f"  Objects before: {before_count}")
    print(f"  Objects after: {after_count}")
    print(f"  Eliminated: {before_count - after_count}")
    print(f"    ‚Ä¢ Small‚ÜíSmall: {small_to_small_merges}")
    print(f"    ‚Ä¢ Small‚ÜíLarge: {small_to_large_merges}")
    print(f"    ‚Ä¢ Orphans: {orphan_count}")
    
    return merged_mask

def calculate_shape_metrics(obj_mask):
    """Calculate shape quality metrics"""
    contours, _ = cv2.findContours(obj_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if len(contours) == 0:
        return None
    
    contour = max(contours, key=cv2.contourArea)
    area = cv2.contourArea(contour)
    perimeter = cv2.arcLength(contour, True)
    
    if area == 0 or perimeter == 0:
        return None
    
    x, y, w, h = cv2.boundingRect(contour)
    hull = cv2.convexHull(contour)
    hull_area = cv2.contourArea(hull)
    hull_perimeter = cv2.arcLength(hull, True)
    
    metrics = {
        'area': area,
        'perimeter': perimeter,
        'width': w,
        'height': h,
        'compactness': (4 * np.pi * area) / (perimeter ** 2) if perimeter > 0 else 0,
        'solidity': area / hull_area if hull_area > 0 else 0,
        'aspect_ratio': max(w, h) / min(w, h) if min(w, h) > 0 else 0,
        'extent': area / (w * h) if (w * h) > 0 else 0,
        'convexity': hull_perimeter / perimeter if perimeter > 0 else 0,
    }
    
    return metrics

def filter_objects_by_quality(instance_mask):
    """Filter objects based on size and shape quality"""
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    cleaned_mask = np.zeros_like(instance_mask, dtype=np.uint16)
    new_id = 1
    
    removed_reasons = {
        'too_small': 0, 'too_large': 0, 'low_compactness': 0,
        'low_solidity': 0, 'high_aspect_ratio': 0,
        'low_convexity': 0, 'low_extent': 0
    }
    
    print(f"\n{'='*60}")
    print(f"FILTERING OBJECTS")
    print(f"{'='*60}")
    print(f"Criteria:")
    print(f"  Area: {MIN_AREA:,} - {MAX_AREA:,} pixels")
    print(f"  Compactness: ‚â• {MIN_COMPACTNESS:.3f}")
    print(f"  Solidity: ‚â• {MIN_SOLIDITY:.2f}")
    print(f"  Aspect ratio: ‚â§ {MAX_ASPECT_RATIO:.1f}")
    print(f"  Convexity: ‚â• {MIN_CONVEXITY:.2f}")
    print(f"  Extent: ‚â• {MIN_EXTENT:.2f}")
    
    kept_objects = []
    
    for inst_id in unique_ids:
        obj_mask = (instance_mask == inst_id).astype(np.uint8)
        metrics = calculate_shape_metrics(obj_mask)
        
        if metrics is None:
            removed_reasons['too_small'] += 1
            continue
        
        keep = True
        reason = None
        
        if metrics['area'] < MIN_AREA:
            keep, reason = False, 'too_small'
        elif metrics['area'] > MAX_AREA:
            keep, reason = False, 'too_large'
        elif metrics['compactness'] < MIN_COMPACTNESS:
            keep, reason = False, 'low_compactness'
        elif metrics['solidity'] < MIN_SOLIDITY:
            keep, reason = False, 'low_solidity'
        elif metrics['aspect_ratio'] > MAX_ASPECT_RATIO:
            keep, reason = False, 'high_aspect_ratio'
        elif metrics['convexity'] < MIN_CONVEXITY:
            keep, reason = False, 'low_convexity'
        elif metrics['extent'] < MIN_EXTENT:
            keep, reason = False, 'low_extent'
        
        if keep:
            cleaned_mask[obj_mask > 0] = new_id
            kept_objects.append(metrics)
            new_id += 1
        else:
            if reason:
                removed_reasons[reason] += 1
    
    total_removed = sum(removed_reasons.values())
    total_kept = new_id - 1
    
    print(f"\nüìä Results:")
    print(f"  ‚úì Kept: {total_kept} objects")
    print(f"  ‚úó Removed: {total_removed} objects")
    
    if total_removed > 0:
        print(f"\nüìã Removal reasons:")
        for reason, count in removed_reasons.items():
            if count > 0:
                print(f"  ‚Ä¢ {reason.replace('_', ' ').title()}: {count}")
    
    if kept_objects:
        areas = [m['area'] for m in kept_objects]
        print(f"\nüìà Kept objects stats:")
        print(f"  Area: {min(areas):,.0f} - {max(areas):,.0f} px")
    
    return cleaned_mask

def extract_polygons_per_object(instance_mask, transform):
    """Extract ONE polygon per object"""
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    polygons = []
    
    print(f"\n{'='*60}")
    print(f"EXTRACTING POLYGONS")
    print(f"{'='*60}")
    
    for inst_id in unique_ids:
        obj_mask = (instance_mask == inst_id).astype(np.uint8)
        area = np.sum(obj_mask)
        
        for geom, val in shapes(obj_mask, mask=obj_mask, transform=transform):
            if val > 0:
                poly = shape(geom)
                polygons.append({
                    'polygon': poly,
                    'id': int(inst_id),
                    'area': int(area)
                })
                break
    
    print(f"‚úì Extracted {len(polygons)} polygons")
    
    if polygons:
        areas = [p['area'] for p in polygons]
        print(f"  Area range: {min(areas):,} - {max(areas):,} pixels")
    
    return polygons

def save_geojson(polygons, output_path, crs):
    """Save polygons to GeoJSON"""
    features = []
    
    for poly_data in polygons:
        feature = {
            "type": "Feature",
            "properties": {
                "id": poly_data['id'],
                "area_pixels": poly_data['area']
            },
            "geometry": mapping(poly_data['polygon'])
        }
        features.append(feature)
    
    geojson = {
        "type": "FeatureCollection",
        "crs": {
            "type": "name",
            "properties": {"name": str(crs) if crs else "EPSG:4326"}
        },
        "features": features
    }
    
    with open(output_path, 'w') as f:
        json.dump(geojson, f, indent=2)
    
    print(f"\n‚úì GeoJSON saved: {output_path}")
    print(f"  ‚Üí {len(features)} polygons")

def extract_boundaries(instance_mask, thickness=4):
    """Extract boundary for each object"""
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    boundaries = np.zeros_like(instance_mask, dtype=np.uint8)
    kernel = np.ones((3, 3), np.uint8)
    
    for inst_id in unique_ids:
        obj_mask = (instance_mask == inst_id).astype(np.uint8)
        edges = cv2.morphologyEx(obj_mask, cv2.MORPH_GRADIENT, kernel)
        boundaries = np.maximum(boundaries, edges)
    
    if thickness > 1:
        thick_kernel = np.ones((thickness, thickness), np.uint8)
        boundaries = cv2.dilate(boundaries, thick_kernel, iterations=1)
    
    return boundaries

def visualize_results(instance_mask, boundaries, polygons, save_path):
    """Create visualization"""
    binary_mask = (instance_mask > 0).astype(np.uint8)
    num_objects = len(np.unique(instance_mask)) - 1
    
    fig, axs = plt.subplots(2, 2, figsize=(18, 18))
    
    axs[0, 0].imshow(instance_mask, cmap="nipy_spectral", interpolation='nearest')
    axs[0, 0].set_title(f"Instance Mask (Smart Merge + Dilation)\n({num_objects} objects)", 
                        fontsize=14, fontweight='bold')
    axs[0, 0].axis("off")
    
    rgb = np.stack([binary_mask, binary_mask, binary_mask], axis=-1).astype(float)
    rgb[boundaries > 0] = [1.0, 0, 0]
    
    axs[0, 1].imshow(rgb, interpolation='nearest')
    axs[0, 1].set_title(f"Objects with Boundaries\n({len(polygons)} polygons)", 
                        fontsize=14, fontweight='bold')
    axs[0, 1].axis("off")
    
    axs[1, 0].imshow(binary_mask, cmap="gray", interpolation='nearest')
    axs[1, 0].set_title(f"Binary Mask", fontsize=14, fontweight='bold')
    axs[1, 0].axis("off")
    
    boundary_display = np.zeros_like(instance_mask, dtype=np.uint8)
    boundary_display[boundaries > 0] = 255
    
    axs[1, 1].imshow(boundary_display, cmap="gray", interpolation='nearest')
    axs[1, 1].set_title(f"Boundaries ({BOUNDARY_THICKNESS}px)", 
                        fontsize=14, fontweight='bold')
    axs[1, 1].axis("off")
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\n‚úì Visualization saved: {save_path}")
    plt.show()

def main():
    """Main pipeline"""
    
    if not Path(mask_path).exists():
        raise FileNotFoundError(f"Mask file not found: {mask_path}")
    
    print(f"\n{'='*60}")
    print(f"SAM MASK ‚Üí CLEAN POLYGONS")
    print(f"WITH DILATION + TWO-PASS MERGING")
    print(f"{'='*60}")
    print(f"\nInput: {mask_path}")
    
    # Load mask
    instance_mask, profile, transform, crs = load_sam_mask(mask_path)
    
    # Watershed (if enabled)
    instance_mask = apply_watershed_separation(instance_mask)
    
    # PRE-MERGE: Fill holes and dilate to close gaps
    instance_mask = apply_pre_merge_morphology(instance_mask)
    
    # MERGE: Two-pass small object merging
    instance_mask = merge_small_objects(instance_mask)
    
    # POST-MERGE: Erode back to original size
    instance_mask = apply_post_merge_morphology(instance_mask)
    
    # FILTER: Remove orphans and bad shapes
    instance_mask = filter_objects_by_quality(instance_mask)
    
    # Extract polygons
    polygons = extract_polygons_per_object(instance_mask, transform)
    
    # Extract boundaries
    print(f"\n{'='*60}")
    print(f"EXTRACTING BOUNDARIES")
    print(f"{'='*60}")
    boundaries = extract_boundaries(instance_mask, BOUNDARY_THICKNESS)
    print(f"‚úì Boundary pixels: {np.sum(boundaries > 0):,}")
    
    # Save files
    profile.update({'count': 1, 'dtype': 'uint16', 'compress': 'lzw', 'nodata': 0})
    
    print(f"\n{'='*60}")
    print(f"SAVING FILES")
    print(f"{'='*60}")
    
    with rasterio.open(clean_mask_path, "w", **profile) as dst:
        dst.write(instance_mask.astype(np.uint16), 1)
    print(f"‚úì Instance mask: {clean_mask_path}")
    
    if CREATE_BOUNDARY_FILE:
        profile_boundary = profile.copy()
        profile_boundary['dtype'] = 'uint8'
        with rasterio.open(boundaries_path, "w", **profile_boundary) as dst:
            dst.write(boundaries, 1)
        print(f"‚úì Boundaries: {boundaries_path}")
    
    if SAVE_GEOJSON:
        save_geojson(polygons, polygons_geojson, crs)
    
    # Visualize
    print(f"\n{'='*60}")
    print(f"GENERATING VISUALIZATION")
    print(f"{'='*60}")
    visualize_results(instance_mask, boundaries, polygons, "polygons_visualization.png")
    
    print(f"\n{'='*60}")
    print(f"‚úì COMPLETE!")
    print(f"{'='*60}")
    print(f"\nOutput files:")
    print(f"  1. {clean_mask_path} - Filtered instance mask")
    print(f"  2. {boundaries_path} - Boundaries")
    print(f"  3. {polygons_geojson} - Polygons (GeoJSON)")
    print(f"  4. polygons_visualization.png - Visualization")
    
    print(f"\n‚úì Processing:")
    print(f"  ‚Ä¢ Pre-merge dilation ({MORPH_DILATE_KERNEL}px) - closed gaps")
    print(f"  ‚Ä¢ Two-pass merging - consolidated fragments")
    print(f"  ‚Ä¢ Post-merge erosion ({MORPH_ERODE_KERNEL}px) - restored size")
    print(f"  ‚Ä¢ Quality filtering - removed artifacts")
    
    return instance_mask, boundaries, polygons

if __name__ == "__main__":
    try:
        instance_mask, boundaries, polygons = main()
    except Exception as e:
        print(f"\nERROR: {e}")
        import traceback
        traceback.print_exc()

In [10]:
import rasterio
import numpy as np
import cv2
from rasterio.features import shapes
from shapely.geometry import shape, mapping
from scipy import ndimage
import matplotlib.pyplot as plt
from pathlib import Path
import json

# --- Configuration ---
# mask_path = "outputs/tile_22528_36864_masks.tif"
clean_mask_path = "masks_clean.tif"
boundaries_path = "masks_boundaries.tif"
polygons_geojson = "polygons.geojson"

# --- Watershed Separation (for touching objects) ---
USE_WATERSHED_SEPARATION = True  # Separate touching objects using watershed
WATERSHED_THRESHOLD = 0.1        # Lower = more aggressive separation (0.1-0.5)

# --- Morphological Processing ---
USE_MORPHOLOGY = False        # Apply morphological operations to improve shapes
MORPH_CLOSE_KERNEL = 3       # Kernel size for closing (fill small holes)
MORPH_CLOSE_ITERATIONS = 1   # How many times to apply closing
MORPH_DILATE_KERNEL = 0      # Kernel size for dilation (expand objects)
MORPH_DILATE_ITERATIONS = 0  # How many times to dilate
MORPH_ERODE_KERNEL = 0       # Kernel size for erosion (shrink back slightly)
MORPH_ERODE_ITERATIONS = 0   # How many times to erode (use to smooth after dilation)

# --- Filtering Parameters ---
MIN_AREA = 10000              # Minimum area in pixels (increase to remove smaller blobs)
MAX_AREA = 20000000          # Maximum area (remove if too large)

# Shape quality filters
MIN_COMPACTNESS = 0.01       # Remove irregular shapes (0-1, higher = more compact)
MIN_SOLIDITY = 0.05          # Remove fragmented shapes (0-1, area/convex_hull)
MAX_ASPECT_RATIO = 100.0      # Remove elongated shapes (width/height ratio)
MIN_CONVEXITY = 0.5         # Remove concave/irregular shapes (0-1)
MIN_EXTENT = 0.10            # Remove thin/sparse shapes (area/bounding_box)

BOUNDARY_THICKNESS = 5      
CREATE_BOUNDARY_FILE = True
SAVE_GEOJSON = True

def load_sam_mask(mask_path):
    """Load SAM mask and convert to instance mask"""
    with rasterio.open(mask_path) as src:
        mask_data = src.read()
        profile = src.profile.copy()
        transform = src.transform
        crs = src.crs
        
    print(f"\n{'='*60}")
    print(f"MASK ANALYSIS")
    print(f"{'='*60}")
    print(f"Shape: {mask_data.shape}")
    print(f"Dtype: {mask_data.dtype}")
    print(f"Bands: {mask_data.shape[0] if mask_data.ndim == 3 else 1}")
    
    # Handle multi-band masks
    if mask_data.ndim == 3 and mask_data.shape[0] > 1:
        print(f"\n‚úì Multi-band mask: {mask_data.shape[0]} bands")
        
        instance_mask = np.zeros(mask_data.shape[1:], dtype=np.uint16)
        
        print(f"\nAnalyzing bands:")
        valid_bands = 0
        for band_idx in range(mask_data.shape[0]):
            band = mask_data[band_idx]
            nonzero = np.count_nonzero(band > 0)
            
            if nonzero > 0:
                valid_bands += 1
                band_mask = band > 0
                instance_mask[band_mask & (instance_mask == 0)] = band_idx + 1
                
                if band_idx < 10:
                    print(f"  Band {band_idx+1}: {nonzero:,} pixels")
        
        if mask_data.shape[0] > 10:
            print(f"  ... and {mask_data.shape[0] - 10} more bands")
        
        print(f"\n‚úì Converted {valid_bands} bands")
        
    else:
        instance_mask = mask_data.squeeze()
        unique_vals = np.unique(instance_mask)
        unique_vals = unique_vals[unique_vals > 0]
        
        if len(unique_vals) > 1:
            print(f"\n‚úì Single-band instance mask: {len(unique_vals)} objects")
        else:
            print(f"\n‚ö† Binary mask - separating connected regions")
            instance_mask, num = ndimage.label(instance_mask > 0)
            print(f"‚Üí Found {num} connected regions")
    
    num_instances = len(np.unique(instance_mask)) - 1
    total_pixels = np.sum(instance_mask > 0)
    
    print(f"\n{'='*60}")
    print(f"RESULT: {num_instances} objects, {total_pixels:,} pixels")
    print(f"{'='*60}")
    
    return instance_mask, profile, transform, crs

def apply_watershed_separation(instance_mask):
    """
    Apply watershed algorithm to separate touching objects
    This is useful when SAM detects objects that are touching/overlapping
    
    Parameters:
    -----------
    instance_mask : numpy.ndarray
        Instance mask where objects might be touching
    
    Returns:
    --------
    numpy.ndarray : Instance mask with separated objects
    """
    if not USE_WATERSHED_SEPARATION:
        return instance_mask
    
    print(f"\n{'='*60}")
    print(f"WATERSHED SEPARATION (Separate Touching Objects)")
    print(f"{'='*60}")
    
    # Create binary mask of all objects
    binary_mask = (instance_mask > 0).astype(np.uint8)
    
    # Distance transform to find object centers
    dist = cv2.distanceTransform(binary_mask, cv2.DIST_L2, 3)
    
    # Normalize distance
    dist_norm = cv2.normalize(dist, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    
    # Find local maxima (object centers) with configurable threshold
    _, peaks = cv2.threshold(dist_norm, WATERSHED_THRESHOLD * dist_norm.max(), 255, cv2.THRESH_BINARY)
    peaks = peaks.astype(np.uint8)
    
    # Label the peaks (these become watershed markers)
    _, markers = cv2.connectedComponents(peaks)
    
    # Small dilation of markers to ensure they're inside objects
    kernel = np.ones((2, 2), np.uint8)
    markers = cv2.dilate(markers.astype(np.uint8), kernel, iterations=1)
    markers = markers.astype(np.int32)
    
    # Mark background
    markers[binary_mask == 0] = 0
    
    # Apply watershed
    binary_3ch = cv2.cvtColor(binary_mask * 255, cv2.COLOR_GRAY2BGR)
    markers = cv2.watershed(binary_3ch, markers)
    
    # Create separated instance mask
    # Watershed boundaries are marked as -1, we set them to 0
    separated_mask = np.where(markers > 0, markers, 0).astype(np.uint16)
    
    # Count objects before and after
    before_count = len(np.unique(instance_mask)) - 1
    after_count = len(np.unique(separated_mask)) - 1
    
    print(f"‚úì Watershed separation applied:")
    print(f"  Threshold: {WATERSHED_THRESHOLD}")
    print(f"  Objects before: {before_count}")
    print(f"  Objects after: {after_count}")
    print(f"  New objects separated: {after_count - before_count}")
    
    return separated_mask

def apply_morphological_operations(instance_mask):
    """
    Apply morphological operations to improve object shapes
    - Closing: fills small holes and gaps
    - Dilation: expands objects, smooths edges
    - Erosion: shrinks objects back (optional, for smoothing)
    """
    if not USE_MORPHOLOGY:
        return instance_mask
    
    print(f"\n{'='*60}")
    print(f"MORPHOLOGICAL PROCESSING")
    print(f"{'='*60}")
    
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    improved_mask = np.zeros_like(instance_mask, dtype=np.uint16)
    
    for inst_id in unique_ids:
        obj_mask = (instance_mask == inst_id).astype(np.uint8)
        
        # 1. Morphological Closing (fill small holes and gaps)
        if MORPH_CLOSE_KERNEL > 0 and MORPH_CLOSE_ITERATIONS > 0:
            close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, 
                                                     (MORPH_CLOSE_KERNEL, MORPH_CLOSE_KERNEL))
            obj_mask = cv2.morphologyEx(obj_mask, cv2.MORPH_CLOSE, close_kernel, 
                                       iterations=MORPH_CLOSE_ITERATIONS)
        
        # 2. Dilation (expand object, smooth edges)
        if MORPH_DILATE_KERNEL > 0 and MORPH_DILATE_ITERATIONS > 0:
            dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, 
                                                      (MORPH_DILATE_KERNEL, MORPH_DILATE_KERNEL))
            obj_mask = cv2.dilate(obj_mask, dilate_kernel, iterations=MORPH_DILATE_ITERATIONS)
        
        # 3. Erosion (shrink back slightly to smooth, optional)
        if MORPH_ERODE_KERNEL > 0 and MORPH_ERODE_ITERATIONS > 0:
            erode_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, 
                                                     (MORPH_ERODE_KERNEL, MORPH_ERODE_KERNEL))
            obj_mask = cv2.erode(obj_mask, erode_kernel, iterations=MORPH_ERODE_ITERATIONS)
        
        # Assign to output mask
        improved_mask[obj_mask > 0] = inst_id
    
    print(f"‚úì Applied morphological operations:")
    if MORPH_CLOSE_KERNEL > 0:
        print(f"  ‚Ä¢ Closing: kernel={MORPH_CLOSE_KERNEL}, iterations={MORPH_CLOSE_ITERATIONS}")
    if MORPH_DILATE_KERNEL > 0:
        print(f"  ‚Ä¢ Dilation: kernel={MORPH_DILATE_KERNEL}, iterations={MORPH_DILATE_ITERATIONS}")
    if MORPH_ERODE_KERNEL > 0:
        print(f"  ‚Ä¢ Erosion: kernel={MORPH_ERODE_KERNEL}, iterations={MORPH_ERODE_ITERATIONS}")
    
    # Show before/after stats
    before_pixels = np.sum(instance_mask > 0)
    after_pixels = np.sum(improved_mask > 0)
    change = after_pixels - before_pixels
    change_pct = (change / before_pixels * 100) if before_pixels > 0 else 0
    
    print(f"\nüìä Pixel changes:")
    print(f"  Before: {before_pixels:,} pixels")
    print(f"  After: {after_pixels:,} pixels")
    print(f"  Change: {change:+,} pixels ({change_pct:+.1f}%)")
    
    return improved_mask

def calculate_shape_metrics(obj_mask):
    """
    Calculate shape quality metrics for filtering
    
    Returns:
    --------
    dict : Shape metrics (area, compactness, solidity, aspect_ratio, etc.)
    """
    # Find contours
    contours, _ = cv2.findContours(obj_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if len(contours) == 0:
        return None
    
    contour = max(contours, key=cv2.contourArea)
    
    # Basic metrics
    area = cv2.contourArea(contour)
    perimeter = cv2.arcLength(contour, True)
    
    if area == 0 or perimeter == 0:
        return None
    
    # Bounding rectangle
    x, y, w, h = cv2.boundingRect(contour)
    
    # Convex hull
    hull = cv2.convexHull(contour)
    hull_area = cv2.contourArea(hull)
    hull_perimeter = cv2.arcLength(hull, True)
    
    # Calculate metrics
    metrics = {
        'area': area,
        'perimeter': perimeter,
        'width': w,
        'height': h,
    }
    
    # Compactness (circularity): 4œÄ*area/perimeter¬≤
    # Perfect circle = 1.0, irregular = closer to 0
    metrics['compactness'] = (4 * np.pi * area) / (perimeter ** 2) if perimeter > 0 else 0
    
    # Solidity: area / convex_hull_area
    # Measures how "solid" the shape is (no concavities)
    metrics['solidity'] = area / hull_area if hull_area > 0 else 0
    
    # Aspect ratio: width / height
    metrics['aspect_ratio'] = max(w, h) / min(w, h) if min(w, h) > 0 else 0
    
    # Extent: area / bounding_box_area
    # Measures how much of the bounding box is filled
    bbox_area = w * h
    metrics['extent'] = area / bbox_area if bbox_area > 0 else 0
    
    # Convexity: convex_hull_perimeter / perimeter
    metrics['convexity'] = hull_perimeter / perimeter if perimeter > 0 else 0
    
    return metrics

def filter_objects_by_quality(instance_mask):
    """
    Filter objects based on size and shape quality
    """
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    cleaned_mask = np.zeros_like(instance_mask, dtype=np.uint16)
    new_id = 1
    
    # Statistics
    removed_reasons = {
        'too_small': 0,
        'too_large': 0,
        'low_compactness': 0,
        'low_solidity': 0,
        'high_aspect_ratio': 0,
        'low_convexity': 0,
        'low_extent': 0
    }
    
    print(f"\n{'='*60}")
    print(f"FILTERING OBJECTS")
    print(f"{'='*60}")
    print(f"Criteria:")
    print(f"  Area: {MIN_AREA:,} - {MAX_AREA:,} pixels")
    print(f"  Compactness: ‚â• {MIN_COMPACTNESS:.2f}")
    print(f"  Solidity: ‚â• {MIN_SOLIDITY:.2f}")
    print(f"  Aspect ratio: ‚â§ {MAX_ASPECT_RATIO:.1f}")
    print(f"  Convexity: ‚â• {MIN_CONVEXITY:.2f}")
    print(f"  Extent: ‚â• {MIN_EXTENT:.2f}")
    
    kept_objects = []
    
    for inst_id in unique_ids:
        obj_mask = (instance_mask == inst_id).astype(np.uint8)
        
        # Calculate shape metrics
        metrics = calculate_shape_metrics(obj_mask)
        
        if metrics is None:
            removed_reasons['too_small'] += 1
            continue
        
        # Apply filters
        keep = True
        reason = None
        
        # Area filter
        if metrics['area'] < MIN_AREA:
            keep = False
            reason = 'too_small'
        elif metrics['area'] > MAX_AREA:
            keep = False
            reason = 'too_large'
        # Compactness filter (remove irregular blobs)
        elif metrics['compactness'] < MIN_COMPACTNESS:
            keep = False
            reason = 'low_compactness'
        # Solidity filter (remove fragmented shapes)
        elif metrics['solidity'] < MIN_SOLIDITY:
            keep = False
            reason = 'low_solidity'
        # Aspect ratio filter (remove elongated shapes)
        elif metrics['aspect_ratio'] > MAX_ASPECT_RATIO:
            keep = False
            reason = 'high_aspect_ratio'
        # Convexity filter (remove very concave shapes)
        elif metrics['convexity'] < MIN_CONVEXITY:
            keep = False
            reason = 'low_convexity'
        # Extent filter (remove sparse/thin shapes)
        elif metrics['extent'] < MIN_EXTENT:
            keep = False
            reason = 'low_extent'
        
        if keep:
            cleaned_mask[obj_mask > 0] = new_id
            kept_objects.append(metrics)
            new_id += 1
        else:
            if reason:
                removed_reasons[reason] += 1
    
    total_removed = sum(removed_reasons.values())
    total_kept = new_id - 1
    
    print(f"\nüìä Results:")
    print(f"  ‚úì Kept: {total_kept} objects")
    print(f"  ‚úó Removed: {total_removed} objects")
    
    if total_removed > 0:
        print(f"\nüìã Removal reasons:")
        for reason, count in removed_reasons.items():
            if count > 0:
                print(f"  ‚Ä¢ {reason.replace('_', ' ').title()}: {count}")
    
    if kept_objects:
        areas = [m['area'] for m in kept_objects]
        compactness = [m['compactness'] for m in kept_objects]
        
        print(f"\nüìà Kept objects stats:")
        print(f"  Area: {min(areas):,.0f} - {max(areas):,.0f} px (avg: {np.mean(areas):,.0f})")
        print(f"  Compactness: {min(compactness):.3f} - {max(compactness):.3f} (avg: {np.mean(compactness):.3f})")
    
    return cleaned_mask

def extract_polygons_per_object(instance_mask, transform):
    """Extract ONE polygon per object"""
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    polygons = []
    
    print(f"\n{'='*60}")
    print(f"EXTRACTING POLYGONS")
    print(f"{'='*60}")
    
    for inst_id in unique_ids:
        obj_mask = (instance_mask == inst_id).astype(np.uint8)
        area = np.sum(obj_mask)
        
        # Extract polygon
        for geom, val in shapes(obj_mask, mask=obj_mask, transform=transform):
            if val > 0:
                poly = shape(geom)
                polygons.append({
                    'polygon': poly,
                    'id': int(inst_id),
                    'area': int(area)
                })
                break
    
    print(f"‚úì Extracted {len(polygons)} polygons")
    
    if polygons:
        areas = [p['area'] for p in polygons]
        print(f"\nPolygon Statistics:")
        print(f"  Count: {len(polygons)}")
        print(f"  Area range: {min(areas):,} - {max(areas):,} pixels")
        print(f"  Average: {np.mean(areas):,.0f} pixels")
    
    return polygons

def save_geojson(polygons, output_path, crs):
    """Save polygons to GeoJSON"""
    features = []
    
    for poly_data in polygons:
        feature = {
            "type": "Feature",
            "properties": {
                "id": poly_data['id'],
                "area_pixels": poly_data['area']
            },
            "geometry": mapping(poly_data['polygon'])
        }
        features.append(feature)
    
    geojson = {
        "type": "FeatureCollection",
        "crs": {
            "type": "name",
            "properties": {"name": str(crs) if crs else "EPSG:4326"}
        },
        "features": features
    }
    
    with open(output_path, 'w') as f:
        json.dump(geojson, f, indent=2)
    
    print(f"\n‚úì GeoJSON saved: {output_path}")
    print(f"  ‚Üí {len(features)} polygons")

def extract_boundaries(instance_mask, thickness=4):
    """Extract boundary for each object"""
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    boundaries = np.zeros_like(instance_mask, dtype=np.uint8)
    kernel = np.ones((3, 3), np.uint8)
    
    for inst_id in unique_ids:
        obj_mask = (instance_mask == inst_id).astype(np.uint8)
        edges = cv2.morphologyEx(obj_mask, cv2.MORPH_GRADIENT, kernel)
        boundaries = np.maximum(boundaries, edges)
    
    if thickness > 1:
        thick_kernel = np.ones((thickness, thickness), np.uint8)
        boundaries = cv2.dilate(boundaries, thick_kernel, iterations=1)
    
    return boundaries

def visualize_results(instance_mask, boundaries, polygons, save_path):
    """Create visualization"""
    binary_mask = (instance_mask > 0).astype(np.uint8)
    num_objects = len(np.unique(instance_mask)) - 1
    
    fig, axs = plt.subplots(2, 2, figsize=(18, 18))
    
    # 1. Instance mask (color-coded)
    axs[0, 0].imshow(instance_mask, cmap="nipy_spectral", interpolation='nearest')
    processing_text = []
    if USE_WATERSHED_SEPARATION:
        processing_text.append("Watershed")
    if USE_MORPHOLOGY:
        processing_text.append("Morphology")
    title_suffix = f" ({' + '.join(processing_text)})" if processing_text else ""
    axs[0, 0].set_title(f"Instance Mask{title_suffix}\n({num_objects} high-quality objects)", 
                        fontsize=14, fontweight='bold')
    axs[0, 0].axis("off")
    
    # 2. With boundaries overlay
    rgb = np.stack([binary_mask, binary_mask, binary_mask], axis=-1).astype(float)
    rgb[boundaries > 0] = [1.0, 0, 0]
    
    axs[0, 1].imshow(rgb, interpolation='nearest')
    axs[0, 1].set_title(f"Objects with Boundaries\n({len(polygons)} polygons, boundaries in RED)", 
                        fontsize=14, fontweight='bold', color='darkgreen')
    axs[0, 1].axis("off")
    
    # 3. Binary mask
    axs[1, 0].imshow(binary_mask, cmap="gray", interpolation='nearest')
    axs[1, 0].set_title(f"Binary Mask\n(All filtered objects)", 
                        fontsize=14, fontweight='bold')
    axs[1, 0].axis("off")
    
    # 4. Boundaries only
    boundary_display = np.zeros_like(instance_mask, dtype=np.uint8)
    boundary_display[boundaries > 0] = 255
    
    axs[1, 1].imshow(boundary_display, cmap="gray", interpolation='nearest')
    axs[1, 1].set_title(f"Boundaries Only\n({BOUNDARY_THICKNESS}px thickness)", 
                        fontsize=14, fontweight='bold')
    axs[1, 1].axis("off")
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\n‚úì Visualization saved: {save_path}")
    plt.show()

def main(mask_path):
    """Main pipeline"""
    
    if not Path(mask_path).exists():
        raise FileNotFoundError(f"Mask file not found: {mask_path}")
    
    print(f"\n{'='*60}")
    print(f"SAM MASK ‚Üí FILTERED POLYGONS")
    print(f"{'='*60}")
    print(f"\nInput: {mask_path}")
    
    # Load SAM mask
    instance_mask, profile, transform, crs = load_sam_mask(mask_path)
    
    # Apply watershed separation to separate touching objects
    instance_mask = apply_watershed_separation(instance_mask)
    
    # Apply morphological operations to improve shapes
    instance_mask = apply_morphological_operations(instance_mask)
    
    # Filter objects by size and shape quality
    instance_mask = filter_objects_by_quality(instance_mask)
    
    # Extract polygons (one per object)
    polygons = extract_polygons_per_object(instance_mask, transform)
    
    # Extract boundaries
    print(f"\n{'='*60}")
    print(f"EXTRACTING BOUNDARIES")
    print(f"{'='*60}")
    print(f"Thickness: {BOUNDARY_THICKNESS}px")
    boundaries = extract_boundaries(instance_mask, BOUNDARY_THICKNESS)
    print(f"‚úì Boundary pixels: {np.sum(boundaries > 0):,}")
    
    # Save files
    profile.update({
        'count': 1,
        'dtype': 'uint16',
        'compress': 'lzw',
        'nodata': 0
    })
    
    print(f"\n{'='*60}")
    print(f"SAVING FILES")
    print(f"{'='*60}")
    
    with rasterio.open(clean_mask_path, "w", **profile) as dst:
        dst.write(instance_mask.astype(np.uint16), 1)
    print(f"‚úì Instance mask: {clean_mask_path}")
    
    if CREATE_BOUNDARY_FILE:
        profile_boundary = profile.copy()
        profile_boundary['dtype'] = 'uint8'
        
        with rasterio.open(boundaries_path, "w", **profile_boundary) as dst:
            dst.write(boundaries, 1)
        print(f"‚úì Boundaries: {boundaries_path}")
    
    if SAVE_GEOJSON:
        save_geojson(polygons, polygons_geojson, crs)
    
    # Visualize
    print(f"\n{'='*60}")
    print(f"GENERATING VISUALIZATION")
    print(f"{'='*60}")
    visualize_results(instance_mask, boundaries, polygons, "polygons_visualization.png")
    
    print(f"\n{'='*60}")
    print(f"‚úì COMPLETE!")
    print(f"{'='*60}")
    print(f"\nOutput files:")
    print(f"  1. {clean_mask_path} - Filtered instance mask")
    print(f"  2. {boundaries_path} - Boundaries")
    print(f"  3. {polygons_geojson} - Polygons (GeoJSON) ‚Üê USE THIS!")
    print(f"  4. polygons_visualization.png - Visualization")
    
    print(f"\n‚úì Processing applied:")
    if USE_WATERSHED_SEPARATION:
        print(f"  ‚Ä¢ Watershed separation: separated touching objects")
    if USE_MORPHOLOGY:
        print(f"  ‚Ä¢ Morphological processing: filled holes and smoothed edges")
    print(f"  ‚Ä¢ Quality filtering: removed small/irregular shapes")
    print(f"\n‚úì Only high-quality objects with proper shapes!")
    
    return instance_mask, boundaries, polygons

# if __name__ == "__main__":
#     try:
#         instance_mask, boundaries, polygons = main()
#     except Exception as e:
#         print(f"\nERROR: {e}")
#         import traceback
#         traceback.print_exc()

In [None]:
import os
import rasterio
import numpy as np
import matplotlib.pyplot as plt
from samgeo import SamGeo2
import cv2

# Initialize SAM
sam2 = SamGeo2(
    model_id="sam2-hiera-large",
    apply_postprocessing=False,
    points_per_side=32,
    points_per_batch=64,
    pred_iou_thresh=0.6,
    stability_score_thresh=0.85,
    stability_score_offset=0.7,
    crop_n_layers=1,
    box_nms_thresh=0.9,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=25.0,
    use_m2m=True,
)

def overlay_mask_downsampled(image_path, output_dir="outputs2", max_size=1024, alpha=0.4, overlay_boundaries=True):
    """
    Overlay instance mask on a downsampled original image.
    
    Parameters
    ----------
    image_path : str
        Path to the original image.
    output_dir : str
        Directory to save intermediate outputs.
    max_size : int
        Maximum width or height for visualization.
    alpha : float
        Transparency for overlay.
    overlay_boundaries : bool
        Whether to overlay boundaries in red.
    """
    os.makedirs(output_dir, exist_ok=True)
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    mask_path = os.path.join(output_dir, f"{base_name}_masks.tif")
    
    # Generate SAM masks
    sam2.generate(image_path)
    sam2.save_masks(mask_path)
    
    # Run your main pipeline to get instance mask
    try:
        instance_mask, boundaries, polygons = main(mask_path)  # Assumes main() is defined elsewhere
    except Exception as e:
        print(f"ERROR processing {mask_path}: {e}")
        import traceback
        traceback.print_exc()
        return None

    # Open original image
    with rasterio.open(image_path) as src:
        img = src.read([1, 2, 3])  # Assuming RGB
        img = img.transpose(1, 2, 0).astype(float)  # H x W x C
        img = (img - img.min()) / (img.max() - img.min())  # Normalize to 0-1

        orig_h, orig_w = img.shape[:2]
        scale = min(max_size / orig_w, max_size / orig_h, 1.0)

        if scale < 1.0:
            new_w, new_h = int(orig_w * scale), int(orig_h * scale)
            img_ds = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
            mask_ds = cv2.resize(instance_mask.astype(np.float32), (new_w, new_h), interpolation=cv2.INTER_NEAREST)
            boundaries_ds = cv2.resize(boundaries.astype(np.uint8), (new_w, new_h), interpolation=cv2.INTER_NEAREST)
        else:
            img_ds = img
            mask_ds = instance_mask
            boundaries_ds = boundaries

        # Create colored overlay
        num_objects = int(np.max(mask_ds))  # Fix TypeError by casting to int
        if num_objects == 0:
            overlay = img_ds
        else:
            import matplotlib.cm as cm
            colors = cm.nipy_spectral(np.linspace(0, 1, num_objects + 1))[:, :3]  # RGB only
            mask_rgb = colors[mask_ds.astype(int)]
            overlay = (1 - alpha) * img_ds + alpha * mask_rgb

        # Overlay boundaries if requested
        if overlay_boundaries:
            overlay[boundaries_ds > 0] = [1.0, 0, 0]  # Red boundaries

        # Plot and save
        plt.figure(figsize=(12, 12))
        plt.imshow(overlay)
        plt.axis("off")
        plt.title(f"{base_name}: {num_objects} objects")
        save_path = os.path.join(output_dir, f"{base_name}_overlay.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
        print(f"Overlay saved: {save_path}")

    return {
        "image_path": image_path,
        "mask_path": mask_path,
        "overlay_path": save_path,
        "num_objects": num_objects
    }


# ---- Main loop ----
image_folder = "tiff_testing"
results = []

for file in sorted(os.listdir(image_folder)):
    if file.lower().endswith(".tif"):
        full_path = os.path.join(image_folder, file)
        print(f"Processing {full_path} ...")
        res = overlay_mask_downsampled(full_path)
        if res:
            results.append(res)

print("All images processed.")
