<a href="https://colab.research.google.com/github/rabinatwayana/DL_torchgeo_MMFlood_Segmentation/blob/master/Drive_mmflood.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Deep Learning-based Semantic Flood Mapping using UNet and SegFormer on Multimodal Earth Observation Data

Author: Rabina Twayana

This notebook aims to perform semantic segmentation for flood mapping by training, evaluating, and comparing different deep learning models.

## Environment and Install Packages

The notebook was run in Google Colab. Following packages were installed.

In [None]:
# Case: Google Colab
# !pip install torchgeo
# !pip install wandb

# Case: Local (using conda)
# conda create -n torchgeo_env python=3.11
# conda activate torchgeo_env  
# conda install -c conda-forge torchgeo
# !conda install wandb -y

In [None]:
# import wandb


## Import Packages

In [27]:
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger ##https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.WandbLogger.html#lightning.pytorch.loggers.WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torchgeo.trainers import SemanticSegmentationTask
from torch.utils.data import DataLoader
from datetime import datetime
from torchgeo.datasets import MMFlood
import json
import pandas as pd


# # import wandb
# # wandb.login()

# # Project that the run is recorded to
# # project = "MMFlood_DL_Experiments"


# checkpoint_callback = ModelCheckpoint(
#     monitor="val_iou",        # metric to monitor (IoU in your case)
#     mode="max",               # save the checkpoint with max val_iou
#     save_top_k=1,             # save only the best model
#     filename="best-{epoch:02d}-{val_iou:.4f}"
# )


# # UNet run
# # wandb_logger_unet = WandbLogger(
# #     project="MMFlood_DL_Experiments",
# #     name="unet"
# # )

# # SegFormer run
# wandb_logger_segformer = WandbLogger(
#     project="MMFlood_DL_Experiments",
#     name="segformer"
# )


# trainer = Trainer(logger=wandb_logger_segformer, callbacks=[checkpoint_callback])

## Dataset
MMFlood dataset is a multimodal flood delineation dataset (Montello et al., 2022).

 Some Sentinel-1 tiles have missing data, which are automatically set to 0. Corresponding pixels in masks are set to 255 and should be ignored in performance computation.

Dataset features:

- 1,748 Sentinel-1 tiles of varying pixel dimensions

- Multimodal dataset (Sentinel-1, DEMs and hydrography maps (available for 1,012 tiles out of 1,748))

- 95 flood events from 42 different countries ranging from 2014 to 2021

- Flood delineation maps (ground truth) is obtained from Copernicus EMS

- Missing data in Sentinel-1 tiles are set to 0 and corrsponding pixels in masks are set to 255 (must ignored in performance computation)  

- Image size is 2000 * 2000

Dataset classes:

- no flood

- flood



### Download data

If data is already exist in root dir, download will be skipped. Data download failed when I tried in colab. So downlaod and data subset was done locally.

In [None]:
dataset = MMFlood(
    root="data",   # where data will be stored
    download=True,         # this triggers download
    checksum=True          # optional but recommended
)


In [None]:
def get_activations_stats(annot_path):
    # Load JSON
    with open(annot_path, "r") as f:
        data = json.load(f)

    # Convert dict to DataFrame
    df = pd.DataFrame.from_dict(data, orient="index")
    print("Column List: ",list(df.columns))

    country_counts = df["country"].value_counts()
    total_countries = country_counts.shape[0]

    print(f"\nTotal Countries: {total_countries}")


    print(f"\nTotal activations: {len(df)}")

    subset_counts = df["subset"].value_counts()
    print("\nTotal train/test/val count\n",subset_counts)

    

    table_counts = pd.crosstab(
        df["country"],
        df["subset"]
    )
    table_counts["total"]=table_counts["train"]+table_counts["test"]+table_counts["val"]
    print("\nCountry wise Train/Test/Val activation count ")
    table_counts_sorted = table_counts.sort_values(
        by="test",
        ascending=False
    )
    print(table_counts_sorted)

    # Reset index to keep activation ID
    # df = df.reset_index().rename(columns={"index": "activation_id"})


get_activations_stats("data/original_activations.json")

