# Train a Semantic Segmentation Model using Segmentation-Models-PyTorch

[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/geoai/blob/main/docs/examples/train_segmentation_model.ipynb)

This notebook demonstrates how to train semantic segmentation models for object detection (e.g., building detection) using the [segmentation-models-pytorch](https://smp.readthedocs.io) library. Unlike instance segmentation with Mask R-CNN, this approach treats the task as pixel-level binary classification.

## Install packages
To use the new functionality, ensure the required packages are installed.


In [1]:
%pip install geoai-py

Collecting geoai-py
  Downloading geoai_py-0.24.0-py2.py3-none-any.whl.metadata (11 kB)
Collecting buildingregulariser (from geoai-py)
  Downloading buildingregulariser-0.2.4-py3-none-any.whl.metadata (7.2 kB)
Collecting contextily (from geoai-py)
  Downloading contextily-1.7.0-py3-none-any.whl.metadata (3.1 kB)
Collecting ever-beta (from geoai-py)
  Downloading ever_beta-0.5.3-py3-none-any.whl.metadata (18 kB)
Collecting jupyter-server-proxy (from geoai-py)
  Downloading jupyter_server_proxy-4.4.0-py3-none-any.whl.metadata (8.7 kB)
Collecting leafmap>=0.57.1 (from geoai-py)
  Downloading leafmap-0.57.10-py2.py3-none-any.whl.metadata (17 kB)
Collecting localtileserver (from geoai-py)
  Downloading localtileserver-0.10.7-py3-none-any.whl.metadata (5.5 kB)
Collecting maplibre (from geoai-py)
  Downloading maplibre-0.3.6-py3-none-any.whl.metadata (4.2 kB)
Collecting overturemaps (from geoai-py)
  Downloading overturemaps-0.18.0-py3-none-any.whl.metadata (5.2 kB)
Collecting planetary-compu

## Import libraries

In [2]:
import geoai

## Download sample data

We'll use the same dataset as the Mask R-CNN example for consistency.

In [3]:
train_raster_url = (
    "https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_rgb_train.tif"
)
train_vector_url = "https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_train_buildings.geojson"
test_raster_url = (
    "https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_test.tif"
)

In [4]:
train_raster_path = geoai.download_file(train_raster_url)
train_vector_path = geoai.download_file(train_vector_url)
test_raster_path = geoai.download_file(test_raster_url)

Downloading naip_rgb_train.tif: 100%|██████████| 8.88M/8.88M [00:00<00:00, 44.4MB/s]
Downloading naip_train_buildings.geojson: 334kB [00:00, 46.2MB/s]
Downloading naip_test.tif: 100%|██████████| 19.7M/19.7M [00:00<00:00, 37.9MB/s]


## Visualize sample data

In [5]:
geoai.get_raster_info(train_raster_path)

{'driver': 'GTiff',
 'width': 2503,
 'height': 1126,
 'count': 3,
 'dtype': 'uint8',
 'crs': 'EPSG:26911',
 'transform': Affine(0.6000000000000046, 0.0, 454780.8,
        0.0, -0.6, 5278242.6),
 'bounds': BoundingBox(left=454780.8, bottom=5277567.0, right=456282.6, top=5278242.6),
 'resolution': (0.6000000000000046, 0.6),
 'nodata': None,
 'band_stats': [{'band': 1,
   'min': 12.0,
   'max': 251.0,
   'mean': 150.6730747259594,
   'std': 48.01908734374099},
  {'band': 2,
   'min': 49.0,
   'max': 251.0,
   'mean': 141.92468895229808,
   'std': 43.46595463573497},
  {'band': 3,
   'min': 53.0,
   'max': 251.0,
   'mean': 120.89909373405554,
   'std': 41.78086244480775}]}

In [6]:
geoai.view_vector_interactive(train_vector_path, tiles=train_raster_path)

In [7]:
geoai.view_raster(test_raster_path)

Map(center=[47.6464835, -117.59043650000001], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom…

## Create training data

We'll create the same training tiles as before.

In [8]:
out_folder = "buildings"
tiles = geoai.export_geotiff_tiles(
    in_raster=train_raster_path,
    out_folder=out_folder,
    in_class_data=train_vector_path,
    tile_size=512,
    stride=256,
    buffer_radius=0,
)


Raster info for naip_rgb_train.tif:
  CRS: EPSG:26911
  Dimensions: 2503 x 1126
  Resolution: (0.6000000000000046, 0.6)
  Bands: 3
  Bounds: BoundingBox(left=454780.8, bottom=5277567.0, right=456282.6, top=5278242.6)
Loaded 735 features from naip_train_buildings.geojson
Vector CRS: EPSG:4326
Reprojecting features from EPSG:4326 to EPSG:26911
Found 1 unique classes: ['building']


Generated: 36, With features: 36: 100%|██████████| 36/36 [00:06<00:00,  5.21it/s]


------- Export Summary -------
Total tiles exported: 36
Tiles with features: 36 (100.0%)
Average feature pixels per tile: 46795.0
Output saved to: buildings

------- Georeference Verification -------





## Train semantic segmentation model

Now we'll train a semantic segmentation model using the new `train_segmentation_model` function. This function supports various architectures from `segmentation-models-pytorch`:

- **Architectures**: `unet`, `unetplusplus` `deeplabv3`, `deeplabv3plus`, `fpn`, `pspnet`, `linknet`, `manet`
- **Encoders**: `resnet34`, `resnet50`, `efficientnet-b0`, `mobilenet_v2`, etc.

For more details, please refer to the [segmentation-models-pytorch documentation](https://smp.readthedocs.io/en/latest/models.html).

### Example 1: U-Net with ResNet34 encoder


In [9]:
# Train U-Net model
geoai.train_segmentation_model(
    images_dir=f"{out_folder}/images",
    labels_dir=f"{out_folder}/labels",
    output_dir=f"{out_folder}/unet_models",
    architecture="unet",
    encoder_name="resnet34",
    encoder_weights="imagenet",
    num_channels=3,
    num_classes=2,  # background and building
    batch_size=8,
    num_epochs=5,
    learning_rate=0.001,
    val_split=0.2,
    verbose=True,
)

Using device: cuda
Found 36 image files and 36 label files
Training on 28 images, validating on 8 images
Checking image sizes for compatibility...
All sampled images have the same size: (512, 512)
No resizing needed.
Testing data loader...
Data loader test passed.


config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]

Starting training with unet + resnet34
Model parameters: 24,436,514
Epoch: 1, Batch: 1/4, Loss: 0.5944, Time: 2.93s
Epoch 1/5: Train Loss: 0.4697, Val Loss: 0.4332, Val IoU: 0.4902, Val F1: 0.5835, Val Precision: 0.6575, Val Recall: 0.5762
Saving best model with IoU: 0.4902
Epoch: 2, Batch: 1/4, Loss: 0.3526, Time: 1.21s
Epoch 2/5: Train Loss: 0.3015, Val Loss: 1.3010, Val IoU: 0.5125, Val F1: 0.6136, Val Precision: 0.6918, Val Recall: 0.5977
Saving best model with IoU: 0.5125
Epoch: 3, Batch: 1/4, Loss: 0.2357, Time: 1.39s
Epoch 3/5: Train Loss: 0.2000, Val Loss: 8.7563, Val IoU: 0.4317, Val F1: 0.4876, Val Precision: 0.5797, Val Recall: 0.5122
Epoch: 4, Batch: 1/4, Loss: 0.1590, Time: 1.97s
Epoch 4/5: Train Loss: 0.1637, Val Loss: 8.9805, Val IoU: 0.4252, Val F1: 0.4738, Val Precision: 0.5593, Val Recall: 0.5064
Epoch: 5, Batch: 1/4, Loss: 0.1487, Time: 1.84s
Epoch 5/5: Train Loss: 0.1386, Val Loss: 4.5415, Val IoU: 0.4432, Val F1: 0.4994, Val Precision: 0.7806, Val Recall: 0.5237
Tr

### Example 2: SegFormer with resnet152 encoder

In [None]:
geoai.train_segmentation_model(
    images_dir=f"{out_folder}/images",
    labels_dir=f"{out_folder}/labels",
    output_dir=f"{out_folder}/segformer_models",
    architecture="segformer",
    encoder_name="resnet152",
    encoder_weights="imagenet",
    num_channels=3,
    num_classes=2,
    batch_size=6,  # Smaller batch size for more complex model
    num_epochs=5,
    learning_rate=0.0005,
    val_split=0.2,
)

## Run inference

Now we'll use the trained model to make predictions on the test image.

In [None]:
# Define paths
masks_path = "naip_test_semantic_prediction.tif"
model_path = f"{out_folder}/unet_models/best_model.pth"

In [None]:
# Run semantic segmentation inference
geoai.semantic_segmentation(
    input_path=test_raster_path,
    output_path=masks_path,
    model_path=model_path,
    architecture="unet",
    encoder_name="resnet34",
    num_channels=3,
    num_classes=2,
    window_size=512,
    overlap=256,
    batch_size=4,
)

## Output probability map (optional)

You can also output the probability map by providing the `probability_path` parameter. This will save a multi-band raster where each band represents the probability for each class (0-1 range).

In [None]:
# Run inference with probability output
probability_path = "naip_test_probability_map.tif"

geoai.semantic_segmentation(
    input_path=test_raster_path,
    output_path=masks_path,
    model_path=model_path,
    architecture="unet",
    encoder_name="resnet34",
    num_channels=3,
    num_classes=2,
    window_size=512,
    overlap=256,
    batch_size=4,
    probability_path=probability_path,  # Output probability map
)

In [None]:
# Visualize probability map for building class (band 2)
geoai.view_raster(
    probability_path, indexes=[2], basemap=test_raster_path, backend="ipyleaflet"
)

You can also control the classification threshold for binary segmentation. By default, argmax is used, but you can specify a custom threshold:

In [None]:
# Run inference with custom probability threshold
masks_path_threshold = "naip_test_semantic_prediction_threshold2.tif"

geoai.semantic_segmentation(
    input_path=test_raster_path,
    output_path=masks_path_threshold,
    model_path=model_path,
    architecture="unet",
    encoder_name="resnet34",
    num_channels=3,
    num_classes=2,
    window_size=512,
    overlap=256,
    batch_size=4,
    probability_threshold=0.3,  # Only classify as building if probability >= 0.7
)

## Vectorize masks

Convert the predicted mask to vector format for better visualization and analysis.

In [None]:
output_vector_path = "naip_test_semantic_prediction.geojson"
gdf = geoai.orthogonalize(masks_path, output_vector_path, epsilon=2)

## Add geometric properties

In [None]:
gdf_props = geoai.add_geometric_properties(gdf, area_unit="m2", length_unit="m")

## Visualize results

In [None]:
geoai.view_raster(masks_path, nodata=0, basemap=test_raster_path, backend="ipyleaflet")

In [None]:
geoai.view_vector_interactive(gdf_props, column="area_m2", tiles=test_raster_path)

In [None]:
gdf_filtered = gdf_props[(gdf_props["area_m2"] > 50)]

In [None]:
geoai.view_vector_interactive(gdf_filtered, column="area_m2", tiles=test_raster_path)

In [None]:
geoai.create_split_map(
    left_layer=gdf_filtered,
    right_layer=test_raster_path,
    left_args={"style": {"color": "red", "fillOpacity": 0.2}},
    basemap=test_raster_path,
)

## Model Performance Analysis

Let's examine the training curves and model performance:

In [None]:
geoai.plot_performance_metrics(
    history_path=f"{out_folder}/unet_models/training_history.pth",
    figsize=(15, 5),
    verbose=True,
)

![image](https://github.com/user-attachments/assets/9355446f-f9ba-4818-aedb-4bb5dee56813)

## Performance Metrics

**IoU (Intersection over Union)** and **Dice score** are both popular metrics used to evaluate the similarity between two binary masks—often in image segmentation tasks. While they are related, they are not the same.

---

### 🔸 **Definitions**

#### **IoU (Jaccard Index)**

$$
\text{IoU} = \frac{|A \cap B|}{|A \cup B|}
$$

* Measures the overlap between predicted region $A$ and ground truth region $B$ relative to their union.
* Ranges from 0 (no overlap) to 1 (perfect overlap).

#### **Dice Score (F1 Score for Sets)**

$$
\text{Dice} = \frac{2|A \cap B|}{|A| + |B|}
$$

* Measures the overlap between $A$ and $B$, but gives more weight to the intersection.
* Also ranges from 0 to 1.

---

### 🔸 **Key Differences**

| Metric   | Formula                     | Penalizes                      | Sensitivity                      |
| -------- | --------------------------- | ------------------------------ | -------------------------------- |
| **IoU**  | $\frac{TP}{TP + FP + FN}$   | FP and FN equally              | Less sensitive to small objects  |
| **Dice** | $\frac{2TP}{2TP + FP + FN}$ | Less harsh on small mismatches | More sensitive to small overlaps |

> TP: True Positive, FP: False Positive, FN: False Negative

---

### 🔸 **Relationship**

Dice and IoU are mathematically related:

$$
\text{Dice} = \frac{2 \cdot \text{IoU}}{1 + \text{IoU}} \quad \text{or} \quad \text{IoU} = \frac{\text{Dice}}{2 - \text{Dice}}
$$

---

### 🔸 **When to Use What**

* **IoU**: Common in object detection and semantic segmentation benchmarks (e.g., COCO, Pascal VOC).
* **Dice**: Preferred in medical imaging and when class imbalance is an issue, due to its sensitivity to small regions.