# OpenEarhMap Semantinc Segmentation

This demo code demonstrates training and testing of UNet-EfficientNet-B4 for the OpenEarthMap dataset (https://open-earth-map.org/). This demo code is based on the work from the "segmentation_models.pytorch" repository by qubvel, available at: https://github.com/qubvel/segmentation_models.pytorch. We extend our sincere appreciation to the original author for their invaluable contributions to the field of semantic segmentation and for providing this open-source implementation.

---

### Requirements

In [None]:
%pip install rasterio 
%pip install pretrainedmodels 
%pip install efficientnet_pytorch 
%pip install timm 

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


### Import
---

In [12]:
import sys #library ที่จัดเตรียมฟังก์ชันและตัวแปรที่ใช้เพื่อจัดการกับส่วนต่างๆของ Python Runtime Environment
sys.path.append(r'C:\Users\lucky\Desktop\Understanding-spatial-data-for-development\OEM_230725\SPD.ipynb')  # Correct path to the directory containing source.py
import os #นำเข้า module OS มาในโค้ดภาษา Python
import time #คำสั่งต่างๆ มากมายที่เกี่ยวกับเวลา
import numpy as np # library ที่ใช้ในการคำนวนทางคณิตศาสตร์ในภาษา Python
import torch #library for use Machine Learning
import torch.nn as nn #Neural Network คือ Layer ที่ใช้ในการคำนวณค่า
from torch.utils.data import DataLoader #คือ Dataset object และลิสต์ของจำนวนข้อมูลที่ต้องการจะแบ่ง
import source  # Ensure 'source.py' is in the 'final_finish' directory
import segmentation_models_pytorch as smp #
import glob
import torchvision.transforms.functional as TF
import math
import cv2
from PIL import Image
import warnings
from pathlib import Path


warnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES'] = '0'


### Define main parameters

In [13]:
OEM_ROOT = r"C:\Users\lucky\Desktop\Understanding-spatial-data-for-development\OEM_230725\data\OpenEarthMap_Demo"
OEM_DATA_DIR = os.path.join(OEM_ROOT, 'train_val')
TEST_DIR = OEM_ROOT+'test/'
TRAIN_LIST = os.path.join(OEM_ROOT, "train.txt")
VAL_LIST = os.path.join(OEM_ROOT, "val.txt")
WEIGHT_DIR = r"C:\Users\lucky\Desktop\Understanding-spatial-data-for-development\OEM_230725\weight" # path to save weights
OUT_DIR = r"C:\Users\lucky\Desktop\Understanding-spatial-data-for-development\OEM_230725\result" # path to save prediction images
os.makedirs(WEIGHT_DIR, exist_ok=True)
test_large = OEM_ROOT+'/N35.675E139.725.tif'

seed = 0
learning_rate = 0.0001
batch_size = 4
n_epochs = 5
classes = [1, 2, 3, 4, 5, 6, 7, 8]
n_classes = len(classes)+1
classes_wt = np.ones([n_classes], dtype=np.float32)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Number of epochs   :", n_epochs)
print("Number of classes  :", n_classes)
print("Batch size         :", batch_size)
print("Device             :", device)

Number of epochs   : 5
Number of classes  : 9
Batch size         : 4
Device             : cuda


### Prepare training and validation file lists

In this demo for Google Colab, we use only two regions, i.e., Tokyo and Kyoto for training. To train with the full set, please download the OpenEarthMap dataset from https://zenodo.org/record/7223446. Note for xBD data preparation is available at https://github.com/bao18/open_earth_map.

In [14]:
img_pths = [f for f in Path(OEM_DATA_DIR).rglob("*.png") if "labels" in str(f)]
train_pths = [str(f) for f in img_pths if f.name in np.loadtxt(TRAIN_LIST, dtype=str)]
val_pths = [str(f) for f in img_pths if f.name in np.loadtxt(VAL_LIST, dtype=str)]

print("Total samples      :", len(img_pths))
print("Training samples   :", len(train_pths))
print("Validation samples :", len(val_pths))

Total samples      : 10000
Training samples   : 20
Validation samples : 10


### Define training and validation dataloaders

In [15]:
trainset = source.dataset.Dataset(train_pths, classes=classes, size=512, train=True)
validset = source.dataset.Dataset(val_pths, classes=classes, train=False)

train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)
valid_loader = DataLoader(validset, batch_size=batch_size, shuffle=False, num_workers=0)


### Setup network

In [16]:
network = smp.Unet(
    classes=n_classes,
    activation=None,
    encoder_weights="imagenet",
    encoder_name="efficientnet-b4",
    decoder_attention_type="scse",
)

