In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import tifffile as tiff
import cv2
import os
from tqdm.notebook import tqdm
import zipfile

In [None]:
SHFT = [0, .5] # [0] for no shift or [0, .5] for shift with 1/2
EXTERNAL = 'ext' # 'None' or 'ext' for external data
PSEUDO = 'v55' # 'None' or 'v41' if pseudolabeling mode
SIZE = 512 # 256 or 512
REDUCE = 2
SAT_THR = 40
PIX_THR = 200 * SIZE // 256
DATA_PATH = './data2'
MASKS = f'{DATA_PATH}/train.csv'
if PSEUDO: MASKS_PSEUDO = f'{DATA_PATH}/pseudolbl_{PSEUDO}.csv'
DATA = f'{DATA_PATH}/train/'
if PSEUDO: DATA_PSEUDO = f'{DATA_PATH}/test/'
if len(SHFT) > 1:
    if EXTERNAL:
        if PSEUDO: 
            TILES_PATH = f'{DATA_PATH}/tiles_r{REDUCE}_s{SIZE}_shft_{EXTERNAL}_{PSEUDO}/'
            MASKS_PATH = f'{DATA_PATH}/masks_r{REDUCE}_s{SIZE}_shft_{EXTERNAL}_{PSEUDO}/'
        else:
            TILES_PATH = f'{DATA_PATH}/tiles_r{REDUCE}_s{SIZE}_shft_{EXTERNAL}/'
            MASKS_PATH = f'{DATA_PATH}/masks_r{REDUCE}_s{SIZE}_shft_{EXTERNAL}/'
    else:
        if PSEUDO: 
            TILES_PATH = f'{DATA_PATH}/tiles_r{REDUCE}_s{SIZE}_shft_{PSEUDO}/'
            MASKS_PATH = f'{DATA_PATH}/masks_r{REDUCE}_s{SIZE}_shft_{PSEUDO}/'
        else:
            TILES_PATH = f'{DATA_PATH}/tiles_r{REDUCE}_s{SIZE}_shft/'
            MASKS_PATH = f'{DATA_PATH}/masks_r{REDUCE}_s{SIZE}_shft/'
else:
    if EXTERNAL:
        if PSEUDO: 
            TILES_PATH = f'{DATA_PATH}/tiles_r{REDUCE}_s{SIZE}_{EXTERNAL}_{PSEUDO}/'
            MASKS_PATH = f'{DATA_PATH}/masks_r{REDUCE}_s{SIZE}_{EXTERNAL}_{PSEUDO}/'
        else:
            TILES_PATH = f'{DATA_PATH}/tiles_r{REDUCE}_s{SIZE}_{EXTERNAL}/'
            MASKS_PATH = f'{DATA_PATH}/masks_r{REDUCE}_s{SIZE}_{EXTERNAL}/'
    else:
        if PSEUDO: 
            TILES_PATH = f'{DATA_PATH}/tiles_r{REDUCE}_s{SIZE}_{PSEUDO}/'
            MASKS_PATH = f'{DATA_PATH}/masks_r{REDUCE}_s{SIZE}_{PSEUDO}/'
        else:
            TILES_PATH = f'{DATA_PATH}/tiles_r{REDUCE}_s{SIZE}/'
            MASKS_PATH = f'{DATA_PATH}/masks_r{REDUCE}_s{SIZE}/'
if not os.path.exists(TILES_PATH):
    os.mkdir(TILES_PATH)
if not os.path.exists(MASKS_PATH):
    os.mkdir(MASKS_PATH)

