# **Imports**

In [1]:
!pip install -q segmentation-models-pytorch torchmetrics transformers

In [22]:
import json
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DataLoader

import segmentation_models_pytorch as smp
from torchmetrics import JaccardIndex, Accuracy, F1Score
from torchmetrics.segmentation import DiceScore
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torch.amp import GradScaler, autocast

from PIL import Image
import numpy as np

from IPython.display import FileLink

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


# **Setup**

## **Dataset**

In [24]:
# Initial setup
DATA_ROOT = Path("/kaggle/input/data-clean")
IMG_DIR = DATA_ROOT / "images"
MSK_DIR = DATA_ROOT / "masks"

train_img_dir = IMG_DIR / "train"
val_img_dir   = IMG_DIR / "validation"
test_img_dir  = IMG_DIR / "test"

train_msk_dir = MSK_DIR / "train"
val_msk_dir   = MSK_DIR / "validation"
test_msk_dir  = MSK_DIR / "test"

CONFIG = {
    "experiment_name": "unet_resnet34_dacl10k_512",
    "model": "Unet",
    "encoder": "resnet34",
    "encoder_weights": "imagenet",
    "num_classes": 14,
    "image_size": (512, 512),
    "batch_size": 8,
    "epochs": 40,
    "learning_rate": 1e-4,
    "loss": "CrossEntropyLoss",
    "optimizer": "Adam",
    "scheduler": "ReduceLROnPlateau",
    "metrics": [
        "mean_iou_per_class",
        "dice_macro",
        "f1_macro",
        "global_pixel_accuracy",
    ],
}

NUM_CLASSES = CONFIG["num_classes"]
BATCH_SIZE  = CONFIG["batch_size"]
EPOCHS      = CONFIG["epochs"]
LR          = CONFIG["learning_rate"]
IMAGE_SIZE  = CONFIG["image_size"]

In [25]:
# Dataset setup
class Dacl10kDataset(Dataset):
    def __init__(self, img_dir, msk_dir, image_size=(512, 512)):
        self.img_dir = Path(img_dir)
        self.msk_dir = Path(msk_dir)
        self.image_size = image_size

        self.img_paths = sorted([p for p in self.img_dir.iterdir()])

        # Transformations for training images
        self.img_transform = transforms.Compose([
            transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize(  # Normalize each channel with ImageNet normalization values
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ])

        self.mask_resize = transforms.Resize(
            self.image_size,
            interpolation=transforms.InterpolationMode.NEAREST, # Change interpolation value to keep integers
        )

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        msk_path = self.msk_dir / img_path.name.replace("jpg", "png") # same filename, jpg -> png

        # Image
        img = Image.open(img_path).convert("RGB")
        img = self.img_transform(img)

        # Mask
        mask = Image.open(msk_path)
        mask = self.mask_resize(mask)
        mask = torch.from_numpy(np.array(mask, dtype=np.int64))  # [H, W] long with 0..13

        return img, mask

In [26]:
# Dataset definition
train_dataset = Dacl10kDataset(train_img_dir, train_msk_dir, IMAGE_SIZE)
val_dataset   = Dacl10kDataset(val_img_dir,   val_msk_dir,   IMAGE_SIZE)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
)

print("Train samples:", len(train_dataset), " | batches:", len(train_loader))
print("Val samples:  ", len(val_dataset),   " | batches:", len(val_loader))


Train samples: 5895  | batches: 737
Val samples:   1040  | batches: 130


## **Model, loss, optimizer, metrics**

In [27]:
# Model
model = smp.Unet(
    encoder_name=CONFIG["encoder"],
    encoder_weights=CONFIG["encoder_weights"],
    in_channels=3,
    classes=NUM_CLASSES,
).to(device)

# Loss and optimizer. Scheduler will decrease learning rate as  it hits a plateau
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=3,
    verbose=True,
)

# Metrics
# 1) Mean IoU per class (Jaccard)
miou_metric = JaccardIndex(
    task="multiclass",
    num_classes=NUM_CLASSES,
).to(device)

# 2) Dice score (macro over classes)
dice_metric = DiceScore(
    num_classes=NUM_CLASSES,
    average="macro",
).to(device)

# 3) F1 Score (macro over classes)
f1_metric = F1Score(
    task="multiclass",
    num_classes=NUM_CLASSES,
    average="macro",
).to(device)

