In [None]:
import kagglehub

from pathlib import Path
import random
import time
import csv

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from PIL import Image
from datasets import Dataset, Image as ImageFeature
from transformers import SamModel, SamProcessor

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode


device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device: ", device)

import transformers, torchvision
print("Transformers Version:", transformers.__version__)
print("Torch Version:", torch.__version__)
print("Torchvision Version:", torchvision.__version__)

## Configuration

In [None]:
# If false, U-Net is used
CHOOSE_SAM = True

# Use weights in cross entropy to counteract class imbalance
USE_WEIGHTED_LOSS = True
WEIGHTS_CALC_METHOD = "median"
WEIGHT_SCALAR = 0.2

USE_FOCAL_LOSS = False  # No effect if not USE_WEIGHTED_LOSS
FOCAL_LOSS_GAMMA = 2.0

# Reproduebility
SEED = 42

# image sets
TARGET_IMAGE_SIZE = (1024, 1024)

# Train set statistics for seed 42
MEAN = [0.4093, 0.3789, 0.2801]
STD = [0.1428, 0.1080, 0.0972]

# dataloaders
BATCH_SIZE = 2
NUM_WORKERS = 2

# training
LEARNING_RATE = 5e-5
NUM_EPOCHS = 3
PATIENCE = 3


Log Config

In [None]:
print("=" * 60)
print("TRAINING CONFIGURATION SUMMARY")
print("=" * 60)
print(f"Model:          {'SAM (head only)' if CHOOSE_SAM else 'U-Net (full)'}")
print(f"Loss:           {f'FocalLoss (γ={FOCAL_LOSS_GAMMA})' if USE_FOCAL_LOSS and USE_WEIGHTED_LOSS else 'Weighted CE' if USE_WEIGHTED_LOSS else 'CE'}")
print(f"Class weights:  {'Yes (median)' if USE_WEIGHTED_LOSS else 'No'}")
print(f"Reproducibility:{' SEED=' + str(SEED) if 'SEED' in locals() else 'No seed'}")
print(f"Image size:     {TARGET_IMAGE_SIZE}")
print(f"Normalization:  mean={MEAN}, std={STD}")
print("-" * 60)
print(f"Batch size:     {BATCH_SIZE}")
print(f"Workers:        {NUM_WORKERS}")
print(f"Optimizer:      AdamW(lr={LEARNING_RATE}, wd=1e-4)")
print(f"Scheduler:      ReduceLROnPlateau(patience=5)")
print(f"Epochs:         {NUM_EPOCHS}")
print(f"Early stop:     Patience={PATIENCE}")
print("=" * 60)

In [None]:
DATA_DIR = Path(kagglehub.dataset_download("balraj98/deepglobe-land-cover-classification-dataset"))
print("Path to the dataset: ", DATA_DIR.absolute())

metadata_df = pd.read_csv(DATA_DIR / "metadata.csv")
print("First rows of metadata:")
display(metadata_df.head())

In [None]:
labeled_df = metadata_df.dropna(subset=["sat_image_path", "mask_path"]).copy()
for key in ["sat_image_path", "mask_path"]:
    labeled_df[key] = labeled_df[key].apply(lambda p: str(DATA_DIR / str(p)))

# Fractions
train_frac = 0.70
valid_frac = 0.15
test_frac  = 0.15

# First: train vs (valid+test)
train_df, temp_df = train_test_split(
    labeled_df,
    test_size=valid_frac + test_frac,
    random_state=SEED,
    shuffle=True,
)

# Second: valid vs test inside temp
valid_df, test_df = train_test_split(
    temp_df,
    test_size=test_frac / (valid_frac + test_frac),
    random_state=SEED,
    shuffle=True,
)

splits = {
    "train": train_df,
    "valid": valid_df,
    "test":  test_df,
}

for name, df in splits.items():
    print(f"{name.capitalize()} size:", len(df))

