##### Silence tensorflow warnings

In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

##### Import necessary libraries

In [2]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, img_to_array, array_to_img
from tensorflow.keras import layers
import tensorflow_probability as tfp
import numpy as np
import random

##### Check that tf is running on GPU

In [3]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

##### Define constants/parameters

In [4]:
IMG_SIZE = 512, 512
NUM_CLASSES = 10
BATCH_SIZE = 10
FILTERS = [32, 64, 128, 256]
path_to_images = '<path to images>'
# CLASS_WEIGHTS = [0.018132846245274156, 0.9198221511503452, 0.06204500260438064]

##### Create the custom loss function class

In [5]:
class RFCM_loss():

    def __init__(self, fuzzy_factor=2):
        '''
        Unsupervised Robust Fuzzy C-mean loss function for ConvNet based image segmentation
        Junyu Chen, et al. Learning Fuzzy Clustering for SPECT/CT Segmentation
        via Convolutional Neural Networks. Medical physics, 2021 (In press).
        :param fuzzy_factor: exponent for controlling fuzzy overlap, default value = 2
        :param regularizer_wt: weighting parameter for regularization, default value = 0
        Note that ground truth segmentation is NOT needed in this loss fuction, instead, the input image is required.
        :param y_pred: prediction from ConvNet, assuming that SoftMax has been applied.
        :param image: input image to the ConvNet.
        '''
        self.fuzzy_factor = fuzzy_factor

    def rfcm_loss_func(self, image, y_pred):
        img_channels = 3
        # num_clus = y_pred.get_shape().as_list()[-1]  # num_clus equals the classes of the problem
        num_clus = y_pred.shape[-1]  # num_clus equals the classes of the problem

        # flatten y_true(img) and prediction (y_pred)
        img = tf.reshape(image, (-1, tf.reduce_prod(tf.shape(image)[1:-1]), img_channels))
        # y_pred/seg represents the initial random probability assignments
        seg = tf.reshape(y_pred, (-1, tf.reduce_prod(tf.shape(y_pred)[1:-1]), num_clus))
        J_1 = 0
        for i in range(num_clus):
            J_2 = 0
            # mem --> membership value: sum_square of probs assigned 
            mem = tf.pow(seg[..., i], self.fuzzy_factor)
            for j in range(3):
                img_channel = img[..., j]
                # calculate centroid
                centroid = tf.reduce_sum(tf.multiply(img_channel, mem))/tf.reduce_sum(mem)
                # calculate distances from centroid
                J_2 += tf.multiply(mem, tf.square(img_channel - centroid))
            J_1 += J_2 / num_clus
        return tf.reduce_mean(J_1)
        

    def loss(self, I, J):
        return self.rfcm_loss_func(I, J)

In [6]:
input_img_paths = sorted(
    [
        os.path.join(path_to_images, fname)
        for fname in os.listdir(path_to_images)
        if fname.endswith((".JPG", ".png"))
    ]
)

##### Create helper class to create dataset/batches from input data

In [7]:
class CustomDataset(tf.keras.utils.Sequence):

    """Helper to create datasets"""

    def __init__(self, batch_size, img_size, input_img_paths):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        

    def __len__(self):
        return len(self.input_img_paths) // self.batch_size

    def __getitem__(self, idx):
        i = idx * self.batch_size
        batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
        x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
        for j, path in enumerate(batch_input_img_paths):
            img = load_img(path, target_size=self.img_size)
            x[j] = img
            x[j] /= 255
        return x, x

In [8]:
train_gen = CustomDataset(
    BATCH_SIZE,
    IMG_SIZE,
    input_img_paths
)

##### The encoder part of the model

In [9]:
def create_encoder(inputs, filters_list, size=3, apply_dropout=False):

  encoder_outputs = []
  init_inputs = inputs  
  # initializer = tf.random_normal_initializer(0., 0.02)

  for filters in filters_list:

    x = layers.SeparableConv2D(filters,
                               size,
                               padding='same',
                               # kernel_initializer=initializer,
                               use_bias=False)(inputs)
    if apply_dropout:
      x = layers.Dropout(0.5)(x)
    x = layers.ReLU()(x)
    x = layers.BatchNormalization()(x)
    
    x = layers.SeparableConv2D(filters,
                               size,
                               padding='same',
                               # kernel_initializer=initializer,
                               use_bias=False)(x)                           
    x= layers.MaxPooling2D(3, strides=2, padding="same")(x)
    if apply_dropout:
      x = layers.Dropout(0.5)(x)
    x = layers.ReLU(name=f'activation_{filters}')(x)
    x = layers.BatchNormalization()(x)  

    encoder_outputs.append(x)
    inputs = x

  return tf.keras.Model(inputs=init_inputs, outputs=encoder_outputs)

