## IMPORTS

In [1]:
import tensorflow as tf
import os
from glob import glob
!pip install -q tensorflow-io
import tensorflow_io as tfio


from matplotlib import pyplot as plt

from IPython import display
import imageio

import numpy as np
import math

## HYPERPARAMETERS

## Read dataset from train and validation tfrecords

In [2]:
# paths for the tfrecords
# for train
train_positive =  "E:\\dataset\\Leisang\\myTry\\BleedingDataDCM/train_positive_samples_win.tfrecords"
train_negative =  'E:\\dataset\\Leisang\\myTry\\BleedingDataDCM/train_negative_samples_win.tfrecords'

# for val
val_positive =  "E:\\dataset\\Leisang\\myTry\\BleedingDataDCM/val_positive_samples_win.tfrecords"
val_negative =  'E:\\dataset\\Leisang\\myTry\\BleedingDataDCM/val_negative_samples_win.tfrecords'

# for ubuntu

# # for train
# train_positive =  "/media/ytx/Japan_Deep_Data/dataset/LeiSang/myTry/BleedingDataDCM/train_positive_samples_unbuntu.tfrecords"
# train_negative =  '/media/ytx/Japan_Deep_Data/dataset/LeiSang/myTry/BleedingDataDCM/train_negative_samples_unbuntu.tfrecords'

# # for val
# val_positive =  "/media/ytx/Japan_Deep_Data/dataset/LeiSang/myTry/BleedingDataDCM/val_positive_samples_unbuntu.tfrecords"
# val_negative =  '/media/ytx/Japan_Deep_Data/dataset/LeiSang/myTry/BleedingDataDCM/val_negative_samples_unbuntu.tfrecords'

train_positive_dataset = tf.data.TFRecordDataset(train_positive)
train_negative_dataset = tf.data.TFRecordDataset(train_negative)

val_positive_dataset = tf.data.TFRecordDataset(val_positive)
val_negative_dataset = tf.data.TFRecordDataset(val_negative)



## PREPROCESSING

In [3]:
IMG_HEIGHT = 256
IMG_WIDTH = 256

