# Image segmentation with U-Net
Inspired by [this repository](https://)

In [1]:
%load_ext autoreload
%autoreload 2

# Setting up

## Import dependencies

In [2]:
from model import UNet

import logging
import os
from pathlib import Path

import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

import wandb
from evaluate import evaluate
from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss

## Setting the GPU device

In [3]:
device = torch.device(
    'cuda' if torch.cuda.is_available()
    else 'mps' if torch.backends.mps.is_available()
    else 'cpu'
)

# Defining helper functions

## Loading the data

### Creating dataset from given directories

In [4]:
def _create_dataset(
    image_directory: Path,
    mask_directory: Path,
    image_scale: float,
):
    try:
        dataset = CarvanaDataset(image_directory, mask_directory, image_scale)
    except (AssertionError, RuntimeError, IndexError):
        dataset = BasicDataset(image_directory, mask_directory, image_scale)

    return dataset

### Splitting dataset into training and validation sets

In [5]:
def _split_dataset(dataset, validation_percentage: float):
    num_validation = int(len(dataset) * validation_percentage)
    num_training = len(dataset) - num_validation
    training_set, validation_set = random_split(
        dataset,
        [num_training, num_validation],
        generator=torch.Generator().manual_seed(0),
    )
    return training_set, validation_set

### Creating data loaders

In [6]:
def _create_data_loaders(training_set, validation_set, batch_size):
    loader_arguments = dict(
        batch_size=batch_size,
        num_workers=os.cpu_count(),
        pin_memory=True,
    )
    training_loader = DataLoader(
        training_set, shuffle=True, **loader_arguments
    )
    validation_loader = DataLoader(
        validation_set, shuffle=False, drop_last=True, **loader_arguments
    )
    return training_loader, validation_loader

## Checkpointing

In [7]:
image_directory = Path('./data/images/')
mask_directory = Path('./data/masks/')
checkpoint_directory = Path('./checkpoints/')

In [8]:
def save_checkpoint(model, dataset, checkpoint_directory, epoch):
    Path(checkpoint_directory).mkdir(parents=True, exist_ok=True)
    state_dict = model.state_dict()
    state_dict['mask_values'] = dataset.mask_values
    torch.save(state_dict, str(checkpoint_directory / f'checkpoint_epoch{epoch}.pth'))
    logging.info(f'Checkpoint {epoch} saved!')

## Training

### Determining the loss for a batch

In [9]:
def _get_batch_loss(model, batch, criterion, automatic_mixed_precision):
    images, true_masks = batch['image'], batch['mask']
    invalid_shape_message = (
        f'Network has been defined with {model.n_channels} input channels, '
        f'but loaded images have {images.shape[1]} channels. Please check that '
        'the images are loaded correctly.'
    )
    assert images.shape[1] == model.n_channels, invalid_shape_message
    images = images.to(
        device=device, dtype=torch.float32, memory_format=torch.channels_last
    )
    true_masks = true_masks.to(device=device, dtype=torch.long)
    with torch.autocast(
        device.type if device.type != 'mps' else 'cpu',
        enabled=automatic_mixed_precision,
    ):
        predicted_masks = model(images)
        if model.n_classes == 1:
            loss = criterion(predicted_masks.squeeze(1), true_masks.float())
            loss += dice_loss(
                F.sigmoid(predicted_masks.squeeze(1)),
                true_masks.float(),
                multiclass=False,
            )
        else:
            loss = criterion(predicted_masks, true_masks)
            loss += dice_loss(
                F.softmax(predicted_masks, dim=1).float(),
                F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
                multiclass=True
            )

        return images, loss, true_masks, predicted_masks

### Logging

In [10]:
def _initialize_logging(
    epochs,
    batch_size,
    learning_rate,
    validation_percentage,
    save_checkpoint,
    image_scale,
    num_training,
    num_validation,
    automatic_mixed_precision,
):
    experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
    experiment_config = {
        'epochs': epochs,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'val_percent': validation_percentage,
        'save_checkpoint': save_checkpoint,
        'img_scale': image_scale,
        'automatic_mixed_precision': automatic_mixed_precision,
    }
    experiment.config.update(experiment_config)

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {num_training}
        Validation size: {num_validation}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {image_scale}
        Mixed Precision: {automatic_mixed_precision}
    ''')

### Evaluation

In [11]:
def _evaluate(
    validation_loader,
    epoch,
    num_training,
    batch_size,
    global_step,
    scheduler,
    optimizer,
    images,
    true_masks,
    predicted_masks,
    automatic_mixed_precision,
    experiment,
):
    division_step = (num_training // (5 * batch_size))
    if division_step > 0:
        if global_step % division_step == 0:
            histograms = {}
            for tag, value in model.named_parameters():
                tag = tag.replace('/', '.')
                if not (torch.isinf(value) | torch.isnan(value)).any():
                    histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():
                    histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

            val_score = evaluate(model, validation_loader, device, automatic_mixed_precision)
            scheduler.step(val_score)

            logging.info(f'Validation Dice score: {val_score}')
            try:
                experiment.log({
                    'learning rate': optimizer.param_groups[0]['lr'],
                    'validation Dice': val_score,
                    'images': wandb.Image(images[0].cpu()),
                    'masks': {
                        'true': wandb.Image(true_masks[0].float().cpu()),
                        'pred': wandb.Image(predicted_masks.argmax(dim=1)[0].float().cpu()),
                    },
                    'step': global_step,
                    'epoch': epoch,
                    **histograms
                })
            except:
                pass


# Putting it together

## Parameters

In [12]:
EPOCHS = 5
BATCH_SIZE = 1
LEARNING_RATE = 1e-5
VALIDATION_PERCENTAGE = 0.1
SAVE_CHECKPOINT = True
IMAGE_SCALE = 0.5
AUTOMATIC_MIXED_PRECISION = False
WEIGHT_DECAY = 1e-8
MOMENTUM = 0.999
GRADIENT_CLIPPING = 1.0

BILINEAR = True
N_CHANNELS = 3  # for RGB images
N_CLASSES = 2  # the number of probabilities you want to get per pixel

STATE_DICT_PATH = ''
LOAD_STATE_DICT = False

## Loading the model

In [13]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logging.info(f'Using device {device}')

model = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, bilinear=BILINEAR)
model = model.to(memory_format=torch.channels_last)

logging.info(
    f'Network:\n'
     f'\t{model.n_channels} input channels\n'
     f'\t{model.n_classes} output channels (classes)\n'
     f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling'
)

INFO: Using device mps
INFO: Network:
	3 input channels
	2 output channels (classes)
	Bilinear upscaling


In [14]:
if LOAD_STATE_DICT:
    state_dict = torch.load(STATE_DICT_PATH, map_location=device)
    del state_dict['mask_values']
    model.load_state_dict(state_dict)
    logging.info(f'Model loaded from {STATE_DICT_PATH}')

In [15]:
model.to(device=device)

UNet(
  (inc): DoubleConvolution(
    (double_convolution): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Downscaling(
    (maxpool_convolution): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConvolution(
        (double_convolution): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), st

## Loading the data

In [16]:
dataset = _create_dataset(image_directory, mask_directory, IMAGE_SCALE)

INFO: Creating dataset with 5088 examples
INFO: Scanning mask files to determine unique values
100%|██████████| 5088/5088 [00:33<00:00, 152.99it/s]
INFO: Unique mask values: [0, 1]


In [17]:
training_set, validation_set = _split_dataset(dataset, VALIDATION_PERCENTAGE)
num_training = len(training_set)
num_validation = len(validation_set)

In [18]:
training_loader, validation_loader = _create_data_loaders(
    training_set, validation_set, BATCH_SIZE
)

## Configuring the training

In [19]:
experiment = _initialize_logging(
    EPOCHS,
    BATCH_SIZE,
    LEARNING_RATE,
    VALIDATION_PERCENTAGE,
    SAVE_CHECKPOINT,
    IMAGE_SCALE,
    num_training,
    num_validation,
    AUTOMATIC_MIXED_PRECISION,
)

[34m[1mwandb[0m: Currently logged in as: [33manony-moose-855306159078586455[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO: Starting training:
        Epochs:          5
        Batch size:      1
        Learning rate:   1e-05
        Training size:   4580
        Validation size: 508
        Checkpoints:     True
        Device:          mps
        Images scaling:  0.5
        Mixed Precision: False
    


Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x1681f9f90>> (for post_run_cell), with arguments args (<ExecutionResult object at 1681fa1d0, execution_count=19 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 1657e03d0, raw_cell="experiment = _initialize_logging(
    EPOCHS,
    .." store_history=True silent=False shell_futures=True cell_id=None> result=None>,),kwargs {}:


TypeError: _WandbInit._pause_backend() takes 1 positional argument but 2 were given

In [20]:
optimizer = optim.RMSprop(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    momentum=MOMENTUM,
    foreach=True,
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'max', patience=5
)
grad_scaler = torch.cuda.amp.GradScaler(enabled=AUTOMATIC_MIXED_PRECISION)
criterion = torch.nn.CrossEntropyLoss() if model.n_classes > 1 else torch.nn.BCEWithLogitsLoss()

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x1681f9f90>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 16551ee10, raw_cell="optimizer = optim.RMSprop(
    model.parameters(),.." store_history=True silent=False shell_futures=True cell_id=None>,),kwargs {}:


TypeError: _WandbInit._resume_backend() takes 1 positional argument but 2 were given

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x1681f9f90>> (for post_run_cell), with arguments args (<ExecutionResult object at 16820da90, execution_count=20 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 16551ee10, raw_cell="optimizer = optim.RMSprop(
    model.parameters(),.." store_history=True silent=False shell_futures=True cell_id=None> result=None>,),kwargs {}:


TypeError: _WandbInit._pause_backend() takes 1 positional argument but 2 were given

## Train the model

In [21]:

global_step = 0
for epoch in range(1, EPOCHS + 1):
    model.train()
    epoch_loss = 0
    with tqdm(
        total=num_training, desc=f'Epoch {epoch}/{EPOCHS}', unit='img'
    ) as progress_bar:
        for batch in training_loader:
            images, loss, true_masks, predicted_masks = _get_batch_loss(
                model, batch, criterion, AUTOMATIC_MIXED_PRECISION
            )

            optimizer.zero_grad(set_to_none=True)
            grad_scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIPPING)
            grad_scaler.step(optimizer)
            grad_scaler.update()

            progress_bar.update(images.shape[0])
            global_step += 1
            epoch_loss += loss.item()
            experiment.log({'train loss': loss.item(), 'step': global_step, 'epoch': epoch})
            progress_bar.set_postfix(**{'loss (batch)': loss.item()})

            _evaluate(
                validation_loader,
                epoch,
                num_training,
                BATCH_SIZE,
                global_step,
                scheduler,
                optimizer,
                images,
                true_masks,
                predicted_masks,
                AUTOMATIC_MIXED_PRECISION,
                experiment,
            )

    if SAVE_CHECKPOINT:
        save_checkpoint(model, dataset, checkpoint_directory, epoch)

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x1681f9f90>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 16b44de10, raw_cell="
global_step = 0
for epoch in range(1, EPOCHS + 1).." store_history=True silent=False shell_futures=True cell_id=None>,),kwargs {}:


TypeError: _WandbInit._resume_backend() takes 1 positional argument but 2 were given

Epoch 1/5:   0%|          | 0/4580 [00:01<?, ?img/s]


IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/paripasviktor/ltu/adm/venv/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/paripasviktor/ltu/adm/venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/paripasviktor/ltu/adm/venv/lib/python3.11/site-packages/torch/utils/data/dataset.py", line 364, in __getitems__
    return [self.dataset[self.indices[idx]] for idx in indices]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/paripasviktor/ltu/adm/venv/lib/python3.11/site-packages/torch/utils/data/dataset.py", line 364, in <listcomp>
    return [self.dataset[self.indices[idx]] for idx in indices]
            ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
  File "/Users/paripasviktor/ltu/adm/project/utils/data_loading.py", line 108, in __getitem__
    mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/paripasviktor/ltu/adm/project/utils/data_loading.py", line 124, in preprocess
    return self._mask(image, new_size, mask_values)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/paripasviktor/ltu/adm/project/utils/data_loading.py", line 141, in _mask
    mask[image == value] = i
    ~~~~^^^^^^^^^^^^^^^^
IndexError: boolean index did not match indexed array along dimension 0; dimension is 959 but corresponding boolean dimension is 640


Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x1681f9f90>> (for post_run_cell), with arguments args (<ExecutionResult object at 1690fac50, execution_count=21 error_before_exec=None error_in_exec=Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/paripasviktor/ltu/adm/venv/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/paripasviktor/ltu/adm/venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/paripasviktor/ltu/adm/venv/lib/python3.11/site-packages/torch/utils/data/dataset.py", line 364, in __getitems__
    return [self.dataset[self.indices[idx]] for idx in indices]
           ^^^^^^^

TypeError: _WandbInit._pause_backend() takes 1 positional argument but 2 were given