paper: Ronneberger, O., Fischer, P., & Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation (arXiv:1505.04597). arXiv. https://doi.org/10.48550/arXiv.1505.04597

ISBI-2012 electron microscopy data: https://downloads.imagej.net/ISBI-2012-challenge.zip

https://github.com/alexklibisz/isbi-2012/blob/master/notebooks/data.ipynb

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math
import typing as T
from collections import Counter
from pathlib import Path
from typing import Iterator

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tifffile as tiff
import torch
import torch.nn as nn
import torch.nn.modules.loss as torch_loss
import torch.optim as optim
import torchvision.transforms.v2 as vision_trafos_v2
from einops import rearrange
from tensordict import tensorclass
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torchvision import tv_tensors

import random_neural_net_models.learner as rnnm_learner
import random_neural_net_models.utils as utils

In [None]:
path_em_data = Path("../data/ISBI-2012-challenge")
path_em_data, path_em_data.exists()

In [None]:
X = tiff.imread(path_em_data / "train-volume.tif")
Y = tiff.imread(path_em_data / "train-labels.tif")

In [None]:
X.shape, Y.shape

In [None]:
for i, (_x, _y) in enumerate(zip(X, Y)):

    if i > 3:
        break

    fig, axs = plt.subplots(ncols=2)
    axs[0].imshow(_x)
    axs[1].imshow(_y)
    plt.tight_layout()

In [None]:
_x.mean(), _x.std()

In [None]:
Counter(_y.ravel()), Counter(_x.ravel()).most_common(5)

In [None]:
utils.make_deterministic(42)

In [None]:
device = utils.get_device()
device

In [None]:
class ISBIDatasetWithLabels(Dataset):
    def __init__(
        self,
        X: np.ndarray,
        Y: np.ndarray,
        transform: nn.Module = None,  # https://pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_getting_started.html#sphx-glr-auto-examples-transforms-plot-transforms-getting-started-py
        add_channel: bool = True,
    ):
        self.X = X
        self.Y = Y
        self.n = len(X)
        if X.shape != Y.shape:
            raise ValueError(
                f"X and y must have same length, got {X.shape=} and {Y.shape=}"
            )

        self.transform = transform
        self.add_channel = add_channel

    def __len__(self):
        return self.n

    def __getitem__(self, idx: int) -> T.Tuple[torch.Tensor, torch.Tensor]:

        img = tv_tensors.Image(self.X[idx])
        labels = tv_tensors.Mask(self.Y[idx] == 255)

        if self.transform:
            img, labels = self.transform(img, labels)

        return img, labels


class ISBIDatasetWithLabelsIterable(IterableDataset):
    def __init__(
        self,
        X: np.ndarray,
        Y: np.ndarray,
        transform: nn.Module = None,  # https://pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_getting_started.html#sphx-glr-auto-examples-transforms-plot-transforms-getting-started-py
        add_channel: bool = True,
        n_repetitions: int = 1,
    ):
        self.X = X
        self.Y = Y
        self.n_repetitions = n_repetitions
        self.n = len(X)
        if X.shape != Y.shape:
            raise ValueError(
                f"X and y must have same length, got {X.shape=} and {Y.shape=}"
            )

        self.transform = transform
        self.add_channel = add_channel

    def generate(self) -> T.Iterator[T.Tuple[torch.Tensor, torch.Tensor]]:
        for _ in range(self.n_repetitions):
            for idx in range(self.n):
                img = tv_tensors.Image(self.X[idx])
                labels = tv_tensors.Mask(self.Y[idx] == 255)

                if self.transform:
                    img, labels = self.transform(img, labels)

                yield img, labels

    def __iter__(self) -> Iterator:
        return iter(self.generate())


@tensorclass
class ISBIBlockWithLabels:
    image: torch.Tensor
    labels: torch.Tensor


def collate_isbi_dataset_to_block_with_labels(
    input: T.List[T.Tuple[torch.Tensor, torch.Tensor]]
) -> ISBIBlockWithLabels:

    images = torch.stack([v[0] for v in input])
    labels = torch.stack([v[1] for v in input])
    return ISBIBlockWithLabels(
        image=images, labels=labels, batch_size=[images.shape[0]]
    )

