# Generating object masks from input prompts with SAM

[![image](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/opengeos/segment-geospatial/blob/main/docs/examples/input_prompts.ipynb)
[![image](https://img.shields.io/badge/Open-Planetary%20Computer-black?style=flat&logo=microsoft)](https://pccompute.westeurope.cloudapp.azure.com/compute/hub/user-redirect/git-pull?repo=https://github.com/opengeos/segment-geospatial&urlpath=lab/tree/segment-geospatial/docs/examples/input_prompts.ipynb&branch=main)
[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/segment-geospatial/blob/main/docs/examples/input_prompts.ipynb)

This notebook shows how to generate object masks from input prompts with the Segment Anything Model (SAM).

Make sure you use GPU runtime for this notebook. For Google Colab, go to `Runtime` -> `Change runtime type` and select `GPU` as the hardware accelerator.

The notebook is adapted from [segment-anything/notebooks/predictor_example.ipynb](https://github.com/opengeos/segment-anything/blob/pypi/notebooks/predictor_example.ipynb), but I have made it much easier to save the segmentation results and visualize them.

## Install dependencies

Uncomment and run the following cell to install the required dependencies.

In [2]:
%pip install segment-geospatial
%pip install leafmap
%pip install rioxarray

Collecting segment-geospatial
  Downloading segment_geospatial-0.10.6-py2.py3-none-any.whl (52 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.7/52.7 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Collecting leafmap (from segment-geospatial)
  Downloading leafmap-0.35.10-py2.py3-none-any.whl (1.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting localtileserver (from segment-geospatial)
  Downloading localtileserver-0.10.3-py3-none-any.whl (17.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m54.2 MB/s[0m eta [36m0:00:00[0m
Collecting patool (from segment-geospatial)
  Downloading patool-2.3.0-py2.py3-none-any.whl (96 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m96.6/96.6 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
Collecting rasterio (from segment-geospatial)
  Downloading rasterio-1.3.10-cp310-cp310-many

In [3]:
import os
import leafmap
from samgeo import SamGeo, tms_to_geotiff

## Create an interactive map

In [4]:
m = leafmap.Map(center=[ 43.1680, -116.7135], zoom=20, height="800px")
m.add_basemap("SATELLITE")
m

Map(center=[43.168, -116.7135], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title', 'z…

## Download a sample image

Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map

In [5]:
url = "https://drive.google.com/file/d/11-wM9mPONwvFaUycktDwnWAdXM5kwH4A/view?usp=sharing"
leafmap.download_file(url, output="image.tif")

Downloading...
From: https://drive.google.com/uc?id=11-wM9mPONwvFaUycktDwnWAdXM5kwH4A
To: /content/image.tif
100%|██████████| 3.15M/3.15M [00:00<00:00, 160MB/s]


'/content/image.tif'

In [6]:
import rioxarray as rxr
img_raster=rxr.open_rasterio("/content/image.tif")
print("The crs of your data is:", img_raster.rio.crs)
print("The nodatavalue of your data is:", img_raster.rio.nodata)
print("The shape of your data is:", img_raster.shape)
print("The spatial resolution for your data is:", img_raster.rio.resolution())
print("The metadata for your data is:", img_raster.attrs)
print("extent", img_raster.rio.bounds() )

The crs of your data is: EPSG:26911
The nodatavalue of your data is: 255
The shape of your data is: (4, 471, 1413)
The spatial resolution for your data is: (0.007128879999976225, -0.00712887999935831)
The metadata for your data is: {'AREA_OR_POINT': 'Area', 'DataType': 'Generic', 'RepresentationType': 'ATHEMATIC', '_FillValue': 255, 'scale_factor': 1.0, 'add_offset': 0.0}
extent (523281.39023086417, 4779516.929034361, 523291.46333830414, 4779520.286736839)


In [None]:
if m.user_roi is not None:
    bbox = m.user_roi_bounds()
else:
    bbox = [-122.1497, 37.6311, -122.1203, 37.6458]

In [None]:
image = "satellite.tif"
tms_to_geotiff(output=image, bbox=bbox, zoom=16, source="Satellite", overwrite=True)

Downloaded image 01/30
Downloaded image 02/30
Downloaded image 03/30
Downloaded image 04/30
Downloaded image 05/30
Downloaded image 06/30
Downloaded image 07/30
Downloaded image 08/30
Downloaded image 09/30
Downloaded image 10/30
Downloaded image 11/30
Downloaded image 12/30
Downloaded image 13/30
Downloaded image 14/30
Downloaded image 15/30
Downloaded image 16/30
Downloaded image 17/30
Downloaded image 18/30
Downloaded image 19/30
Downloaded image 20/30
Downloaded image 21/30
Downloaded image 22/30
Downloaded image 23/30
Downloaded image 24/30
Downloaded image 25/30
Downloaded image 26/30
Downloaded image 27/30
Downloaded image 28/30
Downloaded image 29/30
Downloaded image 30/30
Saving GeoTIFF. Please wait...
Image saved to satellite.tif


In [None]:
import rioxarray as rxr
img_raster=rxr.open_rasterio("satellite.tif")
print("The crs of your data is:", img_raster.rio.crs)
print("The nodatavalue of your data is:", img_raster.rio.nodata)
print("The shape of your data is:", img_raster.shape)
print("The spatial resolution for your data is:", img_raster.rio.resolution())
print("The metadata for your data is:", img_raster.attrs)
print("extent", img_raster.rio.bounds() )

The crs of your data is: EPSG:3857
The nodatavalue of your data is: None
The shape of your data is: (3, 865, 1370)
The spatial resolution for your data is: (2.388900021402451, -2.388983011722694)
The metadata for your data is: {'AREA_OR_POINT': 'Area', 'scale_factor': 1.0, 'add_offset': 0.0}
extent (-13597642.404551128, 4527442.892607778, -13594369.611521805, 4529509.3629129175)


You can also use your own image. Uncomment and run the following cell to use your own image.

In [8]:
image = '/content/image.tif'

Display the downloaded image on the map.

In [9]:
m.layers[-1].visible = False
m.add_raster(image, layer_name="Image")
m

Map(bottom=98465298.0, center=[43.168, -116.7135], controls=(ZoomControl(options=['position', 'zoom_in_text', …

## Initialize SAM class

Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory.

Set `automatic=False` to disable the `SamAutomaticMaskGenerator` and enable the `SamPredictor`.

In [10]:
sam = SamGeo(
    model_type="vit_h",
    automatic=False,
    sam_kwargs=None,
)

Downloading...
From: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
To: /root/.cache/torch/hub/checkpoints/sam_vit_h_4b8939.pth


Model checkpoint for vit_h not found.


100%|██████████| 2.56G/2.56G [00:15<00:00, 169MB/s]


Specify the image to segment.

In [11]:
sam.set_image(image)

## Image segmentation with input points

A single point can be used to segment an object. The point can be specified as a tuple of (x, y), such as (col, row) or (lon, lat). The points can also be specified as a file path to a vector dataset. For non (col, row) input points, specify the `point_crs` parameter, which will automatically transform the points to the image column and row coordinates.

Try a single point input:

In [None]:
point_coords = [[-122.1419, 37.6383]]
sam.predict(point_coords, point_labels=1, point_crs="EPSG:4326", output="mask1.tif")
m.add_raster("mask1.tif", layer_name="Mask1", nodata=0, cmap="Blues", opacity=1)
m

Map(bottom=3246791.0, center=[37.638450000000006, -122.13499999999999], controls=(ZoomControl(options=['positi…

Try multiple points input:

In [None]:
point_coords = [[-122.1464, 37.6431], [-122.1449, 37.6415], [-122.1451, 37.6395]]
sam.predict(point_coords, point_labels=1, point_crs="EPSG:4326", output="mask2.tif")
m.add_raster("mask2.tif", layer_name="Mask2", nodata=0, cmap="Greens", opacity=1)
m

Map(bottom=3246791.0, center=[37.638450000000006, -122.13499999999999], controls=(ZoomControl(options=['positi…

## Interactive segmentation

Display the interactive map and use the marker tool to draw points on the map. Then click on the `Segment` button to segment the objects. The results will be added to the map automatically. Click on the `Reset` button to clear the points and the results.

In [12]:
m = sam.show_map()
m

Map(center=[20, 0], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title', 'zoom_out_text…

![](https://i.imgur.com/2Nyg9uW.gif)