In [None]:
!pip install git+https://github.com/huggingface/transformers.git
!hf auth login


In [1]:
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("mask-generation", model="facebook/sam3")

Loading weights:   0%|          | 0/685 [00:00<?, ?it/s]

In [5]:

# import torch
# import numpy as np
# import matplotlib.pyplot as plt
# from PIL import Image, ImageDraw, ImageFont
# from transformers import Sam3Model, Sam3Processor


# IMAGE_PATH = "roads/tile_10240_57344.tif"
# TEXT_PROMPT = "road"

# SCORE_THRESHOLD = 0.55  # instance confidence
# MASK_THRESHOLD = 0.40   # pixel threshold


# def load_model():
#     device = "cuda" if torch.cuda.is_available() else "cpu"
#     model = Sam3Model.from_pretrained(
#         "facebook/sam3",
#         torch_dtype=torch.bfloat16
#     ).to(device)
#     processor = Sam3Processor.from_pretrained("facebook/sam3")
#     return model, processor, device

# def visualize(image, results):
#     overlay = image.convert("RGBA")
#     masks = results["masks"].float().cpu().numpy()
#     boxes = results["boxes"].float().cpu().numpy()

#     if len(masks) == 0:
#         return overlay.convert("RGB")

#     cmap = plt.cm.plasma
#     draw = ImageDraw.Draw(overlay)

#     try:
#         font = ImageFont.truetype("arial.ttf", 14)
#     except:
#         font = ImageFont.load_default()

#     for i, mask in enumerate(masks):
#         color = tuple(int(c * 255) for c in cmap(i / len(masks))[:3])
#         layer = Image.new("RGBA", image.size, color + (0,))
#         m = Image.fromarray((mask * 255).astype(np.uint8))
#         layer.putalpha(m.point(lambda x: 80 if x > 0 else 0))
#         overlay = Image.alpha_composite(overlay, layer)

#         x1, y1, x2, y2 = boxes[i]
#         label = f"road_{i+1}"
#         draw.rectangle(draw.textbbox((x1, y1), label, font=font), fill="black")
#         draw.text((x1, y1), label, fill="white", font=font)

#     return overlay.convert("RGB")

# def main():
#     model, processor, device = load_model()

#     image = Image.open(IMAGE_PATH).convert("RGB")

#     inputs = processor(
#         images=image,
#         text=TEXT_PROMPT,
#         return_tensors="pt"
#     ).to(device)

#     with torch.no_grad():
#         outputs = model(**inputs)

#     results = processor.post_process_instance_segmentation(
#         outputs,
#         threshold=SCORE_THRESHOLD,
#         mask_threshold=MASK_THRESHOLD,
#         target_sizes=inputs["original_sizes"].tolist()
#     )[0]

#     vis = visualize(image, results)

#     plt.figure(figsize=(12, 6))
#     plt.subplot(1, 2, 1)
#     plt.imshow(image)
#     plt.title("Original")
#     plt.axis("off")

#     plt.subplot(1, 2, 2)
#     plt.imshow(vis)
#     plt.title("Road Segmentation (No Filtering)")
#     plt.axis("off")
#     plt.show()

# if __name__ == "__main__":
#     main()

In [None]:
"""
Complete Road Segmentation Pipeline with SAM3
Segments roads from images and exports to GeoJSON with unique IDs
Includes enhanced post-processing: smoothing and strengthening
"""

import os
import json
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from PIL import Image
import rasterio
from rasterio.features import shapes
from shapely.geometry import shape, mapping
from skimage.morphology import skeletonize
from transformers import Sam3Model, Sam3Processor


# PART 1: SAM3 MODEL AND SEGMENTATION


def load_sam3():
    """Load SAM3 model"""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    try:
        model = Sam3Model.from_pretrained(
            "facebook/sam3",
            torch_dtype=torch.bfloat16
        ).to(device)
        processor = Sam3Processor.from_pretrained("facebook/sam3")
        model.eval()
        return model, processor, device
    except Exception as e:
        print(f"Error loading SAM3 model: {e}")
        raise

