In [1]:
import os
import torch
import numpy as np
import geopandas as gpd
import cv2
from PIL import Image
from torchvision import transforms as T
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from shapely.geometry import Polygon
import rasterio
from rasterio.transform import Affine

# Set the directory for cropped images and the output directory for shapefiles
cropped_images_dir = '\\Yehmh\\ZF\\202403_georeferenced\\traps_images_T4\\'
shapefiles_dir = '\\Yehmh\\ZF\\202403_georeferenced\\traps_images_seg_2\\'
os.makedirs(shapefiles_dir, exist_ok=True)

# Load the SAM model
sam_checkpoint = 'sam_vit_h_4b8939.pth'  # Path to your SAM checkpoint file
model_type = "vit_h"  # SAM model type

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# Initialize the automatic mask generator
mask_generator = SamAutomaticMaskGenerator(sam)

# Define a transformation for the images
transform = T.Compose([
    T.Resize((1024, 1024)),  # Resize to the input size expected by SAM
    T.ToTensor(),
])

# Function to convert pixel coordinates to geographic (TM2) coordinates
def pixel_to_tm2(pixel_coords, transform):
    tm2_coords = [transform * (x, y) for x, y in pixel_coords]
    return np.array(tm2_coords)

# Function to segment an image and save as shapefile
def segment_image_to_shapefile(image_path, output_path):
    # Open the image using rasterio to get the affine transform
    with rasterio.open(image_path) as src:
        affine_transform = src.transform

    image = Image.open(image_path).convert('RGB')
    image_array = np.array(image)

    with torch.no_grad():
        # Generate masks using SAM Automatic Mask Generator
        masks = mask_generator.generate(image_array)
    
    polygons = []
    for mask in masks:
        mask_np = mask['segmentation'].astype(np.uint8)
        contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        for contour in contours:
            contour = contour.squeeze().reshape(-1, 2)
            if len(contour) >= 3:  # at least 3 points to form a polygon
                # Convert contour coordinates to TM2
                tm2_coords = pixel_to_tm2(contour, affine_transform)
                # Create Polygon object with TM2 coordinates
                polygons.append(Polygon(tm2_coords))
        
    if polygons:
        gdf = gpd.GeoDataFrame(geometry=polygons, crs=src.crs)
        gdf.to_file(output_path)
    else:
        print(f'No valid polygons found in {image_path}')

# Segment all cropped images and save as shapefiles
for image_name in os.listdir(cropped_images_dir):
    if image_name.endswith('.tif'):
        image_path = os.path.join(cropped_images_dir, image_name)
        output_path = os.path.join(shapefiles_dir, f'segmented_{os.path.splitext(image_name)[0]}.shp')
        segment_image_to_shapefile(image_path, output_path)
        print(f'Segmented and saved {image_name} as {output_path}')

print('Segmentation and shapefile creation completed.')


Segmented and saved T1S1.tif as \Yehmh\ZF\202403_georeferenced\traps_images_seg_2\segmented_T1S1.shp
Segmented and saved T1S10.tif as \Yehmh\ZF\202403_georeferenced\traps_images_seg_2\segmented_T1S10.shp
Segmented and saved T1S2.tif as \Yehmh\ZF\202403_georeferenced\traps_images_seg_2\segmented_T1S2.shp
Segmented and saved T1S3.tif as \Yehmh\ZF\202403_georeferenced\traps_images_seg_2\segmented_T1S3.shp
Segmented and saved T1S4.tif as \Yehmh\ZF\202403_georeferenced\traps_images_seg_2\segmented_T1S4.shp
Segmented and saved T1S5.tif as \Yehmh\ZF\202403_georeferenced\traps_images_seg_2\segmented_T1S5.shp
Segmented and saved T1S6.tif as \Yehmh\ZF\202403_georeferenced\traps_images_seg_2\segmented_T1S6.shp
Segmented and saved T1S7.tif as \Yehmh\ZF\202403_georeferenced\traps_images_seg_2\segmented_T1S7.shp
Segmented and saved T1S8.tif as \Yehmh\ZF\202403_georeferenced\traps_images_seg_2\segmented_T1S8.shp
Segmented and saved T1S9.tif as \Yehmh\ZF\202403_georeferenced\traps_images_seg_2\segment

In [4]:
import torch
torch.cuda.empty_cache()

In [6]:
import gc
gc.collect()

3449

In [8]:
# merge all shape files

import os
import geopandas as gpd
import pandas as pd

# Directory containing the individual shapefiles
shapefiles_dir = 'h:\\Yehmh\\ZF\\202403_georeferenced\\traps_images_seg\\'
merged_shapefile_path = 'h:\\Yehmh\\ZF\\202403_georeferenced\\merged_traps_seg.shp'

# Initialize an empty GeoDataFrame
merged_gdf = gpd.GeoDataFrame()

# Loop through the shapefiles and concatenate them
for shapefile_name in os.listdir(shapefiles_dir):
    if shapefile_name.endswith('.shp'):
        shapefile_path = os.path.join(shapefiles_dir, shapefile_name)
        gdf = gpd.read_file(shapefile_path)
        merged_gdf = gpd.GeoDataFrame(pd.concat([merged_gdf, gdf], ignore_index=True))

# Save the merged shapefile
merged_gdf.to_file(merged_shapefile_path, driver='ESRI Shapefile')

print('Merging shapefiles completed.')


Merging shapefiles completed.