def resize(input, target):
    print(input.shape)
    print(target.shape)
    resized_input = tf.image.resize(input, [IMG_HEIGHT, IMG_WIDTH],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    resized_target = tf.image.resize(target, [IMG_HEIGHT, IMG_WIDTH],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    resized_input =  tf.reshape(resized_input, [IMG_HEIGHT, IMG_WIDTH, 1])
    resized_target =  tf.reshape(resized_target, [IMG_HEIGHT, IMG_WIDTH, 1])
    return  resized_input, resized_target

# @tf.function()
def truncate(x, min, max):
#     print(x.shape)
    cliped =  tf.clip_by_value(x, min, max)
    return cliped
 
def norm(x, min, max):
    # normalize_value = (value − min_value) / (max_value − min_value)
    tensor = tf.math.divide(tf.subtract(x, min),
                    tf.subtract(max, min))
    return tensor

def linear_normalization(input, min=30720.0, max=34816.0):
    truncated_input = truncate(input, min, max)
    norm_input = norm(truncated_input, min, max )
    return  norm_input

In [4]:
# parser the dataset to decode the features
# Create a dictionary describing the features.
image_feature_description = {
        'dicom_path': tf.io.FixedLenFeature([], tf.string),
        'seg_label': tf.io.FixedLenFeature([], tf.string), 
        'cls_label': tf.io.FixedLenFeature([], tf.int64),
}

@tf.function()
def _parse_image_function(example_proto):
    # extract features # Parse the input tf.Example proto using the dictionary above.
    parsed_features = tf.io.parse_single_example(example_proto, image_feature_description)
       
    # decode dicom
    dicom_path = parsed_features["dicom_path"]
    image_bytes = tf.io.read_file(dicom_path)
    input_image = tf.cast(tfio.image.decode_dicom_image(image_bytes, dtype=tf.uint16), tf.float32)
    # decode mask
    seg_label = tf.cast(tf.io.decode_raw(parsed_features['seg_label'], tf.uint8), tf.float32)
    # reshape
    seg_label = tf.reshape(seg_label, [-1, 512,512,1])
    
    
    # preprocessing--->
    # resize
    input_image, seg_label = resize(input_image, seg_label)
    print(input_image.shape)
    print(seg_label.shape)
    # normalize
    input_image = linear_normalization(input_image)
    seg_label =  seg_label/255.0
    
#     print("range of cliped_input: [{}, {}]".format(np.min(input_image), np.max(input_image)))
#     print("range of cliped_target: [{}, {}]".format(np.min(seg_label), np.max(seg_label)))
    
    
    return dicom_path, input_image, seg_label,  parsed_features["cls_label"]

# train: positive
train_nb_pos = 3035
train_nb_neg = 33588
parsed_train_positive_dataset = train_positive_dataset.map(_parse_image_function)
parsed_train_negative_dataset = train_negative_dataset.map(_parse_image_function)

val_nb_pos = 238
val_nb_neg = 4420
parsed_val_positive_dataset = val_positive_dataset.map(_parse_image_function)
parsed_val_negative_dataset = val_negative_dataset.map(_parse_image_function)

(None, None, None, None)
(None, 512, 512, 1)
(256, 256, 1)
(256, 256, 1)


In [None]:
# # check the dataset
# def winwise(input,LB,HB):
#         # 20 ,380 for range (-32768, 32767)
#         # for tf input , (0, 65535)-? LB =  32788, 33148
#         input[input<LB] = LB # low boundary , if < LW , set to LW
#         input[input>HB] = HB # high boundary, if > Hw, Set to 255
#         return input

for image_features in parsed_train_positive_dataset.take(2):
    dicom_path =  image_features[0]
    input =  image_features[1]
    #     dicom_path = image_features[0].numpy()
    target = image_features[2]
    cls_label = image_features[3]
    print("dicom_path:", dicom_path)
    print("input_image.shape", input.shape)
    print("seg_label", target.shape)
    print("cls_label", cls_label)
    # reshape label 
    target =  tf.reshape(target, input.shape)
    # # for lab ubuntu system
    # input, target, image_path= load('/media/ytx/Japan_Deep_Data/dataset/LeiSang/myTry/BleedingDataDCM/train/ZA-006_000/00000001.DCM')
    fig, axes = plt.subplots(1,2, figsize=(10,10))
    axes[0].imshow(np.squeeze(input.numpy()), cmap='gray')
    axes[0].set_title('input range:[{}, {}]'.format((input.numpy().min()), np.max(input.numpy())))
    axes[1].imshow(np.squeeze(target.numpy()), cmap='gray')
    axes[1].set_title('target range:[{}, {}]'.format(np.min(target), np.max(target)))
    print(dicom_path)

## DEFINE LOAD FUNCTION IN INPUT PIPELINE

## INPUT PIPELINE

In [None]:
# balance the dataset
train_dataset = tf.data.experimental.sample_from_datasets([parsed_train_positive_dataset, parsed_train_negative_dataset], weights=[0.5, 0.5])

train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE).prefetch(2)

test_dataset = parsed_val_positive_dataset.concatenate(parsed_val_negative_dataset)
test_dataset = test_dataset.shuffle(BUFFER_SIZE)  # for the random check if not comment it
test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(2)



# To use this dataset, you'll need the number of steps per epoch.

The definition of "epoch" in this case is less clear. Say it's the number of batches required to see each positive example once:

In [None]:
resampled_steps_per_epoch = np.ceil(2.0*train_nb_pos/BATCH_SIZE)

In [None]:
#  check the dataset
for Dicom_path, one_batch_input, one_batch_segLabel, one_batch_classLabel  in test_dataset.take(1):
    print("dicom_path:", Dicom_path[0])
    print("input_image.shape", one_batch_input.shape)
    print("seg_label", one_batch_segLabel.shape)
    print("cls_label", one_batch_classLabel)
    fig4, axes4 = plt.subplots(1,2, figsize=(10,10))
    axes4[0].imshow(np.squeeze(one_batch_input[0].numpy()), cmap='gray')
    axes4[0].set_title("cliped_norm_input")
    axes4[1].imshow(np.squeeze(one_batch_segLabel[0].numpy()), cmap='gray')
    axes4[1].set_title("cliped_norm_target")
    plt.show()

## DESIGN MODEL

In [None]:
OUTPUT_CHANNELS = 1
latent_dim =50

In [None]:
# model input conponents
en_inputs = tf.keras.layers.Input(shape=[IMG_HEIGHT,IMG_HEIGHT,1])
de_inputs = tf.keras.layers.Input(shape=[latent_dim])
# full model design and construct encoder, decoder , AE object
# entire model

# Define encoder part ---->
x = en_inputs
x = tf.keras.layers.Conv2D(
                filters=16, kernel_size=3, strides=(1, 1), activation='relu', padding="same")(x)
x = tf.keras.layers.Conv2D(
                filters=32, kernel_size=3, strides=(2, 2), activation='relu', padding="same")(x)
x = tf.keras.layers.Conv2D(
                filters=64, kernel_size=3, strides=(2, 2), activation='relu', padding="SAME")(x)
x = tf.keras.layers.Conv2D(
                filters=128, kernel_size=3, strides=(2, 2), activation='relu',padding="SAME")(x)
x = tf.keras.layers.Flatten()(x)
# No activation
latent_v = tf.keras.layers.Dense(latent_dim)(x)
encoder = tf.keras.Model(inputs=en_inputs, outputs=latent_v, name='encoder')


# Define decoder part ---->
x = tf.keras.layers.Dense(units=131072, activation=tf.nn.relu)(de_inputs)
x = tf.keras.layers.Reshape(target_shape=(32, 32, 128))(x)
x = tf.keras.layers.Conv2DTranspose(
      filters=64,
      kernel_size=3,
      strides=(2, 2),
      padding="SAME",
      activation='relu')(x)
x = tf.keras.layers.Conv2DTranspose(
      filters=32,
      kernel_size=3,
      strides=(2, 2),
      padding="SAME",
      activation='relu')(x)
x = tf.keras.layers.Conv2DTranspose(
      filters=16,
      kernel_size=3,
      strides=(2, 2),
      padding="SAME",
      activation='relu')(x)

decoded = tf.keras.layers.Conv2DTranspose(
      filters=1, kernel_size=3, strides=(1, 1), padding="SAME", use_bias=True)(x)
# output with sigmoid
decoded =  tf.sigmoid(decoded)
decoder = tf.keras.Model(inputs=de_inputs, outputs=decoded, name='decoder')


# Define AE model---->
outputs = decoder(latent_v)
autoencoder =  tf.keras.Model(inputs=en_inputs, outputs=outputs, name='AE')

# a sperately decoder is struggle leave for the moment

# # # create a placeholder for an encoded (32-dimensional) input
# encoded_input = tf.keras.layers.Input(shape=(latent_dim))
# # # retrieve the last layer of the autoencoder model
# decoder_layer = autoencoder.layers[-1]
# # # create the decoder model
# decoder = tf.keras.Model(encoded_input, decoder_layer)
             

In [None]:
tf.keras.utils.plot_model(encoder, to_file=encoder_path, show_shapes=True, dpi=64)

In [None]:
tf.keras.utils.plot_model(decoder, to_file=decoder_path, show_shapes=True, dpi=64)

In [None]:
tf.keras.utils.plot_model(autoencoder, to_file=autoencoder_path, show_shapes=True, dpi=64)

## OPTIMIZER AND OBJECTIVE LOSSES


In [None]:

BCE =  tf.keras.losses.BinaryCrossentropy() 
Huber =  tf.keras.losses.Huber(delta=0.1)
@tf.function()
# define losses
def compute_loss(decoded_x, x):
    
    # cross_entropy,  use reduce mean not sum, otherwise loss will be very big
    BCE_loss = BCE(y_true=x, y_pred=decoded_x) 
    
    
    # L1 loss 
#     L1_loss = tf.reduce_mean(tf.abs(x - decoded_x))
    
    # Huber_loss
    Huber_loss = Huber(y_true=x, y_pred=decoded_x)
    total_loss = Huber_loss + BCE_loss
    return total_loss, BCE_loss, Huber_loss
    
# appliy graidients  this is acutaully trianing step
# @tf.function()
# def compute_apply_gradients(model, x, optimizer, epoch):
#     decoded_x =  model(x)
#     with tf.GradientTape() as tape:
#         total_loss, CE_loss, L1_loss = compute_loss(decoded_x, x)
#     gradients = tape.gradient(total_loss, model.trainable_variables)
#     optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
#     with summary_writer.as_default():
#         # write scalars to the tensorboard after each train step
#         tf.summary.scalar('total_loss', total_loss, step=epoch)
#         tf.summary.scalar('CE_loss', CE_loss, step=epoch)
#         tf.summary.scalar('L1_loss', L1_loss, step=epoch) 
@tf.function()
def train_step(model, input, target, epoch):
    print("in trianing step")
    with tf.GradientTape() as tape:  # very interesting
        decoded_img = model(input, training=True)
        total_loss, BCE_loss, Huber_loss = compute_loss(decoded_img, target)
    gradients = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return total_loss, BCE_loss, Huber_loss

@tf.function()
def test_step(model, input, target, epoch):
    with tf.GradientTape() as tape:  # very interesting
        decoded_img = model(input, training=True)
        total_loss, BCE_loss, Huber_loss = compute_loss(decoded_img, target)
    return total_loss, BCE_loss, Huber_loss
#     train_avg_loss(train_loss)
#     train_avg_metric(metric)     

In [None]:
# # check the AE OUTPUT
# print("inp.shape:", cliped_norm_input.shape)
# print("inp[tf.newaxis,...]:", cliped_norm_input[tf.newaxis,...].shape)
# AE_output = autoencoder(resized_input[tf.newaxis,...], training=False)  # inp is the image sample from cell code 6 ; 
# print(AE_output.shape)
# plt.imshow(np.squeeze(AE_output[0,...]), cmap="gray")

## Training PREPARING

In [None]:

# define the epoch image check
def generate_images(model, test_input, tar, path, total_loss, batch_idx, epoch, save=True):
    prediction = model(test_input, training=True)

   

    display_list = [np.squeeze(test_input[0]), np.squeeze(tar[0]), np.squeeze(prediction[0])]
    title = ['Input range:[{},{}]'.format(test_input[0].numpy().min(), test_input[0].numpy().max()), 
           'GT range:[{},{}]'.format(tar[0].numpy().min(), tar[0].numpy().max()), 'Pred. [B:{}/E:{}]: loss->'.format(batch_idx,epoch)+ str(total_loss)]
    
    fig5 = plt.figure(figsize=(15,5))
    fig5.suptitle(str(path[0].numpy(), 'utf-8')) # decode convert byte b'' to normal ''
  
    for i in range(3):
        ax = fig5.add_subplot(1,3, i+1,title=title[i])
        # getting the pixel values between [0, 1] to plot it.
        ax.imshow(display_list[i], cmap="gray") # prediction in range[-1, 1]*0.5 = [-0.5, 0.5],+ 0.5=[0, 1]
        ax.axis('off')
     
    if save==True:
        fig5.savefig(save_figure_path + "/image_at_epoch_{:04d}_batch_{}.png".format(epoch,batch_idx))
    
    plt.figure(figsize=(8,8))
    plt.title('Pred. [B:{}/E:{}]: loss->'.format(batch_idx,epoch)+ str(total_loss))
    plt.imshow(np.squeeze(prediction[0]), cmap="gray") # prediction in range[-1, 1]*0.5 = [-0.5, 0.5],+ 0.5=[0, 1]
    plt.axis('off')
    if save ==True:
        plt.savefig(save_figure_path + "/Predictions/pred_only_at_epoch_{:04d}_batch_{}.png".format(epoch,batch_idx))
    plt.show()


# define the  encoder latent space check
def en_generate_images(model, test_input, tar, path, batch_idx, epoch, save=True):
    prediction = model(test_input, training=True)

   

    display_list = [np.squeeze(test_input[0]), np.squeeze(tar[0]), np.squeeze(prediction[0])]
    title = ['Input range:[{},{}]'.format(test_input[0].numpy().min(), test_input[0].numpy().max()), 
           'GT range:[{},{}]'.format(tar[0].numpy().min(), tar[0].numpy().max()), 'Pred. Latent Space {}'.format(prediction[0].shape)]
    
    fig5 = plt.figure(figsize=(15,5))
    fig5.suptitle(str(path[0].numpy(), 'utf-8')) # decode convert byte b'' to normal ''
  
    for i in range(3):
        ax = fig5.add_subplot(1,3, i+1,title=title[i])
        # getting the pixel values between [0, 1] to plot it.
        ax.imshow(display_list[i], cmap="gray") # prediction in range[-1, 1]*0.5 = [-0.5, 0.5],+ 0.5=[0, 1]
        ax.axis('off')
     
    if save==True:
        fig5.savefig(save_figure_path + "/image_at_epoch_{:04d}_batch_{}.png".format(epoch,batch_idx))
    
    plt.figure(figsize=(8,8))
    plt.title('Pred. Latent Space {}'.format(prediction[0].shape))
#     plt.imshow(np.squeeze(prediction[0]), cmap="gray") # prediction in range[-1, 1]*0.5 = [-0.5, 0.5],+ 0.5=[0, 1]
    plt.axis('off')
    if save ==True:
        plt.savefig(save_figure_path + "/Predictions/pred_only_at_epoch_{:04d}_batch_{}.png".format(epoch,batch_idx))
    plt.show()


def en_de_generate_images(encoder,decoder, test_input, tar, path, batch_idx, epoch, save=True):
#     prediction = model(test_input, training=True)
    latent_v = encoder(test_input, training=False)
    prediction = decoder(latent_v)

    display_list = [np.squeeze(test_input[0]), np.squeeze(tar[0]), np.squeeze(prediction[0])]
    title = ['Input range:[{},{}]'.format(test_input[0].numpy().min(), test_input[0].numpy().max()), 
           'GT range:[{},{}]'.format(tar[0].numpy().min(), tar[0].numpy().max()), 'en-to-de Pred. out {}'.format(prediction.shape)]
    
    fig5 = plt.figure(figsize=(15,5))
    fig5.suptitle(str(path[0].numpy(), 'utf-8')) # decode convert byte b'' to normal ''
  
    for i in range(3):
        ax = fig5.add_subplot(1,3, i+1,title=title[i])
        # getting the pixel values between [0, 1] to plot it.
        ax.imshow(display_list[i], cmap="gray") # prediction in range[-1, 1]*0.5 = [-0.5, 0.5],+ 0.5=[0, 1]
        ax.axis('off')
     
    if save==True:
        fig5.savefig(save_figure_path + "/image_at_epoch_{:04d}_batch_{}.png".format(epoch,batch_idx))
    
    plt.figure(figsize=(8,8))
    plt.title('en-to-de Pred. out {}'.format(prediction.shape))
    plt.imshow(np.squeeze(prediction[0]), cmap="gray") # prediction in range[-1, 1]*0.5 = [-0.5, 0.5],+ 0.5=[0, 1]
    plt.axis('off')
    if save ==True:
        plt.savefig(save_figure_path + "/Predictions/pred_only_at_epoch_{:04d}_batch_{}.png".format(epoch,batch_idx))
    plt.show()

## TRAINING

In [None]:
import datetime

epochs = 100
# define opitmizer 
optimizer =  tf.keras.optimizers.Adam(1e-4)
## get current working directory
cwd = os.getcwd()
print("current working directory:", cwd)




full_AE_saves =  os.path.join(cwd, save_figure_path)
print(full_AE_saves)
if not os.path.exists(full_AE_saves):
    os.makedirs(full_AE_saves)
predictions_save_path =os.path.join(full_AE_saves, "Predictions")

if not os.path.exists(predictions_save_path):
    os.makedirs(predictions_save_path)

# define check points


# pre_saved_ckpt_path = os.path.join(checkpoint_dir, "ckpt")
# change to pre-saved model path: 
# 8 
pre_saved_ckpt_path = "E:\\Projects\\logs\dicoms\\fixedArranged\\fixed\\8Bleed_AE_Huber_BCE_Sigmoid_NormNarrow_ShuffleTr_RandWinTrTe_SaveBest\\training_checkpoints\\ckpt"
# for ubuntu
# pre_saved_ckpt_path = "/media/ytx/Japan_Deep_Data/DicomProject2020/logs/NewFramework/fixed/8Bleed_AE_Huber_BCE_Sigmoid_NormNarrow_ShuffleTr_RandWinTrTe_SaveBest/training_checkpoints/ckpt"
# the path for the new training
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

if not os.path.exists(pre_saved_ckpt_path):
    print("please check the pre-trained saved model path")
#  contents of states to be saved as attributes on the checkpoint object
checkpoint_ob = tf.train.Checkpoint(step= tf.Variable(1),
                                    epoch=  tf.Variable(1),
                                    optimizer=optimizer,
                                 encoder=encoder,
                                 decoder =  decoder,
                                 autoencoder = autoencoder
                                 )
# define restore checkpoint manager
restore_manager =  tf.train.CheckpointManager(checkpoint_ob, pre_saved_ckpt_path, max_to_keep=1)

# # define checkpoint manager for new training
# manager =  tf.train.CheckpointManager(checkpoint_ob, checkpoint_prefix, max_to_keep=1)

datetime_rec =  datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

train_summary_writer = tf.summary.create_file_writer(
  log_dir + "train")
val_summary_writer = tf.summary.create_file_writer(
  log_dir + "val")

In [None]:
# start tensorboard
# !kill 5032
%load_ext tensorboard
%tensorboard --logdir {log_dir}

In [None]:
# load pre-trained model
checkpoint_ob.restore(restore_manager.latest_checkpoint)

step =  checkpoint_ob.step
epoch =  checkpoint_ob.epoch
print(int(step))
print(int(epoch))

AE=  checkpoint_ob.autoencoder
en = checkpoint_ob.encoder
de =  checkpoint_ob.decoder
print(AE)



## use the pre-saved encoder as the base model and design the new model
base_model  = en
print(base_model.trainable)

In [None]:
base_model =  en # use encoder as the base model

# check the loaded ae output

for TEST_dicom_path, TEST_input, TEST_target, TEST_cls_label in test_dataset.take(1):  # take one data from the test set
    print("dicom_path:", TEST_dicom_path[0])
    print("input_image.shape", TEST_input.shape)
    print("seg_label", TEST_target.shape)
    print("cls_label", TEST_cls_label)
    test_total_loss, test_BCE_loss, test_Huber_loss = test_step(AE, TEST_input, TEST_target, epoch) 
    generate_images(AE, TEST_input, TEST_target,TEST_dicom_path,test_total_loss.numpy(), int(step), int(epoch), save=False)
    
    # check the loaded encoder and decoder
    en_de_generate_images(base_model, de,TEST_input, TEST_target,TEST_dicom_path, int(step), int(epoch), save=False)

# check the base base model traininable
print(base_model.trainable)

# 
trainable_or_not =  True

if trainable_or_not == True:
    base_model.trainable =  True # set encoder untranable
print("after reset basemodel trainable property:", base_model.trainable)

In [None]:
# check the trainable model architecture
base_model.summary()
tf.keras.utils.plot_model(base_model, to_file=base_model_path, show_shapes=True, dpi=128)

In [None]:
# add classification head for bleed exist or not
print("# of layers in the base_model:", len(base_model.layers))
class_head_nodes = tf.keras.layers.Dense(2)(base_model.layers[-1].output)
class_softmax_output =  tf.nn.softmax(class_head_nodes)
print(base_model.output)
print(base_model.layers[-1].output)
Classifier =  tf.keras.Model(inputs=en_inputs, outputs=class_softmax_output, name='classifier_head')
# check the architecture
Classifier.summary()
tf.keras.utils.plot_model(Classifier, to_file=classifier_path, show_shapes=True, dpi=128)

In [None]:
# set THE SETMENTATION PART
# 1st take out the feature from the base_model
base_feature =  base_model.layers[-3].output
print(base_feature)


# define down_sample block as the pix-2-pix from tensorflow official set
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

    result.add(tf.keras.layers.BatchNormalization())

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    result.add(tf.keras.layers.ReLU())

    return result


# def skips needed to be connected
skips = []

def en_for_seg():
# new another path for the segmentation of the new model
    x = downsample(64, 4, apply_batchnorm=False)(en_inputs) # output  (None, 128, 128, 64)
    skips.append(x)
    x = downsample(128, 4)(x)  # output (None, 64, 64, 128)
    skips.append(x)
    x = downsample(256, 4)(x)  # output (None, 32, 32, 256)
    skips.append(x)
    # if want seperate train the generator, return the Model object e.g. return tf.keras.Model(inputs=inputs, outputs=x)
    return tf.keras.Model(inputs=en_inputs, outputs=x)

en_coder_for_seg_out =  en_for_seg()
print(en_coder_for_seg_out.output)
tf.keras.utils.plot_model(en_coder_for_seg_out, to_file="encodre_for_seg.png", show_shapes=True, dpi=128)





# concatenate the feature from base and feature from the new enocoder
concate_features =  tf.keras.layers.Concatenate()([base_feature, en_coder_for_seg_out.output])
print(concate_features)

# start to one more step to encoder
further_en = downsample(512, 4)(concate_features)  # output (None, 16, 16, 512)
print("one more step encoding:",further_en)

# global pooling for that one more step encoding fetures
GA  = tf.keras.layers.GlobalAveragePooling2D()(further_en)  # squeeze
print("GA:", GA)
# RESHAPE GA  TO THE IMAGE FORMMAT
RESHAPE_GA =  tf.reshape(GA, [-1,1,1, GA.shape[-1]])
print("RESHAPE_GA:", RESHAPE_GA)
#resize to the feature maps size
resize_GA =  tf.image.resize(RESHAPE_GA, (further_en.shape[1], further_en.shape[2]))
print("resized GA:", resize_GA)
# concate future_en and reisze_GA
bottom_concate =  tf.keras.layers.Concatenate()([resize_GA, further_en])
print("bottom_concate:", bottom_concate)

print()
print("start to upsample--------->")
# start to upsample and decoder from resize_GA and further_en
up_feature1 =  upsample(512, 4) (bottom_concate) #  upsample/relarge ti [None, 32, 32, 1024] without dropout default false for drop
print("up_feature1:", up_feature1)
up_concate1 =  tf.keras.layers.Concatenate()([up_feature1, skips[-1]])
print("up_concate1:", up_concate1)
up_feature2 =  upsample(256, 4) (up_concate1) #  upsample/relarge ti [None, 64, 64, 256] without dropout default false for drop
print("up_feature2:", up_feature2)
up_concate2 =  tf.keras.layers.Concatenate()([up_feature2, skips[-2]])
print("up_concate2:", up_concate2)
up_feature3 =  upsample(128, 4) (up_concate2) #  upsample/relarge ti [None, 128, 128, 128] without dropout default false for drop
print("up_feature3:", up_feature2)
up_concate3 =  tf.keras.layers.Concatenate()([up_feature3, skips[-3]])
print("up_concate3:", up_concate3)

# last layer output
initializer = tf.random_normal_initializer(0., 0.02)
last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2, # stride 2
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='sigmoid') # (bs, 256, 256, 3)
seg_out  =  last(up_concate3)