def sam3_to_instance_mask(masks):
    """
    Convert SAM3 masks to instance segmentation mask
    masks: (N,H,W) tensor
    returns: (H,W) uint16 instance mask
    """
    if len(masks) == 0:
        return None
    
    masks = masks.float().cpu().numpy()
    h, w = masks.shape[1:]
    instance_mask = np.zeros((h, w), dtype=np.uint16)
    
    # Assign each mask a unique ID
    for i, m in enumerate(masks, start=1):
        instance_mask[m > 0.5] = i
    
    return instance_mask


# PART 2: MASK POST-PROCESSING


def postprocess_mask(instance_mask, smooth_iterations=4, close_kernel_size=9, open_kernel_size=3):
    """
    Post-process instance mask to smooth and strengthen road segments
    Enhanced version with slightly better smoothing
    
    Args:
        instance_mask: (H,W) uint16 instance segmentation mask
        smooth_iterations: Number of morphological closing iterations
        close_kernel_size: Kernel size for closing (fills gaps)
        open_kernel_size: Kernel size for opening (removes noise)
    
    Returns:
        Cleaned instance mask
    """
    print(f"  Post-processing mask...")
    print(f"    Original objects: {len(np.unique(instance_mask)) - 1}")
    
    # Process each object separately
    cleaned_mask = np.zeros_like(instance_mask)
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    valid_id = 1
    
    for inst_id in unique_ids:
        # Extract single object
        obj_mask = (instance_mask == inst_id).astype(np.uint8)
        
        # Morphological closing (fill gaps, smooth boundaries) - enhanced
        close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (close_kernel_size, close_kernel_size))
        for _ in range(smooth_iterations):
            obj_mask = cv2.morphologyEx(obj_mask, cv2.MORPH_CLOSE, close_kernel)
        
        # Morphological opening (remove small noise)
        open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (open_kernel_size, open_kernel_size))
        obj_mask = cv2.morphologyEx(obj_mask, cv2.MORPH_OPEN, open_kernel)
        
        # Enhanced Gaussian blur + threshold for smoother edges
        obj_mask_float = obj_mask.astype(np.float32)
        obj_mask_float = cv2.GaussianBlur(obj_mask_float, (7, 7), 1.5)
        obj_mask = (obj_mask_float > 0.5).astype(np.uint8)
        
        # Additional light smoothing pass
        obj_mask_float = obj_mask.astype(np.float32)
        obj_mask_float = cv2.GaussianBlur(obj_mask_float, (5, 5), 0.8)
        obj_mask = (obj_mask_float > 0.5).astype(np.uint8)
        
        # Assign to cleaned mask if still valid
        if np.any(obj_mask):
            cleaned_mask[obj_mask > 0] = valid_id
            valid_id += 1
    
    print(f"    After smoothing: {len(np.unique(cleaned_mask)) - 1} objects")
    
    return cleaned_mask