In [None]:
trafos = vision_trafos_v2.Compose(
    [
        vision_trafos_v2.RandomAffine(degrees=0, shear=5),
        vision_trafos_v2.RandomCrop(size=(64, 64)),
        vision_trafos_v2.RandomVerticalFlip(),
        vision_trafos_v2.RandomHorizontalFlip(),
        vision_trafos_v2.ToDtype(torch.float32),
        vision_trafos_v2.Normalize(mean=[0.0], std=[255.0]),
    ]
)

In [None]:
# ds_train = ISBIDatasetWithLabels(X1, Y1, transform=trafos)
# ds_valid = ISBIDatasetWithLabels(X[-5:], Y[-5:], transform=trafos)
n_repetitiopns = 5
X0, X1, Y0, Y1 = X[:-5], X[-5:], Y[:-5], Y[-5:]
ds_train = ISBIDatasetWithLabelsIterable(
    X0, Y0, transform=trafos, n_repetitions=n_repetitiopns
)
ds_valid = ISBIDatasetWithLabelsIterable(
    X1, Y1, transform=trafos, n_repetitions=n_repetitiopns
)

In [None]:
def show_image_and_labels(image: torch.Tensor, labels: torch.Tensor):
    fig, axs = plt.subplots(ncols=2, nrows=2)

    axs[0, 0].imshow(image[0])
    axs[0, 1].imshow(labels)

    sns.histplot(x=image[0].ravel(), ax=axs[1, 0])
    sns.histplot(x=labels.ravel(), ax=axs[1, 1])

    print(image[0].ravel().mean(), image[0].ravel().std())
    print(image[0].ravel().min(), image[0].ravel().max())

    plt.tight_layout()


# img, labels = ds_train[0]
img, labels = next(iter(ds_train.generate()))

show_image_and_labels(img, labels)

In [None]:
batch_size = 5  # *n_repetitiopns
dl_train = DataLoader(
    ds_train,
    batch_size=batch_size,
    collate_fn=collate_isbi_dataset_to_block_with_labels,
)
dl_valid = DataLoader(
    ds_valid,
    batch_size=batch_size,
    collate_fn=collate_isbi_dataset_to_block_with_labels,
)

In [None]:
next(iter(dl_train))

simplest model

In [None]:
class SimpleModel(nn.Module):

    def __init__(self):

        super().__init__()

        self.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=64,
            kernel_size=(3, 3),
            stride=1,
            padding="same",
        )
        nn.init.kaiming_normal_(self.conv1.weight, nonlinearity="relu")
        self.act_conv1 = nn.ReLU()

        self.conv2 = nn.Conv2d(
            in_channels=64,
            out_channels=64,
            kernel_size=(3, 3),
            stride=1,
            padding="same",
        )
        nn.init.kaiming_normal_(self.conv2.weight, nonlinearity="relu")
        self.act_conv2 = nn.ReLU()

        self.conv3 = nn.Conv2d(
            in_channels=64,
            out_channels=64,
            kernel_size=(3, 3),
            stride=1,
            padding="same",
        )
        nn.init.kaiming_normal_(self.conv3.weight, nonlinearity="relu")
        self.act_conv3 = nn.ReLU()

        self.conv4 = nn.Conv2d(
            in_channels=64,
            out_channels=64,
            kernel_size=(3, 3),
            stride=1,
            padding="same",
        )
        nn.init.kaiming_normal_(self.conv4.weight, nonlinearity="relu")
        self.act_conv4 = nn.ReLU()

        self.conv5 = nn.Conv2d(
            in_channels=64,
            out_channels=1,
            kernel_size=(1, 1),
            stride=1,
            padding="same",
        )
        nn.init.kaiming_normal_(self.conv5.weight, nonlinearity="relu")

        self.net = nn.Sequential(
            self.conv1,
            self.act_conv1,
            self.conv2,
            self.act_conv2,
            self.conv3,
            self.act_conv3,
            self.conv4,
            self.act_conv4,
            self.conv5,
        )

    def forward(self, input: ISBIBlockWithLabels) -> torch.Tensor:
        x = input.image.float()
        return self.net(x)

