# UNet++ on Frame-Difference Dataset

In [1]:
from pathlib import Path
import random
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import cv2
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

DATA_ROOT = Path('.')
DIFF_DIR = DATA_ROOT / 'd_images'
MASK_DIR = DATA_ROOT / 'd_masks'
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

TRAIN_BATCH_SIZE = 4
VAL_BATCH_SIZE = 4
NUM_WORKERS = 0
LEARNING_RATE = 1e-3
NUM_EPOCHS = 20
AUGMENT_TRAIN = True

device = torch.device('mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu'))
print('Using device:', device)

Using device: mps


In [2]:
def load_pairs(diff_dir: Path, mask_dir: Path):
    pairs = []
    for diff_path in sorted(diff_dir.glob('*.npy')):
        base = diff_path.stem.replace('_diff', '')
        mask_path = mask_dir / f"{base}_diff.png"
        if mask_path.exists():
            seq = diff_path.name[:6]
            pairs.append(dict(diff=diff_path, mask=mask_path, sequence=seq))
    return pd.DataFrame(pairs)

pairs_df = load_pairs(DIFF_DIR, MASK_DIR)
print('Total pairs:', len(pairs_df))
if len(pairs_df) == 0:
    raise RuntimeError('No diff/mask pairs found. Run build_diff_dataset.py first.')

seqs = pairs_df['sequence'].unique()
train_seq, temp_seq = train_test_split(seqs, test_size=0.30, random_state=SEED, shuffle=True)
val_seq, test_seq = train_test_split(temp_seq, test_size=0.50, random_state=SEED, shuffle=True)

splits = {
    'train': pairs_df[pairs_df['sequence'].isin(train_seq)].reset_index(drop=True),
    'val': pairs_df[pairs_df['sequence'].isin(val_seq)].reset_index(drop=True),
    'test': pairs_df[pairs_df['sequence'].isin(test_seq)].reset_index(drop=True),
}
for name, df in splits.items():
    print(f"{name}: {len(df)} samples from {df['sequence'].nunique()} sequences")


Total pairs: 6337
train: 4761 samples from 60 sequences
val: 843 samples from 13 sequences
test: 733 samples from 14 sequences


In [3]:
class DiffDataset(Dataset):
    def __init__(self, df, augment=False):
        self.df = df
        self.augment = augment

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        diff = np.load(row['diff']).astype(np.float32)
        mask = cv2.imread(str(row['mask']), cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(row['mask'])
        mask = (mask > 0).astype(np.float32)

        diff_norm = diff / 255.0
        if self.augment:
            if random.random() < 0.5:
                diff_norm = np.flip(diff_norm, axis=1).copy()
                mask = np.flip(mask, axis=1).copy()
            if random.random() < 0.5:
                diff_norm = np.flip(diff_norm, axis=0).copy()
                mask = np.flip(mask, axis=0).copy()

        diff_tensor = torch.from_numpy(diff_norm).unsqueeze(0)
        mask_tensor = torch.from_numpy(mask).unsqueeze(0)
        return diff_tensor, mask_tensor

data_loaders = {
    'train': DataLoader(DiffDataset(splits['train'], augment=AUGMENT_TRAIN), batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS),
    'val': DataLoader(DiffDataset(splits['val']), batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS),
    'test': DataLoader(DiffDataset(splits['test']), batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS),
}
for name, loader in data_loaders.items():
    x, y = next(iter(loader))
    print(name, x.shape, y.shape)

print(f"Train/val/test samples: {len(splits['train'])}/{len(splits['val'])}/{len(splits['test'])}")
print(f"Batch sizes: train={TRAIN_BATCH_SIZE}, val/test={VAL_BATCH_SIZE}, num_workers={NUM_WORKERS}")
print(f"Hyperparams: epochs={NUM_EPOCHS}, lr={LEARNING_RATE}, augment_train={AUGMENT_TRAIN}")

train torch.Size([4, 1, 512, 512]) torch.Size([4, 1, 512, 512])
val torch.Size([4, 1, 512, 512]) torch.Size([4, 1, 512, 512])
test torch.Size([4, 1, 512, 512]) torch.Size([4, 1, 512, 512])


In [4]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class UNetPlusPlus(nn.Module):
    def __init__(self, in_channels=1, num_classes=1, filters=(32, 64, 128, 256, 512)):
        super().__init__()
        f = filters
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = ConvBlock(in_channels, f[0])
        self.conv1_0 = ConvBlock(f[0], f[1])
        self.conv2_0 = ConvBlock(f[1], f[2])
        self.conv3_0 = ConvBlock(f[2], f[3])
        self.conv4_0 = ConvBlock(f[3], f[4])

        self.conv0_1 = ConvBlock(f[0] + f[1], f[0])
        self.conv1_1 = ConvBlock(f[1] + f[2], f[1])
        self.conv2_1 = ConvBlock(f[2] + f[3], f[2])
        self.conv3_1 = ConvBlock(f[3] + f[4], f[3])

        self.conv0_2 = ConvBlock(f[0] * 2 + f[1], f[0])
        self.conv1_2 = ConvBlock(f[1] * 2 + f[2], f[1])
        self.conv2_2 = ConvBlock(f[2] * 2 + f[3], f[2])

        self.conv0_3 = ConvBlock(f[0] * 3 + f[1], f[0])
        self.conv1_3 = ConvBlock(f[1] * 3 + f[2], f[1])

        self.conv0_4 = ConvBlock(f[0] * 4 + f[1], f[0])

        self.final = nn.Conv2d(f[0], num_classes, kernel_size=1)

    def forward(self, x):
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x2_0 = self.conv2_0(self.pool(x1_0))
        x3_0 = self.conv3_0(self.pool(x2_0))
        x4_0 = self.conv4_0(self.pool(x3_0))

        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], dim=1))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], dim=1))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], dim=1))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], dim=1))

        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], dim=1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], dim=1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], dim=1))

        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], dim=1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], dim=1))

        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], dim=1))

        logits = self.final(x0_4)
        return logits

