In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
pip install opencv-python numpy matplotlib tqdm

Note: you may need to restart the kernel to use updated packages.


In [3]:
import os
import cv2
import numpy as np
from tqdm import tqdm

# -----------------------------
# CONFIG
# -----------------------------
IMAGE_DIR = "/kaggle/input/xbd-dataset/xbd/tier1/images"
MASK_DIR  = "/kaggle/input/xbd-dataset/xbd/tier1/masks"

PIXEL_DIFF_THRESHOLD = 30

# Damage severity thresholds (ratio of changed pixels)
SLIGHT_THR = 0.02
MODERATE_THR = 0.10

# -----------------------------
# FUNCTIONS
# -----------------------------
def load_gray(path):
    img = cv2.imread(path)
    return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

def compute_change_binary(pre, post):
    diff = cv2.absdiff(post, pre)
    diff = cv2.GaussianBlur(diff, (5, 5), 0)
    _, binary = cv2.threshold(diff, PIXEL_DIFF_THRESHOLD, 255, cv2.THRESH_BINARY)
    return binary

def damage_severity(change_ratio):
    if change_ratio < SLIGHT_THR:
        return 0  # No / Very slight
    elif change_ratio < MODERATE_THR:
        return 1  # Moderate
    else:
        return 2  # High

def load_mask(mask_path):
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    return (mask > 0).astype(np.uint8)

def iou_score(pred, gt):
    intersection = np.logical_and(pred, gt).sum()
    union = np.logical_or(pred, gt).sum()
    return intersection / union if union != 0 else 1.0

# -----------------------------
# PAIR IMAGES
# -----------------------------
files = os.listdir(IMAGE_DIR)

pre_imgs, post_imgs = {}, {}
for f in files:
    if "pre" in f:
        key = f.replace("_pre_disaster.png", "")
        pre_imgs[key] = f
    elif "post" in f:
        key = f.replace("_post_disaster.png", "")
        post_imgs[key] = f

keys = list(set(pre_imgs.keys()) & set(post_imgs.keys()))

# -----------------------------
# METRICS
# -----------------------------
correct = 0
total = 0
ious = []

# -----------------------------
# PROCESS
# -----------------------------
for key in tqdm(keys):
    pre = load_gray(os.path.join(IMAGE_DIR, pre_imgs[key]))
    post = load_gray(os.path.join(IMAGE_DIR, post_imgs[key]))

    change_binary = compute_change_binary(pre, post)
    change_ratio = np.count_nonzero(change_binary) / change_binary.size
    pred_severity = damage_severity(change_ratio)

    # ---- Ground Truth from mask ----
    mask_path = os.path.join(MASK_DIR, key + "_post_disaster.png")
    if not os.path.exists(mask_path):
        continue

    gt_mask = load_mask(mask_path)
    gt_ratio = gt_mask.sum() / gt_mask.size
    gt_severity = damage_severity(gt_ratio)

    # ---- Accuracy ----
    if pred_severity == gt_severity:
        correct += 1
    total += 1

    # ---- IoU ----
    ious.append(iou_score(change_binary > 0, gt_mask))

# -----------------------------
# RESULTS
# -----------------------------
accuracy = correct / total if total > 0 else 0
mean_iou = np.mean(ious)

print(f"Classification Accuracy: {accuracy:.3f}")
print(f"Mean IoU: {mean_iou:.3f}")


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2799/2799 [04:38<00:00, 10.07it/s]

Classification Accuracy: 0.243
Mean IoU: 0.057





In [4]:
PIXEL_DIFF_THRESHOLDS = [10, 20, 30, 40, 50]

results_pixel = []