loss

In [None]:
class BCEISBI(torch_loss.BCEWithLogitsLoss):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(
        self, inference: torch.Tensor, input: ISBIBlockWithLabels
    ) -> torch.Tensor:
        return super().forward(inference.ravel(), input.labels.ravel().float())

In [None]:
model = SimpleModel()

In [None]:
learning_rate = 0.1
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss = BCEISBI()
loss_callback = rnnm_learner.TrainLossCallback()

save_dir = Path(
    "./models"
)  # location used by learner.find_learning_rate to store the model before the search

# the name_patterns used below work only because of how DenseNet and Layer are defined, you may have to use different patterns
activations_callback = rnnm_learner.TrainActivationsCallback(
    every_n=10, max_depth_search=4, name_patterns=(".*act.*",)
)
gradients_callback = rnnm_learner.TrainGradientsCallback(
    every_n=10, max_depth_search=4, name_patterns=(".*conv.*",)
)
parameters_callback = rnnm_learner.TrainParametersCallback(
    every_n=10, max_depth_search=4, name_patterns=(".*conv.*",)
)

callbacks = [
    loss_callback,
    activations_callback,
    gradients_callback,
    parameters_callback,
]

In [None]:
learner = rnnm_learner.Learner(
    model,
    optimizer,
    loss,
    callbacks=callbacks,
    save_dir=save_dir,
    device=device,
)

In [None]:
lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 100, 100)

learner.find_learning_rate(
    dl_train, n_epochs=50, lr_find_callback=lr_find_callback
)

In [None]:
lr_find_callback.plot(yscale="log", ylim=(1e-1, 10))

In [None]:
def calc_steps_per_epoch(
    ds_train: T.Union[Dataset, IterableDataset],
    dl_train: DataLoader,
    X0: np.ndarray,
    n_repetitiopns: int,
    batch_size: int,
) -> int:
    if hasattr(ds_train, "__len__"):
        steps_per_epoch = len(dl_train)
    else:
        steps_per_epoch = math.ceil(len(X0) * n_repetitiopns / batch_size)
    return steps_per_epoch


steps_per_epoch = calc_steps_per_epoch(
    ds_train, dl_train, X0, n_repetitiopns, batch_size
)
steps_per_epoch

In [None]:
learning_rate = 4e-4
n_epochs = 50

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=learning_rate,
    epochs=n_epochs,
    steps_per_epoch=steps_per_epoch,
)
scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)
learner.update_callback(scheduler_callback)

In [None]:
learner.fit(dl_train, n_epochs=n_epochs, dataloader_valid=dl_valid)

In [None]:
loss_callback.plot(yscale="log", window=20, window_valid=10)

In [None]:
parameters_callback.plot()

In [None]:
gradients_callback.plot()

In [None]:
activations_callback.plot()

In [None]:
y_logits, inputs = learner.predict(dl_valid, return_inputs=True)

In [None]:
inputs

In [None]:
y_probs = y_logits.detach().sigmoid().numpy()
y_probs.shape

In [None]:
losses_simple = loss_callback.get_losses_valid()

losses_simple

In [None]:
def show_inputs_labels_and_predictions(
    ix: int, inputs: ISBIBlockWithLabels, y_prob: np.ndarray
):
    fig, axs = plt.subplots(ncols=3)

    axs[0].imshow(inputs.image[ix, 0])
    axs[1].imshow(inputs.labels[ix])
    axs[2].imshow(y_prob[ix, 0])
    plt.show()

In [None]:
for ix in range(y_logits.shape[0]):
    show_inputs_labels_and_predictions(ix, inputs, y_probs)

(shallow) u-net model

