# Contextual Adaptation: Deforestation Segmentation in Kalimantan (Indonesia)

This notebook adapts a satellite-image segmentation approach to the context of tropical forest loss in Kalimantan, Indonesia, where deforestation is strongly associated with oil palm expansion and related land-use change.

Pipeline:
1) Export Sentinel-2 SR (median composite) and Hansen GFC forest loss mask from Google Earth Engine (GEE)
2) Download GeoTIFFs to server and tile into 512×512 patches
3) Train baseline U-Net vs adapted U-Net (BCE+Dice + augment)
4) Evaluate IoU/Dice/Precision/Recall, run paired significance test, and inspect failure cases


In [None]:
import os, random, math
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import precision_score, recall_score
from scipy.stats import ttest_rel, wilcoxon

print("torch:", torch.__version__)
print("cuda:", torch.cuda.is_available())


## 1) AOI + Time Window

To keep exports manageable for coursework, we use a smaller AOI within East Kalimantan.
You can expand later, but this is a safe default for GPU server workflows.

AOI BBOX: (lon_min, lat_min, lon_max, lat_max)


In [None]:
# ========= CONFIG =========
# AOI: East Kalimantan sub-region (manageable export size)
AOI_BBOX = (116.0, -1.8, 117.2, -0.6)

# Time window (imagery)
START_DATE = "2019-01-01"
END_DATE   = "2023-12-31"

# Cloud threshold
MAX_CLOUD_PCT = 20

# 4-band input: Green, Red, NIR, SWIR
S2_BANDS = ["B3", "B4", "B8", "B11"]

# Label window (Hansen lossyear -> binary loss mask)
LOSS_START_YEAR = 2019
LOSS_END_YEAR   = 2023

# Export scale (m)
EXPORT_SCALE = 10  # B11 is 20m native, will be resampled by GEE export

# Local folders
ROOT = Path(".")
DATA_RAW = ROOT / "data_raw"      # downloaded GeoTIFFs here
DATA_NPY = ROOT / "data"          # patches here
DATA_RAW.mkdir(exist_ok=True)
DATA_NPY.mkdir(exist_ok=True)

print("AOI_BBOX:", AOI_BBOX)
print("DATA_RAW:", DATA_RAW.resolve())
print("DATA_NPY:", DATA_NPY.resolve())


## 2) Export Data from Google Earth Engine (GEE)

We export:
- Sentinel-2 Surface Reflectance (harmonized) median composite over AOI and time range
- Hansen Global Forest Change (v1.12) lossyear -> binary loss mask (2019–2023)

Export target: Google Drive folder (recommended). Then download the two GeoTIFFs into `./data_raw/`.


In [None]:
import ee

try:
    ee.Initialize()
    print("✅ Earth Engine initialized.")
except Exception as e:
    print("❌ Earth Engine not initialized:", e)
    print("Run in terminal: earthengine authenticate")
    raise


In [None]:
def bbox_to_ee_geometry(bbox):
    lon_min, lat_min, lon_max, lat_max = bbox
    return ee.Geometry.Rectangle([lon_min, lat_min, lon_max, lat_max])

def s2_sr_median(aoi, start_date, end_date, max_cloud, bands):
    col = (ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
           .filterBounds(aoi)
           .filterDate(start_date, end_date)
           .filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", max_cloud))
           .select(bands))
    return col.median().clip(aoi)

def hansen_loss_mask(aoi, start_year, end_year):
    gfc = ee.Image("UMD/hansen/global_forest_change_2024_v1_12")
    lossyear = gfc.select("lossyear")  # 1..24 => 2001..2024
    start_off = start_year - 2000
    end_off   = end_year - 2000
    mask = lossyear.gte(start_off).And(lossyear.lte(end_off)).rename("loss_mask").uint8()
    return mask.clip(aoi)

def export_to_drive(image, aoi, description, folder, scale, crs="EPSG:4326", max_pixels=1e13):
    task = ee.batch.Export.image.toDrive(
        image=image,
        description=description,
        folder=folder,
        fileNamePrefix=description,
        region=aoi,
        scale=scale,
        crs=crs,
        maxPixels=max_pixels
    )
    task.start()
    return task


In [None]:
DRIVE_FOLDER = "kalimantan_cw2_exports"

aoi = bbox_to_ee_geometry(AOI_BBOX)

s2_img = s2_sr_median(aoi, START_DATE, END_DATE, MAX_CLOUD_PCT, S2_BANDS)
loss   = hansen_loss_mask(aoi, LOSS_START_YEAR, LOSS_END_YEAR)

tasks = []
tasks.append(export_to_drive(s2_img, aoi, "S2_SR_EKAL_2019_2023_MEDIAN", DRIVE_FOLDER, EXPORT_SCALE))
tasks.append(export_to_drive(loss,   aoi, "HANSEN_LOSSMASK_EKAL_2019_2023", DRIVE_FOLDER, EXPORT_SCALE))

