## Segmenting blood vessels in optical images of the retina using U-Net.

**Model is based on:**  
Ronneberger et al., "[U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597)," MICCAI, 2015

**Dataset is from:**  
Hoover et al., "[Locating Blood Vessels in Retinal Images by Piece-wise Threhsold Probing of a Matched Filter Response](http://cecas.clemson.edu/~ahoover/stare/)," IEEE Transactions on Medical Imaging, 19(3) 203-10, 2000


In [10]:
import pathlib
import imageio
import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

In [11]:
class DatasetSTARE(torch.utils.data.Dataset):

    def __init__(self, database, subset, nimages, transform=None):
        """Constructor.

        database: path to the .npz file.
        subset: can be 'train', 'val', or 'test'.
        nimages: number of random cropped regions to extract from original images.
        transform: transforms used for data augmentation.
        """
        super().__init__()
        fimages = sorted((database / 'images').glob('*.ppm'))
        fmasks = sorted((database / 'masks1').glob('*.ppm'))
        if subset == 'train':
            indices = list(set(range(20)) - set((6, 7, 11, 12)))
        elif subset == 'val':
            indices = [6, 11]
        elif subset == 'test':
            indices = [7, 12]
        self._images = np.array([imageio.imread(x) for i, x in enumerate(fimages) if i in indices])
        self._masks = np.array([imageio.imread(x) for i, x in enumerate(fmasks) if i in indices])
        self._transform = transform
        self._nimages = nimages

    def __len__(self):
        """Dataset size."""
        return self._nimages

    def __getitem__(self, index):
        """Image and its corresponding mask at a given index."""
        image = self._images[index % len(self._images)]
        mask = self._masks[index % len(self._masks)]
        image = image.transpose([2,0,1])  # transpose dimensions such that image shape is: channels, height, width
        image = image.astype(np.float32) / 255  # convert image from 8-bit integer to 32-bit floating precision
        mask = mask.astype(np.float32) / 255
        image = torch.as_tensor(image.copy())  # cast NumPy array to Torch tensor
        mask = torch.as_tensor(mask.copy()).unsqueeze(0)
        if self._transform is not None:
            while True:
                combined = torch.cat([image, mask], dim=0)
                combined = self._transform(combined)
                image2 = combined[:len(image)]
                mask2 = combined[len(image):]
                if (100 * mask2.sum() / mask2.numel()) < 5:  # skip if mask occupies <5% of the image
                    continue
                image = image2
                mask = mask2
                break
        return image, mask