# count parameters
params = 0
for p in network.parameters():
    if p.requires_grad:
        params += p.numel()

criterion = source.losses.CEWithLogitsLoss(weights=classes_wt)
criterion_name = 'CE'
metric = source.metrics.IoU2()
optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
network_fout = f"{network.name}_s{seed}_{criterion.name}"
OUT_DIR += network_fout # path to save prediction images
os.makedirs(OUT_DIR, exist_ok=True)

print("Model output name  :", network_fout)
print("Number of parameters: ", params)

if torch.cuda.device_count() > 1:
    print("Number of GPUs :", torch.cuda.device_count())
    network = torch.nn.DataParallel(network)
    optimizer = torch.optim.Adam(
        [dict(params=network.module.parameters(), lr=learning_rate)]
    )

https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth
Model output name  : u-efficientnet-b4_s0_CELoss
Number of parameters:  20304278


### Visualization functions

In [17]:
class_rgb = {
    "Bareland": [128, 0, 0],
    "Grass": [0, 255, 36],
    "Pavement": [148, 148, 148],
    "Road": [255, 255, 255],
    "Tree": [34, 97, 38],
    "Water": [0, 69, 255],
    "Cropland": [75, 181, 73],
    "buildings": [222, 31, 7],
}

class_gray = {
    "Bareland": 1,
    "Grass": 2,
    "Pavement": 3,
    "Road": 4,
    "Tree": 5,
    "Water": 6,
    "Cropland": 7,
    "buildings": 8,
}

def label2rgb(a):
    """
    a: labels (HxW)
    """
    out = np.zeros(shape=a.shape + (3,), dtype="uint8")
    for k, v in class_gray.items():
        out[a == v, 0] = class_rgb[k][0]
        out[a == v, 1] = class_rgb[k][1]
        out[a == v, 2] = class_rgb[k][2]
    return out

### Training

In [None]:
### Training
start = time.time()

max_score = 0
train_hist = []
valid_hist = []

for epoch in range(n_epochs):
    print(f"\n🚀 Epoch: {epoch + 1}")

    logs_train, train_iou_per_class = source.runner.train_epoch(
        model=network,
        optimizer=optimizer,
        criterion=criterion,
        metric=metric,
        dataloader=train_loader,
        device=device,
    )

    logs_valid, valid_iou_per_class = source.runner.valid_epoch(
        model=network,
        criterion=criterion,
        metric=metric,
        dataloader=valid_loader,
        device=device,
    )

    train_hist.append(logs_train)
    valid_hist.append(logs_valid)

    score = logs_valid["mIoU"]

    if max_score < score:
        max_score = score
        torch.save(network.state_dict(), os.path.join(WEIGHT_DIR, f"{network_fout}.pth"))
        print("✅ Model saved!")

    # Print per-class IoU
    print("\n📌 **IoU per class (as %):**")
    for i, class_iou in enumerate(valid_iou_per_class, start=0):
        print(f" - Class {i}: {class_iou * 100:.2f}%")

    # Find the best-performing class
    best_class = np.nanargmax(valid_iou_per_class) + 1
    best_class_iou = valid_iou_per_class[best_class - 1] * 100
    print(f"\n🏆 **Max IoU Score: Class {best_class} with {best_class_iou:.2f}%**")

    # Print epoch summary
    print(f"\n📌 **Epoch {epoch + 1} Summary**")
    print(f"🏆 Max Train mIoU so far: {logs_train['mIoU'] * 100:.2f}%")
    print(f"🏆 Max Valid mIoU so far: {logs_valid['mIoU'] * 100:.2f}%")

end = time.time()
print("\n⏳ Processing time:", end - start)


🚀 Epoch: 1


Train: 100%|██████████| 5/5 [00:19<00:00,  3.81s/it, CELoss=2.52, mIoU=3.19%]
Valid: 100%|██████████| 3/3 [00:26<00:00,  8.69s/it, CELoss=2.34, mIoU=3.53%]


✅ Model saved!

📌 **IoU per class (as %):**
 - Class 1: 0.00%
 - Class 2: 2.64%
 - Class 3: 0.07%
 - Class 4: 0.05%
 - Class 5: 3.54%
 - Class 6: 0.00%
 - Class 7: 0.03%
 - Class 8: 0.18%
 - Class 9: 2.44%

🏆 **Max IoU Score: Class 5 with 3.54%**

