CHANGELOG

v0.9
- hold out and external testing
- check for white out after color augmentation. If yes, only basic morph aug.
- shard each i,m pair into it's one individual dataset object

v0.8 11/24/2023
- for hold out and external testing

v0.7 11/1/2023
- divide train into pos and neg masks
- repeat pos x5 because there are about x5 neg compared to pos dats/mask pairs

v0.6 10/9/2023
- change to segmentation mask only
- load each shard one by one, augment, then save

v0.5
- revise v0.4 to version for vessel detection and rough segmentation

v0.4
- initiate version for arteriolosclerotic vessel classification and fine segmentation

v0.3
    augment = tf.keras.Sequential([
        layers.RandomTranslation(0.08, 0.08, fill_mode='wrap'),
        layers.RandomRotation(1, fill_mode='wrap'),
        layers.RandomFlip(),
    ])
    
    image = rgb_perturb_stain_concentration(image, sigma1=0.58, sigma2=0.27)

v0.2
- rolling color augmentation into tf.data.Dataset.map.. this seems to resolve memory issue

v0.1 6/29/2023
- add in color augmentation
- add in restart run all function to refresh memory each loop

v0.0 6/16/23:
- no color augmentation
- no stain normalization

Pseudocode Outline

Create list of folder names in wsi-arterio_batches_all/data
Loop through list of folder names
    Load train dats and msk
    Load valid dats and msk
    Augment train dats and msk tf dataset
    Convert valid dats and msk to tf dataset
    Save train and valid into the same folder

In [1]:
from jarvis.utils.general import gpus
gpus.autoselect()

import glob, numpy as np, tensorflow as tf, time
from tensorflow.keras import layers
from histomicstk.preprocessing.augmentation import rgb_perturb_stain_concentration

import sys  
sys.path.append('/home/jjlou/Jerry/jerry_packages')
from jerry_utils import restart_kernel, shard_dataset

[ 2024-01-22 08:43:35 ] CUDA_VISIBLE_DEVICES automatically set to: 2           


In [2]:
root = '/home/jjlou/Jerry/wsi-arterio/vessel_detection_and_rough_segmentation/data_test'
repeats = 5

In [3]:
# random flip via tensorflow layers
def aug(images, mask):
    
    augment = tf.keras.Sequential([
        layers.RandomTranslation(0.05, 0.05),
        layers.RandomRotation(1),
        layers.RandomFlip(),
    ])
       
    images = tf.cast(images, 'uint8')
    mask = tf.cast(mask, 'uint8') 
    
    mask = tf.stack([mask, mask, mask], -1)
    
    images_mask = tf.concat([images, mask], -1)  
    images_mask = augment(images_mask)  
    
    image = images_mask[:,:,:3]
    mask = images_mask[:,:,3]
    mask = tf.expand_dims(mask,axis=2)
    
    return tf.cast(image, 'uint8'), tf.cast(mask, 'uint8') 

In [4]:
def color_aug(image, mask): 
    image = rgb_perturb_stain_concentration(image, sigma1=0.43, sigma2=0.17)
    return tf.cast(image, 'uint8'), tf.cast(mask, 'uint8')

In [5]:
#tf.numpy_function creates output with unknown shape.. so need to reshape it
def reshape(image, mask):
    image = tf.reshape(image, [512,512,3])
    mask = tf.reshape(mask, [512,512])
    return tf.cast(image, 'uint8'), tf.cast(mask, 'uint8')

In [6]:
load = sorted(glob.glob(f'{root}/train_Basic.Aug/*/*'))
save = f'{root}/train_Color.Basic.Aug_x{repeats}_sharded'

In [7]:
for s in load:
    ID = s.split('/')[-2]
    num_shard = s.split('/')[-1]
    instance_num = 0
    block = tf.data.Dataset.load(s)
    block_length = len(block)
    if not glob.glob(f'{save}/{ID}/{num_shard}_instance.{block_length-1}/shard_{repeats-1}'):
        for i, l in block:
            shard_save = f'{save}/{ID}/{num_shard}_instance.{instance_num}'
            if not glob.glob(f'{shard_save}/shard_{repeats-1}/*/*/*.snapshot'):
                shard = tf.data.Dataset.from_tensors((i,l))
                shard = shard.repeat(repeats)
                test = tf.data.Dataset.from_tensors((i,l))
                test = (
                    test.map(lambda i,m: reshape(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                    .map(lambda i,m: tf.numpy_function(color_aug, [i,m], [tf.uint8, tf.uint8]), num_parallel_calls=tf.data.AUTOTUNE)
                    .map(lambda i,m: reshape(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                )
                try:
                    for a,b in test:
                        rgb_mean = tf.math.reduce_mean(a)
                except:
                        shard = (
                            shard.map(lambda i,m: reshape(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                            .map(lambda i,m: aug(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                            .batch(1)
                        )
                        shard_dataset(shard, shard_save)
                        instance_num += 1
                        continue
                if (rgb_mean > 210) or (rgb_mean < 65) or (rgb_mean == None):
                    shard = (
                        shard.map(lambda i,m: reshape(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                        .map(lambda i,m: aug(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                        .batch(1)
                    )
                    shard_dataset(shard, shard_save)
                else:
                    shard = (
                        shard.map(lambda i,m: reshape(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                        .map(lambda i,m: tf.numpy_function(color_aug, [i,m], [tf.uint8, tf.uint8]), num_parallel_calls=tf.data.AUTOTUNE)
                        .map(lambda i,m: reshape(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                        .map(lambda i,m: aug(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                        .batch(1)
                    )
                    shard_dataset(shard, shard_save)
                instance_num += 1
            else:
                instance_num += 1
        restart_kernel()
        time.sleep(10)
    else:
        continue