In [None]:
import os
import cv2
import skimage.io
from tqdm.notebook import tqdm
import zipfile
import numpy as np
import matplotlib.pyplot as plt

In [None]:
sz = 156
N = 64
TIFF = 1
RESIZE = None
DARK_THRESHOLD = None
SHOW = None
TRAIN = './data/train_images/'
TILES = './data/train_images_tiles_q1_156_64/'
MASKS = './data/train_label_masks/'
if not os.path.exists(TILES):
    os.mkdir(TILES)

In [None]:
def tile(img, dark=None):
    result = []
    shape = img.shape
    pad0,pad1 = (sz - shape[0]%sz)%sz, (sz - shape[1]%sz)%sz
    img = np.pad(img,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
                constant_values=255)
    img = img.reshape(img.shape[0]//sz,sz,img.shape[1]//sz,sz,3)
    img = img.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
    if len(img) < N:
        img = np.pad(img,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=255)
    if dark:
        img_nd = np.where(img < dark, 255, img)
        idxs = np.argsort(img_nd.reshape(img_nd.shape[0],-1).sum(-1))[:N]
    else:
        idxs = np.argsort(img.reshape(img.shape[0],-1).sum(-1))[:N]
    img = img[idxs]
    for i in range(len(img)):
        result.append({'img':img[i], 'idx':i})
    return result

In [None]:
x_tot, x2_tot = [], []
if SHOW:
    names = [name[:-10] for name in os.listdir(MASKS)][:SHOW]
    fig, axes = plt.subplots(figsize=(18, SHOW), nrows=len(names), ncols=N)
else: 
    names = [name[:-10] for name in os.listdir(MASKS)]
for i_name, name in tqdm(enumerate(names), total=len(names)):
    img = skimage.io.MultiImage(os.path.join(TRAIN, name + '.tiff'))[TIFF]
    if RESIZE:
        img = cv2.resize(img, (int(img.shape[1] / RESIZE), int(img.shape[0] / RESIZE)))
    tiles = tile(img, dark=DARK_THRESHOLD)
    for t in tiles:
        img, idx = t['img'], t['idx']
        if SHOW:
            axes[i_name, idx].imshow(img)
            axes[i_name, idx].axis('off')
        x_tot.append((img / 255).reshape(-1, 3).mean(0))
        x2_tot.append(((img / 255) ** 2).reshape(-1, 3).mean(0)) 
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        is_written = cv2.imwrite(f'{TILES}{name}_{idx}.png', img)
        if not is_written:
            print('error write to file', f'{TILES}{name}_{idx}.png')
if SHOW:
    plt.show()

In [None]:
img_avr =  np.array(x_tot).mean(0)
img_std =  np.sqrt(np.array(x2_tot).mean(0) - img_avr ** 2)
print('mean:', img_avr, '| std:', np.sqrt(img_std))