# Importing Libraries

In [None]:
import os
import pandas as pd
import numpy as np
import glob2
import random
import cv2
from skimage import io
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from model import *
from losses import *
from data_generator import DataGenerator

import tensorflow as tf
%matplotlib inline

In [None]:
DATASET_PATH = os.path.abspath('../dataset/lgg-mri-segmentation/kaggle_3m/')

In [None]:
data = pd.read_csv(os.path.join(DATASET_PATH + '/data.csv'))
data.info()

In [None]:
data.head()

In [None]:
images = sorted(glob2.glob(DATASET_PATH + '/**/*.tif'))
len(images)

In [None]:
patient_id = [x.split('/')[-2] for x in images]
patient_id[:5]

In [None]:
df = pd.DataFrame(list(zip(patient_id, images)), columns=['patient_id', 'image_path'])
df.head()

In [None]:
df_imgs = df[~df['image_path'].str.contains("mask")] # if have not mask
df_masks = df[df['image_path'].str.contains("mask")]# if have mask

# File path line length images for later sorting
BASE_LEN = len(DATASET_PATH + '/TCGA_DU_6408_19860521/TCGA_DU_6408_19860521_')
END_IMG_LEN = 4
END_MASK_LEN = 9

# Data sorting
imgs = sorted(df_imgs["image_path"].values, key=lambda x : int(x[BASE_LEN:-END_IMG_LEN]))
masks = sorted(df_masks["image_path"].values, key=lambda x : int(x[BASE_LEN:-END_MASK_LEN]))

# Sorting check
idx = random.randint(0, len(imgs)-1)
print("Path to the Image:", imgs[idx], "\nPath to the Mask:", masks[idx])

In [None]:
# Final dataframe
brain_df = pd.DataFrame({"patient_id": df_imgs.patient_id.values,
                         "image_path": imgs,
                         "mask_path": masks
                        })

def has_mask(mask_path):
    value = np.max(cv2.imread(mask_path))
    if value > 0: 
        return 1
    else:
        return 0
    
brain_df['mask'] = brain_df['mask_path'].apply(lambda x: has_mask(x))
brain_df

# Data Visualization

Please refer to the code provided in `brain_mri_FP32_training_resunet.ipynb`

# Data Split 

In [None]:
brain_df_mask = brain_df[brain_df['mask'] == 1]
brain_df_mask.shape

In [None]:
# creating test, train and val sets

X_train, X_val = train_test_split(brain_df_mask, test_size=0.15)
X_test, X_val = train_test_split(X_val, test_size=0.5)
print("Train size is {}, valid size is {} & test size is {}".format(len(X_train), len(X_val), len(X_test)))

train_ids = list(X_train.image_path)
train_mask = list(X_train.mask_path)

val_ids = list(X_val.image_path)
val_mask= list(X_val.mask_path)

# Data Generator

In [None]:
train_data = DataGenerator(train_ids, train_mask)
val_data = DataGenerator(val_ids, val_mask)

# QAT Modelling

In [None]:
import tensorflow_model_optimization as tfmot
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

In [None]:
class DefaultBNQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
    def get_weights_and_quantizers(self, layer):
        return []
    
    def get_activations_and_quantizers(self, layer):
        return []
    
    def set_quantize_weights(self, layer, quantize_weights):
        pass

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    def get_output_quantizers(self, layer):
        return [tfmot.quantization.keras.quantizers.MovingAverageQuantizer(
    num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]

    def get_config(self):
        return {}

In [None]:
def apply_quantization_to_batch_normalization(layer):
    if isinstance(layer, tf.keras.layers.BatchNormalization):
        return quantize_annotate_layer(layer, DefaultBNQuantizeConfig())
    
    return layer

# Loading FP32 model

In [None]:
seg_model = get_model()
seg_model.load_weights('./ResUNet-segModel-weights.hdf5')

# Making model Quantization Aware

In [None]:
annotated_model = tf.keras.models.clone_model(
                    seg_model,
                    clone_function=apply_quantization_to_batch_normalization,
)

In [None]:
with quantize_scope(
    {'DefaultBNQuantizeConfig': DefaultBNQuantizeConfig}):
    # Use `quantize_apply` to actually make the model quantization aware.
    quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

In [None]:
quant_aware_model.summary()

In [None]:
# save the best model with lower validation loss
qaware_checkpointer = ModelCheckpoint(filepath="QAware_ResUNet-segModel-weights.hdf5", 
                                verbose=1, 
                                save_best_only=True
                            )

adam = tf.keras.optimizers.Adam(lr = 0.05, epsilon = 0.1)

quant_aware_model.compile(optimizer = adam, 
                  loss = focal_tversky, 
                  metrics = [tversky, dice_coef]
                 )

earlystopping = EarlyStopping(monitor='val_loss',
                              mode='min', 
                              verbose=1, 
                              patience=20
                             )

reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                              mode='min',
                              verbose=1,
                              patience=10,
                              min_delta=0.0001,
                              factor=0.2
                             )

In [None]:
quant_h = quant_aware_model.fit(train_data, 
                    epochs = 60, 
                    validation_data = val_data,
                    callbacks = [qaware_checkpointer, earlystopping, reduce_lr]
                )

# Evaluation

In [None]:
test_ids = list(X_test.image_path)
test_mask = list(X_test.mask_path)
test_data = DataGenerator(test_ids, test_mask)
_, tv, dice = quant_aware_model.evaluate(test_data)
print("Segmentation tversky is {:.2f}%".format(tv*100))
print("Segmentation Dice is {:.2f}".format(dice))

# Converting weights to INT8 and saving it

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()

In [None]:
with open('QAT_INT8_Brain_MRI_Segmentation.tflite', 'wb') as f:
    f.write(quantized_tflite_model)