In [1]:
import numpy as np
from pathlib import Path
import os
import shutil
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
src_fldr = Path('stenoses_data/')

In [5]:
def get_range(mask, axis):
    size = mask.shape[axis]
    idxs = np.where((mask == 1).any(axis=axis))[0]
    imin, imax = idxs.min(), idxs.max()
    eps = 20
    imin, imax = max(imin-eps, 0), min(imax+eps, size)
    # return imin, imax
    return imax - imin

def get_slice(L, idx): return [l[idx] for l in L]
def slice_up(L):
    n = len(L[0])
    for i,l in enumerate(L):
        assert len(l) == n, f'shape mismatch at {i}'
    return [get_slice(L,i) for i in range(n)]

In [6]:
sub = src_fldr / 'train'
sizes = []
for pt in (sub/'masks').iterdir():
    mask = Image.open(pt)
    mask = np.array(mask)
    sizes.append((get_range(mask, 0), get_range(mask, 1)))
hsizes, wsizes = slice_up(sizes)
print(min(hsizes), max(hsizes))
print(min(wsizes), max(wsizes))

134 205
51 89


In [7]:
sub = src_fldr / 'test'
sizes = []
for pt in (sub/'masks').iterdir():
    mask = Image.open(pt)
    mask = np.array(mask)
    sizes.append((get_range(mask, 0), get_range(mask, 1)))
hsizes, wsizes = slice_up(sizes)
print(min(hsizes), max(hsizes))
print(min(wsizes), max(wsizes))

177 222
53 72


In [15]:
tgt_fldr = Path('./cut_stenoses_data')
tgt_fldr.mkdir()

In [13]:
def get_range(mask, axis):
    size = mask.shape[axis]
    idxs = np.where((mask == 1).any(axis=axis))[0]
    imin, imax = idxs.min(), idxs.max()
    eps = 20
    imin, imax = max(imin-eps, 0), min(imax+eps, size)
    pad = 256 - (imax - imin)
    assert pad >= 0
    assert size >= 256
    imin = max(imin - pad // 2, 0)
    imax = imin + 256
    return imin, imax

In [10]:
def read(pt):
    return np.array(Image.open(pt))
def write(arr, pt):
    Image.fromarray(arr).save(pt)

In [16]:
for phase in ['train', 'test']:
    sub = src_fldr / phase
    (tgt_fldr / phase).mkdir()
    (tgt_fldr / phase / 'images').mkdir()
    (tgt_fldr / phase / 'masks').mkdir()
    for mask_pt in (sub/'masks').iterdir():
        img_pt = sub / 'images' / mask_pt.name
        mask = read(mask_pt)
        img = read(img_pt)
        hmin, hmax = get_range(mask, 0)
        wmin, wmax = get_range(mask, 1)
        mask = mask[hmin:hmax, wmin:wmax]
        img = img[hmin:hmax, wmin:wmax]
        write(mask, tgt_fldr / mask_pt.relative_to(src_fldr))
        write(img, tgt_fldr / img_pt.relative_to(src_fldr))