In [None]:
if EXTERNAL:
    ext_imgs_path = './data/images_1024'
    ext_msks_path = './data/masks_1024'
    for img_name in tqdm(os.listdir(ext_imgs_path)):
        img = cv2.imread(f'{ext_imgs_path}/{img_name}')
        if img is None:
            print('error load image:', img_path)
        img = cv2.resize(img, 
                         (img.shape[1] // REDUCE, img.shape[0] // REDUCE), 
                         interpolation=cv2.INTER_AREA)
        msk = cv2.imread(f'{ext_msks_path}/{img_name}', cv2.IMREAD_GRAYSCALE)
        msk = cv2.resize(msk, 
                         (msk.shape[1] // REDUCE, msk.shape[0] // REDUCE), 
                         interpolation=cv2.INTER_NEAREST)
        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv)
        if (s > SAT_THR).sum() <= PIX_THR or img.sum() <= PIX_THR: 
            continue
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_name_ = img_name[:img_name.rfind('_')].replace('_', '') + img_name[img_name.rfind('_'):]
        is_written = cv2.imwrite(f'{TILES_PATH}/{img_name_}', img)
        if not is_written:
            print('error write to file', f'{TILES_PATH}/{img_name_}')
        is_written = cv2.imwrite(f'{MASKS_PATH}/{img_name_}', msk)
        if not is_written:
            print('error write to file', f'{MASKS_PATH}/{img_name_}')

In [None]:
def enc2mask(encs, shape):
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for m, enc in enumerate(encs):
        if isinstance(enc, np.float) and np.isnan(enc): continue
        s = enc.split()
        for i in range(len(s) // 2):
            start = int(s[2 * i]) - 1
            length = int(s[2 * i + 1])
            img[start : start + length] = 1 + m
    return img.reshape(shape).T

def mask2enc(mask, n=1):
    pixels = mask.T.flatten()
    encs = []
    for i in range(1, n + 1):
        p = (pixels == i).astype(np.int8)
        if p.sum() == 0: encs.append(np.nan)
        else:
            p = np.concatenate([[0], p, [0]])
            runs = np.where(p[1:] != p[:-1])[0] + 1
            runs[1::2] -= runs[::2]
            encs.append(' '.join(str(x) for x in runs))
    return encs

In [None]:
dfs, paths = [], []
dfs.append(pd.read_csv(MASKS).set_index('id'))
paths.append(DATA)
if PSEUDO:
    dfs.append(pd.read_csv(MASKS_PSEUDO).set_index('id'))
    paths.append(DATA_PSEUDO)

In [None]:
for df_masks, data_path in zip(dfs, paths):
    print(df_masks.head())
    print(data_path)

In [None]:
for df_masks, data_path in zip(dfs, paths):
    print('processing:', data_path)
    x_tot, x2_tot = [], []
    for ishift, shift in enumerate(SHFT):
        for index, encs in tqdm(df_masks.iterrows(), total=len(df_masks), desc=f'shift {ishift}'):
            img = tiff.imread(os.path.join(data_path, index + '.tiff'))
            if len(img.shape) == 5: img = np.transpose(img.squeeze(), (1, 2, 0))
            if img.shape[0] == 3: img = np.transpose(img.squeeze(), (1, 2, 0))
            mask = enc2mask(encs, (img.shape[1], img.shape[0]))
            
            shape = img.shape
            try:
                pad0 = (REDUCE * SIZE - shape[0] % (REDUCE * SIZE)) % (REDUCE * SIZE)
                pad1 = (REDUCE * SIZE - shape[1] % (REDUCE * SIZE)) % (REDUCE * SIZE)
                img = np.pad(img, 
                             [[pad0 // 2, pad0 - pad0 // 2],
                              [pad1 // 2, pad1 - pad1 // 2],
                              [0, 0]],
                             constant_values=0)
                img = img[
                    0 + int(REDUCE * SIZE * shift) : img.shape[0] - int(REDUCE * SIZE * shift),
                    0 + int(REDUCE * SIZE * shift) : img.shape[1] - int(REDUCE * SIZE * shift),
                    :
                ]
                mask = np.pad(mask,
                              [[pad0 // 2, pad0 - pad0 // 2],
                               [pad1 // 2, pad1 - pad1 // 2]],
                              constant_values=0)
                mask = mask[
                    0 + int(REDUCE * SIZE * shift) : mask.shape[0] - int(REDUCE * SIZE * shift),
                    0 + int(REDUCE * SIZE * shift) : mask.shape[1] - int(REDUCE * SIZE * shift)
                ]

                img = cv2.resize(img,
                                 (img.shape[1] // REDUCE, img.shape[0] // REDUCE),
                                 interpolation=cv2.INTER_AREA)
                img = img.reshape(img.shape[0] // SIZE, SIZE, img.shape[1] // SIZE, SIZE, 3)
                img = img.transpose(0, 2, 1, 3, 4).reshape(-1, SIZE, SIZE, 3)

                mask = cv2.resize(mask,
                                  (mask.shape[1] // REDUCE, mask.shape[0] // REDUCE),
                                  interpolation=cv2.INTER_NEAREST)
                mask = mask.reshape(mask.shape[0] // SIZE, SIZE, mask.shape[1] // SIZE, SIZE)
                mask = mask.transpose(0, 2, 1, 3).reshape(-1, SIZE, SIZE)

                for i, (im, m) in enumerate(zip(img, mask)):
                    hsv = cv2.cvtColor(im, cv2.COLOR_BGR2HSV)
                    h, s, v = cv2.split(hsv)
                    if (s > SAT_THR).sum() <= PIX_THR or im.sum() <= PIX_THR: 
                        continue
                    x_tot.append((im / 255).reshape(-1, 3).mean(0))
                    x2_tot.append(((im / 255) ** 2).reshape(-1, 3).mean(0))
                    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                    is_written = cv2.imwrite(f'{TILES_PATH}/{index}_{i}_{ishift}.png', im)
                    if not is_written:
                        print('error write to file', f'{TILES_PATH}/{index}_{i}_{ishift}.png')
                    is_written = cv2.imwrite(f'{MASKS_PATH}/{index}_{i}_{ishift}.png', m)
                    if not is_written:
                        print('error write to file', f'{MASKS_PATH}/{index}_{i}_{ishift}.png')
            except:
                print('error', shape)

    img_avg =  np.array(x_tot).mean(0)
    img_std =  np.sqrt(np.array(x2_tot).mean(0) - img_avg ** 2)
    print('mean:', img_avg, '| std:', img_std)

In [None]:
print(len(os.listdir(TILES_PATH)))
print(os.listdir(TILES_PATH)[:10])
print(len(os.listdir(MASKS_PATH)))
print(os.listdir(MASKS_PATH)[:10])
images = set([x[:x.find('_')] for x in os.listdir(TILES_PATH)])
print(images)
images = set([x[:x.find('_')] for x in os.listdir(MASKS_PATH)])
print(images)

In [None]:
NUM = 5
columns, rows = NUM, NUM
idx0 = 10000
fig=plt.figure(figsize=(columns * NUM, rows * NUM))
fnames = sorted(sorted(os.listdir(TILES_PATH)))
for i in range(rows):
    for j in range(columns):
        idx = i + j * columns
        img = cv2.imread(f'{TILES_PATH}/{fnames[idx0 + idx]}')
        mask = cv2.imread(f'{MASKS_PATH}/{fnames[idx0 + idx]}', 
                          cv2.IMREAD_GRAYSCALE)
        fig.add_subplot(rows, columns, idx + 1)
        plt.axis('off')
        plt.imshow(Image.fromarray(img))
        plt.imshow(Image.fromarray(mask), alpha=.3)
        plt.title(fnames[idx0+idx])
plt.show()