## Generate segmentation map from satellite images

### prepare pretrained model

In [1]:
import torch

id2label = {
    0: "Ignore",
    1: "Background",
    2: "Building",
    3: "Road",
    4: "Water",
    5: "Barren",
    6: "Forest",
    7: "Agricultural"
}
label2id = { v: k for k, v in id2label.items() }
device = (
    "cuda"
    if torch.cuda.is_available()
    else "cpu"
)

from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
feature_extractor = SegformerFeatureExtractor()
model = SegformerForSemanticSegmentation.from_pretrained(
    pretrained_model_name_or_path="wu-pr-gw/segformer-b2-finetuned-with-LoveDA",
    use_auth_token=True,
    id2label=id2label,
    label2id=label2id,
).to(device)



### generate segmentation maps from satellite images in batch

In [3]:
import os
import glob

import numpy as np
from PIL import Image

INFER_BATCH_SIZE = 5
TARGET_SIZE = 512

DATASET_ROOT = "./dataset"

image_paths = glob.glob(os.path.join(DATASET_ROOT, "images", "*.png"))
# sort image files by their id ( e.g. satellite000003.png -> 000003 )
image_paths.sort(key=lambda filepath: int(filepath[9:15]))
idx = 0

while idx < len(image_paths):
    batch_images = [
        Image.open(image_path).convert(mode="RGB") for image_path in image_paths[idx:idx+INFER_BATCH_SIZE]
    ]
    inputs = feature_extractor(images=batch_images, return_tensors="pt").to(device)
    segmentation_maps = feature_extractor.post_process_semantic_segmentation(
        outputs=model(**inputs),
        target_sizes=[(TARGET_SIZE, TARGET_SIZE)]*len(batch_images)
    )
    for index_in_batch, segmentation_map in enumerate(segmentation_maps):
        segmentation_map_image = Image.fromarray(
            segmentation_map.cpu().numpy().astype(np.uint8), "L"
        )  # load as grayscale
        segmentation_map_image_name = os.path.splitext(
            os.path.basename(image_paths[idx+index_in_batch])
        )[0] + "_masked.png"
        segmentation_map_image.save(os.path.join(DATASET_ROOT, "masked_images", segmentation_map_image_name))

    idx += INFER_BATCH_SIZE
    if idx % 1000 == 0:
        print(f"{idx} / {len(image_paths)}")