Final_seg_model = tf.keras.Model(inputs=en_inputs, outputs=[seg_out, class_softmax_output])
tf.keras.utils.plot_model(Final_seg_model, to_file="new_model.png", show_shapes=True, dpi=128)

In [None]:
# check new model
# check the AE OUTPUT
# print("inp.shape:", cliped_norm_input.shape)
# print("inp[tf.newaxis,...]:", cliped_norm_input[tf.newaxis,...].shape)
# AE_output = autoencoder(resized_input[tf.newaxis,...], training=False)  # inp is the image sample from cell code 6 ; 
# print(AE_output.shape)
# plt.imshow(np.squeeze(AE_output[0,...]), cmap="gray")

for TEST_dicom_path, TEST_input, TEST_target, TEST_cls_label in test_dataset.take(1):  # take one data from the test set
    print("inp[tf.newaxis,...]:", TEST_input.shape)
    print("dicom_path:", TEST_dicom_path[0])
    print("input_image.shape", TEST_input.shape)
    print("seg_label", TEST_target.shape)
    print("cls_label", TEST_cls_label)
    
    new_model_output =  Final_seg_model(TEST_input)
    # outputs=[seg_out, class_sigmoid_output]
    output1=   new_model_output[0]
    output2=   new_model_output[1]
    print("output1 shape:", output1.shape)  
    print("output2 shape:", output2.shape)  
