CHANGELOG

v0.3 11/15/2023
- hold out and external testing

v0.2 10/17/2023
- change to segmentation only

v0.1 7/31/2023
- OOM errors
- process each patch individually and shard the dataset to save

v0.0 7/30/2023
- initiate

In [1]:
from jarvis.utils.general import gpus
gpus.autoselect()

import glob, numpy as np, pandas as pd, tensorflow as tf, matplotlib.pyplot as plt, os, time
from jarvis.utils.display import imshow
from jarvis.utils import arrays as jars
from scipy import ndimage

# import after setting OPENCV_IO_MAX_IMAGE_PIXELS to 2^50
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = pow(2,50).__str__() 
import cv2

import sys  
sys.path.append('/home/jjlou/Jerry/jerry_packages')
from jerry_utils import restart_kernel, show



In [2]:
#### recursive algorithm to create 784,784 patches at different resolution ####

def create_tiles(arr, shape=(784, 784), **kwargs):

    assert arr.ndim == 2

    # --- Pad arr to square shape evenly divisible by shape / 2
    ratio = np.ceil(np.array(arr.shape) / (np.array(shape) * 0.5))
    full_ = np.round(ratio * (np.array(shape) * 0.5)).astype('int')
    
    arr = pad_zeros_to_shape(arr, shape=full_)

    # --- Recursively create tiles
    return create_tiles_recursive(raw=arr, shape=shape)

def create_tiles_recursive(raw, shape=(784, 784), overlap=0.5, sub=None, tiles=None, **kwargs):

    if tiles is None:
        tiles = []

    if sub is None:
        sub = raw 

    # =================================================
    # RECURSION | CASE 1 - SMALL ARRAY (FINISH)
    # =================================================
    if (sub.shape[0] <= shape[0]) and (sub.shape[1] <= shape[1]):
        return tiles + [pad_zeros_to_shape(arr=sub, shape=shape)]

    # =================================================
    # RECURSION | CASE 2 - LARGE ARRAY
    # =================================================

    # --- Create tiles at current resolution
    tiles += create_tiles_single(arr=sub, shape=shape, overlap=overlap, **kwargs)

    # --- Subsample to lower resolution
    res = (np.array(sub.shape) - np.array(shape) * overlap) / np.array(raw.shape)
    sub = ndimage.zoom(raw, res)

    return create_tiles_recursive(raw=raw, shape=shape, overlap=overlap, sub=sub, tiles=tiles) 

def create_tiles_single(arr, shape=(784, 784), overlap=0.5, **kwargs):

    tiles = []

    # --- Determine overlap step
    ii, jj = np.round(np.array(shape) * overlap).astype('int')

    for i in range(0, arr.shape[0] - shape[0] + 1, ii):
        for j in range(0, arr.shape[1] - shape[1] + 1, jj):
            tiles.append(arr[i:i+shape[0], j:j+shape[1]])

    return tiles

def pad_zeros_to_shape(arr, shape=(784, 784), **kwargs):

    if (arr.shape[0] == shape[0]) and (arr.shape[1] == shape[1]):
        return arr

    assert arr.shape[0] <= shape[0]
    assert arr.shape[1] <= shape[1]

    lo = np.round((np.array(shape) - np.array(arr.shape)) / 2).astype('int')
    hi = np.array(shape) - np.array(arr.shape) - lo

    return np.pad(arr, ((lo[0], hi[0]), (lo[1], hi[1])))

In [3]:
def load_data(dat=None, msk=None, path_root=None, shard_size=None):
    loaded_image = jars.create(dat).data[0].astype('uint8')
    loaded_mask = jars.create(msk).data[0].astype('uint8')

    ch0 = create_tiles(loaded_image[:,:,0])
    ch1 = create_tiles(loaded_image[:,:,1])
    ch2 = create_tiles(loaded_image[:,:,2])
    channels = tf.data.Dataset.from_tensor_slices((ch0, ch1, ch2))
    images = []
    for zero, one, two in channels:
        img = tf.stack([zero, one, two], axis=2)
        images.append(img)

    masks = create_tiles(loaded_mask[:,:,0])
    
    iterate = [n for n in range(0, len(images), shard_size)]
    for i in iterate:
        save_path = f'{path_root}/shard_{iterate.index(i)}'
        if not glob.glob(f'{save_path}/*/*/*.snapshot'):
            assert not glob.glob(f'{save_path}/*/*.shard'), f'{save_path} did not save properly'
            shard_dats = images[i:i+shard_size]
            shard_msks = masks[i:i+shard_size]
            shard = tf.data.Dataset.from_tensor_slices((shard_dats, shard_msks))
            shard.save(save_path)
        else:
            continue

