CHANGELOG

v0.8 11/13/2023
- conver to train and hold out test

v0.6 10/10/2023
- load each raw shard one at a time, augment, then save

v0.5 10/9/2023
- convert to Vessel_WallsLumen_Segmentation version
- repeat 30 times (done for original dataset)

v0.4
- initiate version for arteriolosclerotic vessel classification and segmentation

v0.3
    augment = tf.keras.Sequential([
        layers.RandomTranslation(0.08, 0.08, fill_mode='wrap'),
        layers.RandomRotation(1, fill_mode='wrap'),
        layers.RandomFlip(),
    ])
    
    image = rgb_perturb_stain_concentration(image, sigma1=0.58, sigma2=0.27)

v0.2
- rolling color augmentation into tf.data.Dataset.map.. this seems to resolve memory issue

v0.1 6/29/2023
- add in color augmentation
- add in restart run all function to refresh memory each loop

v0.0 6/16/23:
- no color augmentation
- no stain normalization

Pseudocode Outline

Create list of folder names in wsi-arterio_batches_all/data
Loop through list of folder names
    Load train dats and msk
    Load valid dats and msk
    Augment train dats and msk tf dataset
    Convert valid dats and msk to tf dataset
    Save train and valid into the same folder

In [1]:
from jarvis.utils.general import gpus
gpus.autoselect()

import glob, numpy as np, pandas as pd, tensorflow as tf, matplotlib.pyplot as plt, skimage.io, os, time
from tensorflow.keras import Input, Model, models, layers, optimizers, losses, callbacks, utils
from pathlib import Path
from jarvis.utils.display import imshow
from PIL import Image
from skimage.transform import rescale, resize
from histomicstk.preprocessing.augmentation import rgb_perturb_stain_concentration
from IPython.display import HTML, Javascript, display

Image.MAX_IMAGE_PIXELS = None

import sys  
sys.path.append('/home/jjlou/Jerry/jerry_packages')
from jerry_utils import restart_kernel, load_dataset



In [2]:
root = '/home/jjlou/Jerry/wsi-arterio/arteriosclerotic_vessel_detection_and_fine_segmentation/Vessel_WallsLumen_Segmentation/data_test'

In [3]:
# random flip via tensorflow layers
def aug(images, mask):
    
    augment = tf.keras.Sequential([
        layers.RandomTranslation(0.05, 0.05),
        layers.RandomRotation(1),
        layers.RandomFlip(),
    ])
       
    images = tf.cast(images, 'uint8')
    mask = tf.cast(mask, 'uint8') 
    
    mask = tf.stack([mask, mask, mask], -1)
    
    images_mask = tf.concat([images, mask], -1)  
    images_mask = augment(images_mask)  
    
    image = images_mask[:,:,:3]
    mask = images_mask[:,:,3]
    mask = tf.expand_dims(mask,axis=2)
    
    return tf.cast(image, 'uint8'), tf.cast(mask, 'uint8') 

In [4]:
def color_aug(image, mask): 
    image = rgb_perturb_stain_concentration(image, sigma1=0.90, sigma2=0.60)
    return tf.cast(image, 'uint8'), tf.cast(mask, 'uint8')

In [5]:
#tf.numpy_function creates output with unknown shape.. so need to reshape it
def reshape(image, mask):
    image = tf.reshape(image, [512,512,3])
    mask = tf.reshape(mask, [512,512])
    return tf.cast(image, 'uint8'), tf.cast(mask, 'uint8')

def expand_mask_dims(image, mask):
    mask = tf.expand_dims(mask, axis=2)
    return image, mask

hold_load = f'{root}/hold_out'
hold_save = f'{root}/hold_out_Color.BasicMorph.Aug'
hold_shards_load = sorted(glob.glob(f'{hold_load}/*'))
hold_shard_num = len(hold_shards_load)

if not glob.glob(f'{hold_save}/shard_{hold_shard_num-1}/*/*/*.snapshot'):
    for s in hold_shards_load:
        shard_save = f'{hold_save}/shard_{hold_shards_load.index(s)}'
        if not glob.glob(f'{shard_save}/*/*/*.snapshot'):
            assert not glob.glob(f'{shard_save}/*/*.shard'), f'{shard_save} did not save properly'
            shard = tf.data.Dataset.load(s)
            shard = (
                        shard.map(lambda i,m: expand_mask_dims(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                        .batch(1)
                    )
            shard.save(shard_save)
            restart_kernel()
            time.sleep(10)
        else:
            continue

external_load = f'{root}/external'
external_save = f'{root}/external_Color.BasicMorph.Aug'
external_shards_load = sorted(glob.glob(f'{external_load}/*'))
external_shard_num = len(hold_shards_load)

if not glob.glob(f'{external_save}/shard_{external_shard_num-1}/*/*/*.snapshot'):
    for s in external_shards_load:
        shard_save = f'{external_save}/shard_{external_shards_load.index(s)}'
        if not glob.glob(f'{shard_save}/*/*/*.snapshot'):
            assert not glob.glob(f'{shard_save}/*/*.shard'), f'{shard_save} did not save properly'
            shard = tf.data.Dataset.load(s)
            shard = (
                        shard.map(lambda i,m: expand_mask_dims(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                        .batch(1)
                    )
            shard.save(shard_save)
            restart_kernel()
            time.sleep(10)
        else:
            continue

In [6]:
train_load = f'{root}/train'
train_save = f'{root}/train_Color.BasicMorph.Aug_x48'
train_shards_load = sorted(glob.glob(f'{train_load}/*'))
train_shard_num = len(train_shards_load)

for s in train_shards_load:
    shard_save = f'{train_save}/shard_{train_shards_load.index(s)}'
    if not glob.glob(f'{shard_save}/*/*/*.snapshot'):
        assert not glob.glob(f'{shard_save}/*/*.shard'), f'{shard_save} did not save properly'
        shard = tf.data.Dataset.load(s)
        shard = (
            shard.shuffle(len(shard), reshuffle_each_iteration=True)
            .repeat(45)
            .map(lambda i,m: tf.numpy_function(color_aug, [i,m], [tf.uint8, tf.uint8]), num_parallel_calls=tf.data.AUTOTUNE)
            .map(lambda i,m: reshape(i,m), num_parallel_calls=tf.data.AUTOTUNE)
            .map(lambda i,m: aug(i,m), num_parallel_calls=tf.data.AUTOTUNE)
            .batch(1)
        ) 
        shard.save(shard_save)
        restart_kernel()
        time.sleep(10)
    else:
        continue