### Colab Tpu Start
#### Data Prepare
1. Stratified GroupKFold TFRecords https://www.kaggle.com/shigengtian/stratified-groupkfold-tfrecords
2. upload tfrecords file to gcp storage
3. upload train_folds.csv to colab

##### auth gcp

In [None]:
from google.colab import auth
auth.authenticate_user()
project_id = 'xxxxxx'
!gcloud config set project {project_id}

### tpu init

In [None]:
%tensorflow_version 2.x
import tensorflow as tf
print("Tensorflow version " + tf.__version__)

try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  # print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)

In [None]:
import pandas as pd

In [None]:
df_fold = pd.read_csv("train_folds.csv")
len(df_fold[df_fold["fold"]!=0])

In [None]:
gcs_pattern = 'gs://xxxx/ranzcr/0_train*'
filenames = tf.io.gfile.glob(gcs_pattern)

In [None]:
filenames

In [None]:
import re
import numpy as np
import pandas as pd
import tensorflow as tf

EPOCHS = 100
device = "TPU"
batch_size = 8 * tpu_strategy.num_replicas_in_sync
IMAGE_SIZE = (600, 600)
start_lr = 0.0001
min_lr = 0.000001
max_lr = 0.00005 * tpu_strategy.num_replicas_in_sync
# max_lr = 0.00005
rampup_epochs = 5
sustain_epochs = 0
exp_decay = .6

AUTO = tf.data.experimental.AUTOTUNE  # optimize different parts of input loading.
df_fold = pd.read_csv("train_folds.csv")


train_fns = tf.io.gfile.glob('gs://xxx/ranzcr/0_train*') # change path
validation_fns = tf.io.gfile.glob('gs://xxx/ranzcr/0_val*') # change path



def parse_tfrecord(example):
    columns = df_fold.columns
    features = {}
    byte_features = ['StudyInstanceUID', 'PatientID', 'images']
    for column in columns:
        if column in byte_features:
            features[column] = tf.io.FixedLenFeature([], tf.string)
        else:
            features[column] = tf.io.FixedLenFeature([], tf.int64)

    # print(feature)
    features["images"] = tf.io.FixedLenFeature([], tf.string)
    example = tf.io.parse_single_example(example, features)
    image = tf.image.decode_jpeg(example['images'], channels=3)
    image = tf.image.resize(image, (IMAGE_SIZE))
    label = []
    out_label_column = ['StudyInstanceUID', 'PatientID', 'images', 'fold']
    for column in columns:
        if column not in out_label_column:
            label.append(example[column])
    label = tf.stack(label)
    return image, label


def load_dataset(filenames):
    # Read from TFRecords. For optimal performance, we interleave reads from multiple files.
    records = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    return records.map(parse_tfrecord, num_parallel_calls=AUTO)

def load_dataset(filenames):
  # Read from TFRecords. For optimal performance, we interleave reads from multiple files.
  records = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
  return records.map(parse_tfrecord, num_parallel_calls=AUTO)

def get_training_dataset():
    dataset = load_dataset(train_fns)
    def data_augment(img, one_hot_class):
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_flip_up_down(img)
        img = tf.image.adjust_brightness(img, 0.1)
        img = tf.image.random_contrast(img, 0.9, 1)
        img = tf.image.random_saturation(img, 0.9, 1)
        # img = tf.image.per_image_standardization(img)
        return img, one_hot_class

    augmented = dataset.map(data_augment, num_parallel_calls=AUTO)

    # Prefetch the next batch while training (autotune prefetch buffer size).
    return augmented.repeat().shuffle(10000).batch(batch_size).prefetch(AUTO)


def get_validate_dataset():
    dataset = load_dataset(validation_fns)
    def data_augment(img, one_hot_class):
        return img, one_hot_class

    augmented = dataset.map(data_augment, num_parallel_calls=AUTO)
    return augmented.batch(batch_size).prefetch(AUTO)


training_dataset = get_training_dataset()
validation_dataset = get_validate_dataset()

def get_dataset_iterator(dataset, n_examples):
    return dataset.unbatch().batch(n_examples).as_numpy_iterator()

def create_model():
    pretrained_model = tf.keras.applications.EfficientNetB5(input_shape=[*IMAGE_SIZE, 3], include_top=False, drop_connect_rate=0.2) ## chose a model
    pretrained_model.trainable = True
    model = tf.keras.Sequential([
        pretrained_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(11, activation='sigmoid')
    ])
    model.compile(
        optimizer='adam',
        loss='binary_crossentropy',
        metrics=[tf.keras.metrics.AUC(multi_label=True)]
    )
    return model


if device == "TPU":
    with tpu_strategy.scope():  # creating the model in the TPUStrategy scope means we will train the model on the TPU
        print("TPU")
        model = create_model()
else:
    print("GPU")
    model = create_model()

### train

In [None]:
def lrfn(epoch):
  if epoch < rampup_epochs:
    return (max_lr - start_lr)/rampup_epochs * epoch + start_lr
  elif epoch < rampup_epochs + sustain_epochs:
    return max_lr
  else:
    return (max_lr - min_lr) * exp_decay**(epoch-rampup_epochs-sustain_epochs) + min_lr
    
lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda epoch: lrfn(epoch), verbose=True)

rlr = tf.keras.callbacks.ReduceLROnPlateau(monitor = 'val_loss', factor = 0.2, patience = 3, verbose = 0, 
                                min_delta = 1e-4, min_lr = 1e-6, mode = 'min')

ckp = tf.keras.callbacks.ModelCheckpoint('/content/efn5_v3.h5',monitor = 'val_loss',
                      verbose = 0, save_best_only = True, mode = 'min')


es = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', min_delta = 1e-6, patience = 3, mode = 'min', 
                    restore_best_weights = False, verbose = 0)

history = model.fit(training_dataset,
                    validation_data=validation_dataset,
                    steps_per_epoch=24080//batch_size,
                    epochs=EPOCHS,
                    callbacks=[rlr, ckp, es]
                    )

hist_df = pd.DataFrame(history.history) 