# 4) Global Pixel Accuracy
acc_metric = Accuracy(
    task="multiclass",
    num_classes=NUM_CLASSES,
).to(device)

print("UNet params (M):", sum(p.numel() for p in model.parameters()) / 1e6)

UNet params (M): 24.438254


In [28]:
scaler = GradScaler(enabled=(device.type == "cuda")) # Uses AMP to speed up training

CHECKPOINT_EVERY = 2  # epochs

def train_one_epoch(model, loader, optimizer, criterion, epoch):
    model.train() # set model to training mode
    running_loss = 0.0 # start total loss at 0.0

    for step, (images, masks) in enumerate(loader, start=1):
        images = images.to(device, non_blocking=True)
        masks  = masks.to(device, non_blocking=True)

        optimizer.zero_grad() # zero the gradients

         # Uses AMP to speed up training
        with autocast(device_type="cuda", enabled=(device.type == "cuda")):
            outputs = model(images)           
            loss = criterion(outputs, masks)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() # add losses in each step

        if step % 50 == 0 or step == 1:
            print(f"[Epoch {epoch}] Step {step}/{len(loader)} - Loss: {loss.item():.4f}")

    return running_loss / len(loader) # return mean loss across an epoch


@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    val_loss = 0.0

    # Reset metrics each evaluation
    miou_metric.reset()
    dice_metric.reset()
    f1_metric.reset()
    acc_metric.reset()

    for images, masks in loader:
        images = images.to(device, non_blocking=True)
        masks  = masks.to(device, non_blocking=True)

        # Uses AMP to speed up training
        with autocast(device_type="cuda", enabled=(device.type == "cuda")):
            outputs = model(images)
            loss = criterion(outputs, masks)

        val_loss += loss.item()

        preds = torch.argmax(outputs, dim=1)  # Get class index with highest probability

        # Update metrics 
        miou_metric.update(preds, masks)
        dice_metric.update(preds, masks)
        f1_metric.update(preds, masks)
        acc_metric.update(preds, masks)

    val_loss /= len(loader) # compute mean loss

    miou = miou_metric.compute().item()   # mean IoU per class
    dice = dice_metric.compute().item()   # macro Dice
    mf1  = f1_metric.compute().item()     # macro F1
    acc  = acc_metric.compute().item()    # global pixel accuracy

    return val_loss, miou, dice, mf1, acc


# **Train**

In [31]:
history = [] 
best_miou = 0.0

for epoch in range(1, EPOCHS + 1):
    print(f"\n===== Epoch {epoch}/{EPOCHS} =====")

    # Training epoch
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, epoch)

    # Metrics for train and validation sets
    train_loss_eval, train_miou, train_dice, train_f1, train_acc = evaluate(model, train_loader, criterion)
    val_loss, val_miou, val_dice, val_f1, val_acc = evaluate(model, val_loader, criterion)

    # step scheduler on val_loss
    scheduler.step(val_loss)

    current_lr = optimizer.param_groups[0]["lr"]
    
    print(
        f"Epoch {epoch:03d} | "
        f"TrainLoss(step): {train_loss:.4f} | "
        f"TrainLoss(eval): {train_loss_eval:.4f} | "
        f"ValLoss: {val_loss:.4f} | "
        f"Train mIoU: {train_miou:.4f} | "
        f"Val mIoU: {val_miou:.4f} | "
        f"Val Dice: {val_dice:.4f} | "
        f"Val F1: {val_f1:.4f} | "
        f"Val Acc: {val_acc:.4f} | "
        f"LR: {current_lr:.6f}"
    )

    # Store metrics
    history.append({
        "epoch": epoch,
        # training loss from the actual training loop
        "train_loss_step": float(train_loss),
        # training loss recomputed in eval mode (no dropout, BN in eval)
        "train_loss_eval": float(train_loss_eval),
        "train_miou": float(train_miou),
        "train_dice": float(train_dice),
        "train_f1_macro": float(train_f1),
        "train_global_pixel_accuracy": float(train_acc),
        "val_loss": float(val_loss),
        "val_miou": float(val_miou),
        "val_dice": float(val_dice),
        "val_f1_macro": float(val_f1),
        "val_global_pixel_accuracy": float(val_acc),
        "lr": float(current_lr),
    })

    # save best model by mIoU
    if val_miou > best_miou:
        best_miou = val_miou
        torch.save(model.state_dict(), "unet_best_miou.pth")
        print("  -> New best mIoU; weights saved to unet_best_miou.pth")

    # periodic full checkpoint save
    if epoch % CHECKPOINT_EVERY == 0:
        ckpt_path = f"unet_checkpoint_epoch_{epoch}.pth"
        torch.save({
            "config": CONFIG,
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "best_miou": best_miou,
            "history": history,
        }, ckpt_path)
        print(f"  -> Checkpoint saved to {ckpt_path}")

