In [None]:
import os

import numpy as np
import pandas as pd
import random
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import KFold
import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow_addons as tfa
from keras.applications.xception import Xception as BaseModel
from keras.applications.xception import preprocess_input
from kaggle_datasets import KaggleDatasets
import itertools

In [None]:
import tensorflow as tf
print("Tensorflow version " + tf.__version__)

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
SEED = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    
seed_everything(SEED)
BATCH_SIZE = strategy.num_replicas_in_sync * 16


DEBUG = False
IMG_SIZE = 900
NUM_CLASSES = 11


In [None]:
feature_map = {
    'image': tf.io.FixedLenFeature([], tf.string),
    'image_annotation': tf.io.FixedLenFeature([], tf.string),
    'StudyInstanceUID': tf.io.FixedLenFeature([], tf.string),  
    'ETT - Abnormal': tf.io.FixedLenFeature([], tf.int64),
    'ETT - Borderline': tf.io.FixedLenFeature([], tf.int64),
    'ETT - Normal': tf.io.FixedLenFeature([], tf.int64),
    'NGT - Abnormal': tf.io.FixedLenFeature([], tf.int64),
    'NGT - Borderline': tf.io.FixedLenFeature([], tf.int64),
    'NGT - Incompletely Imaged': tf.io.FixedLenFeature([], tf.int64),
    'NGT - Normal': tf.io.FixedLenFeature([], tf.int64),
    'CVC - Abnormal': tf.io.FixedLenFeature([], tf.int64),
    'CVC - Borderline': tf.io.FixedLenFeature([], tf.int64),
    'CVC - Normal': tf.io.FixedLenFeature([], tf.int64),
    'Swan Ganz Catheter Present': tf.io.FixedLenFeature([], tf.int64)}

def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image = tf.reshape(image, [IMG_SIZE, IMG_SIZE, 3])
    return tf.cast(image, tf.float32)

def read_tfrecord(example):
    example = tf.io.parse_single_example(example, feature_map)
    image = decode_image(example['image_annotation'])
    target = [
        example['ETT - Abnormal'],
        example['ETT - Borderline'],
        example['ETT - Normal'],
        example['NGT - Abnormal'],
        example['NGT - Borderline'],
        example['NGT - Incompletely Imaged'],
        example['NGT - Normal'],
        example['CVC - Abnormal'],
        example['CVC - Borderline'],
        example['CVC - Normal'],
        example['Swan Ganz Catheter Present']]
    return image, tf.cast(target, tf.float32)


def data_augment(img, target):
    img = tf.image.random_flip_left_right(img)
    return img, target

def preprocess_input_image(img, target):
    return preprocess_input(img), target

def get_dataset(filenames, shuffled=False, repeated=False, 
                cached=False, augmented=False):
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    if cached:
        dataset = dataset.cache()
    if shuffled:
        dataset = dataset.shuffle(1024, seed=SEED)
    dataset = dataset.map(preprocess_input_image, num_parallel_calls=AUTOTUNE)
    if augmented:
        dataset = dataset.map(data_augment, num_parallel_calls=AUTOTUNE)
    
    if repeated:
        dataset = dataset.repeat()
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset



In [None]:
def get_model():
    with strategy.scope():
        base_model = BaseModel(input_shape=(IMG_SIZE,IMG_SIZE,3), include_top=False, weights='imagenet')
        x = base_model.output
        x = L.GlobalAveragePooling2D()(x)
        x = L.Dropout(0.5)(x)
        outputs = L.Dense(NUM_CLASSES, activation="sigmoid")(x)
    
        model = tf.keras.models.Model(inputs=base_model.input, outputs=outputs)
        model.compile(
            optimizer=tf.keras.optimizers.Adam(lr=1e-3),
            loss=tfa.losses.SigmoidFocalCrossEntropy(alpha = 0.5, gamma = 2, reduction=tf.keras.losses.Reduction.NONE),
            metrics=[tf.keras.metrics.AUC(multi_label=True)])
        # model.summary()
    return model

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path("ranzcr-annotation-900-tfrecords")
TF_REC_DS_PATH = GCS_DS_PATH

tfrec_files = []
for fold in range(5):
    training_files = [TF_REC_DS_PATH + f'/{fold}_{num}.tfrec' for num in range(0,5)]
    random.shuffle(training_files)
    tfrec_files.append(training_files)

len(tfrec_files)

In [None]:
# fold_item_counts = [6018,5939,6033,6176,5917]

# annotation
fold_item_counts = [1804, 1783, 1809, 1851, 1848]
for fold in range(5):

    tf.keras.backend.clear_session()
    
    train_filenames = list(itertools.chain.from_iterable([tfrec_files[i] for i in range(5) if i != fold]))
    val_filenames = tfrec_files[fold]
    random.shuffle(train_filenames)

    train_dataset = get_dataset(train_filenames, shuffled=True, augmented=True, repeated=True)
    val_dataset = get_dataset(val_filenames, shuffled=False, cached=True)

    steps_per_epoch = (sum(fold_item_counts) - fold_item_counts[fold]) // BATCH_SIZE
    validation_steps = fold_item_counts[fold] // BATCH_SIZE
    
    model = get_model()
    
#     es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', verbose=1, patience=3)
    sv = tf.keras.callbacks.ModelCheckpoint(f'teacher_model_{fold}.h5', monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=False)
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', verbose=1, factor=0.1, patience=2, min_delta=0.0001, min_lr=1e-6)

    history = model.fit(
        train_dataset,
        steps_per_epoch=steps_per_epoch,
        epochs=20,
        callbacks=[reduce_lr, sv],
        validation_data=val_dataset,
    )