### Removed activations from TEST dataset
- France (EMSR492, EMSR411)
- Italy (EMSR333,EMSR141, EMSR330,EMSR496,EMSR548)
- Ireland (EMSR149)

### Move from train to val
- Germany (EMSR497)
- Ireland (EMSR156)
- Greece (EMSR117)

In [None]:
selected_countries= ["Greece", "Spain", "France", "Italy", "Germany", "UK", "Australia", "Ireland"]
delete_test_activations = ['EMSR492','EMSR411','EMSR333','EMSR141','EMSR330','EMSR496','EMSR548','EMSR149']
train_to_val_activations = ['EMSR497','EMSR156','EMSR117']


def run_data_selection(original_json_path, new_json_path, selected_countries=[], delete_test_activations=[],train_to_val_activations=[] ):
    # Paths
    # root = Path("data/activations")  # path to your dataset
    # original_json =  "data/activations.json"
    # new_json = "data/selected_activations.json"

    # Load existing annotations
    with open(original_json_path) as f:
        metadata = json.load(f)

    # Filter tiles: remove or ignore tiles with empty hydro folders
    selected_metadata = {}
    for tile_id, tile_info in metadata.items():
        if tile_info["country"] in selected_countries:
            selected_metadata[tile_id]=tile_info

    for tile_id in list(selected_metadata.keys()):  # make a copy of keys
        if tile_id in delete_test_activations:
            print("yes")
            del selected_metadata[tile_id]

    # selected_metadata_2 = {}
    # print(selected_metadata.items())
    # for tile_id, tile_info in selected_metadata.items():
    #     if tile_id in delete_test_activations:
    #         print("yes")
    #         continue
    #     selected_metadata_2[tile_id]=tile_info
    
    updated_metadata = {}
    for tile_id, tile_info in selected_metadata.items():
        if tile_id in train_to_val_activations:
            tile_info["subset"]="val"
            updated_metadata[tile_id]=tile_info

    # # Save new JSON
    with open(new_json_path, "w") as f:
        json.dump(selected_metadata, f, indent=4)



In [None]:
# MMFlood reads the activations.json file inside data root directory.
import os
os.rename("data/activations.json", "data/original_activations.json")

In [None]:
original_json_path = "data/original_activations.json"
new_json_path = "data/activations.json"

run_data_selection(original_json_path, new_json_path, selected_countries, delete_test_activations,train_to_val_activations )
get_activations_stats(new_json_path)

In [None]:
import os
import json
import shutil
import pandas as pd

def create_selected_data_folder(act_json_path, activations_dir, target_dir):
    # Load metadata
    with open(act_json_path) as f:
        metadata = json.load(f)

    df = pd.DataFrame.from_dict(metadata, orient="index")
    df = df.reset_index().rename(columns={"index": "activation_id"})
    activation_ids = df["activation_id"].astype(str).tolist()

    os.makedirs(target_dir, exist_ok=True)

    copied = 0
    missing = []

    # List folders once (much faster)
    all_folders = [
        d for d in os.listdir(activations_dir)
        if os.path.isdir(os.path.join(activations_dir, d))
    ]

    for act_id in activation_ids:
        # Find matching folders
        matched = [d for d in all_folders if d.startswith(act_id)]

        if not matched:
            missing.append(act_id)
            continue

        for folder in matched:
            # print(folder)
            src = os.path.join(activations_dir, folder)
            dst = os.path.join(target_dir, folder)

            if not os.path.exists(dst):
                shutil.copytree(src, dst)
                copied += 1

    print(f"Copied {copied} folders")
    if missing:
        print(f"No folder found for  activation IDs: {missing}")
create_selected_data_folder( "data/activations.json","data/activations","selected_data/activations")


In [None]:
shutil.move("data/activations.json", "selected_data/activations.json")