In [None]:
class_df = pd.read_csv(DATA_DIR / "class_dict.csv")
class_df = class_df.reset_index(names=["class_id"])

print("Classes Table:")
display(class_df)
LABEL_MAP = class_df.set_index('class_id')['name'].to_dict()

rgb_tuples = class_df[["r", "g", "b"]].apply(tuple, axis=1)
DEEPGLOBE_CLASSES = dict(zip(rgb_tuples, class_df["class_id"]))

ID_TO_RGB = {class_id: rgb for rgb, class_id in DEEPGLOBE_CLASSES.items()}
print("RGB -> class_id: ", DEEPGLOBE_CLASSES)
print("class_id -> RGB: ", ID_TO_RGB)

In [None]:
def map_rgb_to_class_id(pil_mask: Image.Image) -> np.ndarray:
    mask_np = np.array(pil_mask, dtype=np.uint8)
    label_map = np.full(mask_np.shape[:2], fill_value=6, dtype=np.int64)

    for rgb_tuple, class_id in DEEPGLOBE_CLASSES.items():
        is_class = np.all(mask_np == np.array(rgb_tuple, dtype=np.uint8), axis=-1)
        label_map[is_class] = class_id
    return label_map

print("Function map_rgb_to_class_id defined.")

In [None]:
def decode_class_id_to_rgb(class_id_mask: np.ndarray) -> np.ndarray:
    h, w = class_id_mask.shape
    rgb = np.zeros((h,w,3), dtype=np.uint8)

    for class_id, rgb_color in ID_TO_RGB.items():
        rgb[class_id_mask == class_id] = np.array(rgb_color, dtype=np.uint8)
    return rgb

print("Function decode_class_id_to_rgb defined.")

In [None]:
image_transform = T.Compose([
    T.Resize(TARGET_IMAGE_SIZE, interpolation=InterpolationMode.BILINEAR),
    T.ToTensor(),
    T.Normalize(MEAN, STD),
])

mask_transform = T.Compose([
    T.Resize(TARGET_IMAGE_SIZE, interpolation=InterpolationMode.NEAREST),
])

print("Image and Mask transformation defined")

In [None]:
class DeepGlobeDataset(Dataset):
    def __init__(self, dataframe, image_transform=None, mask_transform=None):
        self.dataframe = dataframe.reset_index(drop=True)
        self.image_transform = image_transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]

        sat_path = row["sat_image_path"]
        mask_path = row["mask_path"]

        sat_img = Image.open(sat_path).convert("RGB")
        mask_img = Image.open(mask_path).convert("RGB")

        if self.image_transform is not None:
            sat_img_tensor = self.image_transform(sat_img)
        else:
            sat_img_tensor = T.ToTensor()(sat_img)

        if self.mask_transform is not None:
            mask_img_transformed = self.mask_transform(mask_img)
        else:
            mask_img_transformed = mask_img

        class_id_mask = map_rgb_to_class_id(mask_img_transformed)

        labels_tensor = torch.from_numpy(class_id_mask).long()

        return {
            "pixel_values": sat_img_tensor,
            "labels": labels_tensor,
        }
print("DeepGlobeDataset defined.")

### Create Datasets and Dataloaders

In [None]:
datasets = {
    name: DeepGlobeDataset(
        dataframe=df,
        image_transform=image_transform,
        mask_transform=mask_transform,
    )
    for name, df in splits.items()
}

for name, ds in datasets.items():
    print(f"{name.capitalize()} dataset size:", len(ds))

dataloaders = {
    "train": DataLoader(datasets["train"], batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS),
    "valid": DataLoader(datasets["valid"], batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS),
    "test":  DataLoader(datasets["test"],  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS),
}

print("Batch size:", BATCH_SIZE)
for name, dl in dataloaders.items():
    print(f"{name.capitalize()} DataLoader created.")

