CHANGELOG

v0.8 11/18/2023
- hold out and external testing

v0.7 10/15/2023
- modified 0.6 for classification augmentation

v0.6 10/10/2023
- load each raw shard one at a time, augment, then save

v0.5 10/9/2023
- convert to Vessel_WallsLumen_Segmentation version
- repeat 30 times (done for original dataset)

v0.4
- initiate version for arteriolosclerotic vessel classification and segmentation

v0.3
    augment = tf.keras.Sequential([
        layers.RandomTranslation(0.08, 0.08, fill_mode='wrap'),
        layers.RandomRotation(1, fill_mode='wrap'),
        layers.RandomFlip(),
    ])
    
    image = rgb_perturb_stain_concentration(image, sigma1=0.58, sigma2=0.27)

v0.2
- rolling color augmentation into tf.data.Dataset.map.. this seems to resolve memory issue

v0.1 6/29/2023
- add in color augmentation
- add in restart run all function to refresh memory each loop

v0.0 6/16/23:
- no color augmentation
- no stain normalization

Pseudocode Outline

Create list of folder names in wsi-arterio_batches_all/data
Loop through list of folder names
    Load train dats and msk
    Load valid dats and msk
    Augment train dats and msk tf dataset
    Convert valid dats and msk to tf dataset
    Save train and valid into the same folder

In [1]:
from jarvis.utils.general import gpus
gpus.autoselect()

import glob, numpy as np, pandas as pd, tensorflow as tf, matplotlib.pyplot as plt, skimage.io, os, time
from tensorflow.keras import Input, Model, models, layers, optimizers, losses, callbacks, utils
from pathlib import Path
from jarvis.utils.display import imshow
from PIL import Image
from skimage.transform import rescale, resize
from histomicstk.preprocessing.augmentation import rgb_perturb_stain_concentration
from IPython.display import HTML, Javascript, display

Image.MAX_IMAGE_PIXELS = None

import sys  
sys.path.append('/home/jjlou/Jerry/jerry_packages')
from jerry_utils import restart_kernel, load_dataset

[ 2024-04-02 22:54:31 ] CUDA_VISIBLE_DEVICES automatically set to: 1           


In [2]:
root = '/home/jjlou/Jerry/wsi-arterio/arteriosclerotic_vessel_detection_and_fine_segmentation/Arteriolosclerosis_classification/data_test'

In [3]:
# random flip, translation, and rotation via tensorflow layers
def aug(images, labels):
    
    augment = tf.keras.Sequential([
        layers.RandomTranslation(0.05, 0.05),
        layers.RandomRotation(1),
        layers.RandomFlip(),
    ])
       
    images = tf.cast(images, 'uint8')
      
    images = augment(images)  
    
    return tf.cast(images, 'uint8'), tf.cast(labels, 'uint8') 

In [4]:
def color_aug(image, labels): 
    image = rgb_perturb_stain_concentration(image, sigma1=0.43, sigma2=0.17)
    return tf.cast(image, 'uint8'), tf.cast(labels, 'uint8')

In [5]:
#tf.numpy_function creates output with unknown shape.. so need to reshape it
def reshape(image, labels):
    image = tf.reshape(image, [512,512,3])
    labels = tf.reshape(labels, [1,])
    return tf.cast(image, 'uint8'), tf.cast(labels, 'uint8')

In [None]:
train_save = f'{root}/Train_Color.BasicMorph.Aug_x38'
train_shards_load_pos = sorted(glob.glob(f'{root}/train/pos/*'))    
train_shards_load_neg = sorted(glob.glob(f'{root}/train/neg/*'))

train_shard_num_pos = len(train_shards_load_pos)
train_shard_num_neg = len(train_shards_load_neg)

# Augment Train Pos
# repeat 19 times since average pos/neg patch ratio is 19 and we well repeat neg shards x2
if not glob.glob(f'{train_save}/pos_shard_{train_shard_num_pos-1}/*/*/*.snapshot'):
    for s in train_shards_load_pos:
        shard_save = f'{train_save}/pos_shard_{train_shards_load_pos.index(s)}'
        if not glob.glob(f'{shard_save}/*/*/*.snapshot'):
            assert not glob.glob(f'{shard_save}/*/*.shard'), f'{shard_save} did not save properly'
            shard = tf.data.Dataset.load(s)
            shard = (
                shard.shuffle(len(shard), reshuffle_each_iteration=True)
                .repeat(38)
                .map(lambda i,m: tf.numpy_function(color_aug, [i,m], [tf.uint8, tf.uint8]), num_parallel_calls=tf.data.AUTOTUNE)
                .map(lambda i,m: reshape(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                .map(lambda i,m: aug(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                .batch(1)
            ) 
            shard.save(shard_save)
            restart_kernel()
            time.sleep(10)
        else:
            continue

# Augment Train Neg
if not glob.glob(f'{train_save}/neg_shard_{train_shard_num_neg-1}/*/*/*.snapshot'):
    for s in train_shards_load_neg:
        shard_save = f'{train_save}/neg_shard_{train_shards_load_neg.index(s)}'
        if not glob.glob(f'{shard_save}/*/*/*.snapshot'):
            assert not glob.glob(f'{shard_save}/*/*.shard'), f'{shard_save} did not save properly'
            shard = tf.data.Dataset.load(s)
            shard = (
                shard.shuffle(len(shard), reshuffle_each_iteration=True)
                .repeat(2)
                .map(lambda i,m: tf.numpy_function(color_aug, [i,m], [tf.uint8, tf.uint8]), num_parallel_calls=tf.data.AUTOTUNE)
                .map(lambda i,m: reshape(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                .map(lambda i,m: aug(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                .batch(1)
            ) 
            shard.save(shard_save)
            restart_kernel()
            time.sleep(10)
        else:
            continue

In [None]:
hold_save = f'{root}/hold_processed'
hold_pos = sorted(glob.glob(f'{root}/hold_out/pos/*'))    
hold_neg = sorted(glob.glob(f'{root}/hold_out/neg/*'))
hold_load = hold_pos + hold_neg

if len(glob.glob(f'{hold_save}/*')) < len(hold_load):
    for s in hold_load:
        modifier = s.split('/')[-2]
        shard_save = f'{hold_save}/{modifier}_shard_{hold_load.index(s)}'
        if not glob.glob(f'{shard_save}/*/*/*.snapshot'):
            assert not glob.glob(f'{shard_save}/*/*.shard'), f'{shard_save} did not save properly'
            shard = tf.data.Dataset.load(s)
            shard = shard.batch(1)
            shard.save(shard_save)
            restart_kernel()
            time.sleep(10)
        else:
            continue

In [None]:
external_save = f'{root}/external_processed'
external_pos = sorted(glob.glob(f'{root}/external/pos/*'))    
external_neg = sorted(glob.glob(f'{root}/external/neg/*'))
external_load = external_pos + external_neg

if len(glob.glob(f'{external_save}/*')) < len(external_load):
    for s in external_load:
        modifier = s.split('/')[-2]
        shard_save = f'{external_save}/{modifier}_shard_{external_load.index(s)}'
        if not glob.glob(f'{shard_save}/*/*/*.snapshot'):
            assert not glob.glob(f'{shard_save}/*/*.shard'), f'{shard_save} did not save properly'
            shard = tf.data.Dataset.load(s)
            shard = shard.batch(1)
            shard.save(shard_save)
            restart_kernel()
            time.sleep(10)
        else:
            continue