In [None]:
from pathlib import Path

import torch
from torch.utils.data import DataLoader

from src.datasets.ebhi_seg_dataset import EBHISegDataset
from src.datasets.transforms import get_train_transforms
from src.models.unet import UNet
from src.models.losses import combined_loss
from src.utils.paths import load_config, get_data_paths, resolve_path
from src.utils.seed import set_seed

cfg = load_config('config/default.yaml')
paths = get_data_paths(cfg)

train_split = resolve_path(paths['splits'] / 'train.txt')
ids = [l.strip() for l in train_split.read_text().splitlines() if l.strip()]
small_ids = ids[:8]
(tmp := (paths['splits'] / 'small_train.txt')).write_text('\n'.join(small_ids))

image_size = int(cfg['data']['image_size'])
mean = cfg['data']['normalization']['mean']
std = cfg['data']['normalization']['std']
transform = get_train_transforms(image_size, mean, std)

ignore_index = int(cfg['loss'].get('ignore_index', -1))
set_seed(int(cfg['training'].get('seed', 42)))

dataset = EBHISegDataset(
    split_file=str(tmp),
    images_dir=str(paths['processed_images']),
    masks_dir=str(paths['processed_masks']),
    transform=transform,
    ignore_index=ignore_index,
)
loader = DataLoader(dataset, batch_size=2, shuffle=True)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=3, num_classes=cfg['data']['num_classes'], base_channels=16).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

loss_history = []

for epoch in range(1, 51):
    model.train()
    running = 0.0
    for batch in loader:
        imgs = batch['image'].to(device)
        masks = batch['mask'].to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = combined_loss(logits, masks, dice_weight=cfg['loss'].get('dice_weight', 1.0), ignore_index=ignore_index)
        loss.backward()
        optimizer.step()
        running += loss.item() * imgs.size(0)
    epoch_loss = running / len(dataset)
    loss_history.append(epoch_loss)
    if epoch % 5 == 0:
        print(f'Epoch {epoch}: loss={epoch_loss:.4f}')

import matplotlib.pyplot as plt
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Tiny subset overfit sanity check')
plt.show()