📌 **Epoch 1 Summary**
🏆 Max Train mIoU so far: 3.19%
🏆 Max Valid mIoU so far: 3.53%

🚀 Epoch: 2


Train: 100%|██████████| 5/5 [00:19<00:00,  3.85s/it, CELoss=2.41, mIoU=3.72%]
Valid: 100%|██████████| 3/3 [00:26<00:00,  8.70s/it, CELoss=2.27, mIoU=3.24%]



📌 **IoU per class (as %):**
 - Class 1: 0.00%
 - Class 2: 2.52%
 - Class 3: 0.12%
 - Class 4: 0.78%
 - Class 5: 2.55%
 - Class 6: 0.01%
 - Class 7: 0.05%
 - Class 8: 0.17%
 - Class 9: 1.73%

🏆 **Max IoU Score: Class 5 with 2.55%**

📌 **Epoch 2 Summary**
🏆 Max Train mIoU so far: 3.72%
🏆 Max Valid mIoU so far: 3.24%

🚀 Epoch: 3


Train: 100%|██████████| 5/5 [00:19<00:00,  3.85s/it, CELoss=2.32, mIoU=4.04%]
Valid: 100%|██████████| 3/3 [00:26<00:00,  8.70s/it, CELoss=2.19, mIoU=4.46%]


✅ Model saved!

📌 **IoU per class (as %):**
 - Class 1: 0.00%
 - Class 2: 2.20%
 - Class 3: 0.09%
 - Class 4: 0.62%
 - Class 5: 4.08%
 - Class 6: 0.00%
 - Class 7: 0.00%
 - Class 8: 0.06%
 - Class 9: 3.76%

🏆 **Max IoU Score: Class 5 with 4.08%**

📌 **Epoch 3 Summary**
🏆 Max Train mIoU so far: 4.04%
🏆 Max Valid mIoU so far: 4.46%

🚀 Epoch: 4


Train: 100%|██████████| 5/5 [00:19<00:00,  3.86s/it, CELoss=2.25, mIoU=4.31%]
Valid: 100%|██████████| 3/3 [00:26<00:00,  8.69s/it, CELoss=2.19, mIoU=4.85%]


✅ Model saved!

📌 **IoU per class (as %):**
 - Class 1: 0.00%
 - Class 2: 1.47%
 - Class 3: 0.11%
 - Class 4: 0.24%
 - Class 5: 4.55%
 - Class 6: 0.00%
 - Class 7: 0.00%
 - Class 8: 0.03%
 - Class 9: 5.45%

🏆 **Max IoU Score: Class 9 with 5.45%**

📌 **Epoch 4 Summary**
🏆 Max Train mIoU so far: 4.31%
🏆 Max Valid mIoU so far: 4.85%

🚀 Epoch: 5


Train: 100%|██████████| 5/5 [00:19<00:00,  3.84s/it, CELoss=2.19, mIoU=4.85%]
Valid: 100%|██████████| 3/3 [00:26<00:00,  8.69s/it, CELoss=2.08, mIoU=5.70%]

✅ Model saved!

📌 **IoU per class (as %):**
 - Class 1: 0.00%
 - Class 2: 1.72%
 - Class 3: 0.12%
 - Class 4: 0.29%
 - Class 5: 4.93%
 - Class 6: 0.00%
 - Class 7: 0.00%
 - Class 8: 0.03%
 - Class 9: 6.66%

🏆 **Max IoU Score: Class 9 with 6.66%**

📌 **Epoch 5 Summary**
🏆 Max Train mIoU so far: 4.85%
🏆 Max Valid mIoU so far: 5.70%

⏳ Processing time: 227.48291087150574





### Testing


In [21]:
start = time.time()

# load network
network.load_state_dict(torch.load(os.path.join(WEIGHT_DIR, f"{network_fout}.pth")))
network.to(device).eval()

test_pths = glob.glob(TEST_DIR+"/*.tif")
#testset = source.dataset.Dataset(test_pths, classes=classes, train=False)

