# Import Necessary Library

In [42]:
from tensorflow import keras
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import ModelCheckpoint
from keras.utils import plot_model
# from keras.utils.vis_utils import plot_model

# Model Parameters

In [28]:
input_shape = 128, 128, 3
batch_sizes = 8
epoch = 100

# Function for Stain Normalization

In [None]:
def Stain_Normalization(img, saveFile=None, Io=255, alpha=1, beta=0.15):

    HERef = np.array([[0.587, 0.136],
                     [0.754, 0.833],
                     [0.294, 0.536]])


    maxCRef = np.array([1.9705, 1.0308])

    # define height and width of image
    
    h, w, c = img.shape

    # reshape image
    img = img.reshape((-1, 3))
    
    # calculate optical density
    OD = -np.log((img.astype(np.float) + 1) / Io)

    # rows correspond to channels (RGB), columns to OD values
    Y = np.reshape(OD, (-1, 3))
    
    # remove transparent pixels
    ODhat = OD[~np.any(OD < beta, axis=1)]

    # compute eigenvectors
    eigvals, eigvecs = np.linalg.eigh(np.cov(ODhat.T))

    # eigvecs *= -1

    # project on the plane spanned by the eigenvectors corresponding to the two
    # largest eigenvalues
    That = ODhat.dot(eigvecs[:, 1:3])

    phi = np.arctan2(That[:, 1], That[:, 0])

    minPhi = np.percentile(phi, alpha)
    maxPhi = np.percentile(phi, 100 - alpha)

    vMin = eigvecs[:, 1:3].dot(np.array([(np.cos(minPhi), np.sin(minPhi))]).T)
    vMax = eigvecs[:, 1:3].dot(np.array([(np.cos(maxPhi), np.sin(maxPhi))]).T)

    # a heuristic to make the vector corresponding to hematoxylin first and the
    # one corresponding to eosin second
    if vMin[0] > vMax[0]:
        HE = np.array((vMin[:, 0], vMax[:, 0])).T
    else:
        HE = np.array((vMax[:, 0], vMin[:, 0])).T

    # determine concentrations of the individual stains
    C = np.linalg.lstsq(HE, Y.T, rcond=None)[0]

    # normalize stain concentrations
    maxC = np.array([np.percentile(C[0, :], 99), np.percentile(C[1, :], 99)])
    tmp = np.divide(maxC, maxCRef)
    C2 = np.divide(C, tmp[:, np.newaxis])

    # recreate the image using reference mixing matrix
    Inorm = np.multiply(Io, np.exp(-HERef.dot(C2)))
    Inorm[Inorm > 255] = 254
    Inorm = np.reshape(Inorm.T, (h, w, 3)).astype(np.uint8)

    return Inorm

# Function for DCSAM 1

