In [1]:
from samgeo import SamGeo2
import numpy as np
import os
import rasterio
from rasterio.plot import show
import matplotlib.pyplot as plt

In [44]:
sam2 = SamGeo2(
    model_id="sam2-hiera-large",
    apply_postprocessing=False,
    points_per_side=32,
    points_per_batch=64,
    pred_iou_thresh=0.6,
    stability_score_thresh=0.85,
    stability_score_offset=0.7,
    crop_n_layers=1,
    box_nms_thresh=0.9,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=25.0,
    use_m2m=True,
)

In [None]:
image = "tiff_testing/tile_28672_16384.tif"
sam2.generate(image)
sam2.save_masks(output="masks.tif")
sam2.show_anns(axis="off", alpha=0.7, output="annotations.tif")

In [None]:
import rasterio
import numpy as np
from rasterio.features import shapes, rasterize
from shapely.geometry import shape
import matplotlib.pyplot as plt

# --- Paths ---
mask_path = "masks.tif"
clean_mask_path = "masks_clean.tif"

# --- Read mask ---
with rasterio.open(mask_path) as src:
    mask_data = src.read()  # all layers
    profile = src.profile
    transform = src.transform

# --- Combine all layers into a single 2D mask ---
mask_combined = np.max(mask_data, axis=0) > 0
mask_combined = mask_combined.astype(np.uint8)

# --- Extract polygons from combined mask ---
polygons = [shape(geom) for geom, val in shapes(mask_combined, mask=mask_combined, transform=transform)]
print(f"Total polygons detected: {len(polygons)}")

# --- Rasterize polygons back to a cleaned mask ---
clean_mask = rasterize([(poly, 1) for poly in polygons],
                       out_shape=mask_combined.shape,
                       transform=transform,
                       all_touched=True)

# --- Save cleaned mask ---
with rasterio.open(clean_mask_path, "w", **profile) as dst:
    dst.write(clean_mask, 1)

print(f"Cleaned mask saved at: {clean_mask_path}")

# --- Visualize original combined mask vs cleaned mask ---
fig, axs = plt.subplots(1, 2, figsize=(14, 7))

axs[0].imshow(mask_combined, cmap="gray")
axs[0].set_title("Original Combined Mask")
axs[0].axis("off")

axs[1].imshow(clean_mask, cmap="gray")
axs[1].set_title("Cleaned Mask (Rasterized Polygons)")
axs[1].axis("off")

plt.show()


In [None]:
# Leave the Neighbours Alone - Debug Version
import rasterio
import numpy as np
from rasterio.features import shapes, rasterize
from shapely.geometry import shape
import matplotlib.pyplot as plt

# --- Paths ---
mask_path = "masks.tif"
clean_mask_path = "masks_clean.tif"

# --- Read mask ---
with rasterio.open(mask_path) as src:
    mask_data = src.read()  # all layers
    profile = src.profile
    transform = src.transform

print(f"Number of layers in input: {mask_data.shape[0]}")
print(f"Mask shape: {mask_data.shape}")

# --- Visualize original layers first ---
fig1, axes1 = plt.subplots(1, min(4, mask_data.shape[0]), figsize=(16, 4))
if mask_data.shape[0] == 1:
    axes1 = [axes1]

for i in range(min(4, mask_data.shape[0])):
    layer = mask_data[i]
    axes1[i].imshow(layer, cmap="tab20")
    axes1[i].set_title(f"Original Layer {i+1}\nUnique values: {len(np.unique(layer))}")
    axes1[i].axis("off")
plt.suptitle("Original Mask Layers (showing first 4)")
plt.tight_layout()
plt.show()

# --- Check what values are in each layer ---
print("\nLayer value analysis:")
for i in range(mask_data.shape[0]):
    unique_vals = np.unique(mask_data[i])
    print(f"Layer {i+1}: {len(unique_vals)} unique values - {unique_vals[:10]}")  # Show first 10

# --- Create output preserving original values ---
# IMPORTANT: If your original mask already has unique IDs per field,
# we should preserve them instead of extracting polygons!

