CHANGELOG

v0.8
- hold out and external testing

v0.7 11/1/2023
- divide train into pos and neg masks
- repeat pos x5 because there are about x5 neg compared to pos dats/mask pairs

v0.6 10/9/2023
- change to segmentation mask only
- load each shard one by one, augment, then save

v0.5
- revise v0.4 to version for vessel detection and rough segmentation

v0.4
- initiate version for arteriolosclerotic vessel classification and fine segmentation

v0.3
    augment = tf.keras.Sequential([
        layers.RandomTranslation(0.08, 0.08, fill_mode='wrap'),
        layers.RandomRotation(1, fill_mode='wrap'),
        layers.RandomFlip(),
    ])
    
    image = rgb_perturb_stain_concentration(image, sigma1=0.58, sigma2=0.27)

v0.2
- rolling color augmentation into tf.data.Dataset.map.. this seems to resolve memory issue

v0.1 6/29/2023
- add in color augmentation
- add in restart run all function to refresh memory each loop

v0.0 6/16/23:
- no color augmentation
- no stain normalization

Pseudocode Outline

Create list of folder names in wsi-arterio_batches_all/data
Loop through list of folder names
    Load train dats and msk
    Load valid dats and msk
    Augment train dats and msk tf dataset
    Convert valid dats and msk to tf dataset
    Save train and valid into the same folder

In [1]:
from jarvis.utils.general import gpus
gpus.autoselect()

import glob, numpy as np, pandas as pd, tensorflow as tf, matplotlib.pyplot as plt, skimage.io, os, gc
from tensorflow.keras import Input, Model, models, layers, optimizers, losses, callbacks, utils
from pathlib import Path
from jarvis.utils.display import imshow
from PIL import Image
from skimage.transform import rescale, resize
from histomicstk.preprocessing.augmentation import rgb_perturb_stain_concentration
from IPython.display import HTML, Javascript, display

Image.MAX_IMAGE_PIXELS = None

import sys  
sys.path.append('/home/jjlou/Jerry/jerry_packages')
from jerry_utils import restart_kernel, load_dataset, save_dataset

[ 2023-11-24 03:50:56 ] CUDA_VISIBLE_DEVICES automatically set to: 3           


In [2]:
root = '/home/jjlou/Jerry/wsi-arterio/vessel_detection_and_rough_segmentation/data_test'

In [3]:
# random flip and random crop via tensorflow layers
def aug(images, mask):
    
    augment = tf.keras.Sequential([
        layers.RandomCrop(512, 512),
        layers.RandomFlip()
    ])
       
    images = tf.cast(images, 'uint8')
    mask = tf.cast(mask, 'uint8') 
    
    mask = tf.stack([mask, mask, mask], -1)
    
    images_mask = tf.concat([images, mask], -1)  
    images_mask = augment(images_mask)  
    
    image = images_mask[:,:,:3]
    mask = images_mask[:,:,3]
    mask = tf.expand_dims(mask,axis=2)
    
    return tf.cast(image, 'uint8'), tf.cast(mask, 'uint8') 

In [4]:
def color_aug(image, mask): 
    image = rgb_perturb_stain_concentration(image, sigma1=0.43, sigma2=0.17)
    return tf.cast(image, 'uint8'), tf.cast(mask, 'uint8')

In [5]:
#tf.numpy_function creates output with unknown shape.. so need to reshape it
def reshape(image, mask):
    image = tf.reshape(image, [512,512,3])
    mask = tf.reshape(mask, [512,512,1])
    return tf.cast(image, 'uint8'), tf.cast(mask, 'uint8')

In [6]:
def expand_mask_dims(image, mask):
    mask = tf.expand_dims(mask, axis=2)
    return image, mask

In [7]:
def process_data_train(load=None, save_path=None):
    for s in load:
        ID = s.split('/')[-2]
        shard_num = s.split('/')[-1]
        save_file = f'{save_path}/{ID}/{shard_num}'
        if not glob.glob(f'{save_file}/*/*/*.snapshot'):
            assert not glob.glob(f'{save_file}/*/*.shard'), f'{save_file} did not save'
            shard = tf.data.Dataset.load(s)
            shard = (
                shard.map(lambda i, m: aug(i, m))
                #.map(lambda i, m: tf.numpy_function(color_aug, [i, m], [tf.uint8, tf.uint8]))
                #.map(lambda i, m: reshape(i, m))
                .batch(1)
            )
            shard.save(save_file)
            restart_kernel()
            time.sleep(10)
        else:
            continue

In [8]:
def process_data_valid(load=None, save_path=None):
    for s in load:
        ID = s.split('/')[-2]
        shard_num = s.split('/')[-1]
        save_file = f'{save_path}/{ID}/{shard_num}'
        if not glob.glob(f'{save_file}/*/*/*.snapshot'):
            assert not glob.glob(f'{save_file}/*/*.shard'), f'{save_file} did not save'
            shard = tf.data.Dataset.load(s)
            shard = (
                shard.map(lambda i,m: expand_mask_dims(i, m), num_parallel_calls=tf.data.AUTOTUNE)
                .batch(1)
            )
            shard.save(save_file)
            restart_kernel()
            time.sleep(10)
        else:
            continue

In [9]:
train_load = sorted(glob.glob(f'{root}/train_balanced/*/*'))
train_save = f'{root}/train_Basic.Aug'

process_data_train(load=train_load, save_path=train_save)

In [10]:
external_load = sorted(glob.glob(f'{root}/external_filtered/*/*'))
external_save = f'{root}/external_processed'

process_data_valid(load=external_load, save_path=external_save)

In [11]:
hold_out_load = sorted(glob.glob(f'{root}/hold_out_filtered/*/*'))
hold_out_save = f'{root}/hold_out_processed'

process_data_valid(load=hold_out_load, save_path=hold_out_save)