<a href="https://colab.research.google.com/github/rabinatwayana/DL_torchgeo_MMFlood_Segmentation/blob/master/Drive_mmflood.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Deep Learning-based Semantic Flood Mapping using UNet and SegFormer on Multimodal Earth Observation Data

Author: Rabina Twayana

This notebook aims to perform semantic segmentation for flood mapping by training, evaluating, and comparing different deep learning models.

## Install Packages

The notebook was run in Google Colab. Following packages were installed.

In [None]:
!pip install torchgeo
!pip install wandb

## Import Packages

In [None]:
import torchgeo
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger ##https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.WandbLogger.html#lightning.pytorch.loggers.WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torchgeo.trainers import SemanticSegmentationTask
from torch.utils.data import DataLoader
from datetime import datetime


# import wandb
# wandb.login()

# Project that the run is recorded to
# project = "MMFlood_DL_Experiments"


checkpoint_callback = ModelCheckpoint(
    monitor="val_iou",        # metric to monitor (IoU in your case)
    mode="max",               # save the checkpoint with max val_iou
    save_top_k=1,             # save only the best model
    filename="best-{epoch:02d}-{val_iou:.4f}"
)


# UNet run
wandb_logger_unet = WandbLogger(
    project="MMFlood_DL_Experiments",
    name="unet"
)

# SegFormer run
wandb_logger_segformer = WandbLogger(
    project="MMFlood_DL_Experiments",
    name="segformer"
)


trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])

## Dataset
MMFlood dataset is a multimodal flood delineation dataset (Montello et al., 2022).

 Some Sentinel-1 tiles have missing data, which are automatically set to 0. Corresponding pixels in masks are set to 255 and should be ignored in performance computation.

Dataset features:

- 1,748 Sentinel-1 tiles of varying pixel dimensions

- Multimodal dataset (Sentinel-1, DEMs and hydrography maps (available for 1,012 tiles out of 1,748))

- 95 flood events from 42 different countries

- Flood delineation maps (ground truth) is obtained from Copernicus EMS

- Missing data in Sentinel-1 tiles are set to 0 and corrsponding pixels in masks are set to 255 (must ignored in performance computation)  

Dataset classes:

- no flood

- flood



## Deeplearning Models

### a) Unet

### b) SegFormer

## Train Model

In [None]:


def train_model(model_name, input_type, train_dataset, val_dataset, max_epochs=50, batch_size=8):
    """
    Train a TorchGeo model (UNet, SegFormer) with dynamic wandb logging and checkpointing.

    Args:
        model_name (str): 'unet' or 'segformer'
        input_type (str): description of input bands, e.g., 's1_dem_hydro'
        train_dataset, val_dataset: PyTorch Datasets
        max_epochs (int): number of training epochs
        batch_size (int): batch size
    """

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Define dynamic wandb logger
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    wandb_logger = WandbLogger(
        project="MMFlood_DL_Experiments",
        name=f"{model_name}_{input_type}_{timestamp}"
    )

    # Define checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        monitor="val_iou",
        mode="max",
        save_top_k=1,
        filename=f"best-{model_name}-{{epoch:02d}}-{{val_iou:.4f}}",
        dirpath=f"checkpoints/{model_name}_{input_type}"
    )

    # Define TorchGeo Trainer (LightningModule)
    task = SemanticSegmentationTask(
        model=model_name,
        in_channels=3,       # adjust based on your input
        num_classes=2,
        loss="ce",
        ignore_index=255
    )

    # Initialize PyTorch Lightning Trainer
    trainer = Trainer(
        max_epochs=max_epochs,
        accelerator="auto",
        devices=1,
        logger=wandb_logger,
        callbacks=[checkpoint_callback]
    )

    # Fit model
    trainer.fit(task, train_loader, val_loader)

    # Return path of best checkpoint
    return checkpoint_callback.best_model_path


In [None]:
# Example: train UNet
best_unet_ckpt = train_model(
    model_name="unet",
    input_type="s1_dem_hydro",
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    max_epochs=50
)

# Example: train SegFormer
best_segformer_ckpt = train_model(
    model_name="segformer",
    input_type="s1_dem_hydro",
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    max_epochs=50
)

print("Best UNet checkpoint:", best_unet_ckpt)
print("Best SegFormer checkpoint:", best_segformer_ckpt)


## References
Montello, F., Arnaudo, E., & Rossi, C. (2022). MMFlood: A Multimodal Dataset for Flood Delineation From Satellite Imagery. IEEE Access, 10, 96774â€“96787. https://doi.org/10.1109/ACCESS.2022.3205419