for fn_img in test_pths:
  img = source.dataset.load_multiband(fn_img)
  h, w = img.shape[:2]
  power = math.ceil(np.log2(h) / np.log2(2))
  shape = (2 ** power, 2 ** power)
  img = cv2.resize(img, shape)

  # test time augmentation
  imgs = []
  imgs.append(img.copy())
  imgs.append(img[:, ::-1, :].copy())
  imgs.append(img[::-1, :, :].copy())
  imgs.append(img[::-1, ::-1, :].copy())

  input = torch.cat([TF.to_tensor(x).unsqueeze(0) for x in imgs], dim=0).float().to(device)

  pred = []
  with torch.no_grad():
      msk = network(input)
      msk = torch.softmax(msk[:, :, ...], dim=1)
      msk = msk.cpu().numpy()
      pred = (msk[0, :, :, :] + msk[1, :, :, ::-1] + msk[2, :, ::-1, :] + msk[3, :, ::-1, ::-1])/4
  pred = pred.argmax(axis=0).astype("uint8")
  size = pred.shape[0:]
  y_pr = cv2.resize(pred, (w, h), interpolation=cv2.INTER_NEAREST)

  # save image as png
  filename = os.path.splitext(os.path.basename(fn_img))[0]
  y_pr_rgb = label2rgb(y_pr)
  Image.fromarray(y_pr_rgb).save(os.path.join(OUT_DIR, filename+'test.png'))

end = time.time()
print('Processing time:',end - start)

Processing time: 0.5042271614074707


### Testing a model for a large Geotiff image

A sample image is provided by the Geospatial Information Authority of Japan at https://cyberjapandata.gsi.go.jp/xyz/seamlessphoto/{z}/{x}/{y}.jpg


In [22]:
start = time.time()

# load network
network.load_state_dict(torch.load(os.path.join(WEIGHT_DIR, f"{network_fout}.pth")))
network.to(device).eval()

# process large Geotiff image
img0 = source.dataset.load_multiband(test_large)

# get crs and transform
crs, trans = source.dataset.get_crs(test_large)

if img0.shape[2] > 3:
    img0 = img0[:,:,[0,1,2]]
height = img0.shape[0]
width = img0.shape[1]
band = img0.shape[2]

patch_size = 512
stride = 256
C = int(np.ceil( (width - patch_size) / stride ) + 1)
R = int(np.ceil( (height - patch_size) / stride ) + 1)

# weight matrix B for avoiding boundaries of patches
if patch_size > stride:
    w = patch_size
    s1 = stride
    s2 = w - s1
    d = 1/(1+s2)
    B1 = np.ones((w,w))
    B1[:,s1::] = np.dot(np.ones((w,1)),(-np.arange(1,s2+1)*d+1).reshape(1,s2))
    B2 = np.flip(B1)
    B3 = B1.T
    B4 = np.flip(B3)
    B = B1*B2*B3*B4
else:
    B = np.ones((w,w))

img1 = np.zeros((patch_size+stride*(R-1), patch_size+stride*(C-1),3))
img1[0:height,0:width,:] = img0.copy()

pred_all = np.zeros((9,patch_size+stride*(R-1), patch_size+stride*(C-1)))
weight = np.zeros((patch_size+stride*(R-1), patch_size+stride*(C-1)))

for r in range(R):
    for c in range(C):
        img = img1[r*stride:r*stride+patch_size,c*stride:c*stride+patch_size,:].copy().astype(np.float32)/255
        imgs = []
        imgs.append(img.copy())
        imgs.append(img[:, ::-1, :].copy())
        imgs.append(img[::-1, :, :].copy())
        imgs.append(img[::-1, ::-1, :].copy())

        input = torch.cat([TF.to_tensor(x).unsqueeze(0) for x in imgs], dim=0).float().to(device)

        pred = []
        with torch.no_grad():
            msk = network(input)
            msk = torch.softmax(msk[:, :, ...], dim=1)
            msk = msk.cpu().numpy()

            pred = (msk[0, :, :, :] + msk[1, :, :, ::-1] + msk[2, :, ::-1, :] + msk[3, :, ::-1, ::-1])/4

        pred_all[:,r*stride:r*stride+patch_size,c*stride:c*stride+patch_size] += pred.copy()*B
        weight[r*stride:r*stride+patch_size,c*stride:c*stride+patch_size] += B

for b in range(9):
    pred_all[b,:,:] = pred_all[b,:,:]/weight
    if b == 0:
        pred_all[b,:,:] = 0

pred_all = pred_all.argmax(axis=0).astype("uint8")

filename = os.path.splitext(os.path.basename(test_large))[0]
pr_rgb = label2rgb(pred_all)
Image.fromarray(pr_rgb[0:height,0:width,:]).save(os.path.join(OUT_DIR, filename+'_pr.png'))

# save geotiff
pr_rgb = np.transpose(pr_rgb[0:height,0:width,:], (2,0,1))
source.dataset.save_img(os.path.join(OUT_DIR, filename+'_pr.tif'),pr_rgb,crs,trans)

end = time.time()
print('Processing time:',end - start)

Processing time: 218.02295088768005