In [None]:
class ConvBlock(nn.Module):

    def __init__(self, n_in: int, n_out: int):
        super().__init__()

        self.conv1 = nn.Conv2d(
            in_channels=n_in,
            out_channels=n_out,
            kernel_size=(3, 3),
            stride=1,
            padding="same",
        )
        nn.init.kaiming_normal_(self.conv1.weight, nonlinearity="relu")
        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(
            in_channels=n_out,
            out_channels=n_out,
            kernel_size=(3, 3),
            stride=1,
            padding="same",
        )
        nn.init.kaiming_normal_(self.conv2.weight, nonlinearity="relu")
        self.act2 = nn.ReLU()

        self.net = nn.Sequential(self.conv1, self.act1, self.conv2, self.act2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class ResConvBlock(nn.Module):

    def __init__(self, n_in: int, n_out: int):
        super().__init__()

        self.conv1 = nn.Conv2d(
            in_channels=n_in,
            out_channels=n_out,
            kernel_size=(3, 3),
            stride=1,
            padding="same",
        )
        nn.init.kaiming_normal_(self.conv1.weight, nonlinearity="relu")
        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(
            in_channels=n_out,
            out_channels=n_out,
            kernel_size=(3, 3),
            stride=1,
            padding="same",
        )
        nn.init.kaiming_normal_(self.conv2.weight, nonlinearity="relu")
        self.act2 = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x0 = self.act1(self.conv1(x))
        x1 = self.act2(self.conv2(x0))
        return x0 + x1


class Up(nn.Module):

    def __init__(self, n_in: int, n_out: int):
        super().__init__()

        self.up_nn = nn.UpsamplingNearest2d(scale_factor=2)
        self.up_conv = nn.Conv2d(n_in, n_out, kernel_size=3, padding="same")
        nn.init.kaiming_normal_(self.up_conv.weight, nonlinearity="relu")
        self.up_act = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.up_nn(x)
        return self.up_act(self.up_conv(x))


def get_conv_block(
    n_in: int, n_out: int, res: bool
) -> T.Union[ConvBlock, ResConvBlock]:
    if res:
        return ResConvBlock(n_in=n_in, n_out=n_out)
    else:
        return ConvBlock(n_in=n_in, n_out=n_out)


class Model2Layers(nn.Module):

    def __init__(self, res: bool = False):

        super().__init__()

        self.net_in = get_conv_block(n_in=1, n_out=64, res=res)  # h/w 64

        self.sample_down = nn.MaxPool2d(kernel_size=(2, 2), stride=2)  # h/w 32

        self.mid_conv = nn.Conv2d(
            in_channels=64,
            out_channels=128,
            kernel_size=(3, 3),
            stride=1,
            padding="same",
        )  # h/w 32
        nn.init.kaiming_normal_(self.mid_conv.weight, nonlinearity="relu")
        self.mid_act_conv = nn.ReLU()

        self.sample_up = Up(n_in=128, n_out=64)  # h/w 64

        self.net_out = get_conv_block(n_in=128, n_out=64, res=res)  # h/w 64

        self.net_final = nn.Conv2d(
            in_channels=64, out_channels=1, kernel_size=(1, 1), padding="same"
        )  # h/w 64

    def forward(self, input: ISBIBlockWithLabels) -> torch.Tensor:
        x = input.image.float()
        x = self.net_in(x)

        y = self.sample_down(x)
        y = self.mid_act_conv(self.mid_conv(y))

        y = self.sample_up(y)

        z = torch.cat((x, y), dim=1)

        z = self.net_out(z)
        z = self.net_final(z)

        return z


class Model3Layers(nn.Module):

    def __init__(self, res: bool):

        super().__init__()

        self.net_in = get_conv_block(n_in=1, n_out=64, res=res)  # h/w 64

        self.sample_down1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)  # h/w 32

        self.net_down_layer1 = get_conv_block(
            n_in=64, n_out=128, res=res
        )  # h/w 32

        self.sample_down2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)  # h/w 16

        self.mid_conv = nn.Conv2d(
            in_channels=128,
            out_channels=256,
            kernel_size=(3, 3),
            stride=1,
            padding="same",
        )  # h/w 16
        nn.init.kaiming_normal_(self.mid_conv.weight, nonlinearity="relu")
        self.mid_act_conv = nn.ReLU()

        self.sample_up1 = Up(n_in=256, n_out=128)  # h/w 32

        self.net_up_layer1 = get_conv_block(
            n_in=256, n_out=128, res=res
        )  # h/w 32

        self.sample_up2 = Up(n_in=128, n_out=64)  # h/w 64

        self.net_out = get_conv_block(n_in=128, n_out=64, res=res)  # h/w 64

        self.net_final = nn.Conv2d(
            in_channels=64, out_channels=1, kernel_size=(1, 1), padding="same"
        )  # h/w 64

    def forward(self, input: ISBIBlockWithLabels) -> torch.Tensor:

        x0 = input.image.float()
        x0 = self.net_in(x0)

        x1 = self.sample_down1(x0)
        x1 = self.net_down_layer1(x1)

        z = self.sample_down2(x1)
        z = self.mid_act_conv(self.mid_conv(z))

        x4 = self.sample_up1(z)
        x4 = torch.cat((x1, x4), dim=1)
        x4 = self.net_up_layer1(x4)

        x5 = self.sample_up2(x4)
        x5 = torch.cat((x0, x5), dim=1)
        x5 = self.net_out(x5)

        y = self.net_final(x5)

        return y