# Check if layers already contain unique field IDs
max_val = np.max(mask_data)
print(f"\nMax value in original mask: {max_val}")

if max_val > 1:
    print("Original mask appears to have unique field IDs already!")
    print("Using direct approach to preserve IDs...")
    
    # Simply copy non-zero values from each layer
    clean_mask = np.zeros(mask_data.shape[1:], dtype=np.uint16)
    
    for i in range(mask_data.shape[0]):
        layer = mask_data[i]
        # Copy non-zero values, letting later layers overwrite if needed
        mask = layer > 0
        clean_mask[mask] = layer[mask]
    
    print(f"Unique fields preserved: {len(np.unique(clean_mask[clean_mask > 0]))}")

else:
    print("Original mask is binary, extracting polygons...")
    
    # Original polygon extraction approach
    clean_mask = np.zeros(mask_data.shape[1:], dtype=np.uint16)
    current_id = 1
    
    for i in range(mask_data.shape[0]):
        layer = mask_data[i]
        layer_mask = (layer > 0).astype(np.uint8)
        
        if np.sum(layer_mask) == 0:
            continue
        
        # Extract connected components instead of polygons
        from scipy import ndimage
        labeled_array, num_features = ndimage.label(layer_mask)
        
        print(f"Layer {i+1}: {num_features} connected components")
        
        # Each connected component gets unique ID
        for field_id in range(1, num_features + 1):
            field_mask = labeled_array == field_id
            clean_mask[field_mask] = current_id
            current_id += 1
    
    print(f"Total unique fields: {current_id - 1}")

# --- Save cleaned mask ---
profile.update(count=1, dtype='uint16')
with rasterio.open(clean_mask_path, "w", **profile) as dst:
    dst.write(clean_mask, 1)

print(f"\nCleaned mask saved at: {clean_mask_path}")
print(f"Unique field IDs in output: {len(np.unique(clean_mask[clean_mask > 0]))}")

# --- Visualize comparison ---
mask_combined = (np.max(mask_data, axis=0) > 0).astype(np.uint8)

fig, axs = plt.subplots(1, 3, figsize=(20, 7))

# Original combined
axs[0].imshow(mask_combined, cmap="gray")
axs[0].set_title(f"Original (combined)")
axs[0].axis("off")

# Cleaned binary
axs[1].imshow(clean_mask > 0, cmap="gray")
axs[1].set_title(f"Cleaned (binary)")
axs[1].axis("off")

# Cleaned with IDs
im = axs[2].imshow(clean_mask, cmap="tab20", interpolation='nearest')
axs[2].set_title(f"Cleaned ({len(np.unique(clean_mask[clean_mask > 0]))} fields)")
axs[2].axis("off")
plt.colorbar(im, ax=axs[2])

plt.tight_layout()
plt.show()

# --- Zoom into bottom-right area to verify ---
h, w = clean_mask.shape
zoom_slice = (slice(int(h*0.7), h), slice(int(w*0.7), w))

fig2, axs2 = plt.subplots(1, 2, figsize=(12, 6))
axs2[0].imshow(mask_combined[zoom_slice], cmap="gray")
axs2[0].set_title("Original bottom-right (zoomed)")
axs2[0].axis("off")

axs2[1].imshow(clean_mask[zoom_slice], cmap="tab20", interpolation='nearest')
axs2[1].set_title("Cleaned bottom-right (zoomed)")
axs2[1].axis("off")

plt.tight_layout()
plt.show()

In [None]:
# Leave the Neighbours Alone - Smooth While Preserving IDs
import rasterio
import numpy as np
from rasterio.features import shapes, rasterize
from shapely.geometry import shape
import matplotlib.pyplot as plt
from scipy import ndimage

# --- Paths ---
mask_path = "masks.tif"
clean_mask_path = "masks_clean.tif"

# --- Read mask ---
with rasterio.open(mask_path) as src:
    mask_data = src.read()  # all layers
    profile = src.profile
    transform = src.transform