In [None]:
def compute_class_weights_from_loader(train_loader, num_classes, device="cpu",
                                      method="inverse", eps=1e-6):

    # Accumulate pixel counts per class
    class_counts = torch.zeros(num_classes, dtype=torch.float64)

    for batch in train_loader:
        labels = batch["labels"]            # (N, H, W)
        # Flatten and count
        labels_flat = labels.view(-1)
        bincount = torch.bincount(labels_flat, minlength=num_classes).to(torch.float64)
        class_counts += bincount

    # Avoid division by zero
    class_counts = class_counts + eps

    if method == "inverse":
        # w_c = 1 / f_c, then normalize to keep mean weight ~ 1
        weights = 1.0 / class_counts
        weights = weights / weights.mean()
    elif method == "median":
        # Median frequency balancing: w_c = median(f) / f_c  [SegNet]
        freq = class_counts / class_counts.sum()
        median_freq = freq.median()
        weights = median_freq / freq
    else:
        raise ValueError(f"Unknown method {method}")

    return weights.to(device=device, dtype=torch.float32)

In [None]:
import matplotlib.patches as mpatches

def legend_patches_for_present_classes(gt_mask_np, pred_mask_np, label_map, id_to_rgb):

    present_ids = sorted(set(np.unique(gt_mask_np)).union(set(np.unique(pred_mask_np))))

    patches = []
    for cid in present_ids:
        cid = int(cid)
        if cid not in id_to_rgb:
            continue
        rgb = id_to_rgb[cid]  # (r,g,b) in 0..255
        color01 = (rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0)
        name = label_map.get(cid, f"class_{cid}")
        patches.append(mpatches.Patch(color=color01, label=name))

    return patches


def add_landuse_legend(ax, patches, title="Land-Cover Classes", loc="upper left"):
    if not patches:
        return
    ax.legend(
        handles=patches,
        title=title,
        loc=loc,
        frameon=True,
        fontsize=8,
        title_fontsize=9,
    )

## Define UNet

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)

        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

## Define SAM-Adaption

