In [1]:
import tensorflow as tf
import numpy as np
import cv2
from tensorflow.keras import applications
import os
import csv
import sys
import pandas as pd
import matplotlib.pyplot as plt
import random
from tqdm import tqdm
import tensorflow.keras.backend as K


In [2]:
os.chdir('D:/hubmap segmentation challenge/')

In [12]:
def mask_to_rle(img, size):
    pixels = cv2.resize(img,size).T.flatten()
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [13]:
def rle_to_mask(rle_string, size):
    rle = np.array(list(map(int, rle_string.split())))
    label = np.zeros((size*size), dtype=np.uint8)
    for start, end in zip(rle[::2], rle[1::2]):
        label[start:start+end] = 1
    return label.reshape(size, size).T

In [4]:
class KSAC_layer(tf.keras.layers.Layer):
    def __init__(self, input_shape, filters, dilation_rates=[6, 12, 18], batchnorm=True):
        super().__init__()
        self.dilation_rates = dilation_rates
        self.batchnorms = []
        self.filters = filters
        if batchnorm:
            self.batchnorms = [tf.keras.layers.BatchNormalization() for _ in dilation_rates]
        self.kernel_initializer = tf.keras.initializers.GlorotUniform()
        self.kernel_shape = (3, 3, input_shape[-1], filters)
        self.kernel = tf.Variable(self.kernel_initializer(self.kernel_shape), trainable=True)

    def call(self, x, training=False):
        feature_maps = [tf.nn.conv2d(x, self.kernel, (1, 1), 'SAME', dilations=d) for d in self.dilation_rates]
        if len(self.batchnorms) > 0:
            for i in range(len(feature_maps)):
                feature_maps[i] = self.batchnorms[i](feature_maps[i], training=training)
        return sum(feature_maps)

In [5]:
class KSAC_pooling(tf.keras.layers.Layer):
    def __init__(self, filters, batchnorm = False):
        super().__init__()
        self.filters = filters
        self.batchnorm = []
        if batchnorm:
            self.batchnorm = tf.keras.layers.BatchNormalization()
        self.conv_layer = tf.keras.layers.Conv2D(filters, 1, (1,1))

    def call(self, x):
        x = tf.keras.layers.GlobalAveragePooling2D(keepdims=True)(x)
        x = self.conv_layer(x)
        if self.batchnorm != []:
            x = self.batchnorm(x)
        return tf.image.resize(images=x, size=x.shape[1:-1])

In [6]:
class KSAC_block(tf.keras.layers.Layer):
    def __init__(self, filters, input_shape, dilation_rate=[6, 12, 18], batchnorm=True):
        super().__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters, 1, (1, 1))
        self.batchnorm = []
        if batchnorm:
            self.batchnorm = tf.keras.layers.BatchNormalization()


        self.ksac_layer = KSAC_layer(input_shape, filters, dilation_rate, batchnorm)
        self.ksac_pooling = KSAC_pooling(filters, batchnorm)
        self.bias = tf.Variable(tf.zeros_initializer()((filters,)), trainable=True, name='bias')

    def call(self, x):
        y = self.conv1(x)
        if self.batchnorm != []:
            y = self.batchnorm(y)
        return tf.nn.relu(y + self.ksac_layer(x) + self.ksac_pooling(x) + self.bias)

In [7]:
class DeepLabV3_Decoder(tf.keras.layers.Layer):
    def __init__(self, filters, out_size, batchnorm=True):
        super().__init__()
        self.batchnorm = []
        if batchnorm:
            self.batchnorm = tf.keras.layers.BatchNormalization()
        self.conv1 = tf.keras.layers.Conv2D(filters, 1, (1, 1))
        self.conv2 = tf.keras.layers.Conv2D(1, 3, (1, 1), "SAME", activation='sigmoid')
        self.out_size = out_size

    def call(self, x1, x2):
        x2 = self.conv1(x2)
        x2 = self.batchnorm(x2)
        x2 = tf.nn.relu(x2)
        x2 = tf.image.resize(images=x2, size=x1.shape[1:-1])
        x = tf.concat([x1, x2], axis=-1)
        x = self.conv2(x)
        x = tf.image.resize(images=x, size=self.out_size)
        return x