for pix_thr in PIXEL_DIFF_THRESHOLDS:
    correct, total, ious = 0, 0, []

    print(f"\nRunning for PIXEL_DIFF_THRESHOLD = {pix_thr}")

    for key in tqdm(keys, desc=f"PixelThr={pix_thr}", leave=False):
        pre = load_gray(os.path.join(IMAGE_DIR, pre_imgs[key]))
        post = load_gray(os.path.join(IMAGE_DIR, post_imgs[key]))

        diff = cv2.absdiff(post, pre)
        diff = cv2.GaussianBlur(diff, (5, 5), 0)
        _, change_binary = cv2.threshold(diff, pix_thr, 255, cv2.THRESH_BINARY)

        change_ratio = np.count_nonzero(change_binary) / change_binary.size
        pred_severity = damage_severity(change_ratio)

        mask_path = os.path.join(MASK_DIR, key + "_post_disaster.png")
        if not os.path.exists(mask_path):
            continue

        gt_mask = load_mask(mask_path)
        gt_ratio = gt_mask.sum() / gt_mask.size
        gt_severity = damage_severity(gt_ratio)

        if pred_severity == gt_severity:
            correct += 1
        total += 1
        ious.append(iou_score(change_binary > 0, gt_mask))

    acc = correct / total
    miou = np.mean(ious)

    results_pixel.append((pix_thr, acc, miou))
    print(f"PixelThr={pix_thr} | Accuracy={acc:.3f} | Mean IoU={miou:.3f}")



Running for PIXEL_DIFF_THRESHOLD = 10


                                                                

PixelThr=10 | Accuracy=0.209 | Mean IoU=0.062

Running for PIXEL_DIFF_THRESHOLD = 20


                                                                

PixelThr=20 | Accuracy=0.213 | Mean IoU=0.061

Running for PIXEL_DIFF_THRESHOLD = 30


                                                                

PixelThr=30 | Accuracy=0.243 | Mean IoU=0.057

Running for PIXEL_DIFF_THRESHOLD = 40


                                                                

PixelThr=40 | Accuracy=0.306 | Mean IoU=0.054

Running for PIXEL_DIFF_THRESHOLD = 50


                                                                

PixelThr=50 | Accuracy=0.395 | Mean IoU=0.049




In [5]:
SEVERITY_THRESHOLDS = [
    (0.01, 0.05),
    (0.02, 0.10),
    (0.05, 0.15),
    (0.05, 0.20),
]

results_severity = []

for slight_thr, moderate_thr in SEVERITY_THRESHOLDS:
    SLIGHT_THR = slight_thr
    MODERATE_THR = moderate_thr

    correct, total, ious = 0, 0, []

    print(f"\nRunning for SLIGHT<{SLIGHT_THR}, MODERATE<{MODERATE_THR}")

    for key in tqdm(keys, desc=f"S={SLIGHT_THR},M={MODERATE_THR}", leave=False):
        pre = load_gray(os.path.join(IMAGE_DIR, pre_imgs[key]))
        post = load_gray(os.path.join(IMAGE_DIR, post_imgs[key]))

        change_binary = compute_change_binary(pre, post)
        change_ratio = np.count_nonzero(change_binary) / change_binary.size
        pred_severity = damage_severity(change_ratio)

        mask_path = os.path.join(MASK_DIR, key + "_post_disaster.png")
        if not os.path.exists(mask_path):
            continue

        gt_mask = load_mask(mask_path)
        gt_ratio = gt_mask.sum() / gt_mask.size
        gt_severity = damage_severity(gt_ratio)

        if pred_severity == gt_severity:
            correct += 1
        total += 1
        ious.append(iou_score(change_binary > 0, gt_mask))

    acc = correct / total
    miou = np.mean(ious)

    results_severity.append((slight_thr, moderate_thr, acc, miou))
    print(f"Slight<{slight_thr}, Moderate<{moderate_thr} | Acc={acc:.3f} | IoU={miou:.3f}")



Running for SLIGHT<0.01, MODERATE<0.05


                                                                  

Slight<0.01, Moderate<0.05 | Acc=0.326 | IoU=0.057

Running for SLIGHT<0.02, MODERATE<0.1


                                                                 

Slight<0.02, Moderate<0.1 | Acc=0.243 | IoU=0.057

Running for SLIGHT<0.05, MODERATE<0.15


                                                                  

Slight<0.05, Moderate<0.15 | Acc=0.216 | IoU=0.057

Running for SLIGHT<0.05, MODERATE<0.2


                                                                 

Slight<0.05, Moderate<0.2 | Acc=0.223 | IoU=0.057




# Siamese U-Net

In [6]:
import os
import torch
import cv2
from torch.utils.data import Dataset

class XBDDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

        self.pre_imgs = {}
        self.post_imgs = {}

        for f in os.listdir(image_dir):
            if "pre" in f:
                key = f.replace("_pre_disaster.png", "")
                self.pre_imgs[key] = f
            elif "post" in f:
                key = f.replace("_post_disaster.png", "")
                self.post_imgs[key] = f

        self.keys = list(set(self.pre_imgs) & set(self.post_imgs))

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        key = self.keys[idx]

        pre = cv2.imread(os.path.join(self.image_dir, self.pre_imgs[key]))
        post = cv2.imread(os.path.join(self.image_dir, self.post_imgs[key]))
        mask = cv2.imread(os.path.join(self.mask_dir, key + "_post_disaster.png"), 0)

        pre = cv2.cvtColor(pre, cv2.COLOR_BGR2RGB)
        post = cv2.cvtColor(post, cv2.COLOR_BGR2RGB)
        mask = (mask > 0).astype("float32")

        pre = torch.tensor(pre).permute(2, 0, 1) / 255.0
        post = torch.tensor(post).permute(2, 0, 1) / 255.0
        mask = torch.tensor(mask).unsqueeze(0)

        return pre, post, mask


In [7]:
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


In [8]:
class SiameseUNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Shared encoder
        self.enc1 = ConvBlock(3, 64)
        self.enc2 = ConvBlock(64, 128)
        self.enc3 = ConvBlock(128, 256)

        self.pool = nn.MaxPool2d(2)

        # Decoder
        self.up2 = nn.ConvTranspose2d(512, 128, 2, stride=2)
        self.dec2 = ConvBlock(384, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = ConvBlock(192, 64)

        self.final = nn.Conv2d(64, 1, 1)

    def forward(self, pre, post):
        # Encoder (shared)
        p1 = self.enc1(pre)
        p2 = self.enc2(self.pool(p1))
        p3 = self.enc3(self.pool(p2))

        q1 = self.enc1(post)
        q2 = self.enc2(self.pool(q1))
        q3 = self.enc3(self.pool(q2))

        # Feature fusion
        f3 = torch.cat([p3, q3], dim=1)

        # Decoder
        d2 = self.up2(f3)
        d2 = self.dec2(torch.cat([d2, p2, q2], dim=1))

        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, p1, q1], dim=1))

        return torch.sigmoid(self.final(d1))


In [9]:


model = SiameseUNet().cuda()
# model = SiameseUNet()
print(model)

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [None]:
# import torch
# torch.cuda.empty_cache()

In [None]:
# from torch.utils.data import DataLoader
# import torch.optim as optim
# import torch

# dataset = XBDDataset(
#     image_dir="/kaggle/input/xbd-dataset/xbd/tier1/images",
#     mask_dir="/kaggle/input/xbd-dataset/xbd/tier1/masks"
# )

# loader = DataLoader(dataset, batch_size=1, shuffle=True)


# criterion = nn.BCELoss()
# optimizer = optim.Adam(model.parameters(), lr=1e-4)

# for epoch in range(10):
#     model.train()
#     epoch_loss = 0

#     for pre, post, mask in loader:
#         pre, post, mask = pre.cuda(), post.cuda(), mask.cuda()

#         pred = model(pre, post)
#         loss = criterion(pred, mask)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         epoch_loss += loss.item()

#     print(f"Epoch {epoch+1} | Loss: {epoch_loss/len(loader):.4f}")


In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim
import torch
import torch.nn as nn
from tqdm import tqdm

dataset = XBDDataset(
    image_dir="/kaggle/input/xbd-dataset/xbd/tier1/images",
    mask_dir="/kaggle/input/xbd-dataset/xbd/tier1/masks"
)

loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=2,       # safe for Kaggle
    pin_memory=True      # helps GPU transfer
)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(10):
    model.train()
    epoch_loss = 0.0

    progress_bar = tqdm(
        loader,
        desc=f"Epoch [{epoch+1}/10]",
        leave=False
    )

    for pre, post, mask in progress_bar:
        pre  = pre.to(device, non_blocking=True)
        post = post.to(device, non_blocking=True)
        mask = mask.to(device, non_blocking=True)

        optimizer.zero_grad()

        pred = model(pre, post)
        loss = criterion(pred, mask)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        progress_bar.set_postfix(
            loss=f"{epoch_loss / (progress_bar.n + 1):.4f}"
        )

    print(f"Epoch {epoch+1} | Avg Loss: {epoch_loss / len(loader):.4f}")


In [None]:
import matplotlib.pyplot as plt


avg_loss = [0.1837, 0.0958, 0.0810, 0.0728, 0.0693, 0.0647, 0.0620, 0.0604, 0.0587, 0.0566]
x = [a for a in range(1,11)]