In [None]:
class SegmentationHead(nn.Module):
    def __init__(self, in_channels: int, num_classes: int, target_size):

        super().__init__()
        self.target_size = target_size

        self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(256)
        self.conv2 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv_out = nn.Conv2d(128, num_classes, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = F.interpolate(
            x,
            size=self.target_size,
            mode="bilinear",
            align_corners=False
        )

        x = self.conv_out(x)
        return x
print("SegmentationHead defined.")

In [None]:
class SamSegmentationModel(nn.Module):
    def __init__(self, sam_image_encoder, num_classes: int, target_size):
        super().__init__()
        self.sam_image_encoder = sam_image_encoder
        
        with torch.no_grad():            
            dummy = torch.zeros(
                1, 3, TARGET_IMAGE_SIZE[0], TARGET_IMAGE_SIZE[1],
                device=device,
                dtype=next(sam_image_encoder.parameters()).dtype
            )
            encoder_output = self.sam_image_encoder(dummy)

            if hasattr(encoder_output, "last_hidden_state") and encoder_output.last_hidden_state is not None:
                feat = encoder_output.last_hidden_state   # (1, C, H_feat, W_feat)
            elif hasattr(encoder_output, "image_embeddings") and encoder_output.image_embeddings is not None:
                feat = encoder_output.image_embeddings    # (1, C, H_feat, W_feat)
            else:
                raise ValueError(f"Unexpected Encoder-Output: {encoder_output}")

            if feat.dim() != 4:
                raise ValueError(f"Expected Feature-Shape (B, C, H, W), was: {feat.shape}")

            in_channels = feat.shape[1]
            print("Channels from SAM-Encoder:", in_channels)
                
        self.head = SegmentationHead(
            in_channels=in_channels,
            num_classes=num_classes,
            target_size=target_size,
        )

    def forward(self, pixel_values):
        with torch.no_grad():
            encoder_output = self.sam_image_encoder(pixel_values)
        
        if hasattr(encoder_output, "last_hidden_state") and encoder_output.last_hidden_state is not None:
            features = encoder_output.last_hidden_state  # (B, C, H_feat, W_feat)
        elif hasattr(encoder_output, "image_embeddings") and encoder_output.image_embeddings is not None:
            features = encoder_output.image_embeddings   # (B, C, H_feat, W_feat)
        else:
            raise ValueError(f"Unexpected Encoder-Output: {encoder_output}")

        if features.dim() != 4:
            raise ValueError(f"Expected Feature-Shape (B, C, H, W), was: {features.shape}")
        
        logits = self.head(features)  # (B, num_classes, H_out, W_out)

        return logits
print("SamSegmentationModel defined.")

### Instantiate Model

In [None]:
num_classes=len(DEEPGLOBE_CLASSES)
CHECKPOINT_DIR = Path("./checkpoints")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

if CHOOSE_SAM:
    from transformers import SamModel, SamProcessor
    
    target_size = TARGET_IMAGE_SIZE
    CHECKPOINT_PATH = Path("./checkpoints/sam_deepglobe_head.pt")
        
    sam_model = SamModel.from_pretrained(
        "facebook/sam-vit-base",
    )
    
    sam_image_encoder = sam_model.vision_encoder.to(device)
    
    for p in sam_image_encoder.parameters():
        p.requires_grad = False
    sam_image_encoder.eval()
    
    model = SamSegmentationModel(
        sam_image_encoder=sam_image_encoder,
        num_classes=num_classes,
        target_size=TARGET_IMAGE_SIZE
    ).to(device)
    
    model.eval()
    
    # Debug
    print("Hidden Channels aus SAM:", model.head.conv1.in_channels)
    print("Number of classes: ", num_classes)
else:
    CHECKPOINT_PATH = Path("./checkpoints/UNet.pt")
    model = UNet(n_channels=3, n_classes=len(DEEPGLOBE_CLASSES))
    model.to(device)

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, weight=None, ignore_index=-100):
        super().__init__()
        self.gamma = gamma
        self.weight = weight
        self.ignore_index = ignore_index

    def forward(self, logits, targets):
        ce_loss = F.cross_entropy(logits, targets, ignore_index=self.ignore_index, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        if self.weight is not None:
            focal_loss = self.weight[targets] * focal_loss 
        return focal_loss.mean()

In [None]:
if USE_WEIGHTED_LOSS:
    class_weights = compute_class_weights_from_loader(
        dataloaders['train'],
        num_classes=len(DEEPGLOBE_CLASSES),
        device=device,
        method=WEIGHTS_CALC_METHOD
    )
    print(f"Computed weights for Cross-Entropy: {class_weights}")
    if USE_FOCAL_LOSS:
        criterion = FocalLoss(
            gamma=FOCAL_LOSS_GAMMA,
            weight=class_weights
        )
    else:
        criterion = torch.nn.CrossEntropyLoss(weight=class_weights * WEIGHT_SCALAR)
else:
    criterion = torch.nn.CrossEntropyLoss()
    
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=LEARNING_RATE,
    weight_decay=1e-4,  
    betas=(0.9, 0.999)  # Default Adam
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='max', 
    factor=0.5, 
    patience=5, 
    min_lr=1e-7, 
    threshold=0.001  # min improvement
)

print("UNet Model, CrossEntropyLoss, and AdamW optimizer initialized.")

In [None]:
def fast_confusion_matrix(preds, labels, num_classes):
    preds = preds.reshape(-1)
    labels = labels.reshape(-1)
        
    mask = (labels >= 0) & (labels < num_classes)
    labels = labels[mask]
    preds = preds[mask]
    
    cm = np.bincount(
        num_classes * labels + preds,
        minlength=num_classes**2
    ).reshape(num_classes, num_classes)

    return cm