def process_image_sam3(
    image_path,
    model,
    processor,
    device,
    text_prompt,
    output_dir,
    score_threshold=0.55,
    mask_threshold=0.40
):
    """Process single image with SAM3 and save instance mask"""
    os.makedirs(output_dir, exist_ok=True)
    base = os.path.splitext(os.path.basename(image_path))[0]
    
    try:
        # Load image
        image = Image.open(image_path).convert("RGB")
        print(f"  Image size: {image.size}")
        
        # Prepare inputs (SAM3 style)
        inputs = processor(
            images=image,
            text=text_prompt,  # Direct text, not list
            return_tensors="pt"
        ).to(device)
        
        # Inference
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Post-process with SAM3 parameters
        results = processor.post_process_instance_segmentation(
            outputs,
            threshold=score_threshold,       # instance confidence
            mask_threshold=mask_threshold,   # pixel threshold
            target_sizes=inputs["original_sizes"].tolist()
        )[0]
        
        # Check if masks found
        if "masks" not in results or len(results["masks"]) == 0:
            print(f"  No objects found")
            return None
        
        print(f"  Found {len(results['masks'])} road segments")
        
        # Create instance mask
        instance_mask = sam3_to_instance_mask(results["masks"])
        
        if instance_mask is None:
            print(f"  Failed to create instance mask")
            return None
        
        # Get georeferencing from original image if available
        try:
            with rasterio.open(image_path) as src:
                transform = src.transform
                crs = src.crs
        except:
            # Create identity transform if no georeference
            h, w = instance_mask.shape
            transform = rasterio.transform.from_bounds(0, 0, w, h, w, h)
            crs = None
        
        # ===== POST-PROCESSING =====
        print(f"\n  === Post-Processing ===")
        
        # Enhanced smooth and strengthen mask
        instance_mask = postprocess_mask(
            instance_mask,
            smooth_iterations=3,
            close_kernel_size=7,
            open_kernel_size=3
        )
        
        final_count = len(np.unique(instance_mask)) - 1
        print(f"  Final segments: {final_count}")
        
        if final_count == 0:
            print(f"  No valid road segments after post-processing")
            return None
        
        # Save mask as GeoTIFF
        mask_path = os.path.join(output_dir, f"{base}_instance_mask.tif")
        with rasterio.open(
            mask_path,
            "w",
            driver="GTiff",
            height=instance_mask.shape[0],
            width=instance_mask.shape[1],
            count=1,
            dtype=instance_mask.dtype,
            transform=transform,
            crs=crs,
            compress='lzw'
        ) as dst:
            dst.write(instance_mask, 1)
        
        print(f"  Saved mask: {mask_path}")
        
        return {
            "image": image_path,
            "mask": mask_path,
            "num_objects": int(instance_mask.max())
        }
        
    except Exception as e:
        print(f"  Error processing {image_path}: {e}")
        import traceback
        traceback.print_exc()
        return None


# PART 3: POLYGON EXTRACTION AND METRICS


def calculate_road_metrics(obj_mask):
    """Calculate geometric metrics for road segments"""
    try:
        contours, _ = cv2.findContours(
            obj_mask, 
            cv2.RETR_EXTERNAL, 
            cv2.CHAIN_APPROX_SIMPLE
        )
        
        if not contours:
            return None
        
        contour = max(contours, key=cv2.contourArea)
        
        # Fit minimum area rectangle
        rect = cv2.minAreaRect(contour)
        (_, _), (w, h), _ = rect
        
        length = max(w, h)
        width = min(w, h)
        
        # Calculate skeleton length
        skeleton = skeletonize(obj_mask > 0)
        skeleton_length = np.sum(skeleton)
        
        # Metrics
        aspect_ratio = length / width if width > 0 else 0
        area = np.count_nonzero(obj_mask)
        elongation = skeleton_length / np.sqrt(area) if area > 0 else 0
        
        return {
            'length': length,
            'width': width,
            'aspect_ratio': aspect_ratio,
            'elongation': elongation,
            'skeleton_length': skeleton_length,
            'length_width_ratio': aspect_ratio
        }
    except Exception as e:
        print(f"    Warning: Metric calculation failed: {e}")
        return None