plt.plot(x, avg_loss)
plt.title("Training loss for Siamese UNet")
plt.show()

In [None]:
import os

save_path = "/kaggle/working/siamese_unet_xbd.pth"

# Move model to CPU before saving (avoids GPU-specific issues)
model_cpu = model.to("cpu")

torch.save(
    {
        "model_state_dict": model_cpu.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epochs": 10
    },
    save_path
)

print(f"Model saved to {save_path}")


In [None]:
model_params = sum(p.numel() for p in model.parameters()) * 4 / 1024**2

print("Model parameter size is(in MB):", model_params)

# Testing on trained model

In [None]:
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


In [None]:
class SiameseUNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Shared encoder
        self.enc1 = ConvBlock(3, 64)
        self.enc2 = ConvBlock(64, 128)
        self.enc3 = ConvBlock(128, 256)

        self.pool = nn.MaxPool2d(2)

        # Decoder
        self.up2 = nn.ConvTranspose2d(512, 128, 2, stride=2)
        self.dec2 = ConvBlock(384, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = ConvBlock(192, 64)

        self.final = nn.Conv2d(64, 1, 1)

    def forward(self, pre, post):
        # Encoder (shared)
        p1 = self.enc1(pre)
        p2 = self.enc2(self.pool(p1))
        p3 = self.enc3(self.pool(p2))

        q1 = self.enc1(post)
        q2 = self.enc2(self.pool(q1))
        q3 = self.enc3(self.pool(q2))

        # Feature fusion
        f3 = torch.cat([p3, q3], dim=1)

        # Decoder
        d2 = self.up2(f3)
        d2 = self.dec2(torch.cat([d2, p2, q2], dim=1))

        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, p1, q1], dim=1))

        return torch.sigmoid(self.final(d1))


In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = SiameseUNet().to(device)
# print(model)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint = torch.load("/kaggle/input/datasets/draqa123/siamese-unet/siamese_unet_xbd.pth", map_location=device)

model = SiameseUNet().to(device)
model.load_state_dict(checkpoint["model_state_dict"])

model.eval()


In [None]:
print("Loaded successfully")
print(sum(p.numel() for p in model.parameters()))


In [None]:
import os
import torch
import cv2
from torch.utils.data import Dataset

class XBDDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

        self.pre_imgs = {}
        self.post_imgs = {}

        for f in os.listdir(image_dir):
            if "pre" in f:
                key = f.replace("_pre_disaster.png", "")
                self.pre_imgs[key] = f
            elif "post" in f:
                key = f.replace("_post_disaster.png", "")
                self.post_imgs[key] = f

        self.keys = list(set(self.pre_imgs) & set(self.post_imgs))

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        key = self.keys[idx]

        pre = cv2.imread(os.path.join(self.image_dir, self.pre_imgs[key]))
        post = cv2.imread(os.path.join(self.image_dir, self.post_imgs[key]))
        mask = cv2.imread(os.path.join(self.mask_dir, key + "_post_disaster.png"), 0)

        pre = cv2.cvtColor(pre, cv2.COLOR_BGR2RGB)
        post = cv2.cvtColor(post, cv2.COLOR_BGR2RGB)
        mask = (mask > 0).astype("float32")

        pre = torch.tensor(pre).permute(2, 0, 1) / 255.0
        post = torch.tensor(post).permute(2, 0, 1) / 255.0
        mask = torch.tensor(mask).unsqueeze(0)

        return pre, post, mask


In [None]:
dataset = XBDDataset(
    image_dir="/kaggle/input/xbd-dataset/xbd/tier1/images",
    mask_dir="/kaggle/input/xbd-dataset/xbd/tier1/masks"
)

In [None]:
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.eval()

test_loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False
)

with torch.no_grad():
    for pre, post, mask in test_loader:

        pre  = pre.to(device)
        post = post.to(device)
        mask = mask.to(device)

        # Forward pass
        pred = model(pre, post)

        # Threshold
        pred_mask = (pred > 0.5).float()

        break   # only first test sample


In [None]:
plt.figure(figsize=(15,5))

# Pre-disaster image
plt.subplot(1,4,1)
plt.title("Pre Image")
plt.imshow(pre[0].cpu().permute(1,2,0))
plt.axis("off")

