CHANGELOG

v0.3
- individualize each repeat shard

In [1]:
from jarvis.utils.general import gpus
gpus.autoselect()

import glob, numpy as np, tensorflow as tf, skimage.io, os, time, cv2
from tensorflow.keras import layers
from histomicstk.preprocessing.augmentation import rgb_perturb_stain_concentration
from IPython.display import HTML, Javascript, display

import sys  
sys.path.append('/home/jjlou/Jerry/jerry_packages')
from jerry_utils import restart_kernel, load_dataset, shard_dataset

[ 2023-12-02 21:38:49 ] CUDA_VISIBLE_DEVICES automatically set to: 3           


In [2]:
root = '/home/jjlou/Jerry/wsi-arterio/arteriosclerotic_vessel_detection_and_fine_segmentation/Vessel_WallsLumen_Segmentation/data_test_v2'
repeats = 12

In [3]:
# random flip via tensorflow layers
def aug(images, mask):
    
    augment = tf.keras.Sequential([
        layers.RandomTranslation(0.05, 0.05),
        layers.RandomRotation(1),
        layers.RandomFlip(),
    ])
       
    images = tf.cast(images, 'uint8')
    mask = tf.cast(mask, 'uint8') 
    
    mask = tf.stack([mask, mask, mask], -1)
    
    images_mask = tf.concat([images, mask], -1)  
    images_mask = augment(images_mask)  
    
    image = images_mask[:,:,:3]
    mask = images_mask[:,:,3]
    mask = tf.expand_dims(mask,axis=2)
    
    return tf.cast(image, 'uint8'), tf.cast(mask, 'uint8') 

In [4]:
def aug_rgb2gray(image, mask):
    image = np.reshape(image, (512,512,3))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 
    image = tf.stack([image, image, image], -1)
    return tf.cast(image, 'uint8'), tf.cast(mask, 'uint8')

In [5]:
#tf.numpy_function creates output with unknown shape.. so need to reshape it
def reshape(image, mask):
    image = tf.reshape(image, [512,512,3])
    mask = tf.reshape(mask, [512,512])
    return tf.cast(image, 'uint8'), tf.cast(mask, 'uint8')

In [6]:
train_load = f'{root}/train'
train_save = f'{root}/train_Gray.Color.BasicMorph.Aug_x{repeats}'
train_shards_load = sorted(glob.glob(f'{train_load}/*/*'))

if len(glob.glob(f'{train_save}/*/*')) < len(train_shards_load):
    for s in train_shards_load:
        folder = s.split('/')[-2]
        ID = s.split('/')[-1]
        shard_save = f'{train_save}/{folder}/{ID}'
        if not glob.glob(f'{shard_save}/shard_{repeats-1}/*/*/*.snapshot'):
            assert not glob.glob(f'{shard_save}/*/*.shard'), f'{shard_save} did not save properly'
            shard = tf.data.Dataset.load(s)
            assert len(shard) > 0, f'{s} has 0 elements'
            shard = (
                shard.shuffle(len(shard), reshuffle_each_iteration=True)
                .repeat(repeats)
                .map(lambda i,m: tf.numpy_function(aug_rgb2gray, [i,m], [tf.uint8, tf.uint8]), num_parallel_calls=tf.data.AUTOTUNE)
                .map(lambda i,m: reshape(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                .map(lambda i,m: aug(i,m), num_parallel_calls=tf.data.AUTOTUNE)
                .batch(1)
            ) 
            shard_dataset(shard, shard_save)
            restart_kernel()
            time.sleep(10)
        else:
            continue