In [None]:
def compute_batch_metrics(logits, labels, num_classes):
    preds = torch.argmax(logits, dim=1)

    preds_np = preds.cpu().numpy()
    labels_np = labels.cpu().numpy()

    correct = (preds_np == labels_np).sum()
    total = np.prod(labels_np.shape)
    pixel_acc = correct / max(1, total)

    ious = []
    for c in range(num_classes):
        pred_c = (preds_np == c)
        label_c = (labels_np == c)

        intersection = np.logical_and(pred_c, label_c).sum()
        union = np.logical_or(pred_c, label_c).sum()

        if union > 0:
            iou_c = intersection / union
            ious.append(iou_c)

    if len(ious) > 0:
        mean_iou = float(np.mean(ious))
    else:
        mean_iou = 0.0

    return pixel_acc, mean_iou
print("Metrics Function Defined.")

In [None]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    num_batches = 0

    for batch_idx, batch in enumerate(dataloader):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()

        logits = model(pixel_values)

        loss = criterion(logits, labels)

        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        num_batches += 1

        if (batch_idx + 1) % 20 == 0:
            print(f"Batch {batch_idx + 1}/{len(dataloader)} - Loss: {loss.item():.4f}")
    epoch_loss = running_loss / max(1, num_batches)
    return epoch_loss
print("train_one_epoch function defined.")

In [None]:
def eval_epoch(model, dataloader, criterion, device, num_classes):
    model.eval()

    total_loss = 0.0
    num_batches = 0
    
    total_cm = np.zeros((num_classes, num_classes), dtype=np.int64)

    with torch.no_grad():
        for batch in dataloader:
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            logits = model(pixel_values)

            loss = criterion(logits, labels)

            preds = torch.argmax(logits, dim=1)
            
            preds_np = preds.cpu().numpy()
            labels_np = labels.cpu().numpy()

            cm = fast_confusion_matrix(preds_np, labels_np, num_classes)
            total_cm += cm

            total_loss += loss.item()
            num_batches += 1

    avg_loss = total_loss / max(1, num_batches)
    
    per_class_iou = {}
    ious = []

    for k in range(num_classes):
        tp = total_cm[k, k]
        fp = total_cm[:, k].sum() - tp
        fn = total_cm[k, :].sum() - tp
        denom = tp + fp + fn
        iou = (tp / denom) if denom > 0 else np.nan
        per_class_iou[k] = iou
        if not np.isnan(iou):
            ious.append(iou)

    mean_iou = float(np.mean(ious)) if len(ious) > 0 else 0.0
    
    correct = np.trace(total_cm)
    total = total_cm.sum()
    pixel_acc = (correct / total) if total > 0 else 0.0

    return avg_loss, pixel_acc, mean_iou, per_class_iou, total_cm

In [None]:
from pathlib import Path
import json
import numpy as np
import pandas as pd
import time

