In [1]:
from normalization import normalize_staining
from glob import glob
from os import makedirs
from os.path import join
import os
import numpy as np
from tqdm import tqdm
from skimage.util import montage, view_as_windows
from skimage import io
from random import shuffle

In [2]:
def mosaic_batch(batch_list):
    img_list = []
    for img_path in batch_list:
        img_list.append(io.imread(img_path))
    img_list_arr = np.stack(img_list) # (K, M, N[, C])
    stitched = montage(img_list_arr, channel_axis=-1)
    return stitched

In [3]:
def normalize_batch(batch_list, patch_size):
    stitched_batch = mosaic_batch(batch_list)
    output = normalize_staining(stitched_batch)
    imgs = view_as_windows(output, (patch_size, patch_size, 3), (patch_size, patch_size, 3))
    return imgs

In [4]:
batch_size = 256
patch_size = 256
input_dir = 'data_bags_raw'
output_dir = 'data_bags_normalized'
makedirs(output_dir, exist_ok=True)
input_files = glob(join('data_bags_raw/TMA/*/*/*.png'))

In [5]:
### batch of patches normalization
shuffle(input_files)
batches = [input_files[i:i+batch_size] for i in range(0, len(input_files), batch_size)]
for batch in tqdm(batches):
    imgs = normalize_batch(batch, patch_size)
    imgs = imgs.reshape((-1, imgs.shape[-3], imgs.shape[-2], imgs.shape[-1]))
    
    for i, img_path in enumerate(batch):
        class_name = img_path.split(os.sep)[-3]
        slide_name = img_path.split(os.sep)[-2]
        img_name = img_path.split(os.sep)[-1]
        makedirs(join(output_dir, class_name, slide_name), exist_ok=True)
        io.imsave(join(output_dir, class_name, slide_name, img_name.split('.')[0]+'.png'), imgs[i], check_contrast=False)

100%|██████████| 353/353 [2:32:16<00:00, 25.88s/it]  
