CHANGELOG

v0.2 11/18/2023
- hold out and external testing

v0.1 10/13/2023
- change to classification version

v0.0 7/14/2023
- initiate from the wsi-arterio_batches_all_classification Data_Partitioner_v0.4

In [1]:
from jarvis.utils.general import gpus
gpus.autoselect()

import glob, numpy as np, pandas as pd, tensorflow as tf, matplotlib.pyplot as plt, os, gc
from jarvis.utils.display import imshow
from tqdm import tqdm

# import after setting OPENCV_IO_MAX_IMAGE_PIXELS to 2^50
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = pow(2,50).__str__() 
import cv2

import sys  
sys.path.append('/home/jjlou/Jerry/jerry_packages')
from jerry_utils import restart_kernel



In [2]:
root = '/home/jjlou/Jerry/wsi-arterio/arteriosclerotic_vessel_detection_and_fine_segmentation/Arteriolosclerosis_classification/data_test'
shape = (512, 512) # resize all images to this shape
shard_size = 5 #load this number of file addresses for each shard

hold_out_list = ['UCI-37-18', 'D-492']
external_list = ['V019-B1_VP', 'V019_B3_VP']

In [3]:
# Create list of all patches (dats) and list of all masks (msk)
dats_pos_batches1and2 = sorted(glob.glob('/data/raw/wsi_arterio_2nd_alg/data/pos/*/*'))
dats_pos_batch3 = sorted(glob.glob('/data/raw/wsi_arterio_2nd_alg/batch_3_data/annotations/pos/*/*'))
dats_pos = dats_pos_batches1and2 + dats_pos_batch3

dats_neg_batches1and2 = sorted(glob.glob('/data/raw/wsi_arterio_2nd_alg/data/neg/*/*'))
dats_neg_batch3 = sorted(glob.glob('/data/raw/wsi_arterio_2nd_alg/batch_3_data/annotations/neg/*/*'))
dats_neg = dats_neg_batches1and2 + dats_neg_batch3

In [4]:
def load_data(dats=None, cls=None, shape=None):
    images = [tf.cast(cv2.resize(cv2.imread(d), shape), dtype='uint8') for d in dats]
    labels = [tf.cast(cls, dtype='uint8') for _ in range(len(images))]
    
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    
    return dataset

In [5]:
def process_data(dats=None, cls=None, shape=None, shard_size=shard_size, path_root=None, path_modifier=''):
    # Only need a path_modifier if "concatenating" datasets by saving to the same path_root
    iterate = [n for n in range(0, len(dats), shard_size)]
    for i in iterate:
        save_path = f'{path_root}/{path_modifier}shard_{iterate.index(i)}'
        if not glob.glob(f'{save_path}/*/*/*.snapshot'):
            assert not glob.glob(f'{save_path}/*/*.shard'), f'{save_path} did not save properly'
            
            shard_dats = dats[i:i+shard_size]
            shard = load_data(dats=shard_dats, cls=cls, shape=shape)
            shard.save(save_path) 
            
            if i+shard_size-1 < len(dats):
                record = {'start': dats[i], 'stop': dats[i+shard_size-1]}
            else:
                record = {'start': dats[i], 'stop': dats[len(dats)-1]}
            record = pd.DataFrame.from_dict(record, orient='index')
            record.to_csv(f'{save_path}/record.csv') 
            
            restart_kernel()
            time.sleep(10)

        else:
            continue

In [6]:
hold_out_dats_pos = []
hold_out_dats_neg = []

# Find all instances of the hold out patient IDs in dats and msk then save in corresponding list
for h in hold_out_list:
    for d in dats_pos:
        if d.find(h) != -1:
            hold_out_dats_pos.append(d)
        else:
            continue 
    for dn in dats_neg:
        if dn.find(h) != -1:
            hold_out_dats_neg.append(dn)
            
hold_path = f'{root}/hold_out'   
hold_num_pos = len([n for n in range(0, len(hold_out_dats_pos), shard_size)]) - 1
hold_num_neg = len([n for n in range(0, len(hold_out_dats_neg), shard_size)]) - 1

if not glob.glob(f'{hold_path}/pos/pos_shard_{hold_num_pos}/*/*/*.snapshot'):
    
    # process positive hold out
    process_data(
        dats=hold_out_dats_pos, 
        cls=1, 
        shape=shape, 
        shard_size=shard_size, 
        path_root=f'{hold_path}/pos', 
        path_modifier='pos_')

if not glob.glob(f'{hold_path}/neg/neg_shard_{hold_num_neg}/*/*/*.snapshot'):
    
    # process negative hold out
    process_data(
        dats=hold_out_dats_neg, 
        cls=0, 
        shape=shape, 
        shard_size=shard_size, 
        path_root=f'{hold_path}/neg', 
        path_modifier='neg_')

# Remove hold out dats and msks from original list
for d in hold_out_dats_pos:
    dats_pos.remove(d)
for n in hold_out_dats_neg:
    dats_neg.remove(n)

In [7]:
external_dats_pos = []
external_dats_neg = []

# Find all instances of the external  patient IDs in dats and msk then save in corresponding list
for h in external_list:
    for d in dats_pos:
        if d.find(h) != -1:
            external_dats_pos.append(d)
        else:
            continue 
    for dn in dats_neg:
        if dn.find(h) != -1:
            external_dats_neg.append(dn)
            
external_path = f'{root}/external'   
external_num_pos = len([n for n in range(0, len(external_dats_pos), shard_size)]) - 1
external_num_neg = len([n for n in range(0, len(external_dats_neg), shard_size)]) - 1

if not glob.glob(f'{external_path}/pos/pos_shard_{external_num_pos}/*/*/*.snapshot'):
    
    # process positive external 
    process_data(
        dats=external_dats_pos, 
        cls=1, 
        shape=shape, 
        shard_size=shard_size, 
        path_root=f'{external_path}/pos', 
        path_modifier='pos_')

if not glob.glob(f'{external_path}/neg/neg_shard_{external_num_neg}/*/*/*.snapshot'):
    
    # process negative external 
    process_data(
        dats=external_dats_neg, 
        cls=0, 
        shape=shape, 
        shard_size=shard_size, 
        path_root=f'{external_path}/neg', 
        path_modifier='neg_')

# Remove external  dats and msks from original list
for d in external_dats_pos:
    dats_pos.remove(d)
for n in external_dats_neg:
    dats_neg.remove(n)

In [8]:
train_path = f'{root}/train'   
train_num_pos = len([n for n in range(0, len(dats_pos), shard_size)]) - 1
train_num_neg = len([n for n in range(0, len(dats_neg), shard_size)]) - 1

if not glob.glob(f'{train_path}/pos/pos_shard_{train_num_pos}/*/*/*.snapshot'):
    
    # process positive train out
    process_data(
        dats=dats_pos, 
        cls=1, 
        shape=shape, 
        shard_size=shard_size, 
        path_root=f'{train_path}/pos', 
        path_modifier='pos_')

if not glob.glob(f'{train_path}/neg/neg_shard_{train_num_neg}/*/*/*.snapshot'):
    
    # process negative train out
    process_data(
        dats=dats_neg, 
        cls=0, 
        shape=shape, 
        shard_size=shard_size, 
        path_root=f'{train_path}/neg', 
        path_modifier='neg_')