# Oil Spill Detection – Model Training & Evaluation

This notebook covers:
- Model architecture design
- Custom attention modules
- Training loop
- Validation & testing
- mIoU and per-class IoU evaluation

The model is trained for semantic segmentation of oil spill phenomena in aerial imagery.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


Using device: cuda


## Attention Modules (Implemented from Scratch)

We implement Channel Attention and Spatial Attention modules
to enhance feature representation, especially for:
- Thin oil sheens
- Small objects (ships, platforms)


In [3]:
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg = self.fc(self.avg_pool(x))
        max_ = self.fc(self.max_pool(x))
        return x * self.sigmoid(avg + max_)


In [4]:
class SpatialAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg = torch.mean(x, dim=1, keepdim=True)
        max_, _ = torch.max(x, dim=1, keepdim=True)
        attn = torch.cat([avg, max_], dim=1)
        return x * self.sigmoid(self.conv(attn))


## Encoder: Modified ResNet-50 Backbone

- Pretrained on ImageNet
- Used as feature extractor
- Outputs multi-scale feature maps


In [5]:
class ResNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

        self.stage0 = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu,
            backbone.maxpool
        )
        self.stage1 = backbone.layer1   # 256
        self.stage2 = backbone.layer2   # 512
        self.stage3 = backbone.layer3   # 1024
        self.stage4 = backbone.layer4   # 2048

    def forward(self, x):
        x0 = self.stage0(x)
        x1 = self.stage1(x0)
        x2 = self.stage2(x1)
        x3 = self.stage3(x2)
        x4 = self.stage4(x3)
        return x1, x2, x3, x4


## Decoder Blocks with Attention


In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
        self.ca = ChannelAttention(out_ch)
        self.sa = SpatialAttention()

    def forward(self, x):
        x = self.conv(x)
        x = self.ca(x)
        x = self.sa(x)
        return x


## Full Model Architecture – OilSpillNet


In [7]:
class OilSpillNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.encoder = ResNetEncoder()

        self.dec4 = DecoderBlock(2048, 512)
        self.dec3 = DecoderBlock(1024 + 512, 256)
        self.dec2 = DecoderBlock(512 + 256, 128)
        self.dec1 = DecoderBlock(256 + 128, 64)

        self.final = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        x1, x2, x3, x4 = self.encoder(x)

        d4 = F.interpolate(self.dec4(x4), size=x3.shape[2:], mode="bilinear", align_corners=False)
        d3 = F.interpolate(self.dec3(torch.cat([x3, d4], dim=1)), size=x2.shape[2:], mode="bilinear", align_corners=False)
        d2 = F.interpolate(self.dec2(torch.cat([x2, d3], dim=1)), size=x1.shape[2:], mode="bilinear", align_corners=False)
        d1 = self.dec1(torch.cat([x1, d2], dim=1))

        out = self.final(F.interpolate(d1, size=x.shape[2:], mode="bilinear", align_corners=False))
        return out


## Loss Functions
- Cross Entropy (class-weighted)
- Dice Loss (overlap-focused)


In [8]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1):
        super().__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = F.softmax(preds, dim=1)
        targets_oh = F.one_hot(targets, NUM_CLASSES).permute(0,3,1,2)

        intersection = (preds * targets_oh).sum(dim=(2,3))
        union = preds.sum(dim=(2,3)) + targets_oh.sum(dim=(2,3))
        dice = (2 * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()


In [9]:
ce_loss   = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE))
dice_loss = DiceLoss()


## Metrics: mIoU


In [10]:
def compute_miou(pred, target):
    ious = []
    for cls in range(NUM_CLASSES):
        inter = ((pred==cls)&(target==cls)).sum()
        union = (pred==cls).sum() + (target==cls).sum() - inter
        if union > 0:
            ious.append(inter.float()/union.float())
    return torch.mean(torch.stack(ious))


In [11]:
def compute_miou_batch(outputs, masks, num_classes):
    """
    outputs: Tensor [B, C, H, W]
    masks:   Tensor [B, H, W]
    return:  scalar mIoU for the batch
    """
    preds = torch.argmax(outputs, dim=1)  # [B, H, W]

    ious = []

    for cls in range(num_classes):
        pred_i = (preds == cls)
        mask_i = (masks == cls)

        intersection = (pred_i & mask_i).sum(dim=(1, 2)).float()
        union = pred_i.sum(dim=(1, 2)).float() + mask_i.sum(dim=(1, 2)).float() - intersection

        valid = union > 0
        if valid.sum() > 0:
            ious.append((intersection[valid] / union[valid]).mean())

    if len(ious) == 0:
        return torch.tensor(0.0, device=outputs.device)

    return torch.mean(torch.stack(ious))


## Training Configuration


In [12]:
# ===============================
# FINAL TRAINING + EARLY STOPPING
# ===============================

model = OilSpillNet(NUM_CLASSES).to(DEVICE)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=1e-4
)

# LR Scheduler (compatible with Kaggle PyTorch)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",
    factor=0.5,
    patience=2
)

num_epochs = 50
best_miou = 0.0
patience = 5
epochs_no_improve = 0

