# Batch segmentation with text prompts

This notebook shows how to generate object masks from text prompts with the Segment Anything Model (SAM). Make sure you use GPU runtime for this notebook. 

## Setup

In [1]:
# Import the os module
import os

path = 'tmp/'

try:
    os.chdir(path)
    print("Current working directory: {0}".format(os.getcwd()))
except FileNotFoundError:
    print("Directory: {0} does not exist".format(path))
except NotADirectoryError:
    print("{0} is not a directory".format(path))
except PermissionError:
    print("You do not have permissions to change to {0}".format(path))

Current working directory: /home/ec2-user/SageMaker/ODP_Demo/tmp


# Call Libs

In [2]:
import leafmap
from samgeo import tms_to_geotiff, split_raster
from samgeo.text_sam import LangSAM

## Create an interactive map

In [3]:
m = leafmap.Map(center=[-22.1278, -51.4430], zoom=17, height="800px")
m.add_basemap("SATELLITE")
m.attribution_control=False
m

Map(center=[-22.1278, -51.443], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title', 'z…

## Download a sample image

Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map

In [4]:
bbox = m.user_roi_bounds()
if bbox is None:
    bbox = [-51.4494, -22.1307, -51.4371, -22.1244]

In [5]:
image = "Image.tif"
tms_to_geotiff(output=image, bbox=bbox, zoom=19, source="Satellite", overwrite=True)

Downloaded image 001/209
Downloaded image 002/209
Downloaded image 003/209
Downloaded image 004/209
Downloaded image 005/209
Downloaded image 006/209
Downloaded image 007/209
Downloaded image 008/209
Downloaded image 009/209
Downloaded image 010/209
Downloaded image 011/209
Downloaded image 012/209
Downloaded image 013/209
Downloaded image 014/209
Downloaded image 015/209
Downloaded image 016/209
Downloaded image 017/209
Downloaded image 018/209
Downloaded image 019/209
Downloaded image 020/209
Downloaded image 021/209
Downloaded image 022/209
Downloaded image 023/209
Downloaded image 024/209
Downloaded image 025/209
Downloaded image 026/209
Downloaded image 027/209
Downloaded image 028/209
Downloaded image 029/209
Downloaded image 030/209
Downloaded image 031/209
Downloaded image 032/209
Downloaded image 033/209
Downloaded image 034/209
Downloaded image 035/209
Downloaded image 036/209
Downloaded image 037/209
Downloaded image 038/209
Downloaded image 039/209
Downloaded image 040/209


You can also use your own image. Uncomment and run the following cell to use your own image.

In [6]:
# image = '/path/to/your/own/image.tif'

Display the downloaded image on the map.

In [7]:
m.layers[-1].visible = False
m.add_raster(image, layer_name="Image")
m

Map(center=[-22.1278, -51.443], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title', 'z…

## Split the image into tiles

In [8]:
split_raster(image, out_dir="tiles", tile_size=(1000, 1000), overlap=0)

## Init CUDA Memory

In [None]:
import torch
torch.cuda.empty_cache()

## Initialize LangSAM class

The initialization of the LangSAM class might take a few minutes. The initialization downloads the model weights and sets up the model for inference.

In [9]:
sam = LangSAM()

final text_encoder_type: bert-base-uncased


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Specify text prompts

In [10]:
text_prompt = "tree"

## Segment images

Part of the model prediction includes setting appropriate thresholds for object detection and text association with the detected objects. These threshold values range from 0 to 1 and are set while calling the predict method of the LangSAM class.

`box_threshold`: This value is used for object detection in the image. A higher value makes the model more selective, identifying only the most confident object instances, leading to fewer overall detections. A lower value, conversely, makes the model more tolerant, leading to increased detections, including potentially less confident ones.

`text_threshold`: This value is used to associate the detected objects with the provided text prompt. A higher value requires a stronger association between the object and the text prompt, leading to more precise but potentially fewer associations. A lower value allows for looser associations, which could increase the number of associations but also introduce less precise matches.

Remember to test different threshold values on your specific data. The optimal threshold can vary depending on the quality and nature of your images, as well as the specificity of your text prompts. Make sure to choose a balance that suits your requirements, whether that's precision or recall.

In [13]:
sam.predict_batch(
    images='tiles', 
    out_dir='masks', 
    text_prompt=text_prompt, 
    box_threshold=0.24, 
    text_threshold=0.24,
    mask_multiplier=255, 
    dtype='uint8',
    merge=True,
    verbose=True
    )

Processing image 01 of 15: tiles/tile_0_0.tif...
Processing image 02 of 15: tiles/tile_0_1.tif...
Processing image 03 of 15: tiles/tile_0_2.tif...
Processing image 04 of 15: tiles/tile_1_0.tif...
Processing image 05 of 15: tiles/tile_1_1.tif...
Processing image 06 of 15: tiles/tile_1_2.tif...
Processing image 07 of 15: tiles/tile_2_0.tif...
Processing image 08 of 15: tiles/tile_2_1.tif...
Processing image 09 of 15: tiles/tile_2_2.tif...
Processing image 10 of 15: tiles/tile_3_0.tif...
Processing image 11 of 15: tiles/tile_3_1.tif...
Processing image 12 of 15: tiles/tile_3_2.tif...
Processing image 13 of 15: tiles/tile_4_0.tif...
Processing image 14 of 15: tiles/tile_4_1.tif...
Processing image 15 of 15: tiles/tile_4_2.tif...
Saved the merged prediction to masks/merged.tif.


## Visualize the results

In [14]:
m.add_raster('masks/merged.tif', cmap='viridis', nodata=0, layer_name='Mask')
m.add_layer_manager()
m

Map(bottom=18893320.0, center=[-22.127547400799678, -51.4432454109192], controls=(ZoomControl(options=['positi…

![](https://i.imgur.com/JUhNkm6.png)