print("✅ Started export tasks:", len(tasks))
print("Now go to GEE Code Editor -> Tasks, or check Google Drive folder:", DRIVE_FOLDER)


### After export finishes

Download these 2 files from Google Drive into `./data_raw/`:

- `data_raw/S2_SR_EKAL_2019_2023_MEDIAN.tif`
- `data_raw/HANSEN_LOSSMASK_EKAL_2019_2023.tif`

Then run the preprocessing below.


In [None]:
import rasterio
from rasterio.windows import Window

S2_TIF   = DATA_RAW / "S2_SR_EKAL_2019_2023_MEDIAN.tif"
MASK_TIF = DATA_RAW / "HANSEN_LOSSMASK_EKAL_2019_2023.tif"

assert S2_TIF.exists(), f"Missing {S2_TIF}"
assert MASK_TIF.exists(), f"Missing {MASK_TIF}"

PATCH  = 512
STRIDE = 512  # set 256 for overlap if you want more samples

def tile_to_patches(s2_path, mask_path, out_root, patch=512, stride=512, max_patches=None):
    out_img = out_root / "all_images"
    out_msk = out_root / "all_masks"
    out_img.mkdir(parents=True, exist_ok=True)
    out_msk.mkdir(parents=True, exist_ok=True)

    idx = 0
    with rasterio.open(s2_path) as s2, rasterio.open(mask_path) as msk:
        assert s2.count == 4, f"Expected 4 bands, got {s2.count}"
        assert msk.count == 1, f"Expected 1-band mask, got {msk.count}"
        H, W = s2.height, s2.width

        for top in range(0, H - patch + 1, stride):
            for left in range(0, W - patch + 1, stride):
                win = Window(left, top, patch, patch)
                img = s2.read(window=win)          # (4,patch,patch)
                y   = msk.read(1, window=win)      # (patch,patch)

                img = np.transpose(img, (1,2,0)).astype(np.float32)  # (patch,patch,4)
                y   = (y > 0).astype(np.float32)                    # binarise

                np.save(out_img / f"{idx:06d}.npy", img)
                np.save(out_msk / f"{idx:06d}.npy", y)
                idx += 1

                if max_patches and idx >= max_patches:
                    break
            if max_patches and idx >= max_patches:
                break

    print("✅ Saved patches:", idx)
    return idx

N = tile_to_patches(S2_TIF, MASK_TIF, DATA_NPY, patch=PATCH, stride=STRIDE, max_patches=None)


In [None]:
from sklearn.model_selection import train_test_split

ALL_IMG_DIR = DATA_NPY / "all_images"
ALL_MSK_DIR = DATA_NPY / "all_masks"

ids = sorted([p.stem for p in ALL_IMG_DIR.glob("*.npy")])
train_ids, test_ids = train_test_split(ids, test_size=0.2, random_state=42)
train_ids, val_ids  = train_test_split(train_ids, test_size=0.2, random_state=42)

def materialize_split(split_name, split_ids):
    (DATA_NPY / split_name / "images").mkdir(parents=True, exist_ok=True)
    (DATA_NPY / split_name / "masks").mkdir(parents=True, exist_ok=True)
    for sid in split_ids:
        (DATA_NPY / split_name / "images" / f"{sid}.npy").write_bytes((ALL_IMG_DIR / f"{sid}.npy").read_bytes())
        (DATA_NPY / split_name / "masks"  / f"{sid}.npy").write_bytes((ALL_MSK_DIR / f"{sid}.npy").read_bytes())

materialize_split("train", train_ids)
materialize_split("val",   val_ids)
materialize_split("test",  test_ids)

print("Split sizes:", len(train_ids), len(val_ids), len(test_ids))


In [None]:
def load_npy_stack(image_dir, mask_dir):
    image_files = sorted([p for p in Path(image_dir).glob("*.npy")])
    mask_files  = sorted([p for p in Path(mask_dir).glob("*.npy")])
    assert len(image_files) == len(mask_files)

    imgs, msks = [], []
    for ip, mp in zip(image_files, mask_files):
        img = np.load(ip).astype("float32")
        msk = np.load(mp).astype("float32")
        assert img.shape[-1] == 4
        assert msk.ndim == 2
        imgs.append(img); msks.append(msk)
    return np.stack(imgs, 0), np.stack(msks, 0)

X_train, Y_train = load_npy_stack(DATA_NPY/"train/images", DATA_NPY/"train/masks")
X_val,   Y_val   = load_npy_stack(DATA_NPY/"val/images",   DATA_NPY/"val/masks")
X_test,  Y_test  = load_npy_stack(DATA_NPY/"test/images",  DATA_NPY/"test/masks")