for epoch in range(num_epochs):
    # -------- TRAIN --------
    model.train()
    train_loss = 0.0

    for imgs, masks in train_loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(imgs)

        loss = ce_loss(outputs, masks) + dice_loss(outputs, masks)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    # -------- VALIDATION --------
    model.eval()
    val_miou = 0.0

    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            outputs = model(imgs)
            val_miou += compute_miou_batch(outputs, masks, NUM_CLASSES).item()

    val_miou /= len(val_loader)

    # Step scheduler
    scheduler.step(val_miou)
    current_lr = optimizer.param_groups[0]["lr"]

    print(
        f"Epoch [{epoch+1}/{num_epochs}] | "
        f"Train Loss: {train_loss:.4f} | "
        f"Val mIoU: {val_miou:.4f} | "
        f"LR: {current_lr:.2e}"
    )

    # -------- EARLY STOPPING --------
    if val_miou > best_miou:
        best_miou = val_miou
        torch.save(model.state_dict(), "/kaggle/working/best_model.pth")
        epochs_no_improve = 0
        print(" Best model saved")
    else:
        epochs_no_improve += 1
        print(f" No improvement ({epochs_no_improve}/{patience})")

    if epochs_no_improve >= patience:
        print(" Early stopping triggered")
        break

# Save last model (backup)
torch.save(model.state_dict(), "/kaggle/working/last_model.pth")

print("\n Training finished")
print(f" Best mIoU: {best_miou:.4f}")


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 213MB/s]


Epoch [1/50] | Train Loss: 2.0413 | Val mIoU: 0.2738 | LR: 3.00e-04
 Best model saved
Epoch [2/50] | Train Loss: 1.7091 | Val mIoU: 0.3467 | LR: 3.00e-04
 Best model saved
Epoch [3/50] | Train Loss: 1.6076 | Val mIoU: 0.3356 | LR: 3.00e-04
 No improvement (1/5)
Epoch [4/50] | Train Loss: 1.5076 | Val mIoU: 0.3271 | LR: 3.00e-04
 No improvement (2/5)
Epoch [5/50] | Train Loss: 1.4589 | Val mIoU: 0.3559 | LR: 3.00e-04
 Best model saved
Epoch [6/50] | Train Loss: 1.3834 | Val mIoU: 0.3546 | LR: 3.00e-04
 No improvement (1/5)
Epoch [7/50] | Train Loss: 1.3611 | Val mIoU: 0.3789 | LR: 3.00e-04
 Best model saved
Epoch [8/50] | Train Loss: 1.3460 | Val mIoU: 0.3987 | LR: 3.00e-04
 Best model saved
Epoch [9/50] | Train Loss: 1.3055 | Val mIoU: 0.3628 | LR: 3.00e-04
 No improvement (1/5)
Epoch [10/50] | Train Loss: 1.2798 | Val mIoU: 0.4129 | LR: 3.00e-04
 Best model saved
Epoch [11/50] | Train Loss: 1.2316 | Val mIoU: 0.4040 | LR: 3.00e-04
 No improvement (1/5)
Epoch [12/50] | Train Loss: 1.24

## Model Saving


In [13]:
torch.save(model.state_dict(), "/kaggle/working/finall_model.pth")
print("Training finished")
print("Best mIoU:", best_miou)


Training finished
Best mIoU: 0.5506737354923698


## Validation & Test Evaluation


In [14]:
def evaluate_miou(model, dataloader, num_classes):
    model.eval()
    total_miou = 0.0

    with torch.no_grad():
        for imgs, masks in tqdm(dataloader, desc="Validating"):
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            outputs = model(imgs)

            total_miou += compute_miou_batch(outputs, masks, num_classes).item()

    return total_miou / len(dataloader)

val_miou = evaluate_miou(model, val_loader, NUM_CLASSES)
print(f"Validation mIoU: {val_miou:.4f}")


Validating: 100%|██████████| 85/85 [00:13<00:00,  6.30it/s]

Validation mIoU: 0.5387





In [15]:
def per_class_iou(model, dataloader, num_classes):
    model.eval()
    inter = torch.zeros(num_classes).to(DEVICE)
    union = torch.zeros(num_classes).to(DEVICE)

    with torch.no_grad():
        for imgs, masks in tqdm(dataloader, desc="Per-class IoU"):
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            preds = torch.argmax(model(imgs), dim=1)

            for c in range(num_classes):
                inter[c] += ((preds == c) & (masks == c)).sum()
                union[c] += ((preds == c) | (masks == c)).sum()

    iou = inter / (union + 1e-6)
    return iou.cpu().numpy()

iou_per_class = per_class_iou(model, val_loader, NUM_CLASSES)

for idx, val in enumerate(iou_per_class):
    print(f"{idx}: IoU = {val:.4f}")


Per-class IoU: 100%|██████████| 85/85 [00:13<00:00,  6.46it/s]

0: IoU = 0.7293
1: IoU = 0.7502
2: IoU = 0.7203
3: IoU = 0.6236
4: IoU = 0.6793
5: IoU = 0.6222





In [16]:
test_ds = OilSpillDataset(ROOT, "test", val_transforms)

test_loader = DataLoader(
    test_ds,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print("Test samples:", len(test_ds))


[TEST] 343 samples loaded
Test samples: 343


In [17]:
test_miou = evaluate_miou(model, test_loader, NUM_CLASSES)
print(f"Test mIoU: {test_miou:.4f}")


Validating: 100%|██████████| 43/43 [00:07<00:00,  6.02it/s]

Test mIoU: 0.5611



