In [None]:
from get_data import *
from metrics import MultipleClassAUROC
from model import *
import tensorflow as tf

In [None]:
DATASET_PATH = os.path.abspath('../dataset/')

In [None]:
data, all_labels = get_data()
train_df, valid_df, test_df = split_train_dev_test(data)
train_gen = get_data_generator(train_df, all_labels)
valid_gen = get_data_generator(valid_df, all_labels)
test_gen = get_data_generator(test_df, all_labels)

In [None]:
model = get_model()

In [None]:
import tensorflow_model_optimization as tfmot

In [None]:
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope

In [None]:
class DefaultBNQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
    def get_weights_and_quantizers(self, layer):
        return []
    
    def get_activations_and_quantizers(self, layer):
        return []
    
    def set_quantize_weights(self, layer, quantize_weights):
        pass

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    def get_output_quantizers(self, layer):
        return [tfmot.quantization.keras.quantizers.MovingAverageQuantizer(
    num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]

    def get_config(self):
        return {}

In [None]:
def apply_quantization_to_batch_normalization(layer):
    if isinstance(layer, tf.keras.layers.BatchNormalization):
        return quantize_annotate_layer(layer, DefaultBNQuantizeConfig())
    
    return layer

In [None]:
annotated_model = tf.keras.models.clone_model(
                    model,
                    clone_function=apply_quantization_to_batch_normalization,
)

In [None]:
with quantize_scope(
  {'DefaultBNQuantizeConfig': DefaultBNQuantizeConfig}):
  # Use `quantize_apply` to actually make the model quantization aware.
  quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

In [None]:
quant_aware_model.summary()

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint
model_train = quant_aware_model
output_weights_name='FP_32_QAT_weights.h5'
checkpoint = ModelCheckpoint(
             output_weights_name,
             save_weights_only=True,
             save_best_only=True,
             verbose=1,
            )

In [None]:
training_stats = {}
auroc = MultipleClassAUROC(
    generator=valid_gen,
    class_names=all_labels,
    weights_path=output_weights_name,
    stats=training_stats
)

In [None]:
from tensorflow.keras.optimizers import Adam
initial_learning_rate=1e-3
optimizer = Adam(lr=initial_learning_rate)
model_train.compile(optimizer=optimizer, loss="binary_crossentropy")

In [None]:
from tensorflow.keras.callbacks import TensorBoard, ReduceLROnPlateau
logs_base_dir = os.getcwd()
patience_reduce_lr=2
min_lr=1e-8
callbacks = [
            checkpoint,
            TensorBoard(log_dir=os.path.join(logs_base_dir, "logs"), batch_size=train_gen.batch_size),
            ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=patience_reduce_lr,
                              verbose=1, mode="min", min_lr=min_lr),
            auroc,
        ]

In [None]:
epochs=20
fit_history = model_train.fit_generator(
    generator=train_gen,
    steps_per_epoch=train_gen.n/train_gen.batch_size,
    epochs=epochs,
    validation_data=valid_gen,
    validation_steps=valid_gen.n/valid_gen.batch_size,
    callbacks=callbacks,
    shuffle=False
)

In [None]:
import matplotlib.pyplot as plt

plt.figure(1, figsize = (15,8)) 
    
plt.subplot(222)  
plt.plot(fit_history.history['loss'])  
plt.plot(fit_history.history['val_loss'])  
plt.title('model loss')  
plt.ylabel('loss')  
plt.xlabel('epoch')  
plt.legend(['train', 'valid']) 

plt.show()

In [None]:
pred_y = model_train.predict_generator(test_gen, steps=test_gen.n/test_gen.batch_size, verbose = True)

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
test_gen.reset()
test_x, test_y = next(test_gen)
# Space
fig, c_ax = plt.subplots(1,1, figsize = (9, 9))
for (idx, c_label) in enumerate(all_labels):
    #Points to graph
    fpr, tpr, thresholds = roc_curve(test_gen.labels[:,idx].astype(int), pred_y[:,idx])
    c_ax.plot(fpr, tpr, label = '%s (AUC:%0.2f)'  % (c_label, auc(fpr, tpr)))
    
#convention
c_ax.legend()

#Labels
c_ax.set_xlabel('False Positive Rate')
c_ax.set_ylabel('True Positive Rate')

# Save as a png
fig.savefig('QAT_FP32.png')