def normalise_images(X, mean=None, std=None):
    if mean is None:
        mean = X.mean(axis=(0,1,2), keepdims=True)
        std  = X.std(axis=(0,1,2), keepdims=True) + 1e-6
    return (X - mean)/std, mean, std

X_train, mean, std = normalise_images(X_train)
X_val, _, _        = normalise_images(X_val, mean, std)
X_test, _, _       = normalise_images(X_test, mean, std)

print("Train:", X_train.shape, Y_train.shape)
print("Val:  ", X_val.shape, Y_val.shape)
print("Test: ", X_test.shape, Y_test.shape)
print("Foreground ratio (train):", float(Y_train.mean()))


In [None]:
i = np.random.randint(0, len(X_train))
plt.figure(figsize=(10,4))
plt.subplot(1,2,1); plt.imshow(X_train[i][:,:,1], cmap="gray"); plt.title("Input B4 (Red)")
plt.subplot(1,2,2); plt.imshow(Y_train[i], cmap="gray"); plt.title("Loss mask (2019–2023)")
plt.tight_layout(); plt.show()


In [None]:
class NpySegDataset(Dataset):
    def __init__(self, X, Y, augment=False):
        self.X = X
        self.Y = Y
        self.augment = augment

    def __len__(self): return len(self.X)

    def _augment(self, x, y):
        if random.random() < 0.5:
            x = np.flip(x, axis=1).copy()
            y = np.flip(y, axis=1).copy()
        if random.random() < 0.5:
            x = np.flip(x, axis=0).copy()
            y = np.flip(y, axis=0).copy()
        if random.random() < 0.5:
            a = 1.0 + (random.random()-0.5)*0.2
            b = (random.random()-0.5)*0.2
            x = a*x + b
        return x, y

    def __getitem__(self, idx):
        x = self.X[idx]
        y = self.Y[idx]
        if self.augment:
            x, y = self._augment(x, y)
        x = torch.from_numpy(np.transpose(x,(2,0,1))).float()
        y = torch.from_numpy(y[None,...]).float()
        return x, y

