## Train Knowledge Distillation model
This notebook's idea is alomost same as 3 stage training's step2 in [RANZCR / ResNet200D / 3-stage training / step2](https://www.kaggle.com/yasufuminakama/ranzcr-resnet200d-3-stage-training-step2).

Keras version implementation is borrowed from Keras official "Knowledge Distillation" example code.
[Knowledge Distillation]
(https://keras.io/examples/vision/knowledge_distillation/#construct-distiller-class)

   

Teacher model training notebook:  
[[Keras TPU] RANZCR Train annotation](https://www.kaggle.com/enukuro/keras-tpu-ranzcr-train-annotation)


Creatting annotation tfrecords notebook:   
[Annotation RANZCR CLiP 900](https://www.kaggle.com/enukuro/annotation-ranzcr-clip-900)

In [None]:
import os

import numpy as np
import pandas as pd
import random
import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow_addons as tfa
from keras import backend as K
from kaggle_datasets import KaggleDatasets
from keras.applications.xception import Xception as BaseModel
from keras.applications.xception import preprocess_input
import itertools
import gc

In [None]:
import tensorflow as tf
print("Tensorflow version " + tf.__version__)

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
SEED = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    
seed_everything(SEED)

BATCH_SIZE = strategy.num_replicas_in_sync * 16
IMG_SIZE = 900
NUM_CLASSES = 11

target_fold = 1

In [None]:
feature_map = {
    'image': tf.io.FixedLenFeature([], tf.string),
    'image_annotation': tf.io.FixedLenFeature([], tf.string),
    'StudyInstanceUID': tf.io.FixedLenFeature([], tf.string),  
    'ETT - Abnormal': tf.io.FixedLenFeature([], tf.int64),
    'ETT - Borderline': tf.io.FixedLenFeature([], tf.int64),
    'ETT - Normal': tf.io.FixedLenFeature([], tf.int64),
    'NGT - Abnormal': tf.io.FixedLenFeature([], tf.int64),
    'NGT - Borderline': tf.io.FixedLenFeature([], tf.int64),
    'NGT - Incompletely Imaged': tf.io.FixedLenFeature([], tf.int64),
    'NGT - Normal': tf.io.FixedLenFeature([], tf.int64),
    'CVC - Abnormal': tf.io.FixedLenFeature([], tf.int64),
    'CVC - Borderline': tf.io.FixedLenFeature([], tf.int64),
    'CVC - Normal': tf.io.FixedLenFeature([], tf.int64),
    'Swan Ganz Catheter Present': tf.io.FixedLenFeature([], tf.int64)}

def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image = tf.reshape(image, [IMG_SIZE, IMG_SIZE, 3])
    return tf.cast(image, tf.float32)

def read_tfrecord(example):
    example = tf.io.parse_single_example(example, feature_map)
    image = decode_image(example['image'])
    image_annotation = decode_image(example['image_annotation'])
    target = [
        example['ETT - Abnormal'],
        example['ETT - Borderline'],
        example['ETT - Normal'],
        example['NGT - Abnormal'],
        example['NGT - Borderline'],
        example['NGT - Incompletely Imaged'],
        example['NGT - Normal'],
        example['CVC - Abnormal'],
        example['CVC - Borderline'],
        example['CVC - Normal'],
        example['Swan Ganz Catheter Present']]
    return [preprocess_input(image), preprocess_input(image_annotation)], tf.cast(target, tf.float32)

def data_augment(img, target):
    img = tf.map_fn(lambda x: tf.image.random_flip_left_right(x), img)
    return img, target

def get_dataset(filenames, shuffled=False, repeated=False, 
                cached=False, augmented=False):
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    if cached:
        dataset = dataset.cache()
    if shuffled:
        dataset = dataset.shuffle(1024, seed=SEED)
    if augmented:
        dataset = dataset.map(data_augment, num_parallel_calls=AUTOTUNE)
    if repeated:
        dataset = dataset.repeat()
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset



In [None]:
# https://keras.io/examples/vision/knowledge_distillation/#construct-distiller-class
class Distiller(tf.keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
        """
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        
    @tf.function
    def train_step(self, data):
        x, y = data
        image, image_annotation = tf.split(x, 2, axis=1)
        image = tf.squeeze(image)
        image_annotation = tf.squeeze(image_annotation)
        
        teacher_predictions, teacher_features = self.teacher(image_annotation, training=False)
        with tf.GradientTape() as tape:
            student_predictions, student_features = self.student(image, training=True)
         
            student_loss = self.student_loss_fn(y, student_predictions)
            student_loss = tf.reduce_sum(student_loss * (1. / BATCH_SIZE))
            distillation_loss = self.distillation_loss_fn(tf.reshape(teacher_features, [BATCH_SIZE, -1]), tf.reshape(student_features, [BATCH_SIZE, -1]))
            # distillation_loss = tf.reduce_sum(distillation_loss * (1. / BATCH_SIZE))
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
            
        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results
    
    @tf.function
    def valid_step(self, data):
        x, y = data
        image, image_annotation = tf.split(x, 2, axis=1)
        image = tf.squeeze(image)
        # Compute predictions
        y_prediction, _ = self.student(image, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)
        student_loss = tf.reduce_sum(student_loss * (1. / BATCH_SIZE))
        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results
    
    @tf.function
    def test_step(self, data):
        x, y = data
        image, image_annotation = tf.split(x, 2, axis=1)
        image = tf.squeeze(image)
        # Compute predictions
        y_prediction, _ = self.student(image, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)
        student_loss = tf.reduce_sum(student_loss * (1. / BATCH_SIZE))
        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

In [None]:
def get_model(name):
    base_model = BaseModel(input_shape=(IMG_SIZE,IMG_SIZE,3), include_top=False, weights='imagenet', pooling="avg")
    base_model_output = base_model.output
    x = L.Dropout(0.5)(base_model_output)
    outputs = L.Dense(NUM_CLASSES, activation="sigmoid")(x)

    model = tf.keras.models.Model(inputs=base_model.input, outputs=[outputs, base_model_output], name=name)
    return model

def get_distiller_model(fold=0):
    with strategy.scope():
        student = get_model('student')
        teacher = get_model('teacher')
        teacher.load_weights(f'../input/ranzcr-annotation-teacher/teacher_model_{fold}.h5')

        distiller = Distiller(student=student, teacher=teacher)
        distiller.compile(
            optimizer=tf.keras.optimizers.Adam(lr=1e-3),
            metrics=[tf.keras.metrics.AUC(multi_label=True)],
            student_loss_fn=tfa.losses.SigmoidFocalCrossEntropy(alpha = 0.5, gamma = 2, reduction=tf.keras.losses.Reduction.NONE),
            distillation_loss_fn=tf.keras.losses.MSE,
            alpha=0.3
        )
    return distiller

In [None]:
TF_REC_DS_PATH = KaggleDatasets().get_gcs_path('ranzcr-annotation-900-tfrecords')

tfrec_files = []
for fold in range(5):
    training_files = [TF_REC_DS_PATH + f'/{fold}_{num}.tfrec' for num in range(0,5)]
    random.shuffle(training_files)
    tfrec_files.append(training_files)

In [None]:
fold_item_counts = [1804, 1783, 1809, 1851, 1848]

for fold in range(5):
    
    if fold != target_fold:
        continue
        
    print(f'fold_{fold} start')
    train_filenames = list(itertools.chain.from_iterable([tfrec_files[i] for i in range(5) if i != fold]))
    val_filenames = tfrec_files[fold]

    random.shuffle(train_filenames)

    train_dataset = get_dataset(train_filenames, shuffled=True, augmented=True, repeated=True)
    val_dataset = get_dataset(val_filenames, shuffled=False, cached=True)

    steps_per_epoch = (sum(fold_item_counts) - fold_item_counts[fold]) // BATCH_SIZE
    validation_steps = fold_item_counts[fold] // BATCH_SIZE
    
    model = get_distiller_model(fold)
    
    sv = tf.keras.callbacks.ModelCheckpoint(f'distiller_model_{fold}.h5', monitor='val_student_loss', verbose=1, save_best_only=True, save_weights_only=True)
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_student_loss', verbose=1, factor=0.1, patience=3, min_delta=0.0001, min_lr=1e-6)

    model.fit(
        train_dataset,
        steps_per_epoch=steps_per_epoch,
        epochs=15,
        callbacks=[reduce_lr, sv],
        validation_data=val_dataset,
    )
    model.built = True
    model.load_weights(f'./distiller_model_{fold}.h5')
    model.get_layer('student').save_weights(f'student_model_{fold}.h5')
    
#     tf.keras.backend.clear_session()
#     del model
#     gc.collect()
    