In [None]:
### Installing Keras Package for Vision Transformer
!pip install vit-keras


In [None]:
### Importing Necessary Packages
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf, tensorflow.keras.backend as K
import re,random,os,math
import pandas as pd
from sklearn.model_selection import KFold
from vit_keras import vit, utils

from kaggle_datasets import KaggleDatasets
GCS_DS_PATH = KaggleDatasets().get_gcs_path('ranzcr-clip-catheter-line-classification') 
TRAINING_FILENAMES = tf.io.gfile.glob(GCS_DS_PATH +'/train_tfrecords/*.tfrec') #+ tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
TEST_FILENAMES = tf.io.gfile.glob(GCS_DS_PATH+'/test_tfrecords/*.tfrec') #

print("#"*5)
print(TRAINING_FILENAMES)
print(TEST_FILENAMES)

In [None]:

# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)




In [None]:
### Flexible Options to try 

batch_size = 4* strategy.num_replicas_in_sync
AUTO = tf.data.experimental.AUTOTUNE
folds = 5
image_size = 128  # We'll resize input images to this size
input_shape = (image_size, image_size,3)
num_epochs = 10


#Seed Everything
SEED = 42
print(f'setting everything to seed {SEED}')
random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
#### Learning Rate 

LR_START = 0.00001
LR_MAX = 0.00005 * strategy.num_replicas_in_sync
LR_MIN = 0.00001
LR_RAMPUP_EPOCHS = 5
LR_SUSTAIN_EPOCHS = 0
LR_EXP_DECAY = .8

def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr
    
lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose = True)

In [None]:
LABELED_TFREC_FORMAT = {
    'CVC - Abnormal': tf.io.FixedLenFeature([], tf.int64),
    'CVC - Borderline': tf.io.FixedLenFeature([], tf.int64),
    'CVC - Normal': tf.io.FixedLenFeature([], tf.int64),
    'ETT - Abnormal': tf.io.FixedLenFeature([], tf.int64),
    'ETT - Borderline': tf.io.FixedLenFeature([], tf.int64),
    'ETT - Normal': tf.io.FixedLenFeature([], tf.int64),
    'NGT - Abnormal': tf.io.FixedLenFeature([], tf.int64),
    'NGT - Borderline': tf.io.FixedLenFeature([], tf.int64),
    'NGT - Incompletely Imaged': tf.io.FixedLenFeature([], tf.int64),
    'NGT - Normal': tf.io.FixedLenFeature([], tf.int64),
    'Swan Ganz Catheter Present': tf.io.FixedLenFeature([], tf.int64),
    'image': tf.io.FixedLenFeature([], tf.string),
}


UNLABELED_TFREC_FORMAT = {
    'image': tf.io.FixedLenFeature([], tf.string),
    'StudyInstanceUID': tf.io.FixedLenFeature([], tf.string)
}

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.image.resize(image, [image_size, image_size])
    image = tf.reshape(image, [image_size, image_size, 3])
    return image


def read_labeled_tfrecord(example):
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    #     label = tf.cast(example['target'], tf.int32)
    labels = [  # Edit this to add whatever labels you want your model to predict
        example['ETT - Abnormal'],
        example['ETT - Borderline'],
        example['ETT - Normal'],
        example['NGT - Abnormal'],
        example['NGT - Borderline'],
        example['NGT - Incompletely Imaged'],
        example['NGT - Normal'],
        example['CVC - Abnormal'],
        example['CVC - Borderline'],
        example['CVC - Normal'],
        example['Swan Ganz Catheter Present'],
    ]
    labels = tf.dtypes.cast(labels, tf.int32)

    return image, labels  # returns a dataset of (image, label) pairs


def read_unlabeled_tfrecord(example):
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['StudyInstanceUID']
    return image, idnum  # returns a dataset of image(s)


def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False  # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames,
                                      num_parallel_reads=AUTO)  # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order)  # use data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord,
                          num_parallel_calls=AUTO)  # returns a dataset of (image, label) pairs if labeled = True or (image, id) pair if labeld = False
    return dataset


def get_training_dataset(dataset, do_aug=True):
    dataset = dataset.repeat()  # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO)  # prefetch next batch while training (autotune prefetch buffer size)
    return dataset


def get_validation_dataset(dataset, do_onehot=True):
    dataset = dataset.batch(batch_size)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO)
    return dataset


def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    #     dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO)  # prefetch next batch while training (autotune prefetch buffer size)
    return dataset


def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)


NUM_TRAINING_IMAGES = int(count_data_items(TRAINING_FILENAMES) * (folds - 1.) / folds)
NUM_VALIDATION_IMAGES = int(count_data_items(TRAINING_FILENAMES) * (1. / folds))
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // batch_size

print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES,
                                                                                           NUM_VALIDATION_IMAGES,
                                                                                           NUM_TEST_IMAGES))

In [None]:
######Training and Saving Weights For Inference for each Fold
trn_dict = {}
val_dict = {}

kfold = KFold(folds, shuffle=True, random_state=42)
for f, (trn_ind, val_ind) in enumerate(kfold.split(TRAINING_FILENAMES)):
    trn_dict[f] = trn_ind
    val_dict[f] = val_ind

print(trn_dict)
print(val_dict)


def train(fold=0):
    with strategy.scope():
        transformer = vit.vit_l32(
            image_size=image_size,
            include_top=False,
            pretrained_top = False,
            weights="imagenet21k+imagenet2012",
        )

        model = tf.keras.Sequential([
            transformer,
            tf.keras.layers.Dense(11, activation='sigmoid')
        ])
        model.compile(
            optimizer=tf.keras.optimizers.Adam(),
            loss='binary_crossentropy',
            metrics=[tf.keras.metrics.AUC(multi_label=True)])
    print(model.summary())

    checkpoint_filepath = f'./fold{fold}vit.h5'
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_auc",
        save_best_only=True,
        save_weights_only=True,
        mode='max'
    )
    train_dataset = load_dataset(
        list(pd.DataFrame({'TRAINING_FILENAMES': TRAINING_FILENAMES}).loc[trn_dict[fold]]['TRAINING_FILENAMES']),
        labeled=True)
    val_dataset = load_dataset(
        list(pd.DataFrame({'TRAINING_FILENAMES': TRAINING_FILENAMES}).loc[val_dict[fold]]['TRAINING_FILENAMES']),
        labeled=True, ordered=True)
    history = model.fit(
                        get_training_dataset(train_dataset),
                        steps_per_epoch = STEPS_PER_EPOCH,
                        epochs=num_epochs,
                        callbacks = [checkpoint_callback, lr_callback], #model_checkpoint_callback
                        validation_data = get_validation_dataset(val_dataset),
                        verbose=1
                        )
    return history

history = train(0)