#     test_total_loss, test_BCE_loss, test_Huber_loss = test_step(AE, TEST_input, TEST_target, epoch) 
#     generate_images(AE, TEST_input, TEST_target,TEST_dicom_path,test_total_loss.numpy(), int(step), int(epoch), save=False)
    plt.imshow(np.squeeze(output1[0,...]), cmap="gray")
#     # check the loaded encoder and decoder
#     en_de_generate_images(en, de,TEST_input, TEST_target,TEST_dicom_path, int(step), int(epoch), save=False)

In [None]:
# design new loss for the new model as new_train_step, new_test_step
new_optimizer =  tf.keras.optimizers.Adam(1e-4)
# from tf.keras.utils import to_categorical
BCE =  tf.keras.losses.BinaryCrossentropy() 
Huber =  tf.keras.losses.Huber(delta=0.1)
@tf.function()
# define losses
def new_compute_loss(pred_seg, gt_seg, pred_cls, gt_cls):
    
    #seg  cross_entropy,  use reduce mean not sum, otherwise loss will be very big
    seg_BCE_loss = BCE(y_true=gt_seg, y_pred=pred_seg) 
    
    
    # L1 loss 
#     L1_loss = tf.reduce_mean(tf.abs(x - decoded_x))
    
    #seg Huber_loss
    seg_Huber_loss = Huber(y_true=gt_seg, y_pred=pred_seg)
    
    seg_total_loss = seg_Huber_loss + seg_BCE_loss
    
    
    # classification BCE_loss
    cls_BCE_loss=  BCE(y_true=gt_cls, y_pred=pred_cls)
    
    # total_loss = cls_BCE_loss + seg_loss
    total_loss = cls_BCE_loss + seg_total_loss
    return total_loss, cls_BCE_loss, seg_total_loss
    