# Post-disaster image
plt.subplot(1,4,2)
plt.title("Post Image")
plt.imshow(post[0].cpu().permute(1,2,0))
plt.axis("off")

# Ground Truth Mask
plt.subplot(1,4,3)
plt.title("Ground Truth")
plt.imshow(mask[0].cpu().squeeze(), cmap="gray")
plt.axis("off")

# Predicted Mask
plt.subplot(1,4,4)
plt.title("Prediction")
plt.imshow(pred_mask[0].cpu().squeeze(), cmap="gray")
plt.axis("off")

plt.show()


In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.eval()

test_loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False
)

total_dice = 0
total_iou = 0
total_correct = 0
total_pixels = 0

with torch.no_grad():
    for pre, post, mask in tqdm(test_loader):

        pre  = pre.to(device)
        post = post.to(device)
        mask = mask.to(device)

        # Forward
        pred = model(pre, post)

        pred_mask = (pred > 0.5).float()

        # ---- Metrics ----

        # Dice
        intersection = (pred_mask * mask).sum()
        dice = (2 * intersection) / (
            pred_mask.sum() + mask.sum() + 1e-8
        )

        # IoU
        union = pred_mask.sum() + mask.sum() - intersection
        iou = intersection / (union + 1e-8)

        # Pixel Accuracy
        correct = (pred_mask == mask).sum()
        total_correct += correct.item()
        total_pixels += torch.numel(mask)

        total_dice += dice.item()
        total_iou += iou.item()

# Final Results
num_samples = len(test_loader)

print("Evaluation Results:")
print("Dice Score :", total_dice / num_samples)
print("IoU        :", total_iou / num_samples)
print("Pixel Acc  :", total_correct / total_pixels)


In [None]:
with torch.no_grad():
    pre, post, mask = next(iter(test_loader))

    pre  = pre.to(device)
    post = post.to(device)
    mask = mask.to(device)

    pred = model(pre, post)

    print("Pred raw min/max:", pred.min().item(), pred.max().item())
    print("Mask min/max:", mask.min().item(), mask.max().item())
    print("Pred shape:", pred.shape)
    print("Mask shape:", mask.shape)


# Using SPAUnet

In [None]:
import os
import cv2
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm


In [None]:
class XBDDataset(Dataset):
    def __init__(self, image_dir, mask_dir=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir

        self.pre_imgs = {}
        self.post_imgs = {}

        for f in os.listdir(image_dir):
            if "pre_disaster" in f:
                key = f.replace("_pre_disaster.png", "")
                self.pre_imgs[key] = f
            elif "post_disaster" in f:
                key = f.replace("_post_disaster.png", "")
                self.post_imgs[key] = f

        self.keys = sorted(list(set(self.pre_imgs) & set(self.post_imgs)))

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        key = self.keys[idx]

        pre = cv2.imread(os.path.join(self.image_dir, self.pre_imgs[key]))
        post = cv2.imread(os.path.join(self.image_dir, self.post_imgs[key]))

        pre = cv2.cvtColor(pre, cv2.COLOR_BGR2RGB)
        post = cv2.cvtColor(post, cv2.COLOR_BGR2RGB)

        pre = torch.tensor(pre).permute(2, 0, 1).float() / 255.0
        post = torch.tensor(post).permute(2, 0, 1).float() / 255.0

        if self.mask_dir:
            mask_path = os.path.join(self.mask_dir, key + "_post_disaster.png")
            mask = cv2.imread(mask_path, 0)
            mask = torch.tensor(mask).long()
            return pre, post, mask

        return pre, post


In [None]:
tier1 = XBDDataset("/kaggle/input/xbd-dataset/xbd/tier1/images", "/kaggle/input/xbd-dataset/xbd/tier1/masks")
tier3 = XBDDataset("/kaggle/input/xbd-dataset/xbd/tier3/images", "/kaggle/input/xbd-dataset/xbd/tier3/masks")

full_labeled = torch.utils.data.ConcatDataset([tier1, tier3])

total_len = len(full_labeled)
labeled_len = int(0.2 * total_len)

labeled_dataset, _ = random_split(
    full_labeled,
    [labeled_len, total_len - labeled_len]
)

val_dataset = XBDDataset("/kaggle/input/xbd-dataset/xbd/hold/images", "/kaggle/input/xbd-dataset/xbd/hold/masks")

labeled_loader = DataLoader(labeled_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)


# model arch.

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


In [None]:
class SPABlock(nn.Module):
    def __init__(self, in_ch):
        super().__init__()

        self.pool1 = nn.AdaptiveAvgPool2d(1)
        self.pool2 = nn.AdaptiveAvgPool2d(2)
        self.pool3 = nn.AdaptiveAvgPool2d(4)

        self.conv = nn.Conv2d(in_ch * 3, in_ch, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        h, w = x.shape[2:]

        p1 = F.interpolate(self.pool1(x), size=(h, w), mode='bilinear')
        p2 = F.interpolate(self.pool2(x), size=(h, w), mode='bilinear')
        p3 = F.interpolate(self.pool3(x), size=(h, w), mode='bilinear')

        out = torch.cat([p1, p2, p3], dim=1)
        attention = self.sigmoid(self.conv(out))

        return x * attention


In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.enc1 = DoubleConv(3, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)

        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(self.pool(x1))
        x3 = self.enc3(self.pool(x2))
        x4 = self.enc4(self.pool(x3))
        return x1, x2, x3, x4


In [None]:
class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)


In [None]:
class SPAUNet(nn.Module):
    def __init__(self, num_classes=5):
        super().__init__()

        self.encoder = Encoder()
        self.spa = SPABlock(512)

        self.up3 = UpBlock(512, 256)
        self.up2 = UpBlock(256, 128)
        self.up1 = UpBlock(128, 64)

        self.final = nn.Conv2d(64, num_classes, 1)

    def forward(self, pre, post):
        p1, p2, p3, p4 = self.encoder(pre)
        q1, q2, q3, q4 = self.encoder(post)

        f1 = torch.abs(p1 - q1)
        f2 = torch.abs(p2 - q2)
        f3 = torch.abs(p3 - q3)
        f4 = torch.abs(p4 - q4)

        bottleneck = self.spa(f4)

        d3 = self.up3(bottleneck, f3)
        d2 = self.up2(d3, f2)
        d1 = self.up1(d2, f1)

        return self.final(d1)


In [None]:
def compute_iou(pred, mask, num_classes=5):
    ious = []
    pred = torch.argmax(pred, dim=1)

    for cls in range(num_classes):
        pred_cls = (pred == cls)
        mask_cls = (mask == cls)

        intersection = (pred_cls & mask_cls).sum().float()
        union = (pred_cls | mask_cls).sum().float()

        if union == 0:
            continue

        ious.append(intersection / union)

    return torch.mean(torch.stack(ious))


In [None]:
import os

save_dir = "/kaggle/working/checkpoints"
os.makedirs(save_dir, exist_ok=True)


In [None]:
best_iou = 0
start_epoch = 0


In [None]:
resume_path = os.path.join(save_dir, "last_model.pth")

if os.path.exists(resume_path):
    print("Resuming training...")
    checkpoint = torch.load(resume_path)
    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    start_epoch = checkpoint["epoch"]
    best_iou = checkpoint["best_iou"]


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SPAUNet(num_classes=5).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

epochs = 30

for epoch in range(start_epoch, epochs):

    model.train()
    train_loss = 0

    for pre, post, mask in tqdm(labeled_loader):
        pre = pre.to(device)
        post = post.to(device)
        mask = mask.to(device)

        optimizer.zero_grad()
        output = model(pre, post)
        loss = criterion(output, mask)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    model.eval()
    val_iou = 0

    with torch.no_grad():
        for pre, post, mask in val_loader:
            pre = pre.to(device)
            post = post.to(device)
            mask = mask.to(device)

            output = model(pre, post)
            val_iou += compute_iou(output, mask).item()

    val_iou /= len(val_loader)

    print(f"\nEpoch {epoch+1}")
    print(f"Train Loss: {train_loss/len(labeled_loader):.4f}")
    print(f"Val mIoU: {val_iou:.4f}")

    # ðŸ”¹ Save Last Model Every Epoch
    torch.save({
        "epoch": epoch + 1,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "best_iou": best_iou
    }, os.path.join(save_dir, "last_model.pth"))

    # ðŸ”¹ Save Best Model
    if val_iou > best_iou:
        best_iou = val_iou
        torch.save(model.state_dict(),
                   os.path.join(save_dir, "best_model.pth"))
        print("âœ… Saved Best Model")