In [8]:
class KSAC_network(tf.keras.Model):
    def __init__(self, input_shape, filters, dilation_rate=[6,12,18], batchnorm=True):
        super().__init__()
        resnet_backbone = applications.resnet50.ResNet50(
            include_top=False,
            weights=None,
            input_tensor=None,
            input_shape=input_shape,
            pooling=None,
            classes=1000,
        )
        resnet_backbone = tf.keras.Model(inputs=resnet_backbone.inputs,
                                         outputs=[resnet_backbone.get_layer('conv3_block4_out').output,
                                                  resnet_backbone.get_layer('conv4_block6_out').output])
        #print(resnet_backbone.summary())

        x = tf.keras.Input(input_shape)
        x1, x2 = resnet_backbone(x)
        print(x.shape)
        # print(x1.shape, x2.shape)
        x2 = KSAC_block(filters, x2.shape, dilation_rate, batchnorm)(x2)
        print(x.shape)
        logits = DeepLabV3_Decoder(filters, input_shape[:-1])(x1,x2)
        self.model = tf.keras.Model(inputs=x, outputs=logits)
        self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    def compile(self, optimizer, *args, **kwargs):
        self.focal_loss_metric = keras.metrics.Mean(name="focal_loss")
        self.accuracy_metric = keras.metrics.Mean(name='accuracy')
        self.optimizer = optimizer
        super(KSAC_network, self).compile(*args, **kwargs)

    

In [9]:
ksac_network_1 = KSAC_network((512,512,3), 128)
ksac_network_2 = KSAC_network((512,512,3), 128)
ksac_network_1.model.load_weights('models/ksac_network_weights/ksac_model_100.h5')
ksac_network_2.model.load_weights('models/ksac_network_weights/ksac_model_2_100.h5')

(None, 512, 512, 3)
(None, 512, 512, 3)
(None, 512, 512, 3)
(None, 512, 512, 3)


In [10]:
def infer(model, image):
    """going to flip the model and infer from all possible angles then find overlapping mask"""
    threshold = 0.2
    
    logits_normal = model(image)
    normal_binarised = (logits_normal.numpy() > threshold).astype(np.uint8)
    
    logits_horizontal_flip = np.flip(model(np.flip(image, axis=0)), axis=0)
    horizontal_binarised = (logits_horizontal_flip.numpy() > threshold).astype(np.uint8)
    
    logits_vertical_flip = np.flip(model(np.flip(image, axis=1)), axis=1)
    vertical_binarised = (logits_vertical_flip.numpy() > threshold).astype(np.uint8)
    
    total_mask = (normal_binarised+horizontal_binarised+vertical_binarised).astype(np.float32)
    #including on the final mask any region which more than one flip agrees on
    total_mask = (total_mask > 2.0).astype(np.uint8)
    return total_mask

In [69]:
def ensemble_infer(models, image):
    total_mask = np.zeros((np.shape(image)[1], np.shape(image)[2],1))
    for model in models:
        total_mask += infer(model, image).squeeze(0)
    total_mask = (total_mask > len(models)/2).astype(np.uint8)
    return total_mask

In [100]:
test_data = pd.read_csv('test.csv').values

In [103]:
def test(models, data):
    output_dict = {'id':[],'rle':[]}
    for i in range(0,len(data)):
        #print(str(data[i]))
        image = (cv2.imread(('test_images/'+str(data[i][0])+'.tiff')))
        image = np.expand_dims(cv2.resize((image/255),(512,512)),0)
        mask = ensemble_infer(models,image)
        #print(data[i][4])
        rle = mask_to_rle(mask, (data[i][4],data[i][4]))
        output_dict['id'].append(str(data[i][0]))
        output_dict['rle'].append(rle)
    return output_dict

In [104]:
output_dict = test([ksac_network_1.model, ksac_network_2.model], test_data)

In [93]:
output_dataframe = pd.DataFrame(output_dict)

In [95]:
output_dataframe.to_csv('output_test.csv')