# appliy graidients  this is acutaully trianing step
# @tf.function()
# def compute_apply_gradients(model, x, optimizer, epoch):
#     decoded_x =  model(x)
#     with tf.GradientTape() as tape:
#         total_loss, CE_loss, L1_loss = compute_loss(decoded_x, x)
#     gradients = tape.gradient(total_loss, model.trainable_variables)
#     optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
#     with summary_writer.as_default():
#         # write scalars to the tensorboard after each train step
#         tf.summary.scalar('total_loss', total_loss, step=epoch)
#         tf.summary.scalar('CE_loss', CE_loss, step=epoch)
#         tf.summary.scalar('L1_loss', L1_loss, step=epoch) 
def mask_to_categorical(cls_tar, num_cls=2):
    tar = tf.one_hot(tf.cast(cls_tar, tf.int32), num_cls)
    tar = tf.cast(tar, tf.float32)
    return tar

@tf.function()
def new_train_step(model, input, seg_target, cls_target,epoch):
    # change cls labe 0 or 1 into  [1, 0] [0, 1]
    print(cls_target)
    cls_target = mask_to_categorical(cls_target)
#     print("in trianing step")
    with tf.GradientTape() as tape:  # very interesting
        # outputs=[seg_out, class_sigmoid_output] 
        seg_out, class_softmax_output= model(input, training=True)
        total_loss, cls_BCE_loss, seg_total_loss = new_compute_loss(pred_seg=seg_out,
                                                            gt_seg=seg_target, 
                                                            pred_cls=class_softmax_output, 
                                                            gt_cls=cls_target)
    gradients = tape.gradient(total_loss, model.trainable_variables)
    new_optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return total_loss, cls_BCE_loss, seg_total_loss

