In [1]:
import tensorflow as tf
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
from transunet import TransUNet

In [2]:
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
    
def fix_gpu():
    config = ConfigProto()
    config.gpu_options.allow_growth = True
    session = InteractiveSession(config=config)
    
fix_gpu()

In [3]:
from tensorflow import keras
from utils import *

In [4]:
filter_num = [32, 64, 128, 256]

In [6]:
model = models.transunet_2d((256, 256, 1), filter_num=filter_num, n_labels=1, stack_num_down=2, stack_num_up=2,
                        embed_dim=768, num_mlp=3072, num_heads=12, num_transformer=12,
                        activation='ReLU', mlp_activation='GELU', output_activation='Sigmoid', 
                        batch_norm=True, pool=True, unpool='bilinear', name='transunet')

In [2]:
# model.summary()

In [5]:
def get_dataset(dataset_type, batch_size, image_gen_shape, augment, blacklist=None):
  
    # Prepare data generator
    if dataset_type == 'ge':    
        raw_dir = "../dataUSGthyroid/GE_processed"
        blacklist = ['2061093', '1909699', '2026988', '2052396', '2051655', 
                     '2056390', '176349', '1645263', '2060219', '65544']
    elif dataset_type == 'samsung':
        raw_dir = "../dataUSGthyroid/samsung_processed"
        blacklist = ['2089146', '2110713', '2096868', '441540', '2090807', 
                     '2090038', '2091948', '935892', '2058398', '2096659']
    else:
        raise Exception('Sorry, there is no dataset type called: ' + dataset_type)
        
    raw_masks = raw_dir + "/masks"
    raw_images = raw_dir + "/images"
    raw_images_paths = sorted(glob.glob(raw_images + '**/*', recursive=True))
    raw_masks_paths = sorted(glob.glob(raw_masks + '**/*', recursive=True))
    
    if blacklist:
        raw_images_paths = _filter_paths(raw_images_paths, blacklist)
        raw_masks_paths = _filter_paths(raw_masks_paths, blacklist)
    
    TEST_LEN = 16
    VAL_LEN = 10*16
    TRAIN_LEN = len(raw_images_paths) - VAL_LEN - TEST_LEN
        
    train_images = raw_images_paths[:TRAIN_LEN]
    validation_images = raw_images_paths[-(VAL_LEN+TEST_LEN):-TEST_LEN]
    test_images =  raw_images_paths[-TEST_LEN:]

    train_masks = raw_masks_paths[:TRAIN_LEN]
    validation_masks = raw_masks_paths[-(VAL_LEN+TEST_LEN):-TEST_LEN]
    test_masks =  raw_masks_paths[-TEST_LEN:]

    train_gen = UltraSoundImages(batch_size, train_images, train_masks, size=image_gen_shape, dataset_type=dataset_type, augment=augment)
    val_gen = UltraSoundImages(batch_size, validation_images, validation_masks, size=image_gen_shape, dataset_type=dataset_type, augment=augment)
    test_gen = UltraSoundImages(batch_size, test_images, test_masks, size=image_gen_shape, dataset_type=dataset_type, augment=augment)
    
    return train_gen, val_gen, test_gen

def _filter_paths(paths, blacklist):
    new_paths = []
    for path in paths:
        file_id = path.split('/')[-1].split('_')[0]
        if file_id not in blacklist:
            new_paths.append(path)
    return new_paths

In [7]:
def get_dataset_2(dataset_type='samsung', size=(256,256)):
    if dataset_type=='samsung':
        dataset_path = "../dataUSGthyroid/samsung_processed"
    elif dataset_type=='ge':
        dataset_path = "../dataUSGthyroid/GE_processed"
    else:
        return None
    raw_images = dataset_path + "/images"
    raw_masks = dataset_path + "/masks"

    raw_images_paths = sorted(glob.glob(raw_images + '**/*', recursive=True))
    raw_masks_paths = sorted(glob.glob(raw_masks + '**/*', recursive=True))

    TEST_LEN = 16
    VAL_LEN = 4*16
    # TEST_LEN = 5
    # VAL_LEN = 40
    TRAIN_LEN = len(raw_images_paths) - VAL_LEN - TEST_LEN
    batch_size = 4


    train_images = raw_images_paths[:TRAIN_LEN]
    validation_images = raw_images_paths[-(VAL_LEN+TEST_LEN):-TEST_LEN]
    test_images =  raw_images_paths[-TEST_LEN:]

    train_masks = raw_masks_paths[:TRAIN_LEN]
    validation_masks = raw_masks_paths[-(VAL_LEN+TEST_LEN):-TEST_LEN]
    test_masks =  raw_masks_paths[-TEST_LEN:]

    train_gen = UltraSoundImages(batch_size, train_images, train_masks, size=size, augment=False)
    val_gen = UltraSoundImages(batch_size, validation_images, validation_masks, dataset_type=dataset_type, size=size, augment=True)
    test_gen = UltraSoundImages(batch_size, test_images, test_masks, dataset_type=dataset_type, size=size, augment=True)
    return train_gen, val_gen, test_gen

