# Sample images
Samples some from each dataset and visualize them

Install additional dependencies

In [None]:
# !pip install -U "matplotlib>=3.5"

In [None]:
# # or use conda
# !conda install -U "matplotlib>=3.5"

In [None]:
import os
import random
from typing import Dict, List, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np

from bioimageloader import Config
from bioimageloader.utils import random_label_cmap
from bioimageloader._experimentals import ROOTS, load_all_datasets

cmap = random_label_cmap()  # colormap for labels
plt.rcParams['image.interpolation'] = 'none'  # disable interpolation

In [None]:
def sample_n_resize(
    dataset,
    n_sample,
    height,
    width=None,
    rand_seed=None,
) -> Tuple[List[int], List[Dict[str, np.ndarray]]]:
    """Sample and resize
    
    If ``width`` is not set, it keeps aspect ratio
    
    Returns
    -------
    indices : list of int
        Randomly chosen indices
    samples : list of dictionary
        dict['image'|'mask', np.ndarray]
    """
    if rand_seed:
        random.seed(rand_seed)
    samples = []
    _indices = list(range(len(dataset)))
    random.shuffle(_indices)
    indices = _indices[:n_sample]
    for i, ind in enumerate(indices):
        data = dataset[ind]
        image = data['image']
        if width is None:
            ar = image.shape[1] / image.shape[0]
            width = int(ar * height)
        data['image'] = cv2.resize(image, (width, height))
        if dataset.output == 'mask' or dataset.output == 'both':
            mask = data['mask']
            if mask.dtype == bool:
                mask = mask.astype(np.float32)
            data['mask'] = cv2.resize(mask, (width, height))
        samples.append(data)
    return indices, samples

## Load datasets
Point to root directories

Below, I have datasets under `../Data`

In [None]:
# This variable comes from `_experimentals` module
# Overwrite it below manually
ROOTS

In [None]:
# # If you already have config file
# cfg = Config('../configs/dummy_cfg.yml')
# ROOTS = dict((k, v['root_dir']) for k, v in cfg.items())

In [None]:
# ROOTS = {
#     # anno
#     'DSB2018'                 : '../Data/DSB2018',
#     'TNBC'                    : '../Data/TNBC_NucleiSegmentation',
#     'ComputationalPathology'  : '../Data/ComputationalPathology',
#     'S_BSST265'               : '../Data/BioStudies',
#     'MurphyLab'               : '../Data/2009_ISBI_2DNuclei_code_data',
#     'BBBC006'                 : '../Data/bbbc/006',
#     'BBBC007'                 : '../Data/bbbc/007',
#     'BBBC008'                 : '../Data/bbbc/008',
#     'BBBC018'                 : '../Data/bbbc/018',
#     'BBBC020'                 : '../Data/bbbc/020',
#     'BBBC039'                 : '../Data/bbbc/039',
#     # partial anno
#     'DigitalPathology'        : '../Data/DigitalPathology',
#     'UCSB'                    : '../Data/UCSB_BioSegmentation',
#     'BBBC002'                 : '../Data/bbbc/002',
#     # no anno
#     'BBBC013'                 : '../Data/bbbc/013',
#     'BBBC014'                 : '../Data/bbbc/014',
#     'BBBC015'                 : '../Data/bbbc/015',
#     'BBBC016'                 : '../Data/bbbc/016',
#     'BBBC026'                 : '../Data/bbbc/026',
#     'BBBC041'                 : '../Data/bbbc/041',
#     'FRUNet'                  : '../Data/FRU_processing',
#     'BBBC021'                 : '../Data/bbbc/021',
# }

In [None]:
datasets = load_all_datasets()
datasets

## Sample

In [None]:
SEED = 42
# number of samples for each dataset
NUM_SAMPLE = 2
HEIGHT = 256  # resize

In [None]:
fig, big_axes = plt.subplots(len(datasets), 1, constrained_layout=True,
                             figsize=(10, 4*len(datasets)), dpi=150)
gridspec = big_axes[0].get_subplotspec().get_gridspec()

for i, dset in enumerate(datasets):
    indices, samples = sample_n_resize(dset, NUM_SAMPLE, HEIGHT, rand_seed=SEED)  # sample
    subfig = fig.add_subfigure(gridspec[i])
    subfig.suptitle(dset.acronym)
    axes = subfig.subplots(1, 4)
    for j, (ind, data) in enumerate(zip(indices, samples)):
        j *= NUM_SAMPLE
        axes[j].imshow(data['image'])
        axes[j].set_title(f'{ind}/{len(dset) - 1}')
        if 'mask' in data:
            axes[j+1].imshow(data['mask'], cmap=cmap)
    for _ax in axes:
        _ax.axis('off')

for _ax in big_axes:
    _ax.axis('off')

## Save Samples

In [None]:
SEED = 42
# NUM_SAMPLE = 2
# HEIGHT = 256
DIR = '../docs/_static/sample_images'

os.makedirs(DIR, exist_ok=True)

In [None]:
for i, dset in enumerate(datasets):
    indices, samples = sample_n_resize(dset, NUM_SAMPLE, HEIGHT, rand_seed=SEED)  # sample
    for ind, data in zip(indices, samples):
        fig, ax = plt.subplots()
        ax.imshow(data['image'])
        ax.axis('off')
        f = os.path.join(DIR, f'{dset.acronym}_{i:d}_image_{ind:04d}.png')
        fig.savefig(f, bbox_inches='tight', pad_inches=0)
        plt.close(fig)
        if dset.output == 'mask' or dset.output == 'both':
            mask = data['mask']
            if mask.dtype == bool:
                mask = mask.astype(np.float32)
            fig, ax = plt.subplots()
            ax.imshow(mask, cmap=cmap)
            ax.axis('off')
            f = os.path.join(DIR, f'{dset.acronym}_{i:d}_annotation_{ind:04d}.png')
            fig.savefig(f, bbox_inches='tight', pad_inches=0)
            plt.close(fig)