@tf.function()
def new_test_step(model, input, seg_target, cls_target, epoch):
    with tf.GradientTape() as tape:  # very interesting
        seg_out, class_softmax_output= model(input, training=False)
        total_loss, cls_BCE_loss, seg_total_loss = new_compute_loss(pred_seg=seg_out,
                                                            gt_seg=seg_target, 
                                                            pred_cls=class_softmax_output, 
                                                            gt_cls=cls_target)
    return total_loss, cls_BCE_loss, seg_total_loss

In [None]:
# define the epoch image check
def new_generate_images(model, test_input, tar_seg, tar_cls, path, total_loss, batch_idx, epoch, save=True):
    seg_out, class_sigmoid_output  = model(test_input, training=True)
    
    one_hot_cls_tar =  tf.keras.utils.to_categorical(tar_cls)

    display_list = [np.squeeze(test_input[0]), np.squeeze(tar_seg[0]), np.squeeze(seg_out[0])]
    title = ['Input range:[{},{}]'.format(test_input[0].numpy().min(), test_input[0].numpy().max()), 
           'GT range:[{},{}]'.format(tar_seg[0].numpy().min(), tar_seg[0].numpy().max()), '[B:{}/E:{}]: tar_cls:{}({})->pred_cls:{}'.format(batch_idx,epoch, 
                                                                                                                                    tar_cls[0],
                                                                                                                                    one_hot_cls_tar[0], 
                                                                                                                                    class_sigmoid_output[0])]
    
    fig5 = plt.figure(figsize=(15,5))
    fig5.suptitle(str(path[0].numpy(), 'utf-8')) # decode convert byte b'' to normal ''
  
    for i in range(3):
        ax = fig5.add_subplot(1,3, i+1,title=title[i])
        # getting the pixel values between [0, 1] to plot it.
        ax.imshow(display_list[i], cmap="gray") # prediction in range[-1, 1]*0.5 = [-0.5, 0.5],+ 0.5=[0, 1]
        ax.axis('off')
     
    if save==True:
        fig5.savefig(save_figure_path + "/image_at_epoch_{:04d}_batch_{}.png".format(epoch,batch_idx))
    
    plt.figure(figsize=(8,8))
    plt.title('Pred. [B:{}/E:{}]: loss->'.format(batch_idx,epoch)+ str(total_loss))
    plt.imshow(np.squeeze(seg_out[0]), cmap="gray") # prediction in range[-1, 1]*0.5 = [-0.5, 0.5],+ 0.5=[0, 1]
    plt.axis('off')
    if save ==True:
        plt.savefig(save_figure_path + "/Predictions/pred_only_at_epoch_{:04d}_batch_{}.png".format(epoch,batch_idx))
    plt.show()