def save_final_test_report(run_dir, run_name, choose_sam, test_loss, test_acc, test_miou,
                           per_class_iou, test_cm, label_map, extra_config=None):
    run_dir = Path(run_dir)
    out_dir = run_dir / "final_report"
    out_dir.mkdir(parents=True, exist_ok=True)

    k = int(test_cm.shape[0])
    class_names = [label_map.get(i, f"class_{i}") for i in range(k)]
    support_pixels = test_cm.sum(axis=1).astype(np.int64)

    # per-class IoU CSV
    per_df = pd.DataFrame({
        "class_id": np.arange(k, dtype=int),
        "class_name": class_names,
        "support_pixels": support_pixels,
        "iou": [float(per_class_iou.get(i, np.nan)) for i in range(k)],
    })
    per_df.to_csv(out_dir / "per_class_iou.csv", index=False)

    # confusion matrix raw
    np.save(out_dir / "confusion_matrix.npy", test_cm)

    # summary JSON
    summary = {
        "run_name": run_name,
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
        "choose_sam": bool(choose_sam),
        "test_loss": float(test_loss),
        "pixel_accuracy": float(test_acc),
        "mIoU": float(test_miou),
        "num_classes": k,
        "config": extra_config or {},
    }
    with open(out_dir / "summary.json", "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)

    # one-line file for quick comparisons
    comp_path = run_dir / "final_report" / "comparison_row.csv"
    pd.DataFrame([{
        "run_name": run_name,
        "choose_sam": summary["choose_sam"],
        "test_loss": summary["test_loss"],
        "pixel_accuracy": summary["pixel_accuracy"],
        "mIoU": summary["mIoU"],
    }]).to_csv(comp_path, index=False)

    print(f"[OK] Final report saved: {out_dir}")
    return out_dir

## Training Loop

In [None]:
train_loss_history = []
val_loss_history = []
val_pixel_acc_history = []
val_miou_history = []
best_val_miou = -1.0
patience_counter = 0

print("Logging lists initialized.")

run_dir = Path("runs") / time.strftime("%Y-%m-%d_%H-%M-%S")
run_dir.mkdir(parents=True, exist_ok=True)

# TensorBoard
tb_writer = SummaryWriter(log_dir=str(run_dir))

per_class_cols = [f"iou_class_{k}" for k in range(num_classes)]

# CSV
csv_path = run_dir / "metrics.csv"
csv_fields = [
    "epoch",
    "train_loss",
    "val_loss",
    "val_pixel_acc",
    "val_miou",
    "epoch_time_sec",
    "lr",
    "weight_decay",
    "loss_gap",
    "miou_num_valid_classes",    
] + per_class_cols

with csv_path.open("w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=csv_fields)
    w.writeheader()

# ============================================
for epoch in range(NUM_EPOCHS):
    print(f"\n==============================")
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
    print(f"==============================")

    epoch_start = time.time()

    # training
    train_loss = train_epoch(
        model=model,
        dataloader=dataloaders['train'],
        optimizer=optimizer,
        criterion=criterion,
        device=device
    )

    # validation
    val_loss, val_acc, val_miou, per_class_iou, val_cm = eval_epoch(
        model=model,
        dataloader=dataloaders['valid'],
        criterion=criterion,
        device=device,
        num_classes=num_classes
    )

    scheduler.step(val_miou)

    epoch_time = time.time() - epoch_start
    lr = optimizer.param_groups[0]["lr"]
    wd = optimizer.param_groups[0].get("weight_decay", 0.0)

    valid_iou_vals = [v for v in per_class_iou.values() if v == v]
    miou_num_valid_classes = len(valid_iou_vals)

    # --- TensorBoard Scalars ---
    tb_writer.add_scalar("loss/train", train_loss, epoch + 1)
    tb_writer.add_scalar("loss/val", val_loss, epoch + 1)
    tb_writer.add_scalar("metrics/val_pixel_acc", val_acc, epoch + 1)
    tb_writer.add_scalar("metrics/val_miou", val_miou, epoch + 1)
    tb_writer.add_scalar("time/epoch_sec", epoch_time, epoch + 1)
    tb_writer.add_scalar("opt/lr", lr, epoch + 1)
    tb_writer.flush()

    # --- CSV row ---
    row = {
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "val_pixel_acc": val_acc,
        "val_miou": val_miou,
        "epoch_time_sec": epoch_time,
        "lr": lr,
        "weight_decay": wd,
        "loss_gap": train_loss - val_loss,
        "miou_num_valid_classes": miou_num_valid_classes,
    }

    for k in range(num_classes):
        v = per_class_iou.get(k, float("nan"))
        row[f"iou_class_{k}"] = (float(v) if v == v else "")
    
    with csv_path.open("a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=csv_fields)
        w.writerow(row)

    # Prints
    print(f"Train loss: {train_loss:.4f}")
    print("Validation:")
    print(f"  Loss: {val_loss:.4f}")
    print(f"  PixelAcc: {val_acc:.4f}")
    print(f"  mIoU: {val_miou:.4f}")

    train_loss_history.append(train_loss)
    val_loss_history.append(val_loss)
    val_pixel_acc_history.append(val_acc)
    val_miou_history.append(val_miou)

    # Checkpointing
    if val_miou > best_val_miou:
        best_val_miou = val_miou

        if CHOOSE_SAM:
            torch.save({
                "epoch": epoch,
                "head_state_dict": model.head.state_dict(),   # <─ only the head
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": val_loss,
                "accuracy": val_acc,
                "iou": val_miou,
            }, CHECKPOINT_PATH)
        else:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
                'accuracy': val_acc,
                'iou': val_miou,
            }, CHECKPOINT_PATH)
        print(f"Model saved! Current best validation IoU: {best_val_miou:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1
        print(f"Patience: No improvements since {patience_counter} epochs.")
        
    if patience_counter > PATIENCE:
        print(f"Patience limit reached. Terminating Training...")
        break

tb_writer.close()
print("Logs saved in:", run_dir)

In [None]:
checkpoint = torch.load(CHECKPOINT_PATH, weights_only=False)

if CHOOSE_SAM:
    model.head.load_state_dict(checkpoint["head_state_dict"])
else:
    model.load_state_dict(checkpoint['model_state_dict'])

print("Best model loaded successfully.")

## Evaluation

In [None]:
test_loss, test_acc, test_miou, per_class_iou, test_cm = eval_epoch(
    model=model,
    dataloader=dataloaders['test'],
    criterion=criterion,
    device=device,
    num_classes=num_classes
)

run_name = ("SAM" if CHOOSE_SAM else "UNET") + f"_seed{SEED}_bs{BATCH_SIZE}_ep{NUM_EPOCHS}"

extra_config = {
    "seed": SEED,
    "batch_size": BATCH_SIZE,
    "epochs": NUM_EPOCHS,
    "target_image_size": TARGET_IMAGE_SIZE,
    "weighted_loss": USE_WEIGHTED_LOSS,
    "lr": optimizer.param_groups[0]["lr"],
}

final_dir = save_final_test_report(
    run_dir=run_dir,
    run_name=run_name,
    choose_sam=CHOOSE_SAM,
    test_loss=test_loss,
    test_acc=test_acc,
    test_miou=test_miou,
    per_class_iou=per_class_iou,
    test_cm=test_cm,
    label_map=LABEL_MAP,
    extra_config=extra_config,
)

print(f"\n--- Test Metrics (Best Model) ---")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test IoU: {test_miou:.4f}")
print(f"Per Class IoU: {per_class_iou}")

In [None]:
def show_random_examples(dataloader, model, device, num_examples=6):
    model.eval()
    collected = 0

    fig, axes = plt.subplots(
        num_examples, 3, figsize=(15, 4 * num_examples),
        gridspec_kw={"width_ratios": [1, 1, 1.2]}
    )

    if num_examples == 1:
        axes = np.expand_dims(axes, axis=0)

    with torch.no_grad():
        for batch in dataloader:
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            logits = model(pixel_values)
            preds = torch.argmax(logits, dim=1)

            for i in range(pixel_values.size(0)):
                if collected >= num_examples:
                    break
                
                img = pixel_values[i].cpu().permute(1, 2, 0).numpy()
                img = (img * np.array(STD) + np.array(MEAN))
                img = np.clip(img, 0.0, 1.0)

                gt_mask_np = labels[i].cpu().numpy()
                pred_mask_np = preds[i].cpu().numpy()

                gt_rgb = decode_class_id_to_rgb(gt_mask_np)
                pred_rgb = decode_class_id_to_rgb(pred_mask_np)

                patches = legend_patches_for_present_classes(
                    gt_mask_np,
                    pred_mask_np,
                    LABEL_MAP,
                    ID_TO_RGB
                )

                ax_img, ax_gt, ax_pred = axes[collected]
                
                ax_img.imshow(img)
                ax_img.set_title("Image")
                ax_img.axis("off")

                ax_gt.imshow(gt_rgb)
                ax_gt.set_title("Ground Truth")
                ax_gt.axis("off")

                ax_pred.imshow(pred_rgb)
                ax_pred.set_title("Prediction")
                ax_pred.axis("off")
                
                ax_pred.legend(
                    handles=patches,
                    title="Land Cover",
                    loc="center left",
                    bbox_to_anchor=(1.02, 0.5),
                    frameon=True,
                    fontsize=8,
                    title_fontsize=9
                )

                collected += 1

            if collected >= num_examples:
                break

    plt.tight_layout()
    plt.show()
show_random_examples(dataloaders['test'], model, device, num_examples=6)

In [None]:
def plot_confusion_matrix(
    cm,
    class_names=None,
    normalize=None,
    title="Confusion Matrix",
    label_map=None,
):

    cm = cm.astype(np.float64)

    if normalize == "true":
        denom = cm.sum(axis=1, keepdims=True)
        cm = np.divide(cm, denom, out=np.zeros_like(cm), where=denom != 0)
        fmt = ".2f"
        plot_title = title + " (row-normalized)"
    elif normalize == "pred":
        denom = cm.sum(axis=0, keepdims=True)
        cm = np.divide(cm, denom, out=np.zeros_like(cm), where=denom != 0)
        fmt = ".2f"
        plot_title = title + " (col-normalized)"
    elif normalize == "all":
        denom = cm.sum()
        cm = cm / denom if denom != 0 else cm
        fmt = ".3f"
        plot_title = title + " (global-normalized)"
    else:
        fmt = "d"
        plot_title = title

    plt.figure(figsize=(7, 6))
    plt.imshow(cm, interpolation="nearest")
    plt.title(plot_title)
    plt.xlabel("Predicted")
    plt.ylabel("Ground Truth")
    plt.colorbar()

    k = cm.shape[0]

    if class_names is None:
        class_names = list(range(k))

    if label_map is not None:
        class_names = [label_map.get(c, str(c)) for c in class_names]

    plt.xticks(np.arange(k), class_names, rotation=45, ha="right")
    plt.yticks(np.arange(k), class_names)

    thresh = (cm.max() * 0.6) if cm.size > 0 else 0
    for i in range(k):
        for j in range(k):
            val = cm[i, j]
            txt = format(int(val), fmt) if fmt == "d" else format(val, fmt)
            plt.text(
                j,
                i,
                txt,
                ha="center",
                va="center",
                color="white" if val > thresh else "black",
            )

    plt.tight_layout()
    plt.show()

In [None]:
# raw counts
plot_confusion_matrix(test_cm, title="Test Confusion Matrix", label_map=LABEL_MAP)

plot_confusion_matrix(test_cm, normalize="pred", title="Test Confusion Matrix", label_map=LABEL_MAP)

plot_confusion_matrix(test_cm, normalize="true", title="Test Confusion Matrix", label_map=LABEL_MAP)

plot_confusion_matrix(test_cm, normalize="all", title="Test Confusion Matrix", label_map=LABEL_MAP)

In [None]:
df = pd.read_csv(csv_path)

epochs = df["epoch"].values
train_loss_history = df["train_loss"].values
val_loss_history = df["val_loss"].values
val_pixel_acc_history = df["val_pixel_acc"].values
val_miou_history = df["val_miou"].values

# 1) Loss
plt.figure(figsize=(7, 4))
plt.plot(epochs, train_loss_history, label="Train Loss")
plt.plot(epochs, val_loss_history, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss over epochs")
plt.legend()
plt.grid(True)
plt.show()

# 2) Pixel Accuracy
plt.figure(figsize=(7, 4))
plt.plot(epochs, val_pixel_acc_history, label="Val Pixel Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Pixel Accuracy")
plt.title("Validation Pixel Accuracy over epochs")
plt.legend()
plt.grid(True)
plt.show()

# 3) Mean IoU
plt.figure(figsize=(7, 4))
plt.plot(epochs, val_miou_history, label="Val Mean IoU")
plt.xlabel("Epoch")
plt.ylabel("Mean IoU")
plt.title("Validation Mean IoU over epochs")
plt.legend()
plt.grid(True)
plt.show()