CHANGELOG

v0.1 11/13/2023
- Divide into Training and Hold out only

v0.0 7/14/2023
- initiate from the wsi-arterio_batches_all_classification Data_Partitioner_v0.4

In [1]:
from jarvis.utils.general import gpus
gpus.autoselect()

import glob, numpy as np, pandas as pd, tensorflow as tf, matplotlib.pyplot as plt, os, time

# import after setting OPENCV_IO_MAX_IMAGE_PIXELS to 2^50
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = pow(2,50).__str__() 
import cv2

import sys  
sys.path.append('/home/jjlou/Jerry/jerry_packages')
from jerry_utils import restart_kernel

[ 2023-11-13 22:09:44 ] CUDA_VISIBLE_DEVICES automatically set to: 0           


In [2]:
root = '/home/jjlou/Jerry/wsi-arterio/arteriosclerotic_vessel_detection_and_fine_segmentation/Vessel_WallsLumen_Segmentation/data_test'
shape = (512, 512) # resize all images to this shape
shard_size = 20 #load this number of file addresses for each shard

# Create hold out and valid lists of patient IDs
hold_out_list = ['UCI-37-18', 'D-492', 'V019-B1_VP', 'V019_B3_VP']

In [3]:
# Create list of all patches (dats) and list of all masks (msk)
dats_pos_batches1and2 = sorted(glob.glob('/data/raw/wsi_arterio_2nd_alg/data/pos/*/*'))
dats_pos_batch3 = sorted(glob.glob('/data/raw/wsi_arterio_2nd_alg/batch_3_data/annotations/pos/*/*'))
dats_pos = dats_pos_batches1and2 + dats_pos_batch3

msk_batches1and2 = sorted(glob.glob('/data/raw/wsi_arterio_2nd_alg/data/msk/*/*'))
msk_batch3 = sorted(glob.glob('/data/raw/wsi_arterio_2nd_alg/batch_3_data/annotations/msk/*/*'))
msk_pos = msk_batches1and2 + msk_batch3

In [4]:
def load_data(dats=None, msk=None, shape=shape):
    images = np.stack([tf.cast(cv2.resize(cv2.imread(d), shape), dtype='uint8') for d in dats])
    
    LMask = []
    Mask = []
    for m in msk:
        msk_loaded = cv2.resize(cv2.imread(m), shape)
        msk_loaded = msk_loaded[:,:,0]
        msk_loaded = msk_loaded > 0
        msk_loaded = msk_loaded.astype(float)
        find_L = m.find('L')
        if find_L != -1:
            LMask.append(msk_loaded)
        elif find_L == -1:
            Mask.append(msk_loaded)

    masks = []
    for LMask, Mask in zip(LMask, Mask):
        fm = Mask + LMask
        fm = tf.cast(fm, dtype='uint8')
        masks.append(fm)
    masks = np.stack(masks)
    
    dataset = tf.data.Dataset.from_tensor_slices((images, masks))
    
    return dataset

In [5]:
def process_data(dats=None, msk=None, path_root=None, path_modifier='', shape=shape, shard_size=shard_size):
    iterate = [n for n in range(0, len(dats), shard_size)]
    for i in iterate:
        save_path = f'{path_root}/{path_modifier}shard_{iterate.index(i)}'
        if glob.glob(f'{save_path}/*/*/*.snapshot'):
            continue
        else:
            shard_dats = dats[i:i+shard_size]
            shard_msk = msk[i*2:(i+shard_size)*2]
            shard = load_data(dats=shard_dats, msk=shard_msk)

            # Only need a path_modifier if "concatenating" datasets by saving to the same path_root
            shard.save(save_path) 
            if i+shard_size-1 < len(dats):
                record = {'start': dats[i], 'stop': dats[i+shard_size-1]}
            else:
                record = {'start': dats[i], 'stop': dats[len(dats)-1]}
            record = pd.DataFrame.from_dict(record, orient='index')
            record.to_csv(f'{save_path}/record.csv') 
            
            restart_kernel()
            time.sleep(10)

In [6]:
hold_out_dats_pos = []
hold_out_msk_pos = []

# Find all instances of the hold out patient IDs in dats and msk then save in corresponding list
for h in hold_out_list:
    for d in dats_pos:
        if d.find(h) != -1:
            hold_out_dats_pos.append(d)
        else:
            continue
    for m in msk_pos:
        if m.find(h) != -1:
            hold_out_msk_pos.append(m)
        else:
            continue   
            
hold_path = f'{root}/hold_out'
hold_num = len([n for n in range(0, len(hold_out_dats_pos), shard_size)])
if not glob.glob(f'{hold_path}/shard_{hold_num-1}/*/*/*.snapshot'):
    process_data(
        dats=hold_out_dats_pos, 
        msk=hold_out_msk_pos,   
        path_root=hold_path, 
        path_modifier='')

# Remove hold out dats and msks from original list
for d in hold_out_dats_pos:
    dats_pos.remove(d)
for m in hold_out_msk_pos:
    msk_pos.remove(m)

In [7]:
train_path = f'{root}/train'
train_num = len([n for n in range(0, len(dats_pos), shard_size)])
if not glob.glob(f'{train_path}/shard_{train_num-1}/*/*/*.snapshot'):
    process_data(
        dats=dats_pos, 
        msk=msk_pos,   
        path_root=train_path, 
        path_modifier='')