In [None]:
import time


# build_new check point manager
new_ckpt_prefix = os.path.join(checkpoint_dir, "ckpt")
if not os.path.exists(checkpoint_prefix):
    os.makedirs(checkpoint_prefix)
#  contents of states to be saved as attributes on the checkpoint object
new_ckpt_ob = tf.train.Checkpoint(step= tf.Variable(1),
                                    epoch=  tf.Variable(1),
                                    optimizer=new_optimizer,
                                     new_model =  Final_seg_model
                                 )
# define checkpoint manager
new_manager =  tf.train.CheckpointManager(new_ckpt_ob, checkpoint_prefix, max_to_keep=1)

# check the whether there is a checkpoint in the checkpoint folder, if it is restore from it
if new_manager.latest_checkpoint:
    print("Restored from {}".format(new_manager.latest_checkpoint))
    new_ckpt_ob.restore(new_manager.latest_checkpoint)
else:
    print("Initializing from scratch.")

# reset checkcpoint.step for each epoch
step =  new_ckpt_ob.step
ckpt_epoch =  new_ckpt_ob.epoch  

# test average_loss = for saving the test results
test_avg_tmp_loss = 100

for epoch in range(int(ckpt_epoch), epochs+1):
    train_total_loss_mean = tf.keras.metrics.Mean()
    train_seg_total_loss_mean = tf.keras.metrics.Mean()
    train_cls_BCE_loss_mean = tf.keras.metrics.Mean()
   
    
    # initial saving of checkpoints