BATCH = 4
train_loader = DataLoader(NpySegDataset(X_train, Y_train, augment=True), batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(NpySegDataset(X_val,   Y_val,   augment=False), batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(NpySegDataset(X_test,  Y_test,  augment=False), batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class UNet(nn.Module):
    def __init__(self, in_ch=4, out_ch=1, base=32):
        super().__init__()
        self.d1 = DoubleConv(in_ch, base);   self.p1 = nn.MaxPool2d(2)
        self.d2 = DoubleConv(base, base*2);  self.p2 = nn.MaxPool2d(2)
        self.d3 = DoubleConv(base*2, base*4);self.p3 = nn.MaxPool2d(2)
        self.b  = DoubleConv(base*4, base*8)

        self.u3 = nn.ConvTranspose2d(base*8, base*4, 2, 2); self.c3 = DoubleConv(base*8, base*4)
        self.u2 = nn.ConvTranspose2d(base*4, base*2, 2, 2); self.c2 = DoubleConv(base*4, base*2)
        self.u1 = nn.ConvTranspose2d(base*2, base,   2, 2); self.c1 = DoubleConv(base*2, base)

        self.out = nn.Conv2d(base, out_ch, 1)

    def forward(self, x):
        x1 = self.d1(x)
        x2 = self.d2(self.p1(x1))
        x3 = self.d3(self.p2(x2))
        xb = self.b(self.p3(x3))

        x = self.u3(xb); x = self.c3(torch.cat([x, x3], 1))
        x = self.u2(x);  x = self.c2(torch.cat([x, x2], 1))
        x = self.u1(x);  x = self.c1(torch.cat([x, x1], 1))
        return self.out(x)


In [None]:
bce = nn.BCEWithLogitsLoss()

def dice_loss(logits, targets, eps=1e-6):
    p = torch.sigmoid(logits)
    inter = (p*targets).sum(dim=(2,3))
    den   = (p+targets).sum(dim=(2,3)) + eps
    return (1 - (2*inter/den)).mean()

def loss_baseline(logits, targets):
    return bce(logits, targets)

def loss_adapted(logits, targets, alpha=0.5):
    return alpha*bce(logits, targets) + (1-alpha)*dice_loss(logits, targets)


In [None]:
@torch.no_grad()
def compute_metrics(model, loader, device):
    model.eval()
    ious, dices, ps, rs = [], [], [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        prob = torch.sigmoid(logits)
        pred = (prob > 0.5).float()

        inter = (pred*y).sum(dim=(2,3))
        union = (pred + y - pred*y).sum(dim=(2,3)) + 1e-6
        iou  = (inter/union).mean().item()
        dice = (2*inter / ((pred+y).sum(dim=(2,3)) + 1e-6)).mean().item()

        y_np = y.cpu().numpy().astype(np.uint8).ravel()
        p_np = pred.cpu().numpy().astype(np.uint8).ravel()
        ps.append(precision_score(y_np, p_np, zero_division=0))
        rs.append(recall_score(y_np, p_np, zero_division=0))

        ious.append(iou); dices.append(dice)

    return {
        "IoU_mean": float(np.mean(ious)),
        "Dice_mean": float(np.mean(dices)),
        "Precision_mean": float(np.mean(ps)),
        "Recall_mean": float(np.mean(rs)),
        "IoU_per_batch": ious
    }


In [None]:
def train_model(model, train_loader, val_loader, device, loss_fn, lr=3e-4, epochs=30, patience=6, name="model"):
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    best_iou = -1
    best_path = f"{name}_best.pt"
    bad = 0
    hist = []

    for ep in range(1, epochs+1):
        model.train()
        tr_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            logits = model(x)
            loss = loss_fn(logits, y)
            loss.backward()
            opt.step()
            tr_loss += loss.item()

        val_m = compute_metrics(model, val_loader, device)
        tr_loss /= len(train_loader)
        hist.append({"epoch": ep, "train_loss": tr_loss, **val_m})
        print(f"Epoch {ep:02d} | train_loss {tr_loss:.4f} | val_IoU {val_m['IoU_mean']:.4f} | val_Dice {val_m['Dice_mean']:.4f}")

        if val_m["IoU_mean"] > best_iou:
            best_iou = val_m["IoU_mean"]
            torch.save(model.state_dict(), best_path)
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                print("Early stopping.")
                break

    model.load_state_dict(torch.load(best_path, map_location=device))
    return model, hist, best_path


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

baseline = UNet(in_ch=4, out_ch=1, base=32).to(device)
baseline, hist_b, ckpt_b = train_model(
    baseline, train_loader, val_loader, device,
    loss_fn=loss_baseline, lr=3e-4, epochs=30, patience=6, name="unet_baseline"
)

adapted = UNet(in_ch=4, out_ch=1, base=32).to(device)
adapted, hist_a, ckpt_a = train_model(
    adapted, train_loader, val_loader, device,
    loss_fn=lambda l,t: loss_adapted(l,t,alpha=0.5), lr=3e-4, epochs=30, patience=6, name="unet_adapted"
)


In [None]:
m_b = compute_metrics(baseline, test_loader, device)
m_a = compute_metrics(adapted,  test_loader, device)

print("Baseline test:", m_b)
print("Adapted  test:", m_a)


In [None]:
a = np.array(m_b["IoU_per_batch"])
b = np.array(m_a["IoU_per_batch"])

t_stat, p_t = ttest_rel(a, b)
print("Paired t-test p-value:", p_t)

try:
    w_stat, p_w = wilcoxon(a, b)
    print("Wilcoxon p-value:", p_w)
except Exception as e:
    print("Wilcoxon failed:", e)


In [None]:
@torch.no_grad()
def show_examples(model, X, Y, n=6, title=""):
    model.eval()
    idxs = np.random.choice(len(X), size=n, replace=False)
    plt.figure(figsize=(12, 3*n))
    for i, idx in enumerate(idxs, 1):
        x = torch.from_numpy(np.transpose(X[idx], (2,0,1))).unsqueeze(0).float().to(device)
        y = Y[idx]
        p = (torch.sigmoid(model(x)).cpu().numpy()[0,0] > 0.5).astype(np.uint8)

        plt.subplot(n, 3, (i-1)*3 + 1)
        plt.imshow(X[idx][:,:,1], cmap="gray"); plt.title("Input B4 (Red)"); plt.axis("off")
        plt.subplot(n, 3, (i-1)*3 + 2)
        plt.imshow(y, cmap="gray"); plt.title("GT loss"); plt.axis("off")
        plt.subplot(n, 3, (i-1)*3 + 3)
        plt.imshow(p, cmap="gray"); plt.title("Pred"); plt.axis("off")
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

show_examples(baseline, X_test, Y_test, n=6, title="Baseline – random test samples")
show_examples(adapted,  X_test, Y_test, n=6, title="Adapted – random test samples")


## What to write (short)

- Sentinel-2 SR harmonized imagery was filtered by cloud cover and aggregated with a median composite.
- Labels were derived from Hansen Global Forest Change (lossyear) to form a binary loss mask for 2019–2023.
- Adaptation: BCE+Dice to address class imbalance and intensity augmentation to reflect tropical atmospheric variability.
- Report IoU/Dice/Precision/Recall and run a paired test (t-test or Wilcoxon) on per-batch IoU.
- Include qualitative failure cases (cloud/atmospheric noise, plantation vs secondary forest confusion, boundary ambiguity).