In [6]:
model_type = 'transunet_2d'
image_input_shape = (384, 384, 1)
image_gen_shape = image_input_shape
loss_type = 'custom_focal_tversky'
dataset_type = 'samsung'
batch_size = 4
augment = True
epochs = 10

# train_gen, val_gen, test_gen = get_dataset_2(dataset_type, (256, 256, 1))

In [11]:
images, masks = train_gen.__getitem__(0)
images.shape

(4, 256, 256, 1)

In [7]:
train_gen, val_gen, test_gen = get_dataset(dataset_type, 4, (512,512), True)

Loading images from NRRD format and resizing
Finished loading
Loading images from NRRD format and resizing
Finished loading
Loading images from NRRD format and resizing
Finished loading


In [15]:
images, masks = train_gen.__getitem__(0)
images.shape

(16, 512, 512, 1)

In [8]:
from keras_unet_collection import models, losses

In [8]:
names = []
for name, dev_type in gpus:
    names.append(name)

In [9]:
def custom_focal_tversky(y_true, y_pred, alpha=0.7, gamma=4/3):
    return losses.focal_tversky(y_true, y_pred, alpha=alpha, gamma=gamma)

custom_objects = {'custom_focal_tversky': custom_focal_tversky}

In [8]:
# run_opts = tf.RunOptions(report_tensor_allocations_upon_oom = True)

In [19]:
# tf.contrib.keras.backend.clear_session()

In [11]:
# gpus = tf.config.list_logical_devices('GPU')
strategy = tf.distribute.MirroredStrategy(tf.config.list_logical_devices('GPU'))
with strategy.scope():
    # fix_gpu()
    # Metrics
    # mean_iou = tf.keras.metrics.MeanIoU(num_classes=2)
    metrics = ['accuracy']
    
    # model = get_model(model_type, image_input_shape, filter_num)
    model = models.transunet_2d((512, 512, 1), filter_num=filter_num, n_labels=1, stack_num_down=2, stack_num_up=2,
                            embed_dim=768, num_mlp=3072, num_heads=12, num_transformer=12,
                            activation='ReLU', mlp_activation='GELU', output_activation='Sigmoid', backbone='ResNet50')
    
    # model = tf.keras.models.load_model('output/transunet_2d/samsung/size_256_model_1/model/model', custom_objects=custom_objects)
    # print(model.summary())
    # model.compile(optimizer=keras.optimizers.SGD(learning_rate=1e-2), loss='binary_crossentropy', metrics=['accuracy'])
    # model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    loss_func = get_loss_function(loss_type=loss_type)
    model.compile(optimizer=keras.optimizers.SGD(learning_rate=5e-2), loss=loss_func, metrics=metrics)
    # model.compile(optimizer=keras.optimizers.Adam(), loss=loss_func, metrics=metrics)
    history = model.fit(train_gen, validation_data=val_gen, epochs=epochs)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2')


ValueError: Cannot assign to variable conv1_conv/kernel:0 due to variable shape (7, 7, 1, 64) and value shape (64, 3, 7, 7) are incompatible

In [29]:
model.summary()

Model: "transunet_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_13 (InputLayer)           [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
transunet_down0_0 (Conv2D)      (None, 256, 256, 64) 576         input_13[0][0]                   
__________________________________________________________________________________________________
transunet_down0_0_bn (BatchNorm (None, 256, 256, 64) 256         transunet_down0_0[0][0]          
__________________________________________________________________________________________________
transunet_down0_0_activation (R (None, 256, 256, 64) 0           transunet_down0_0_bn[0][0]       
____________________________________________________________________________________

In [27]:
images, masks = val_gen.__getitem__(0, augment=True)

In [28]:
images.shape

(2, 256, 256, 1)

In [30]:
# model.summary()