def extract_polygons_per_object(instance_mask, transform, global_id_offset=0, image_name=None):
    """
    Extract polygons from instance mask with globally unique IDs
    """
    unique_ids = np.unique(instance_mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    polygons = []
    
    for inst_id in unique_ids:
        obj_mask = (instance_mask == inst_id).astype(np.uint8)
        area = np.count_nonzero(obj_mask)
        
        # Calculate metrics
        metrics = calculate_road_metrics(obj_mask)
        
        # Extract polygon geometry
        for geom, val in shapes(obj_mask, mask=obj_mask, transform=transform):
            if val > 0:
                global_id = global_id_offset + int(inst_id)
                
                poly_data = {
                    'polygon': shape(geom),
                    'global_id': global_id,
                    'local_id': int(inst_id),
                    'area_pixels': int(area)
                }
                
                if image_name:
                    poly_data['source_image'] = image_name
                
                if metrics:
                    poly_data.update({
                        'length_pixels': float(metrics['length']),
                        'width_pixels': float(metrics['width']),
                        'aspect_ratio': float(metrics['aspect_ratio']),
                        'elongation': float(metrics['elongation']),
                        'skeleton_length': float(metrics['skeleton_length']),
                        'length_width_ratio': float(metrics['length_width_ratio'])
                    })
                
                polygons.append(poly_data)
                break
    
    return polygons

def process_mask_to_polygons(mask_path, global_id_offset=0):
    """Process a single mask file and extract polygons"""
    try:
        with rasterio.open(mask_path) as src:
            instance_mask = src.read(1)
            transform = src.transform
            crs = src.crs
        
        # Get image name from mask filename
        image_name = os.path.basename(mask_path).replace('_instance_mask.tif', '')
        
        print(f"  Processing: {os.path.basename(mask_path)}")
        print(f"    Objects: {len(np.unique(instance_mask)) - 1}")
        print(f"    ID offset: {global_id_offset}")
        
        # Extract polygons
        polygons = extract_polygons_per_object(
            instance_mask, 
            transform, 
            global_id_offset,
            image_name
        )
        
        max_local_id = int(instance_mask.max())
        next_offset = global_id_offset + max_local_id + 1
        
        print(f"    Extracted {len(polygons)} polygons")
        print(f"    ID range: {global_id_offset + 1} to {global_id_offset + max_local_id}")
        
        return polygons, next_offset, crs
        
    except Exception as e:
        print(f"  Error processing {mask_path}: {e}")
        return [], global_id_offset, None

def save_combined_geojson(all_polygons, output_path, crs=None):
    """Save all polygons to a single GeoJSON file"""
    features = []
    
    for p in all_polygons:
        properties = {
            "id": p['global_id'],
            "local_id": p['local_id'],
            "area_pixels": p['area_pixels']
        }
        
        if 'source_image' in p:
            properties['source_image'] = p['source_image']
        
        if 'length_pixels' in p:
            properties.update({
                'length_pixels': p['length_pixels'],
                'width_pixels': p['width_pixels'],
                'aspect_ratio': p['aspect_ratio'],
                'elongation': p['elongation'],
                'skeleton_length': p['skeleton_length'],
                'length_width_ratio': p['length_width_ratio']
            })
        
        features.append({
            "type": "Feature",
            "properties": properties,
            "geometry": mapping(p['polygon'])
        })
    
    geojson = {
        "type": "FeatureCollection",
        "crs": {
            "type": "name", 
            "properties": {
                "name": str(crs) if crs else "EPSG:4326"
            }
        },
        "features": features
    }
    
    with open(output_path, 'w') as f:
        json.dump(geojson, f, indent=2)
    
    print(f"\n✓ Combined GeoJSON saved: {output_path}")
    print(f"  Total features: {len(features)}")


# PART 4: COMPLETE PIPELINE


def run_complete_pipeline(
    input_folder="roads",
    output_mask_folder="outputs_sam3",
    output_geojson="all_roads_combined.geojson",
    text_prompt="road",
    score_threshold=0.55,
    mask_threshold=0.40
):
    """
    Complete pipeline: Segment images → Post-process → Extract polygons → Save GeoJSON
    """
    print("="*70)
    print("COMPLETE ROAD SEGMENTATION PIPELINE WITH SAM3")
    print("="*70)
    
    # STEP 1: LOAD MODEL
    print("\n[STEP 1] Loading SAM3 Model...")
    try:
        model, processor, device = load_sam3()
    except Exception as e:
        print(f"Failed to load model: {e}")
        return
    
    # Check input folder
    if not os.path.exists(input_folder):
        print(f"Error: Folder '{input_folder}' not found!")
        return
    
    # Get all image files
    image_files = sorted([
        f for f in os.listdir(input_folder) 
        if f.lower().endswith(('.tif', '.tiff', '.jpg', '.jpeg', '.png'))
    ])
    
    if not image_files:
        print(f"No images found in '{input_folder}'")
        return
    
    print(f"\n[STEP 2] Segmenting {len(image_files)} images...")
    print(f"Parameters: score_threshold={score_threshold}, mask_threshold={mask_threshold}")
    print("-" * 70)
    
    results = []
    for i, file in enumerate(image_files, 1):
        path = os.path.join(input_folder, file)
        print(f"\n[{i}/{len(image_files)}] {file}")
        
        res = process_image_sam3(
            path,
            model,
            processor,
            device,
            text_prompt=text_prompt,
            output_dir=output_mask_folder,
            score_threshold=score_threshold,
            mask_threshold=mask_threshold
        )
        
        if res:
            results.append(res)
    
    if not results:
        print("\nNo objects detected in any images!")
        return
    
    # STEP 3: POLYGON EXTRACTION
    print("\n" + "="*70)
    print("[STEP 3] Extracting Polygons from Masks...")
    print("-" * 70)
    
    # Get all mask files
    mask_files = sorted([
        f for f in os.listdir(output_mask_folder)
        if f.endswith('_instance_mask.tif')
    ])
    
    all_polygons = []
    global_id_offset = 0
    crs = None
    
    for i, mask_file in enumerate(mask_files, 1):
        mask_path = os.path.join(output_mask_folder, mask_file)
        print(f"\n[{i}/{len(mask_files)}] {mask_file}")
        
        polygons, next_offset, mask_crs = process_mask_to_polygons(
            mask_path,
            global_id_offset
        )
        
        if mask_crs and not crs:
            crs = mask_crs
        
        all_polygons.extend(polygons)
        global_id_offset = next_offset
    
    # STEP 4: SAVE GEOJSON
    if all_polygons:
        print("\n" + "="*70)
        print("[STEP 4] Saving Combined GeoJSON...")
        print("-" * 70)
        
        save_combined_geojson(all_polygons, output_geojson, crs)
    
    # FINAL SUMMARY
    print("\n" + "="*70)
    print("PIPELINE COMPLETE - SUMMARY")
    print("="*70)
    print(f"Input images:           {len(image_files)}")
    print(f"Successfully segmented: {len(results)}")
    print(f"Mask files created:     {len(mask_files)}")
    print(f"Total polygons:         {len(all_polygons)}")
    if all_polygons:
        total_objects = sum(r['num_objects'] for r in results)
        print(f"Total road segments:    {total_objects}")
        print(f"Average per image:      {total_objects/len(results):.1f}")
        print(f"Global ID range:        1 to {global_id_offset - 1}")
    print(f"\nOutput GeoJSON:         {output_geojson}")
    print(f"Output masks folder:    {output_mask_folder}/")
    print("="*70)


# MAIN EXECUTION

if __name__ == "__main__":
    # Run the complete pipeline with SAM3 parameters
    run_complete_pipeline(
        input_folder="roads",                          # Input images
        output_mask_folder="outputs_sam3",             # Mask output
        output_geojson="all_roads_combined.geojson",   # Final GeoJSON
        text_prompt="road",                            # Text prompt
        score_threshold=0.60,                          # Instance confidence
        mask_threshold=0.42                            # Pixel threshold
    )

In [None]:
"""
Visualize original images and their instance masks side by side
Handles large images (2048x2048) with automatic resizing
"""

import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from PIL import Image
import rasterio
import cv2

def load_image_and_mask(image_path, mask_path):
    """Load image and corresponding mask"""
    # Load original image
    image = Image.open(image_path).convert("RGB")
    
    # Load instance mask
    with rasterio.open(mask_path) as src:
        mask = src.read(1)
    
    return image, mask

def resize_if_needed(img, mask, max_size=1024):
    """
    Resize image and mask if larger than max_size
    Maintains aspect ratio
    """
    img_array = np.array(img)
    h, w = img_array.shape[:2]
    
    # Check if resizing needed
    if h <= max_size and w <= max_size:
        return img, mask
    
    # Calculate scale
    scale = min(max_size / w, max_size / h)
    new_w, new_h = int(w * scale), int(h * scale)
    
    print(f"  Resizing from {w}x{h} to {new_w}x{new_h}")
    
    # Resize image
    img_resized = img.resize((new_w, new_h), Image.BILINEAR)
    
    # Resize mask (using nearest neighbor to preserve instance IDs)
    mask_resized = cv2.resize(
        mask.astype(np.int32),
        (new_w, new_h),
        interpolation=cv2.INTER_NEAREST
    )
    
    return img_resized, mask_resized

def create_colored_mask(mask):
    """
    Create RGB colored mask from instance mask
    Each instance gets a unique color
    """
    num_instances = int(mask.max())
    
    if num_instances == 0:
        # No instances, return black mask
        h, w = mask.shape
        return np.zeros((h, w, 3), dtype=np.uint8)
    
    # Generate colors using colormap
    colors = cm.nipy_spectral(np.linspace(0, 1, num_instances + 1))[:, :3]
    colors = (colors * 255).astype(np.uint8)
    
    # Apply colors to mask
    colored_mask = colors[mask]
    
    return colored_mask

def visualize_single_pair(image_path, mask_path, max_size=1024, save_path=None):
    """
    Visualize single image-mask pair
    
    Args:
        image_path: Path to original image
        mask_path: Path to instance mask
        max_size: Maximum dimension for display (default 1024)
        save_path: Optional path to save visualization
    """
    print(f"Visualizing: {os.path.basename(image_path)}")
    
    # Load data
    image, mask = load_image_and_mask(image_path, mask_path)
    
    # Resize if needed
    image_display, mask_display = resize_if_needed(image, mask, max_size)
    
    # Create colored mask
    colored_mask = create_colored_mask(mask_display)
    
    # Get statistics
    num_instances = int(mask_display.max())
    original_size = image.size
    display_size = image_display.size
    
    # Create visualization
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    
    # Original image
    axes[0].imshow(image_display)
    axes[0].set_title(f"Original Image\nSize: {original_size[0]}x{original_size[1]}", 
                      fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    # Instance mask
    axes[1].imshow(colored_mask)
    axes[1].set_title(f"Instance Mask\n{num_instances} road segments detected", 
                      fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    # Overall title
    fig.suptitle(os.path.basename(image_path), fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    
    # Save if requested
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"  Saved to: {save_path}")
    
    plt.show()
    print()

def visualize_multiple_pairs(image_folder, mask_folder, max_size=1024, 
                             save_dir=None, show_plots=True):
    """
    Visualize multiple image-mask pairs
    
    Args:
        image_folder: Folder containing original images
        mask_folder: Folder containing instance masks
        max_size: Maximum dimension for display
        save_dir: Optional directory to save visualizations
        show_plots: Whether to show plots (set False for batch processing)
    """
    print("="*70)
    print("VISUALIZING IMAGE-MASK PAIRS")
    print("="*70)
    
    # Create save directory if needed
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    
    # Get all images
    image_files = sorted([
        f for f in os.listdir(image_folder)
        if f.lower().endswith(('.tif', '.tiff', '.jpg', '.jpeg', '.png'))
    ])
    
    if not image_files:
        print(f"No images found in {image_folder}")
        return
    
    print(f"\nFound {len(image_files)} images\n")
    
    # Process each image
    for i, img_file in enumerate(image_files, 1):
        image_path = os.path.join(image_folder, img_file)
        
        # Find corresponding mask
        base_name = os.path.splitext(img_file)[0]
        mask_file = f"{base_name}_instance_mask.tif"
        mask_path = os.path.join(mask_folder, mask_file)
        
        if not os.path.exists(mask_path):
            print(f"[{i}/{len(image_files)}] Skipping {img_file} - no mask found")
            continue
        
        print(f"[{i}/{len(image_files)}] {img_file}")
        
        # Prepare save path
        save_path = None
        if save_dir:
            save_path = os.path.join(save_dir, f"{base_name}_visualization.png")
        
        # Visualize
        try:
            # Load data
            image, mask = load_image_and_mask(image_path, mask_path)
            image_display, mask_display = resize_if_needed(image, mask, max_size)
            colored_mask = create_colored_mask(mask_display)
            
            num_instances = int(mask_display.max())
            original_size = image.size
            
            # Create visualization
            fig, axes = plt.subplots(1, 2, figsize=(16, 8))
            
            axes[0].imshow(image_display)
            axes[0].set_title(f"Original Image\nSize: {original_size[0]}x{original_size[1]}", 
                            fontsize=14, fontweight='bold')
            axes[0].axis('off')
            
            axes[1].imshow(colored_mask)
            axes[1].set_title(f"Instance Mask\n{num_instances} road segments", 
                            fontsize=14, fontweight='bold')
            axes[1].axis('off')
            
            fig.suptitle(img_file, fontsize=16, fontweight='bold')
            plt.tight_layout()
            
            if save_path:
                plt.savefig(save_path, dpi=150, bbox_inches='tight')
                print(f"  Saved: {save_path}")
            
            if show_plots:
                plt.show()
            else:
                plt.close()
            
            print(f"  Instances: {num_instances}")
            print()
            
        except Exception as e:
            print(f"  Error: {e}")
            print()
            continue
    
    print("="*70)
    print("VISUALIZATION COMPLETE")
    print("="*70)

def visualize_grid(image_folder, mask_folder, max_images=9, max_size=1024, save_path=None):
    """
    Visualize multiple image-mask pairs in a grid
    
    Args:
        image_folder: Folder containing original images
        mask_folder: Folder containing instance masks
        max_images: Maximum number of images to show in grid
        max_size: Maximum dimension for each image
        save_path: Optional path to save grid visualization
    """
    print("="*70)
    print("CREATING GRID VISUALIZATION")
    print("="*70)
    
    # Get all images
    image_files = sorted([
        f for f in os.listdir(image_folder)
        if f.lower().endswith(('.tif', '.tiff', '.jpg', '.jpeg', '.png'))
    ])[:max_images]
    
    if not image_files:
        print(f"No images found in {image_folder}")
        return
    
    # Calculate grid size
    n_images = len(image_files)
    n_cols = min(3, n_images)
    n_rows = (n_images + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols * 2, figsize=(n_cols * 8, n_rows * 4))
    
    # Flatten axes for easier indexing
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    
    for idx, img_file in enumerate(image_files):
        row = idx // n_cols
        col = idx % n_cols
        
        image_path = os.path.join(image_folder, img_file)
        base_name = os.path.splitext(img_file)[0]
        mask_path = os.path.join(mask_folder, f"{base_name}_instance_mask.tif")
        
        if not os.path.exists(mask_path):
            continue
        
        try:
            # Load and resize
            image, mask = load_image_and_mask(image_path, mask_path)
            image_display, mask_display = resize_if_needed(image, mask, max_size)
            colored_mask = create_colored_mask(mask_display)
            
            # Plot image
            ax_img = axes[row, col * 2]
            ax_img.imshow(image_display)
            ax_img.set_title(f"{base_name}\nOriginal", fontsize=10)
            ax_img.axis('off')
            
            # Plot mask
            ax_mask = axes[row, col * 2 + 1]
            ax_mask.imshow(colored_mask)
            ax_mask.set_title(f"{int(mask_display.max())} segments\nMask", fontsize=10)
            ax_mask.axis('off')
            
            print(f"Added: {img_file}")
            
        except Exception as e:
            print(f"Error processing {img_file}: {e}")
            continue
    
    # Hide empty subplots
    for idx in range(n_images, n_rows * n_cols):
        row = idx // n_cols
        col = idx % n_cols
        axes[row, col * 2].axis('off')
        axes[row, col * 2 + 1].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"\nGrid saved to: {save_path}")
    
    plt.show()
    print("\n" + "="*70)


# EXAMPLE USAGE


if __name__ == "__main__":
    
    # OPTION 2: Visualize all image-mask pairs (one by one)
    print("\n### OPTION 2: Multiple Visualizations ###\n")
    visualize_multiple_pairs(
        image_folder="roads",
        mask_folder="outputs_sam3",
        max_size=1024,
        save_dir="visualizations",  # Optional: save all visualizations
        show_plots=True  # Set False for batch processing without display
    )
    