##### The decoder part of the model

In [10]:
def get_decode_layer(filters, size=3, apply_dropout=False):
  
  # initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(layers.Conv2DTranspose(filters, size,
                                              padding='same',
                                              # kernel_initializer=initializer,
                                              use_bias=False))
  result.add(layers.Conv2DTranspose(filters, size,
                                            padding='same',
                                            # kernel_initializer=initializer,
                                            use_bias=False))

  result.add(layers.UpSampling2D(2))
  result.add(layers.ReLU())
  result.add(layers.BatchNormalization())                                         
  if apply_dropout:
    result.add(layers.Dropout(0.5))  

  return result

##### Create model function

In [11]:
def get_unet(input_shape, filters_list, output_channels:int, activation:str):

    inputs = layers.Input(shape=input_shape + (3,))

    layer_names = [f'activation_{filters}' for filters in filters_list]

    # encoder part of the model
    encoder = create_encoder(inputs, filters_list)
    # get encoder_outputs
    skips = [encoder.get_layer(name).output for name in layer_names]
    # decoder part of the model
    decoding_layers = [get_decode_layer(filters) for filters in filters_list[::-1]]

    x = skips[-1]
    # reverse order!
    skips = skips[-2::-1]

    # Upsampling and establishing the skip connections
    for up_layer, skip in zip(decoding_layers[:-1], skips):
        x = up_layer(x)
        concat = layers.Concatenate()
        x = concat([x, skip])

    x = decoding_layers[-1](x)
    # last layer of the model
    last_layer = layers.Conv2D(filters=output_channels, kernel_size=1,
                                            padding='same',
                                            activation=activation
                                        ) 

    x = last_layer(x)

    return tf.keras.Model(inputs=inputs, outputs=x)


In [12]:
unet_model = get_unet(
    IMG_SIZE,
    FILTERS,
    NUM_CLASSES,
    activation='softmax'
)

In [13]:
unet_model.compile(
                optimizer='adam',
                # loss = NCutLoss2D().ncutloss_func,
                loss=RFCM_loss().rfcm_loss_func,
                # metrics=[tf.keras.metrics.SparseCategoricalAccuracy(), MyMeanIOU(NUM_CLASSES) ]
        )

In [14]:
callback_02 = tf.keras.callbacks.ModelCheckpoint(
    './sem_seg_model_backbone_4cl',
    monitor="loss",
    mode = 'min',
    save_best_only=True,
)

In [1]:
epochs = 100
unet_model.fit(
    train_gen,
    epochs=epochs,
    callbacks=[callback_02]
)

##### Helper function to visualize results

In [16]:
def get_superimposed_img(path, num_clusters, model):
    
    random_image = load_img(path, target_size=IMG_SIZE)
    rand_arr = img_to_array(random_image)
    rand_arr /= 255
    rand_arr = tf.expand_dims(rand_arr, 0)
    preds = model.predict(rand_arr)
    mask = tf.argmax(preds, axis=-1)
    mask = tf.squeeze(mask)
    # print(np.unique(mask, return_counts=True))
    
    segm_img = np.zeros(shape=IMG_SIZE+(3,))
    
    for clus in range(num_clusters):
        for i in range(3):
            mask_r = np.where(mask == clus, np.random.choice(255), 0)
            mask_g = np.where(mask == clus, np.random.choice(255), 0)
            mask_b = np.where(mask == clus, np.random.choice(255), 0)
            stacked_mask = np.stack((mask_r, mask_g, mask_b), axis=2)
        segm_img +=stacked_mask    
    # superimposed_img = random_image + segm_img

    return array_to_img(segm_img)
    # return tf.keras.preprocessing.image.array_to_img(superimposed_img)

In [2]:
get_superimposed_img(random.choice(input_img_paths), NUM_CLASSES, tf.keras.models.load_model('<path to your model>', custom_objects={'rfcm_loss_func': RFCM_loss.rfcm_loss_func}))