In [4]:
def process_data(dats=None, msks=None, path_root=None, shard_size=None):
    num_shards = len(dats)//shard_size + 1
    for d, m in zip(dats, msks):
        ID = d.split('/')[-2]
        folder = f'{path_root}/{ID}'
        if not glob.glob(f'{folder}/shard_{num_shards-1}/*/*/*.snapshot'):
            assert not glob.glob(f'{folder}/shard_{num_shards-1}/*/*.shard'), f'{folder} did not save properly'
            load_data(dat=d, msk=m, path_root=folder, shard_size=shard_size)
            restart_kernel()
            time.sleep(10)
        else:
            continue

In [5]:
root = '/home/jjlou/Jerry/wsi-arterio/vessel_detection_and_rough_segmentation/data_test'
shard_size = 100

# Create hold out and valid lists of patient IDs
hold_out_list = ['UCI-15-12-MF-HE', 'UCI-15-12-OCC-HE']

external_list = ['V019_B1_HE', 'V019_B3_HE']

# Slides not included in study but present in folder
remove = ['UCI-15-12-OCC-AB', 'UCI-15-12-ST-AB', 'V019_B2_AB40', 'V019_B3_AB40']

In [6]:
dats = sorted(glob.glob('/data/raw/wsi_arterio/data/*/patches_full/*/dat.hdf5'))
msk = sorted(glob.glob('/data/raw/wsi_arterio/data/*/patches_full/*/lbl.hdf5'))

In [7]:
# Remove extra folders that were excluded from study
dats_remove = []
msk_remove = []
for r in remove:
    for d in dats:
        if d.find(r) != -1:
            dats_remove.append(d)
        else:
            continue
    for m in msk:
        if m.find(r) != -1:
            msk_remove.append(m)
        else:
            continue

for dr in dats_remove:
    dats.remove(dr)
for mr in msk_remove:
    msk.remove(mr)

In [8]:
hold_out_dats_list = []
hold_out_msk_list = []

# Find all instances of the hold out patient IDs in dats and msk then save in corresponding list
for h in hold_out_list:
    for d in dats:
        if d.find(h) != -1:
            hold_out_dats_list.append(d)
        else:
            continue
    for m in msk:
        if m.find(h) != -1:
            hold_out_msk_list.append(m)
        else:
            continue      

hold_path = f'{root}/hold_out'
if len(glob.glob(f'{hold_path}/*')) < len(hold_out_dats_list):
    process_data(dats=hold_out_dats_list, msks=hold_out_msk_list, path_root=hold_path, shard_size=shard_size)        

# Remove hold out dats and msks from original list
for d in hold_out_dats_list:
    dats.remove(d)
for m in hold_out_msk_list:
    msk.remove(m)

In [9]:
external_dats_list = []
external_msk_list = []

# Find all instances of the hold out patient IDs in dats and msk then save in corresponding list
for h in external_list:
    for d in dats:
        if d.find(h) != -1:
            external_dats_list.append(d)
        else:
            continue
    for m in msk:
        if m.find(h) != -1:
            external_msk_list.append(m)
        else:
            continue      

external_path = f'{root}/external'
if len(glob.glob(f'{external_path}/*')) < len(external_dats_list):
    process_data(dats=external_dats_list, msks=external_msk_list, path_root=external_path, shard_size=shard_size)        

# Remove hold out dats and msks from original list
for d in external_dats_list:
    dats.remove(d)
for m in external_msk_list:
    msk.remove(m)

In [10]:
train_path = f'{root}/train'
if len(glob.glob(f'{train_path}/*')) < len(dats):
    process_data(dats=dats, msks=msk, path_root=train_path, shard_size=shard_size) 