# Baseline UNet (MobileNetV2 encoder)
Minimal training script using segmentation_models_pytorch to predict bolus masks from a single frame (frame1).


In [3]:
from pathlib import Path
import random
import re

import cv2
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import segmentation_models_pytorch as smp

DATA_ROOT = Path(".")
IMAGES_DIR = DATA_ROOT / "images"
MASKS_DIR = DATA_ROOT / "masks"
SEED = 42
BATCH_SIZE = 8
NUM_EPOCHS = 40
LR = 1e-3
NUM_WORKERS = 0

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device", device)


device cuda


In [10]:
FILENAME_PATTERN = re.compile(r"^(?P<prefix>[A-Za-z0-9]{6})(?P<frame>\d+)_")

def parse_prefix(path: Path):
    m = FILENAME_PATTERN.match(path.stem)
    if not m:
        raise ValueError(f"Unexpected filename: {path.name}")
    return m.group("prefix")

def load_gray(path: Path) -> np.ndarray:
    img = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(path)
    return img

def collect_frames(images_dir: Path, masks_dir: Path):
    rows = []
    for img_path in images_dir.glob("*.png"):
        try:
            prefix = parse_prefix(img_path)
        except ValueError:
            continue
        mask_path = masks_dir / img_path.name
        if not mask_path.exists():
            continue
        rows.append({
            "frame": img_path,
            "mask": mask_path,
            "sequence": prefix,
        })
    return rows

rows = collect_frames(IMAGES_DIR, MASKS_DIR)
print("total frames", len(rows))
if not rows:
    raise RuntimeError("No data found")

seqs = sorted({row["sequence"] for row in rows})
train_seq, temp_seq = train_test_split(seqs, test_size=0.30, random_state=SEED, shuffle=True)
val_seq, test_seq = train_test_split(temp_seq, test_size=0.50, random_state=SEED, shuffle=True)

def split_rows(all_rows, allowed):
    return [r for r in all_rows if r["sequence"] in allowed]

splits = {
    "train": split_rows(rows, set(train_seq)),
    "val": split_rows(rows, set(val_seq)),
    "test": split_rows(rows, set(test_seq)),
}
for name, subset in splits.items():
    print(name, len(subset), "samples")


total frames 6424
train 4821 samples
val 856 samples
test 747 samples


In [11]:
class FrameDataset(Dataset):
    def __init__(self, rows):
        self.rows = list(rows)

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        row = self.rows[idx]
        frame = load_gray(row["frame"]).astype(np.float32) / 255.0
        mask = (load_gray(row["mask"]) > 0).astype(np.float32)

        x = frame[None, ...]
        y = mask[None, ...]
        return torch.from_numpy(x), torch.from_numpy(y)

loader_kwargs = dict(num_workers=NUM_WORKERS, pin_memory=device.type == "cuda")
train_loader = DataLoader(FrameDataset(splits["train"]), batch_size=BATCH_SIZE, shuffle=True, **loader_kwargs)
val_loader = DataLoader(FrameDataset(splits["val"]), batch_size=BATCH_SIZE, shuffle=False, **loader_kwargs)
test_loader = DataLoader(FrameDataset(splits["test"]), batch_size=BATCH_SIZE, shuffle=False, **loader_kwargs)

xb, yb = next(iter(train_loader))
print("batch", xb.shape, yb.shape)


batch torch.Size([8, 1, 512, 512]) torch.Size([8, 1, 512, 512])


In [14]:
model = smp.Unet(
    encoder_name="mobilenet_v2",
    encoder_weights=None,
    in_channels=1,
    classes=1,
)
model.to(device)
print("params (M)", sum(p.numel() for p in model.parameters())/1e6)

bce = nn.BCEWithLogitsLoss()

def dice_loss(logits, targets, eps=1e-6):
    probs = torch.sigmoid(logits)
    inter = (probs * targets).sum(dim=(1,2,3))
    denom = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3))
    dice = (2*inter + eps)/(denom + eps)
    return 1 - dice.mean()

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)