print(f"Number of layers in input: {mask_data.shape[0]}")

# --- Combine all layers preserving unique IDs ---
original_mask = np.zeros(mask_data.shape[1:], dtype=np.uint16)

for i in range(mask_data.shape[0]):
    layer = mask_data[i]
    mask = layer > 0
    original_mask[mask] = layer[mask]

unique_ids = np.unique(original_mask[original_mask > 0])
print(f"Total unique fields in original: {len(unique_ids)}")

# --- Process each field individually to smooth boundaries ---
clean_mask = np.zeros(mask_data.shape[1:], dtype=np.uint16)

for field_id in unique_ids:
    # Extract this field
    field_mask = (original_mask == field_id).astype(np.uint8)
    
    # Extract polygon for this field
    field_polygons = [shape(geom) for geom, val in shapes(field_mask, mask=field_mask, transform=transform)]
    
    # Process each polygon (usually just 1 per field)
    for poly in field_polygons:
        try:
            # Smooth the polygon
            smoothed = poly.simplify(tolerance=0.5, preserve_topology=True)
            
            # Light buffer for smoothing edges
            buffered = smoothed.buffer(0.3).buffer(-0.1)
            
            if buffered.is_valid and not buffered.is_empty:
                # Rasterize this field with its original ID
                field_raster = rasterize([(buffered, field_id)],
                                        out_shape=mask_data.shape[1:],
                                        transform=transform,
                                        all_touched=False,
                                        dtype=np.uint16)
                
                # Add to output (later fields will overwrite in overlap areas)
                clean_mask = np.where(field_raster > 0, field_raster, clean_mask)
            else:
                # If smoothing failed, use original
                smoothed = poly.simplify(tolerance=0.5, preserve_topology=True)
                field_raster = rasterize([(smoothed, field_id)],
                                        out_shape=mask_data.shape[1:],
                                        transform=transform,
                                        all_touched=False,
                                        dtype=np.uint16)
                clean_mask = np.where(field_raster > 0, field_raster, clean_mask)
                
        except Exception as e:
            print(f"Error processing field {field_id}: {e}")
            # Keep original if error
            clean_mask = np.where(field_mask > 0, field_id, clean_mask)
    
    # Progress indicator
    if field_id % 50 == 0:
        print(f"Processed {field_id}/{len(unique_ids)} fields...")

print(f"\nUnique fields in cleaned mask: {len(np.unique(clean_mask[clean_mask > 0]))}")

# --- Save cleaned mask ---
profile.update(count=1, dtype='uint16')
with rasterio.open(clean_mask_path, "w", **profile) as dst:
    dst.write(clean_mask, 1)

print(f"Cleaned mask saved at: {clean_mask_path}")

# --- Visualize comparison ---
fig, axs = plt.subplots(1, 3, figsize=(20, 7))

# Original
axs[0].imshow(original_mask, cmap="tab20", interpolation='nearest')
axs[0].set_title(f"Original ({len(unique_ids)} fields)")
axs[0].axis("off")

# Cleaned
axs[1].imshow(clean_mask, cmap="tab20", interpolation='nearest')
axs[1].set_title(f"Cleaned ({len(np.unique(clean_mask[clean_mask > 0]))} fields)")
axs[1].axis("off")

# Difference
diff = (original_mask > 0).astype(int) - (clean_mask > 0).astype(int)
axs[2].imshow(diff, cmap="RdBu", vmin=-1, vmax=1)
axs[2].set_title("Difference (red=removed, blue=added)")
axs[2].axis("off")

plt.tight_layout()
plt.show()

# --- Zoom into bottom-right to verify separation ---
h, w = clean_mask.shape
zoom_slice = (slice(int(h*0.7), h), slice(int(w*0.7), w))

fig2, axs2 = plt.subplots(1, 2, figsize=(14, 7))

axs2[0].imshow(original_mask[zoom_slice], cmap="tab20", interpolation='nearest')
axs2[0].set_title("Original bottom-right (zoomed)")
axs2[0].axis("off")