model = UNetPlusPlus().to(device)
print('Model params:', sum(p.numel() for p in model.parameters()) / 1e6, 'M')


Model params: 9.159105 M


In [5]:
def dice_loss(logits, targets, eps=1e-6):
    probs = torch.sigmoid(logits)
    numerator = 2 * (probs * targets).sum(dim=(1, 2, 3)) + eps
    denominator = probs.sum(dim=(1, 2, 3)) + targets.sum(dim=(1, 2, 3)) + eps
    loss = 1 - (numerator / denominator)
    return loss.mean()

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
num_epochs = NUM_EPOCHS

In [6]:
best_val = float('inf')
train_history = []
val_history = []

for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0
    train_bar = tqdm(data_loaders['train'], desc=f"Epoch {epoch}/{num_epochs} [train]", leave=False)
    for inputs, targets in train_bar:
        inputs = inputs.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = dice_loss(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        train_bar.set_postfix(loss=loss.item())
    epoch_loss = running_loss / len(data_loaders['train'].dataset)
    train_history.append(epoch_loss)

    model.eval()
    val_loss = 0.0
    val_bar = tqdm(data_loaders['val'], desc=f"Epoch {epoch}/{num_epochs} [val]", leave=False)
    with torch.no_grad():
        for inputs, targets in val_bar:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = dice_loss(outputs, targets)
            val_loss += loss.item() * inputs.size(0)
            val_bar.set_postfix(loss=loss.item())
    val_loss /= max(1, len(data_loaders['val'].dataset))
    val_history.append(val_loss)

    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), 'unetpp_best.pt')
    print(f"Epoch {epoch:02d} | train loss {epoch_loss:.4f} | val loss {val_loss:.4f}")

KeyboardInterrupt: 

In [None]:
# Load best model for evaluation
if Path('unetpp_best.pt').exists():
    model.load_state_dict(torch.load('unetpp_best.pt', map_location=device))
model.eval()

all_preds = []
all_targets = []
per_sample_metrics = []

with torch.no_grad():
    test_bar = tqdm(data_loaders['test'], desc='[test]', leave=False)
    for inputs, targets in test_bar:
        inputs = inputs.to(device)
        targets = targets.to(device)
        logits = model(inputs)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()

        all_preds.append(probs.cpu().numpy())
        all_targets.append(targets.cpu().numpy())

        inter = (preds * targets).sum(dim=(1,2,3))
        union = ((preds + targets) > 0).float().sum(dim=(1,2,3))
        dice = (2 * inter + 1e-6) / (preds.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) + 1e-6)
        iou = (inter + 1e-6) / (union + 1e-6)
        for d, j in zip(dice.cpu().numpy(), iou.cpu().numpy()):
            per_sample_metrics.append(dict(dice=d, iou=j))

probs_flat = np.concatenate([p.reshape(-1) for p in all_preds])
labels_flat = np.concatenate([t.reshape(-1) for t in all_targets])
mask = (labels_flat.max() != labels_flat.min())
auc = roc_auc_score(labels_flat, probs_flat) if mask else None

results = {
    'dice_mean': float(np.mean([m['dice'] for m in per_sample_metrics])),
    'iou_mean': float(np.mean([m['iou'] for m in per_sample_metrics])),
    'auc': float(auc) if auc is not None else None,
    'samples': len(per_sample_metrics)
}
print('Test metrics:', results)