In [None]:
import rasterio
import numpy as np
import cv2
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from shapely.geometry import Polygon
import geopandas as gpd
import pandas as pd
import torch
import matplotlib.pyplot as plt
import gc

# Load the SAM model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize SAM model
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

def get_segmentation(patch):
    """
    Generates segmentation masks from a given image patch.

    Args:
        patch (np.ndarray):The image patch to segment.
        
    Returns:
        list: List of masks generated by the SAM model.
    """
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(patch)
    return masks

def pixel_to_tm2(pixel_coords, transform):
    """
    Converts pixel coordinates to tm2 coordinates.
    
    Args:
        pixel_coords(list): List of pixel coordinates.
        transform (Affine): Trabsformation matrix.
    
    Returns:
        np.ndarray: Array of tm2 coordinates.np.ndarray: tm2 coordinates
    """
    return np.array([transform * (x, y) for x, y in pixel_coords])

def is_contour_on_edge(contour, patch_size):
    """
    Check if a contour touches the edge of the patch.
    
    Args:
        contour (np.ndarray): Array of contour points.
        patch_size (int): Size of the patch.
    
    Returns:
        bool: True if the contour touches the edge, False otherwise.
    """
    x_min, y_min = contour.min(axis=0)
    x_max, y_max = contour.max(axis=0)
    return x_min <= 0 or y_min <= 0 or x_max >= patch_size - 1 or y_max >= patch_size - 1

def process_patch(patch, transform):
    """
    Processes a single patch and returns the extracted polygons.
    
    Args:
        patch (np.ndarray): The image patch to process.
        transform (Affine): Transformation matrix for converting pixel coordinates.
    
    Returns:
        list: List of Polygon objects generated from contours.
    """
    masks = get_segmentation(patch)
    polygons = []

    for mask in masks:
        mask_np = mask['segmentation'].astype(np.uint8)
        contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        for contour in contours:
            contour = contour.squeeze().reshape(-1, 2)
            if len(contour) >= 3 and not is_contour_on_edge(contour, patch.shape[0]):  
                tm2_coords = pixel_to_tm2(contour, transform)
                polygons.append(Polygon(tm2_coords))
    
    return polygons

def main_seg(input_img, output_shp, patch_size, overlap_percentage):
    """
    Main function to process the input image and save segmentation results.
    
    Args:
        input_img (str): Path to the input image file.
        output_shp (str): Path for saving the output shapefile.
        patch_size (int): Size of the patches to process.
        overlap_percentage (float): Percentage of overlap between patches.
    """
    gdfs = []
    
    with rasterio.open(input_img) as src:
        overlap = int(patch_size * overlap_percentage)
        print(f"Image dimensions: {src.width}x{src.height}")
        
        for i in range(0, src.width, patch_size - overlap):
            for j in range(0, src.height, patch_size - overlap):
                window = rasterio.windows.Window(i, j, patch_size, patch_size)
                transform = src.window_transform(window)
                patch = src.read(window=window)
                patch = np.moveaxis(patch, 0, -1)  # Move channels to last dimension
                patch = patch[:, :, :3]  # Use first three channels
                
                # Process the patch
                polygons = process_patch(patch, transform)
                local_gdf = gpd.GeoDataFrame(geometry=polygons, crs=src.crs)
                gdfs.append(local_gdf)

    # Concatenate all GeoDataFrames and save to shapefile
    gdf = pd.concat(gdfs, ignore_index=True)
    gdf.to_file(output_shp)
    print(f"Output saved to {output_shp}")

In [15]:
# Example Usage
if __name__ == "__main__":
    input_image = 'path_to_input_image.tif'  # Specify your input image path
    output_shapefile = 'output_masks.shp'  # Specify the output shapefile path
    patch_size = 256  # Define your patch size
    overlap_percentage = 0.5  # Define your overlap percentage
    main_seg(input_image, output_shapefile, patch_size, overlap_percentage)