In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tifffile as tiff
import cv2
import os
from tqdm.notebook import tqdm
import zipfile
import rasterio
from rasterio.windows import Window

In [None]:
sz = 480   #the size of tiles
WINDOW = 1024
OVERLAP = 204
MASKS = '../input/hubmap-kidney-segmentation/train.csv'
DATA = '../input/hubmap-kidney-segmentation/train/'
OUT_TRAIN = 'train.zip'
OUT_MASKS = 'masks.zip'

In [None]:
#functions to convert encoding to mask and mask to encoding
def enc2mask(encs, shape):
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for m,enc in enumerate(encs):
        if isinstance(enc,np.float) and np.isnan(enc): continue
        s = enc.split()
        for i in range(len(s)//2):
            start = int(s[2*i]) - 1
            length = int(s[2*i+1])
            img[start:start+length] = 1 + m
    return img.reshape(shape).T

def mask2enc(mask, n=1):
    pixels = mask.T.flatten()
    encs = []
    for i in range(1,n+1):
        p = (pixels == i).astype(np.int8)
        if p.sum() == 0: encs.append(np.nan)
        else:
            p = np.concatenate([[0], p, [0]])
            runs = np.where(p[1:] != p[:-1])[0] + 1
            runs[1::2] -= runs[::2]
            encs.append(' '.join(str(x) for x in runs))
    return encs

df_masks = pd.read_csv(MASKS).set_index('id')
df_masks.head()

In [None]:
def make_grid(shape, window=256, min_overlap=32):
    """
        Return Array of size (N,4), where N - number of tiles,
        2nd axis represente slices: x1,x2,y1,y2 
    """
    x, y = shape
    nx = x // (window - min_overlap) + 1
    x1 = np.linspace(0, x, num=nx, endpoint=False, dtype=np.int64)
    x1[-1] = x - window
    x2 = (x1 + window).clip(0, x)
    assert np.all(x2-x1 == window), "Row or height not equal to window. All tiles must be window_x_window"
    ny = y // (window - min_overlap) + 1
    y1 = np.linspace(0, y, num=ny, endpoint=False, dtype=np.int64)
    y1[-1] = y - window
    y2 = (y1 + window).clip(0, y)
    assert np.all(y2-y1 == window), "Column or width not equal to window. All tiles must be window_x_window"
    slices = np.zeros((nx,ny, 4), dtype=np.int64)
    
    for i in range(nx):
        for j in range(ny):
            slices[i,j] = x1[i], x2[i], y1[j], y2[j]    
    return slices.reshape(nx*ny,4)

In [None]:
std_th = 7
x_means,x2_means = [],[]

with zipfile.ZipFile(OUT_TRAIN, 'w') as img_out,\
 zipfile.ZipFile(OUT_MASKS, 'w') as mask_out:
    for index, encs in tqdm(df_masks.iterrows(),total=len(df_masks)):
        img_dataset = rasterio.open(os.path.join(DATA,index+'.tiff'))
        if img_dataset.count != 3:
            img_layers = []
            if len(img_dataset.subdatasets) > 0:
                for i, subdataset in enumerate(img_dataset.subdatasets, 0):
                    img_layers.append(rasterio.open(subdataset))
        slice_coords = make_grid(img_dataset.shape, window=WINDOW, min_overlap=OVERLAP) # np.array; num_slices,4; np.int64 
        mask = enc2mask(encs,(img_dataset.shape[1],img_dataset.shape[0]))  # np.array; H,W; np.uint8
        
        for i, coords in enumerate(slice_coords):
            x1,x2,y1,y2 = coords
            if img_dataset.count == 3:
                img_tile = img_dataset.read(window=Window.from_slices((x1,x2),(y1,y2))) # shape: [C,H,W]
            else:
                img_tile = np.zeros((3,WINDOW,WINDOW),np.uint8)
                for j, layer in enumerate(img_layers):
                    img_tile[j,:,:] = layer.read(window=Window.from_slices((x1,x2),(y1,y2)))[0]
            mask_tile = mask[x1:x2,y1:y2]
            img_tile = img_tile.transpose((1,2,0)) # shape: [H,W,C], np.uint8
            
            #remove image slices with no real content
            if np.all(np.array([np.std(img_tile[:,:,i]) for i in range(3)]) <= std_th):
                continue
            
            img_tile = cv2.resize(img_tile, (sz,sz), interpolation=cv2.INTER_AREA)
            mask_tile = cv2.resize(mask_tile, (sz,sz), interpolation=cv2.INTER_NEAREST)
            
            x_means.append((img_tile/255.0).reshape(-1,3).mean(0))
            x2_means.append(((img_tile/255.0)**2).reshape(-1,3).mean(0))
            
            img_tile = cv2.imencode('.png',cv2.cvtColor(img_tile, cv2.COLOR_RGB2BGR))[1]
            img_out.writestr(f'{index}_{i}.png', img_tile)
            mask_tile = cv2.imencode('.png',mask_tile)[1]
            mask_out.writestr(f'{index}_{i}.png', mask_tile)
                
#image stats
img_mean =  np.stack(x_means).mean(0)
img_std =  np.sqrt(np.stack(x2_means).mean(0) - img_mean**2)
print(f'mean:{img_mean}, std:{img_std}')

In [None]:
columns, rows = 4,4
idx0 = 763
fig=plt.figure(figsize=(columns*4, rows*4))
with zipfile.ZipFile(OUT_TRAIN, 'r') as img_arch, \
     zipfile.ZipFile(OUT_MASKS, 'r') as msk_arch:
    fnames = img_arch.namelist()
    for i in range(rows):
        for j in range(columns):
            idx = i*columns+j
            img = cv2.imdecode(np.frombuffer(img_arch.read(fnames[idx0+idx]), 
                                             np.uint8), cv2.IMREAD_COLOR)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            mask = cv2.imdecode(np.frombuffer(msk_arch.read(fnames[idx0+idx]), 
                                              np.uint8), cv2.IMREAD_GRAYSCALE)
    
            fig.add_subplot(rows, columns, idx+1)
            plt.axis('off')
            plt.imshow(img)
            #plt.imshow(Image.fromarray(mask), alpha=0.2)
plt.show()

In [None]:
columns, rows = 4,4
idx0 = 763
fig=plt.figure(figsize=(columns*4, rows*4))
with zipfile.ZipFile(OUT_TRAIN, 'r') as img_arch, \
     zipfile.ZipFile(OUT_MASKS, 'r') as msk_arch:
    fnames = img_arch.namelist()
    for i in range(rows):
        for j in range(columns):
            idx = i*columns+j
            img = cv2.imdecode(np.frombuffer(img_arch.read(fnames[idx0+idx]), 
                                             np.uint8), cv2.IMREAD_COLOR)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            mask = cv2.imdecode(np.frombuffer(msk_arch.read(fnames[idx0+idx]), 
                                              np.uint8), cv2.IMREAD_GRAYSCALE)
    
            fig.add_subplot(rows, columns, idx+1)
            plt.axis('off')
            plt.imshow(img)
            plt.imshow(mask, alpha=0.2)
plt.show()

In [None]:
! ls -alrh 