#     if epoch == 1:
#         save_path =  manager.save()
#         print("Initial saving checkpoints at epoch {} batch {} at path {}".format(int(ckpt_epoch),  int(step) ,save_path))
       
        
    
    start_time =  time.time()
    for train_dicom_path, train_input, train_seg_target, train_cls_label  in train_dataset:
#         print(train_input.shape)
#         print(target.shape)
#         print(autoencoder)
        # one training step
        total_loss, cls_BCE_loss, seg_total_loss = new_train_step(Final_seg_model, train_input, train_seg_target, train_cls_label, epoch)
        end_time =  time.time()
        
        # calculate mean value for trainning losses
        train_total_loss_mean(total_loss)
        train_seg_total_loss_mean(seg_total_loss)
        train_cls_BCE_loss_mean(cls_BCE_loss)
        
#         #
        if int(step) == 1:
            for test_dicom_path, test_input, test_seg_target, test_cls_label in test_dataset.take(1):
                new_generate_images(Final_seg_model, test_input, test_seg_target, test_cls_label, test_dicom_path,total_loss.numpy(), int(step), epoch)
        
        # every 100 batch step show the generate test image to check.
        if int(step) % 50 == 0:
            display.clear_output(wait=True)   
            for test_dicom_path, test_input, test_seg_target, test_cls_label in test_dataset.take(1):
                new_generate_images(Final_seg_model, test_input, test_seg_target, test_cls_label, test_dicom_path,total_loss.numpy(), int(step), epoch)
                # every 400 steps save the genrate images
          
            print('Epoch{}: {}/{}: total_loss: {}; cls_BCE_loss: {}; seg_total_loss: {} '
          'time elapse {}'.format(int(ckpt_epoch), int(step), resampled_steps_per_epoch, total_loss, cls_BCE_loss, seg_total_loss, end_time-start_time))
                       
             
        
          
        if step > resampled_steps_per_epoch:
            break
        else:
            step.assign_add(1)
        
    # print validation step
    if ckpt_epoch % 1 == 0:
        test_total_loss_mean = tf.keras.metrics.Mean()
        test_seg_total_loss_mean = tf.keras.metrics.Mean()
        test_cls_BCE_loss_mean = tf.keras.metrics.Mean()
        for  test_dicom_path, test_input, test_seg_target, test_cls_label in test_dataset:
            test_total_loss, test_cls_BCE_loss, test_seg_total_loss = new_test_step(Final_seg_model, test_input, test_seg_target, test_cls_label, epoch)
            test_total_loss_mean(test_total_loss)
            test_seg_total_loss_mean(test_seg_total_loss)
            test_cls_BCE_loss_mean(test_cls_BCE_loss)
                                                             
      
        
        print('Epoch: {}, Test total_loss set loss: {},  time elapse for current epoch {}'.format(int(ckpt_epoch),
                                                    test_total_loss_mean.result(),
                                                    end_time - start_time))
        
        with train_summary_writer.as_default():
            print("writing train logs to tensorboard...")                                                       
            # write scalars to the tensorboard after each train step
            tf.summary.scalar('total_loss', train_total_loss_mean.result(), step=epoch)
            tf.summary.scalar('seg_total_loss', train_seg_total_loss_mean.result(), step=epoch)
            tf.summary.scalar('cls_BCE_los', train_cls_BCE_loss_mean.result(), step=epoch)
            
        with val_summary_writer.as_default():
            print("writing val logs to tensorboard...")                                                        
            # write scalars to the tensorboard after each train step
            tf.summary.scalar('total_loss', test_total_loss_mean.result(), step=epoch)
            tf.summary.scalar('seg_total_loss', test_seg_total_loss_mean.result(), step=epoch)
            tf.summary.scalar('cls_BCE_los', test_cls_BCE_loss_mean.result(), step=epoch)
            
        if test_total_loss_mean.result() < test_avg_tmp_loss:
            save_path =  new_manager.save() # save the checkpoint and return the save path
            print("Saved checkpoint for epoch {}-  step {}: {}".format(int(ckpt_epoch), int(step), save_path))
            test_avg_tmp_loss =  test_total_loss_mean.result()
            
        
    ckpt_epoch.assign_add(1)     
    step.assign(1)

#     # saving (checkpoint) the model every 5 epochs
#     if epoch % 5 == 0:
#         print("saving checkpoints at epoch {}".format(epoch))
#         checkpoint.save(file_prefix = checkpoint_prefix)

print("training finished")   