# final "last" checkpoint
torch.save({
    "config": CONFIG,
    "epoch": EPOCHS,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "scheduler_state_dict": scheduler.state_dict(),
    "best_miou": best_miou,
    "history": history,
}, "unet_last.pth")

print("\nTraining complete. Best mIoU:", best_miou)

# Store training logs in JSON format

output_path = Path("/kaggle/working/unet_results.json")

results = {
    "config": CONFIG,
    "history": history,
}

with open(output_path, "w") as f:
    json.dump(results, f, indent=2)

print("Saved metrics to:", output_path)

# Download link for results
FileLink('/kaggle/working/unet_results.json')



===== Epoch 1/40 =====
[Epoch 1] Step 1/737 - Loss: 0.8133
[Epoch 1] Step 50/737 - Loss: 0.8753
[Epoch 1] Step 100/737 - Loss: 0.4947
[Epoch 1] Step 150/737 - Loss: 1.0082
[Epoch 1] Step 200/737 - Loss: 0.8362
[Epoch 1] Step 250/737 - Loss: 0.8627
[Epoch 1] Step 300/737 - Loss: 0.6260
[Epoch 1] Step 350/737 - Loss: 0.4254
[Epoch 1] Step 400/737 - Loss: 0.4679
[Epoch 1] Step 450/737 - Loss: 0.7867
[Epoch 1] Step 500/737 - Loss: 0.9071
[Epoch 1] Step 550/737 - Loss: 0.7628
[Epoch 1] Step 600/737 - Loss: 0.8683
[Epoch 1] Step 650/737 - Loss: 0.7622
[Epoch 1] Step 700/737 - Loss: 0.6173
Epoch 001 | TrainLoss(step): 0.8307 | TrainLoss(eval): 0.7277 | ValLoss: 0.8137 | Train mIoU: 0.1415 | Val mIoU: 0.1265 | Val Dice: 1.3896 | Val F1: 0.1755 | Val Acc: 0.7672 | LR: 0.000100
  -> New best mIoU; weights saved to unet_best_miou.pth

===== Epoch 2/40 =====
[Epoch 2] Step 1/737 - Loss: 1.2544
[Epoch 2] Step 50/737 - Loss: 1.0275
[Epoch 2] Step 100/737 - Loss: 0.7002
[Epoch 2] Step 150/737 - Loss

In [32]:
!zip -r /kaggle/working/unet_outputs.zip /kaggle/working

  adding: kaggle/working/ (stored 0%)
  adding: kaggle/working/unet_checkpoint_epoch_18.pth (deflated 8%)
  adding: kaggle/working/unet_checkpoint_epoch_14.pth (deflated 8%)
  adding: kaggle/working/unet_checkpoint_epoch_34.pth (deflated 8%)
  adding: kaggle/working/unet_checkpoint_epoch_40.pth (deflated 8%)
  adding: kaggle/working/unet_checkpoint_epoch_38.pth (deflated 8%)
  adding: kaggle/working/unet_best_miou.pth (deflated 7%)
  adding: kaggle/working/unet_checkpoint_epoch_28.pth (deflated 8%)
  adding: kaggle/working/unet_checkpoint_epoch_26.pth (deflated 8%)
  adding: kaggle/working/unet_checkpoint_epoch_10.pth (deflated 8%)
  adding: kaggle/working/unet_last.pth (deflated 8%)
  adding: kaggle/working/unet_checkpoint_epoch_24.pth (deflated 8%)
  adding: kaggle/working/.virtual_documents/ (stored 0%)
  adding: kaggle/working/unet_checkpoint_epoch_6.pth (deflated 8%)
  adding: kaggle/working/unet_checkpoint_epoch_16.pth (deflated 8%)
  adding: kaggle/working/unet_checkpoint_epoch_