In [2]:
def spatial_conv1(x, kernel, filter_number, dilation):

    dense = keras.layers.Dense(filter_number, activation="relu")(x)
    
    conv1 = keras.layers.Conv2D(filter_number, (kernel, kernel), dilation_rate=dilation, padding='same')(x)
    conv1 = keras.layers.BatchNormalization(axis=3)(conv1)
    conv1 = keras.layers.Activation('relu')(conv1)
    
    squeeze = keras.layers.Conv2D(filter_number//8, (1, 1), padding='same', dilation_rate=1)(x)
    squeeze = keras.layers.BatchNormalization(axis=3)(squeeze)
    squeeze = keras.layers.Activation('relu')(squeeze)
    
    conv2 = keras.layers.Conv2D(filter_number, (kernel, kernel), dilation_rate=dilation, padding='same')(conv1)
    conv2 = keras.layers.BatchNormalization(axis=3)(conv2)
    conv2 = keras.layers.Activation('relu')(conv2)
    
    concat = keras.layers.concatenate([x, conv1, conv2, squeeze])
    
    activate = keras.layers.Dense(filter_number, use_bias=False, activation="sigmoid")(concat)

    return activate*dense

# Fuction for DCSAM 2

In [3]:
def spatial_conv2(x, kernel, filter_number, dilation):
    
    dense = keras.layers.Dense(filter_number, activation="relu")(x)
    
    conv1 = keras.layers.Conv2D(filter_number, (kernel, kernel), dilation_rate=dilation, padding='same')(x)
    conv1 = keras.layers.BatchNormalization(axis=3)(conv1)
    conv1 = keras.layers.Activation('relu')(conv1)
    
    conv2 = keras.layers.Conv2D(filter_number, (kernel, kernel), dilation_rate=dilation, padding='same')(conv1)
    conv2 = keras.layers.BatchNormalization(axis=3)(conv2)
    conv2 = keras.layers.Activation('relu')(conv2)
    
    squeeze = keras.layers.Conv2D(filter_number//8, (1, 1), padding='same', dilation_rate=1)(x)
    squeeze = keras.layers.BatchNormalization(axis=3)(squeeze)
    squeeze = keras.layers.Activation('relu')(squeeze)
    
    conv3 = keras.layers.Conv2D(filter_number, (kernel, kernel), dilation_rate=dilation, padding='same')(conv2)
    conv3 = keras.layers.BatchNormalization(axis=3)(conv3)
    conv3 = keras.layers.Activation('relu')(conv3)
    
    concat = keras.layers.concatenate([x, conv1, conv2, conv3, squeeze])
    
    activate = keras.layers.Dense(filter_number, use_bias=False, activation="sigmoid")(concat)
    
    return activate*dense

# Function for Encoder Block 1

In [14]:
def encoder_block1(input, kernel, filter_number, dilation):
    x = spatial_conv1(input, kernel, filter_number, dilation)
    p = keras.layers.MaxPooling2D((2,2), strides=(2, 2))(x)
    p = keras.layers.Dropout(0.3)(p)
    
    return x, p

# Function for Encoder Block 2

In [15]:
def encoder_block2(input, kernel, filter_number, dilation):
    x = spatial_conv2(input, kernel, filter_number, dilation)
    p = keras.layers.MaxPooling2D((2,2), strides=(2, 2))(x)
    p = keras.layers.Dropout(0.3)(p)
    
    return x, p

# Function for Decoder Block 1

In [6]:
def decoder_block_segment1(input, kernel, skip_features, filter_number, dilation):
    
    x = keras.layers.UpSampling2D(size=(2, 2), data_format="channels_last")(input)
    x = keras.layers.concatenate([x, skip_features])
    x = spatial_conv1(x, kernel, filter_number, dilation)
    
    return x

# Function for Decoder Block 2

In [9]:
def decoder_block_segment2(input, kernel, skip_features, filter_number, dilation):
    
    x = keras.layers.UpSampling2D(size=(2, 2), data_format="channels_last")(input)
    x = keras.layers.concatenate([x, skip_features])
    x = spatial_conv2(x, kernel, filter_number, dilation)
    
    return x

# Function for DCSA_Net

In [33]:
def DCSA_Net(input_shape):
    
    inputs = keras.Input(input_shape)

    conv1 = keras.layers.Conv2D(64, 5, strides=1, dilation_rate=1, padding='same')(inputs)
    conv1 = keras.layers.BatchNormalization(axis=3)(conv1)
    conv1 = keras.layers.Activation("relu")(conv1)
    conv2 = keras.layers.Conv2D(64, 3, strides=1, dilation_rate=1, padding='same')(inputs)
    conv2 = keras.layers.BatchNormalization(axis=3)(conv2)
    conv2 = keras.layers.Activation("relu")(conv2)
    s1 = keras.layers.concatenate([conv1, conv2])
    p1 = keras.layers.MaxPooling2D((2,2), strides=(2, 2))(s1)
    
    s2, p2 = encoder_block1(p1, 3, 128, 1)
    s3, p3 = encoder_block1(p2, 3, 256, 2)
    s4, p4 = encoder_block2(p3, 3, 512, 3)
    
    bridge = keras.layers.Conv2D(512, 1, strides=1, dilation_rate=1, padding='same')(p4)
    bridge = keras.layers.BatchNormalization()(bridge)
    bridge = keras.layers.Activation("relu")(bridge) #Bridge
    
    d1 = decoder_block_segment1(bridge, 3, s4, 512, 1)
    d2 = decoder_block_segment1(d1, 3, s3, 256, 1)
    d3 = decoder_block_segment2(d2, 3, s2, 128, 1)
    
    d4 = keras.layers.UpSampling2D(size=(2, 2), data_format="channels_last")(d3)
    d4 = keras.layers.concatenate([d4, s1])
    
    conv3 = keras.layers.Conv2D(64, 5, strides=1, dilation_rate=1, padding='same')(d4)
    conv3 = keras.layers.BatchNormalization(axis=3)(conv3)
    conv3 = keras.layers.Activation("relu")(conv3)
    conv4 = keras.layers.Conv2D(64, 3, strides=1, dilation_rate=1, padding='same')(d4)
    conv4 = keras.layers.BatchNormalization(axis=3)(conv4)
    conv4 = keras.layers.Activation("relu")(conv4)
    
    concat = keras.layers.concatenate([conv3, conv4])
    
    cnn_feature_vector = keras.layers.GlobalMaxPooling2D()(concat)
    attention_weights = keras.layers.Dense(units=128, activation='sigmoid')(cnn_feature_vector)
    attention_weights = keras.layers.Reshape((1,1,128))(attention_weights)
    weighted_image_output = keras.layers.Multiply()([concat, attention_weights])
    merged_image = keras.layers.concatenate([inputs, weighted_image_output])
    merged_image = keras.layers.Conv2D(3, (3, 3), padding='same', activation='relu')(merged_image)
    
    outputs = keras.layers.Conv2D(1, (1,1), padding='same', activation='sigmoid')(merged_image)
    
    model = keras.Model(inputs, outputs, name="DCSA_Net")
    
    return model

In [34]:
DCSA_model = DCSA_Net(input_shape)

DCSA_model.summary()

Model: "DCSA_Net"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
conv2d_29 (Conv2D)              (None, 128, 128, 64) 4864        input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_30 (Conv2D)              (None, 128, 128, 64) 1792        input_3[0][0]                    
__________________________________________________________________________________________________
batch_normalization_27 (BatchNo (None, 128, 128, 64) 256         conv2d_29[0][0]                  
___________________________________________________________________________________________

In [20]:
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2.0 * intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.0)

In [21]:
def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

In [22]:
def jacard_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)

In [23]:
def jacard_coef_loss(y_true, y_pred):
    return -jacard_coef(y_true, y_pred)

In [24]:
callback = [keras.callbacks.ReduceLROnPlateau(factor=0.1, patience=5, cooldown=5, verbose=1),
            ModelCheckpoint('weight.h5', verbose=1, save_best_only=True, save_weights_only=True)]

In [25]:
DCSA_model.compile(optimizer="Adam", loss=['binary_crossentropy', dice_coef_loss, jacard_coef_loss], metrics=['accuracy', dice_coef, jacard_coef])

In [None]:
results = DCSA_model.fit(X_train, y_train, batch_size=batch_sizes, epochs=epoch, validation_data=(X_test, y_test), callbacks=callback, shuffle=True)

# Model Testing

In [None]:
y_pred = DCSA_Net.predict(image_dataset_Test)
y_pred_thresholded = y_pred > 0.5

gt_flat = mask_dataset_Test.flatten()
pred_flat = y_pred_thresholded.flatten()

accuracy = accuracy_score(gt_flat, pred_flat)

intersection = np.logical_and(mask_dataset_Test, y_pred_thresholded)
union = np.logical_or(mask_dataset_Test, y_pred_thresholded)

DiceCoef = 2. * intersection.sum() / (mask_dataset_Test.sum() + y_pred_thresholded.sum())

iou_score = np.sum(intersection) / np.sum(union)

print("Accuracy score is: ", accuracy)
print("Dice score is: ", DiceCoef)
print("IoU score is: ", iou_score)

# Test Single Image, Patching 512x512 to 128x128

In [46]:
from patchify import patchify, unpatchify
import cv2

image = cv2.imread("Image.bmp") ## size 512

# convert to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

normalize_image = Stain_Normalization(image)

r, g, b = cv2.split(normalize_image)

R_patches = patchify(r, (128, 128), step=64)
G_patches = patchify(g, (128, 128), step=64)
B_patches = patchify(b, (128, 128), step=64)

rgb = cv2.merge((R_patches,G_patches,B_patches))

rgb = rgb/255.

# Prediction and Merging from 128x128 to 512x512

In [49]:
def cell_segmentation(array):

    predict = []

    for i in range(array.shape[0]):
        for j in range(array.shape[1]):
        
            single_patch = array[i,j,:,:]

#             label, prob = model_StarDist.predict_instances(normalize(single_patch))
            expand_dim = np.expand_dims(single_patch, 0)
        
            prediction = (model.predict(expand_dim)[0,:,:,0] > 0.5).astype(np.uint8)
            
            predict.append(prediction)
    
    predict_array = np.array(predict)
    predict_array_reshaped = np.reshape(predict_array, (R_patches.shape[0], R_patches.shape[0], 128, 128))
    reconstructed_image = unpatchify(predict_array_reshaped, r.shape)
    binary = np.squeeze(reconstructed_image*255).astype('uint8')
    prediction_processing = post_processing(binary)
    
    return prediction_processing

# Function for Post-processing

In [50]:
from scipy import ndimage

def post_processing(binary):
    
    four_connectivity = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5))
    opening = cv2.morphologyEx(binary, cv2.MORPH_OPEN, four_connectivity, iterations=1)
    fill = ndimage.binary_fill_holes(opening).astype('uint8')

    label_nuclei, n_output, n_stats, n_centroids = cv2.connectedComponentsWithStats(fill, connectivity=8)

    label_nuclei = label_nuclei - 1

    n_sizes = n_stats[1:, -1]

    window = np.zeros((binary.shape))
    
    for i in range(0, label_nuclei):
    
        if n_sizes[i] >= 20:
            window[n_output == i + 1] = 255

    noise_remove = (window).astype(np.uint8)
    
    return noise_remove

In [None]:
Nuclei_Segmentation = cell_segmentation(rgb)

In [None]:
fig = plt.figure(figsize=(10,10))
plt.imshow(Nuclei_Segmentation, cmap='gray')
plt.axis('off')
plt.show()