def step(loader, train, desc):
    model.train(mode=train)
    losses, dices, ious = [], [], []
    iterator = tqdm(loader, desc=desc, leave=False)
    for x, y in iterator:
        x = x.to(device)
        y = y.to(device)
        with torch.set_grad_enabled(train):
            logits = model(x)
            loss = bce(logits, y) + dice_loss(logits, y)
            if train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                iterator.set_postfix(loss=f"{loss.item():.4f}")
        probs = torch.sigmoid(logits)
        pred = (probs > 0.5).float()
        inter = (pred * y).sum(dim=(1,2,3))
        union = pred.sum(dim=(1,2,3)) + y.sum(dim=(1,2,3)) - inter
        dice = (2*inter)/(pred.sum(dim=(1,2,3)) + y.sum(dim=(1,2,3))).clamp_min(1e-6)
        iou = inter/union.clamp_min(1e-6)
        losses.append(loss.item())
        dices.extend(dice.detach().cpu().numpy())
        ious.extend(iou.detach().cpu().numpy())
    return float(np.mean(losses)), float(np.mean(dices)), float(np.mean(ious))

best_dice = 0.0
best_state = None
for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_dice, train_iou = step(train_loader, train=True, desc=f"Train {epoch}")
    val_loss, val_dice, val_iou = step(val_loader, train=False, desc="Val")
    scheduler.step(val_dice)
    if val_dice > best_dice:
        best_dice = val_dice
        best_state = {k: v.cpu() for k, v in model.state_dict().items()}
    print(f"Epoch {epoch:03d} | train {train_loss:.4f} dice {train_dice:.3f} iou {train_iou:.3f} | val dice {val_dice:.3f} iou {val_iou:.3f}")

if best_state:
    model.load_state_dict(best_state)
    model.to(device)
    torch.save(model.state_dict(), "baseline_unet_mobilenetv2.pth")
    print("saved best checkpoint")


params (M) 6.628369


Train 1:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 001 | train 0.5847 dice 0.494 iou 0.380 | val dice 0.639 iou 0.513


Train 2:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 002 | train 0.3915 dice 0.639 iou 0.508 | val dice 0.645 iou 0.517


Train 3:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 003 | train 0.3689 dice 0.659 iou 0.530 | val dice 0.631 iou 0.502


Train 4:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 004 | train 0.3466 dice 0.680 iou 0.551 | val dice 0.659 iou 0.528


Train 5:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 005 | train 0.3373 dice 0.688 iou 0.560 | val dice 0.654 iou 0.528


Train 6:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 006 | train 0.3235 dice 0.701 iou 0.574 | val dice 0.685 iou 0.561


Train 7:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 007 | train 0.3233 dice 0.701 iou 0.575 | val dice 0.676 iou 0.547


Train 8:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 008 | train 0.3145 dice 0.709 iou 0.583 | val dice 0.675 iou 0.555


Train 9:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 009 | train 0.3121 dice 0.711 iou 0.585 | val dice 0.681 iou 0.555


Train 10:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 010 | train 0.2974 dice 0.725 iou 0.601 | val dice 0.694 iou 0.570


Train 11:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 011 | train 0.2949 dice 0.727 iou 0.603 | val dice 0.664 iou 0.539


Train 12:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 012 | train 0.3037 dice 0.719 iou 0.594 | val dice 0.662 iou 0.536


Train 13:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 013 | train 0.2892 dice 0.732 iou 0.609 | val dice 0.702 iou 0.581


Train 14:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 014 | train 0.2879 dice 0.734 iou 0.610 | val dice 0.660 iou 0.543


Train 15:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 015 | train 0.2826 dice 0.738 iou 0.615 | val dice 0.676 iou 0.554


Train 16:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 016 | train 0.2754 dice 0.745 iou 0.623 | val dice 0.684 iou 0.558


Train 17:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 017 | train 0.2767 dice 0.744 iou 0.622 | val dice 0.684 iou 0.565


Train 18:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 018 | train 0.2567 dice 0.762 iou 0.643 | val dice 0.713 iou 0.590


Train 19:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 019 | train 0.2508 dice 0.768 iou 0.649 | val dice 0.712 iou 0.592


Train 20:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 020 | train 0.2417 dice 0.776 iou 0.659 | val dice 0.714 iou 0.597


Train 21:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 021 | train 0.2365 dice 0.781 iou 0.665 | val dice 0.704 iou 0.584


Train 22:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 022 | train 0.2387 dice 0.779 iou 0.662 | val dice 0.702 iou 0.583


Train 23:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 023 | train 0.2343 dice 0.783 iou 0.667 | val dice 0.714 iou 0.596


Train 24:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 024 | train 0.2362 dice 0.781 iou 0.665 | val dice 0.697 iou 0.583