axs2[1].imshow(clean_mask[zoom_slice], cmap="tab20", interpolation='nearest')
axs2[1].set_title("Cleaned bottom-right (zoomed)")
axs2[1].axis("off")

plt.tight_layout()
plt.show()

print("\n✓ Each field keeps its unique ID")
print("✓ Boundaries are smoothed")
print("✓ Fields remain separated!")

In [None]:
import rasterio
from rasterio.features import shapes
import numpy as np
from shapely.geometry import shape
import geopandas as gpd
import matplotlib.pyplot as plt
from rasterio.plot import show

# --- Paths ---
mask_path = "masks_clean.tif"
original_image_path = "tiff_testing/tile_28672_16384.tif"

# --- Read mask ---
with rasterio.open(mask_path) as src:
    mask_data = src.read(1)  # read single layer
    transform = src.transform
    crs = src.crs

# --- Convert mask to polygons ---
mask_binary = (mask_data > 0).astype(np.uint8)
polygons = [shape(geom) for geom, val in shapes(mask_binary, mask=mask_binary, transform=transform)]

# --- Create GeoDataFrame ---
gdf = gpd.GeoDataFrame(geometry=polygons, crs=crs)

# --- Plot mask boundaries over original image ---
with rasterio.open(original_image_path) as src_img:
    fig, ax = plt.subplots(figsize=(10,10))
    show(src_img, ax=ax)  # show the original raster
    gdf.boundary.plot(ax=ax, edgecolor='red', linewidth=1)  # overlay polygons
    plt.title("Mask Boundaries Overlay")
    plt.axis("off")
    plt.show()


In [None]:
import rasterio
from rasterio.features import shapes
import numpy as np
from shapely.geometry import shape
import geopandas as gpd
import matplotlib.pyplot as plt

# --- Paths ---
mask_path = "masks_clean.tif"
original_image_path = "tiff_testing/tile_28672_16384.tif"

# --- Read cleaned mask ---
with rasterio.open(mask_path) as src:
    mask_data = src.read(1)
    transform = src.transform
    crs = src.crs

# --- Extract all polygons ---
mask_binary = (mask_data > 0).astype(np.uint8)
polygons = [shape(geom) for geom, val in shapes(mask_binary, mask=mask_binary, transform=transform)]

n_polygons = len(polygons)
print(f"Total polygons found: {n_polygons}")

# --- Read original image ---
with rasterio.open(original_image_path) as src_img:
    img_data = src_img.read()
    img_transform = src_img.transform
    img_extent = [img_transform[2], img_transform[2] + img_transform[0] * src_img.width,
                  img_transform[5] + img_transform[4] * src_img.height, img_transform[5]]

# --- Prepare image ---
if img_data.shape[0] >= 3:
    rgb = np.transpose(img_data[:3], (1, 2, 0))
    if rgb.max() > 255:
        rgb = (rgb / rgb.max() * 255).astype(np.uint8)
    display_img = rgb
    cmap_to_use = None
else:
    display_img = img_data[0]
    cmap_to_use = 'gray'

# --- Create grid of subplots ---
n_cols = 4
n_rows = int(np.ceil(n_polygons / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 4*n_rows))
axes = axes.flatten() if n_polygons > 1 else [axes]

for idx, polygon in enumerate(polygons):
    ax = axes[idx]
    
    # Show original image
    if cmap_to_use:
        ax.imshow(display_img, cmap=cmap_to_use, extent=img_extent)
    else:
        ax.imshow(display_img, extent=img_extent)
    
    # Plot this polygon
    gdf_single = gpd.GeoDataFrame(geometry=[polygon], crs=crs)
    gdf_single.boundary.plot(ax=ax, edgecolor='red', linewidth=1.5)
    
    ax.set_title(f"Polygon {idx+1}", fontsize=10)
    ax.axis("off")

# Hide unused subplots
for i in range(n_polygons, len(axes)):
    axes[i].axis("off")

plt.tight_layout()
plt.show()

print(f"✓ Displayed all {n_polygons} polygons in grid")