In [None]:
# Copy selected data
import shutil
import os
def create_selected_data_folder(act_json_path, activations_dir, target_dir):
    with open(act_json_path) as f:
        metadata = json.load(f)
        df = pd.DataFrame.from_dict(metadata, orient="index")
        df = df.reset_index().rename(columns={"index": "activation_id"})
        activation_ids=list(df['activation_id'])

        # Ensure target directory exists
        os.makedirs(target_dir, exist_ok=True)

        copied = 0
        missing = []

        for act_id in activation_ids:
            src = os.path.join(activations_dir, act_id)
            dst = os.path.join(target_dir, act_id)

            if os.path.isdir(src):
                if not os.path.exists(dst):
                    # shutil.copytree(src, dst)
                    copied += 1
            else:
                missing.append(act_id)

        if missing:
            print(f"Missing activations: {missing}")



In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:




# train_dataset = MMFlood(root="data/", split="train", include_hydro=True)
# val_dataset = MMFlood(root="data/", split="val", include_hydro=False)
# test_dataset = MMFlood(root="data/", split="test", include_hydro=False)

# print("Train samples:", len(train_dataset))
# print("Val samples:", len(val_dataset))
# print("Test samples:", len(test_dataset))

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Deeplearning Models

### a) Unet (Base model)

### b) SegFormer

## Train Model

In [None]:
train_dataset = MMFlood(root="selected_data/", split="train", include_hydro=False, include_dem=True)
val_dataset = MMFlood(root="selected_data/", split="val", include_hydro=False,include_dem=True)
test_dataset = MMFlood(root="selected_data/", split="test", include_hydro=False,include_dem=True)

print("Train samples:", len(train_dataset))
print("Val samples:", len(val_dataset))
print("Test samples:", len(test_dataset))

Train samples: 551
Val samples: 127
Test samples: 109


In [None]:
from torch.utils.data import Subset

indices = list(range(len(train_dataset)))
train_subset = Subset(train_dataset, indices)
train_loader = DataLoader(train_subset, batch_size=16, shuffle=True)


In [14]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x1978ec690>

In [36]:
# import torch
# from torch.utils.data import DataLoader
# from pytorch_lightning import Trainer
# from pytorch_lightning.callbacks import ModelCheckpoint
# from pytorch_lightning.loggers import WandbLogger
# from datetime import datetime

# # TorchGeo imports
from torchgeo.samplers import RandomGeoSampler
from torchgeo.datasets import stack_samples
# from torchgeo.trainers import SemanticSegmentationTask

def train_model(
    model_name: str,
    input_type: str,
    train_dataset,
    val_dataset,
    max_epochs: int = 50,
    batch_size: int = 8,
    patch_size: int = 256,          # size of random patches (in pixels)
    num_train_patches: int = 10000, # how many random patches per epoch
    num_val_patches: int = 2000,
    in_channels: int = 3,           # IMPORTANT: set this to your actual number of input bands!
    num_classes: int = 2,
    learning_rate: float = 0.001,
):
    """
    Train a TorchGeo semantic segmentation model (UNet or SegFormer) with proper geospatial sampling.
    
    Args:
        model_name (str): 'unet' or 'segformer'
        input_type (str): description of input bands (used for logging/checkpoint naming)
        train_dataset, val_dataset: TorchGeo geospatial datasets (e.g. MMFlood or IntersectionDataset)
        max_epochs (int): number of training epochs
        batch_size (int): batch size
        patch_size (int): size of random patches (square)
        num_train_patches (int): number of random patches per epoch for training
        num_val_patches (int): number of random patches per epoch for validation
        in_channels (int): number of input channels (bands) ‚Äî MUST match your data!
        num_classes (int): number of output classes (including background if needed)
        learning_rate (float): initial learning rate

    Returns:
        str: path to the best saved checkpoint
    """
    # 1. Create proper TorchGeo samplers (this fixes the TypeError!)
    train_sampler = RandomGeoSampler(
        train_dataset,
        size=patch_size,
        length=num_train_patches,
        # res=...  # optional: if you want to force a specific resolution
    )

    val_sampler = RandomGeoSampler(
        val_dataset,
        size=patch_size,
        length=num_val_patches,
    )

    # 2. Create DataLoaders with the samplers
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=4,
        pin_memory=True,
        collate_fn=stack_samples,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        sampler=val_sampler,
        num_workers=4,
        pin_memory=True,
        collate_fn=stack_samples,
    )

    # 3. Dynamic Weights & Biases logger
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    wandb_logger = WandbLogger(
        project="MMFlood_DL_Experiments",
        name=f"{model_name}_{input_type}_{timestamp}",
        log_model="all",  # optional: log best model
    )

    # 4. Checkpoint callback (saves best model based on val IoU)
    checkpoint_callback = ModelCheckpoint(
        monitor="val_iou",
        mode="max",
        save_top_k=1,
        filename=f"best-{model_name}-{{epoch:02d}}-{{val_iou:.4f}}",
        dirpath=f"checkpoints/{model_name}_{input_type}",
        auto_insert_metric_name=False,
    )

    # 5. Create the LightningModule (SemanticSegmentationTask)
    task = SemanticSegmentationTask(
        model=model_name,           # "unet" or "segformer"
        backbone="resnet50",        # used for UNet; ignored or optional for SegFormer
        in_channels=in_channels,    # ‚Üê CRITICAL: must match your actual input bands!
        num_classes=num_classes,
        loss="ce",                  # cross-entropy
        ignore_index=255,           # usually used for invalid/no-data pixels
        lr=learning_rate,
        # class_weights=...,        # optional
        # weights="imagenet",       # optional: for pretrained backbones
    )

    # 6. Initialize PyTorch Lightning Trainer
    trainer = Trainer(
        max_epochs=max_epochs,
        accelerator="auto",         # "gpu", "cpu", "mps"...
        devices=1,
        logger=wandb_logger,
        callbacks=[checkpoint_callback],
        # precision="16-mixed",     # optional: mixed precision training
        # log_every_n_steps=10,
        # check_val_every_n_epoch=5,
    )

    # 7. Train the model
    trainer.fit(
        model=task,
        train_dataloaders=train_loader,
        val_dataloaders=val_loader,
    )

    # 8. Return path to the best checkpoint
    return checkpoint_callback.best_model_path


# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# Example usage
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ

# if __name__ == "__main__":
    # Assume these are already created as TorchGeo datasets
    # train_dataset = ...
    # val_dataset = ...

    # # Example 1: Train U-Net
    # best_unet_ckpt = train_model(
    #     model_name="unet",
    #     input_type="s1_dem_hydro",
    #     train_dataset=train_dataset,
    #     val_dataset=val_dataset,
    #     max_epochs=50,
    #     batch_size=8,
    #     patch_size=256,
    #     in_channels=5,          # ‚Üê CHANGE THIS to match your actual number of bands!
    #     num_classes=2,
    # )
    # print("Best UNet checkpoint:", best_unet_ckpt)

# Example 2: Train SegFormer
best_segformer_ckpt = train_model(
    model_name="segformer",
    input_type="s1_dem_hydro",
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    max_epochs=50,
    batch_size=8,
    patch_size=256,
    in_channels=2,          # ‚Üê CHANGE THIS to match your actual number of bands!
    num_classes=2,
)
print("Best SegFormer checkpoint:", best_segformer_ckpt)

üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
/Users/rabinatwayana/miniforge3/envs/torchgeo_env/lib/python3.11/site-packages/lightning/pytorch/loggers/wandb.py:400: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.

  | Name          | Type             | Params | Mode  | FLOPs
-------------------------------------------------------------------
0 | model         | Segformer        | 24.8 M | train | 0    
1 | criterion     | CrossEntropyLoss | 0      | train | 0    
2 | train_metrics | MetricCollection | 0      | train | 0    
3 | val_metrics   | MetricCollection | 0      | train | 0    
4 | test_metrics  | MetricColl

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/rabinatwayana/miniforge3/envs/torchgeo_env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:429: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                           

/Users/rabinatwayana/miniforge3/envs/torchgeo_env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:429: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 0:   9%|‚ñâ         | 117/1250 [12:03<1:56:50,  0.16it/s, v_num=y83h]


Detected KeyboardInterrupt, attempting graceful shutdown ...


: 

## References
Montello, F., Arnaudo, E., & Rossi, C. (2022). MMFlood: A Multimodal Dataset for Flood Delineation From Satellite Imagery. IEEE Access, 10, 96774‚Äì96787. https://doi.org/10.1109/ACCESS.2022.3205419