Train 25:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 025 | train 0.2180 dice 0.798 iou 0.686 | val dice 0.716 iou 0.600


Train 26:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 026 | train 0.2155 dice 0.801 iou 0.688 | val dice 0.707 iou 0.592


Train 27:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 027 | train 0.2097 dice 0.806 iou 0.695 | val dice 0.718 iou 0.601


Train 28:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 028 | train 0.2079 dice 0.808 iou 0.697 | val dice 0.708 iou 0.588


Train 29:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 029 | train 0.2080 dice 0.808 iou 0.697 | val dice 0.713 iou 0.596


Train 30:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 030 | train 0.2067 dice 0.809 iou 0.699 | val dice 0.717 iou 0.599


Train 31:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 031 | train 0.2041 dice 0.811 iou 0.702 | val dice 0.704 iou 0.590


Train 32:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 032 | train 0.1950 dice 0.820 iou 0.713 | val dice 0.717 iou 0.601


Train 33:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 033 | train 0.1921 dice 0.822 iou 0.716 | val dice 0.714 iou 0.599


Train 34:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 034 | train 0.1903 dice 0.824 iou 0.719 | val dice 0.717 iou 0.602


Train 35:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 035 | train 0.1877 dice 0.827 iou 0.722 | val dice 0.719 iou 0.604


Train 36:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 036 | train 0.1859 dice 0.828 iou 0.724 | val dice 0.723 iou 0.606


Train 37:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 037 | train 0.1846 dice 0.830 iou 0.726 | val dice 0.716 iou 0.602


Train 38:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 038 | train 0.1836 dice 0.830 iou 0.727 | val dice 0.708 iou 0.594


Train 39:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 039 | train 0.1820 dice 0.832 iou 0.729 | val dice 0.713 iou 0.599


Train 40:   0%|          | 0/603 [00:00<?, ?it/s]

Val:   0%|          | 0/107 [00:00<?, ?it/s]

Epoch 040 | train 0.1809 dice 0.833 iou 0.730 | val dice 0.712 iou 0.597
saved best checkpoint


In [6]:
# Load trained weights before running the Final test evaluation cell
checkpoint_path = "baseline_unet_mobilenetv2.pth"  # update to your file
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Rebuild model exactly as when training
model = smp.Unet(
    encoder_name="mobilenet_v2",
    encoder_weights=None,
    in_channels=1,
    classes=1,
).to(device)

checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint if isinstance(checkpoint, dict) else checkpoint["state_dict"])
model.eval()


Unet(
  (encoder): MobileNetV2Encoder(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(16, 96, kernel_size=(1

In [13]:
bce = nn.BCEWithLogitsLoss()

def dice_loss(logits, targets, eps=1e-6):
    probs = torch.sigmoid(logits)
    inter = (probs * targets).sum(dim=(1,2,3))
    denom = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3))
    dice = (2*inter + eps)/(denom + eps)
    return 1 - dice.mean()

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)

def step(loader, train, desc):
    model.train(mode=train)
    losses, dices, ious = [], [], []
    iterator = tqdm(loader, desc=desc, leave=False)
    for x, y in iterator:
        x = x.to(device)
        y = y.to(device)
        with torch.set_grad_enabled(train):
            logits = model(x)
            loss = bce(logits, y) + dice_loss(logits, y)
            if train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                iterator.set_postfix(loss=f"{loss.item():.4f}")
        probs = torch.sigmoid(logits)
        pred = (probs > 0.5).float()
        inter = (pred * y).sum(dim=(1,2,3))
        union = pred.sum(dim=(1,2,3)) + y.sum(dim=(1,2,3)) - inter
        dice = (2*inter)/(pred.sum(dim=(1,2,3)) + y.sum(dim=(1,2,3))).clamp_min(1e-6)
        iou = inter/union.clamp_min(1e-6)
        losses.append(loss.item())
        dices.extend(dice.detach().cpu().numpy())
        ious.extend(iou.detach().cpu().numpy())
    return float(np.mean(losses)), float(np.mean(dices)), float(np.mean(ious))

In [14]:
# Final test evaluation
with torch.no_grad():
    test_loss, test_dice, test_iou = step(test_loader, train=False, desc="Test")
print(f"Test dice {test_dice:.3f} | Test IoU {test_iou:.3f}")


Test:   0%|          | 0/94 [00:00<?, ?it/s]

Test dice 0.669 | Test IoU 0.542