In [None]:
model = Model3Layers(res=False)

In [None]:
learning_rate = 0.1
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss = BCEISBI()
loss_callback = rnnm_learner.TrainLossCallback()

save_dir = Path(
    "./models"
)  # location used by learner.find_learning_rate to store the model before the search

# the name_patterns used below work only because of how DenseNet and Layer are defined, you may have to use different patterns
activations_callback = rnnm_learner.TrainActivationsCallback(
    every_n=10, max_depth_search=4, name_patterns=(".*act.*",)
)
gradients_callback = rnnm_learner.TrainGradientsCallback(
    every_n=10, max_depth_search=4, name_patterns=(".*conv.*",)
)
parameters_callback = rnnm_learner.TrainParametersCallback(
    every_n=10, max_depth_search=4, name_patterns=(".*conv.*",)
)

callbacks = [
    loss_callback,
    activations_callback,
    gradients_callback,
    parameters_callback,
]

In [None]:
learner = rnnm_learner.Learner(
    model,
    optimizer,
    loss,
    callbacks=callbacks,
    save_dir=save_dir,
    device=device,
)

In [None]:
lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 100, 100)

learner.find_learning_rate(
    dl_train, n_epochs=50, lr_find_callback=lr_find_callback
)

In [None]:
lr_find_callback.plot(yscale="linear", ylim=(0, 1))

In [None]:
steps_per_epoch = calc_steps_per_epoch(
    ds_train, dl_train, X0, n_repetitiopns, batch_size
)

In [None]:
learning_rate = 3e-4
n_epochs = 50

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=learning_rate,
    epochs=n_epochs,
    steps_per_epoch=steps_per_epoch,
)
scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)
learner.update_callback(scheduler_callback)

In [None]:
learner.fit(dl_train, n_epochs=n_epochs, dataloader_valid=dl_valid)

In [None]:
loss_callback.plot(yscale="log")

In [None]:
y_logits, inputs = learner.predict(dl_valid, return_inputs=True)

In [None]:
inputs

In [None]:
y_probs = y_logits.detach().sigmoid().numpy()
y_probs.shape

In [None]:
losses_shallow = loss_callback.get_losses_valid()
display(losses_simple.tail(), losses_shallow.tail())

In [None]:
fig, ax = plt.subplots()

sns.lineplot(
    data=losses_simple, x="iteration", y="loss_valid", label="simple", ax=ax
)
sns.lineplot(
    data=losses_shallow, x="iteration", y="loss_valid", label="u-net", ax=ax
)

ax.legend(title="model")
ax.set(yscale="log")

plt.tight_layout()

In [None]:
for ix in range(y_logits.shape[0]):
    show_inputs_labels_and_predictions(ix, inputs, y_probs)