# Preparing a Dataset to Build a Segmentation Model

In the previous notebook, we created a simple segmentation model using prepared dataset. However, there was no source code for dataset preparation, but only the result. In this notebook, I include the code I used to prepare the data with one major change: [histogram equalization](https://en.wikipedia.org/wiki/Histogram_equalization).

[As the previous notebook shows](https://www.kaggle.com/code/purplejester/fast-ai-01-basic-segmentation-model-training), some of the training samples are rather strange. A white blob covering huge area of an image without any distinguishable parts of an X-ray scan. To solve this issue, we need to normalize samples to ensure that each of them has a similar distribution of pixel values. As in the "raw" data, some samples are bright, while others are dim, and so on.

The notebook creates a new dataset that (hopefully) works better for our model.

# Imports

In [None]:
import os
from dataclasses import dataclass, field, asdict
from pprint import pprint as pp
from pathlib import Path
from typing import List

import altair as alt
import cv2 as cv
import pandas as pd
import PIL.Image
from skimage import exposure

from fastai.vision.all import *
from fastcore.all import *
from fastprogress import progress_bar
from fast_ai_utils import rle_decode, get_dataset_size, COMBINED_MASK_CODES, DATA_ROOT, OUTPUT_DIR, TRAIN_CSV

In [None]:
get_dataset_size()  # hard-coded from the number of samples in the "raw" dataset

In [None]:
png_files = get_image_files(DATA_ROOT)
png_files[np.random.randint(0, get_dataset_size(), size=5)]

In [None]:
df_labels = pd.read_csv(TRAIN_CSV)
df_labels.sample(5)

Before we start preparing the dataset, let's check a bit its properities.

In [None]:
(
    alt.Chart(
        data=(
            df_labels.segmentation.isna().value_counts().reset_index()
            .rename(columns={"index": "Has Mask?", "segmentation": "Count"})
        )
    )
    .mark_bar()
    .encode(y="Has Mask?:O", x="Count", color="Has Mask?")
)

In [None]:
(
    alt.Chart(
        data=(
            df_labels["class"].value_counts(dropna=False).reset_index()
            .rename(columns={"index": "Class", "class": "Count"})
        )
    )
    .mark_bar()
    .encode(y="Class:O", x="Count", color="Class")
)

First of all, a significant part of the images don't have any segmentation mask. Second, the segmentation classes distributed equally among samples. Which sounds like a good thing as we don't need to care much about classes imbalance.

In [None]:
UNIQUE_CLASSES = df_labels["class"].unique().tolist()
UNIQUE_CLASSES

Also, we have only three segmentation classes. But could it be that some of the segmentation masks overlap? We'll check it soon. But first of all, we'll read image properties and prepare samples.

# Reading image properties

In the dataset we use, each file path includes additional information about a sample it points to. We can use this information when building a model. For example, to come up with a reliable cross-validation scheme. Therefore, we parse this information from files and store into a convenice object called `ImageProperties`.

In [None]:
@dataclass
class ImageProperties:
    slice_id: int
    h: int
    w: int
    px_h: float
    px_w: float
    rel_path: str
    case: int = field(init=False)
    day: int = field(init=False)
    sample_id: str = field(init=False)
        
    def __post_init__(self):
        case_str, day_str, *_ = self.rel_path.split("/")
        self.case = int(case_str.replace("case", ""))
        self.day = int(day_str.split("_")[-1].replace("day", ""))
        self.sample_id = f"case{self.case}_day{self.day}_slice_{self.slice_id:04d}"

def get_image_properties(path: Path) -> ImageProperties:
    rel_path = path.relpath(DATA_ROOT/"train")
    _, *parts = path.stem.split("_")
    slice_id, h, w = map(int, parts[:3])
    px_h, px_w = map(float, parts[3:])
    return ImageProperties(slice_id, h, w, px_h, px_w, str(rel_path))

In [None]:
%%time
cached_file = OUTPUT_DIR/"image_props.csv"

# uncomment the following line to force re-reading properties (used for debugging)
# cached_file.unlink()

if not OUTPUT_DIR.exists():
    OUTPUT_DIR.mkdir()
    
if cached_file.exists():
    print(f"Reading image properties from cache: {cached_file}")
    df_props = pd.read_csv(cached_file)
    props = [
        ImageProperties(**{
            k: v for 
            k, v in t._asdict().items() 
            if k not in ("case", "day", "sample_id")
        })
        for t in df_props.itertuples(index=False)
    ]
    
else:
    print("Extracting metadata from image names...", end="")
    props = list(map(get_image_properties, png_files))
    df_props = pd.DataFrame(props)
    df_props.to_csv(cached_file, index=False)
    print("done!")

> **Note:** a trivial caching logic shown above isn't strictly required here, but it helped to debug the snippet. So it could be a good idea to include such conditional processing steps into a notebook to iterate quicker. Also, we store the retrieved properties in the working directory to include into the output of this notebook and use it later.

In [None]:
df_props = pd.read_csv(cached_file)
df_props.head()

In [None]:
MASK_CODES = dict(zip(UNIQUE_CLASSES, (0b001, 0b010, 0b100)))
MASK_CODES

# Equalizing samples and building segmentation masks

As we retrieved the metadata information, we're ready to start processing the images. For this purpose, we use `scikit-image` package that includes many useful utilities for this kind of work. In essense, our histogram equalization procedure should make samples more similar to each other, without extremely bright or dim outliers.

In [None]:
image_sizes = df_props.set_index("sample_id")[["h", "w"]].to_dict("index")

In [None]:
def equalize(data: np.array, adaptive: bool) -> np.ndarray:
    data = data - np.min(data)
    data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    method = (
        exposure.equalize_adapthist
        if adaptive
        else exposure.equalize_hist
    )
    return method(data)

In [None]:
%%time 

images_dir = OUTPUT_DIR/"images"

if not images_dir.exists():
    images_dir.mkdir()

for prop in progress_bar(props):
    src_path = DATA_ROOT.joinpath("train").joinpath(prop.rel_path)
    dst_path = images_dir/f"{prop.sample_id}.png"
    arr = cv.imread(str(src_path), cv.CV_16UC1)
    arr = equalize(arr, adaptive=True)
    arr = (arr * 255).astype(np.uint8)
    pil_image = PIL.Image.fromarray(arr)
    pil_image.save(dst_path)

In [None]:
%%time

empty_masks = set()

masks_dir = OUTPUT_DIR/"masks"

if not masks_dir.exists():
    masks_dir.mkdir()

for sample_id, df_group in progress_bar(df_labels.groupby("id")):
    size = image_sizes[sample_id]
    shape = size["h"], size["w"]
    decoded_mask = np.zeros(shape, dtype=np.uint8)

    for _, row in df_group.iterrows():
        rle_mask = row.segmentation
        if isinstance(rle_mask, str):
            decoded_mask |= rle_decode(rle_mask, shape, value=MASK_CODES[row["class"]])
    
    if not decoded_mask.any():
        empty_masks.add(sample_id)
    
    mask_image = PIL.Image.fromarray(decoded_mask.T)
    mask_path = masks_dir/f"{row['id']}.png"
    mask_image.save(mask_path)

# Sanity check

The dataset is ready, but it is always a good idea to check if everything looks as expected. For this purpose, we train a model on a random subset of the data and visualize samples, and predictions one more time. 

In [None]:
def get_items(source_dir: Path):
    return get_image_files(source_dir.joinpath("images"))[:get_dataset_size()]

def get_y(fn: Path):
    return fn.parent.parent.joinpath("masks").joinpath(f"{fn.stem}.png")

seg = DataBlock(blocks=(ImageBlock, MaskBlock(COMBINED_MASK_CODES)),
                get_items=get_items,
                get_y=get_y,
                splitter=RandomSplitter(),
                item_tfms=[Resize(192, method="squash")])

dls = seg.dataloaders(OUTPUT_DIR, bs=32)

learn = unet_learner(dls, resnet18, metrics=DiceMulti)

learn.fine_tune(3)

interp = SegmentationInterpretation.from_learner(learn)

interp.plot_top_losses(4, largest=True)

Much better this time! We see that each sample is a clearly visible X-ray scan, instead of a white blob we had before. 

# What's Next?

Now we can use this dataset as our starting point in building of new models, without worrying about bad quality images. Let's see what can be achieved. Stay tuned!