### Set up environment

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

from google.colab import drive
drive.mount('/content/drive')

!ln -s /content/drive/My\ Drive/Colab\ Notebooks/hku-oph/* /content/
%cd /content/src

!pip install scikit-plot
%matplotlib inline

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in 

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
        
# ML libraries required
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.metrics import KappaScore # solution evaluated with qudratic kappa
from fastai.tabular import * # for ensemble model training
import torch
# efficientnet is not integrated into fastai yet


# Other libraries required
import matplotlib.pyplot as plt
from models.efficientnet_pytorch import EfficientNet

# garbage collector
import gc

import random
from datetime import datetime

def seed_everything(seed_value, use_cuda=True):
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    np.random.seed(seed_value) # cpu vars
    torch.manual_seed(seed_value) # cpu  vars
    random.seed(seed_value) # Python
    if use_cuda: 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value) # gpu vars
        torch.backends.cudnn.deterministic = True  #needed
        torch.backends.cudnn.benchmark = False
    
seed_everything(42, True)

from sklearn.model_selection import StratifiedKFold

### Define model details

In [None]:
import json
models_config = {}
only_train_model = None
with open('fine_tune_config.json') as f:
    models_config = json.load(f)
    meta_config = models_config["meta"]
    ensemble_config = models_config["ensemble"]
    models_config = models_config["models"]
    
if "only_train_model" in meta_config:
    only_train_model = meta_config["only_train_model"]

### Import preprocessing modules

In [None]:
%run preprocessing.ipynb

### Load data

In [None]:
class PreProcessCommonWrapper(object):
    def __init__(self, image_dim):
        self.image_dim = image_dim
        self.__name__ = "PreProcessCommonWrapper"
        self.__annotations__ = {}
    def __call__(self, t): # the function formerly known as "bar"
        return contrast_and_crop(t, self.image_dim)

def get_df():
    base_dir = os.path.join('../', 'datasets/hkuretinopathydataset/')
    train_dir = os.path.join('../', 'datasets/hkuretinopathydataset/')
    
    df = pd.read_csv(os.path.join(base_dir, 'images_2_fields.csv'))
    
    # remap the diagnosis axis
    df.loc[df.diagnosis == 3, "diagnosis"] = 4
    df.loc[df.diagnosis == 2.5, "diagnosis"] = 3
    return df, base_dir, train_dir
    
def load_data(model_config, train_index, val_index):
    current_model_config = model_config
    print(current_model_config)
    df, base_dir, train_dir = get_df()
    df['path'] = df['path'].map(lambda x: os.path.join(train_dir, '/'.join(x.split('/')[-2:])))
    # df = df.sample(frac=1, random_state=42)
    src = ImageList.from_df(df=df, path = './', cols='path') \
                   .split_by_idxs(train_index, val_index) \
                   .label_from_df(cols='diagnosis', label_cls=FloatList)  # although labels are in integer form, they are intepreted as Float for training purposes
                   
            
    transformations = get_transforms(do_flip=True,flip_vert=True,max_rotate=360,max_warp=0,max_zoom=1.3,max_lighting=0.1,p_lighting=0.5)
    
    # custom pre-processing (contrast and crop)
    pre_process_common_wrapper = PreProcessCommonWrapper(model_config["image_dim"])
    pre_process_ccs = [TfmPixel(pre_process_common_wrapper)()]
    advprop = model_config["advprop"]
    if advprop:
        pre_process_ccs.append(TfmPixel(advprop_normalise)())
    # apply transformations to training set, but apply the pre_process to train and valid set
    tfms = [transformations[0] + pre_process_ccs, transformations[1] + pre_process_ccs]
    
    # transform data sets
    data = src.transform(tfms, size=model_config["image_dim"], resize_method=ResizeMethod.CROP,padding_mode='zeros',) \
              .databunch(bs=model_config["batch_size"], num_workers=1) \
              .normalize(imagenet_stats if not advprop else None) # default normalise with imagenet stats, prebuilt into fast.ai library    
    # visualise this batch
    # data.show_batch(rows=3, figsize=(10,10), ds_type=DatasetType.Valid)

    print("loaded data")
    return (df, data)

### Helper functions in loading models

In [None]:
def getModel(model_name, data, model_dir=None, advprop=False, **kwargs):
    from os.path import abspath
    if model_dir is not None:
        model_dir = abspath(model_dir)
    model = EfficientNet.from_pretrained(model_name, advprop=advprop)
    model._fc = nn.Linear(model._fc.in_features,data.c) # .c returns number of output cells, ._fc returns the module
    return model

def get_learner(model_name, data, model_dir="models/"):
    return Learner(data, getModel(model_name, data, model_dir=model_dir), metrics = [quadratic_kappa]) \
           .mixup() \
           .to_fp16() 

# quadratic kappa score
from sklearn.metrics import cohen_kappa_score
def quadratic_kappa(y_hat, y):
    y_hat = y_hat.cpu()
    y = y.cpu()
    return torch.tensor(cohen_kappa_score(torch.round(y_hat), y, weights='quadratic'),device='cuda:0')


In [None]:
# lets visualise what we have got
import warnings 
warnings.filterwarnings("ignore")
# df, data = load_data(models_config["effnet-b3"], range(0,800), range(801,1000))
# data.show_batch(rows=3, figsize=(10,10))

### Train main models routine

In [None]:
def train_net(fold_n, train_index, val_index):
    valid_predictions = {}
    valid_labels = []
    ensemble_predictions = {}
    ensemble_labels = []
    for config_name in models_config:
        print(f"---- TRAINING STARTING FOR {config_name} ----")
        config = models_config[config_name]
        df, data = load_data(config, train_index, val_index)
        learner = get_learner(config["pretrained_name"], data, model_dir=config["pretrained_path"])
        # check if a previous training on this fold has been done
        path = config["pretrained_path"]
        onlyfiles = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]
        using_old_fold = False

        for fn in onlyfiles:
            temp = fn.replace('.pth', '').split('_')
            if temp[0] == config_name and temp[1] == "fold" and temp[2] == str(fold_n):
                print(f"-- LOADED PREVIOUS TRAINED FOLD #{fold_n} FOR {config_name} --")
                learner.load(f'{config_name}_fold_{fold_n}')  # load trained weights
                using_old_fold = True
                break

        if not using_old_fold:
          print(f"-- LOADED PREVIOUS APTOS PRETRAINED FOR {config_name} --")
          learner.load(f'{config_name}_best')  # load trained weights

        lr = config["lr"]
        # learner.lr_find()
        # learner.recorder.plot() 
        # break
        
        # training starts here 
        if not using_old_fold:  # only retrain if fold not trained before
            learner.fit_one_cycle(
                config["epoch_n"], 
                lr,
                callbacks=[SaveModelCallback(learner, every='improvement', monitor='valid_loss', name=f'{config_name}_fold_{fold_n}')]
            )
        
        # perform prediction on 
        # ensemble_predictions[config_name], ensemble_labels = learner.get_preds(DatasetType.Train) # Using train set to train the ensemble --> when we have more data we should train it with valid data
        # ensemble_predictions[config_name] = ensemble_predictions[config_name].flatten().tolist()
        ensemble_predictions = None
        ensemble_labels = None
        valid_predictions[config_name], valid_labels = learner.get_preds(DatasetType.Valid)
        valid_predictions[config_name] = valid_predictions[config_name].flatten().tolist()

        learner.show_results()

        # vp = pd.DataFrame.from_dict({"value": valid_predictions[config_name]})
        # vp['ensemble'] = np.mean(vp, axis=1)
        # vp["diagnosis"] = valid_labels
        # get_and_show_ROC(vp, fold_n)
        del learner
        gc.collect()
    return (ensemble_predictions, ensemble_labels, valid_predictions, valid_labels)


### Ensemble layer

In [None]:
def train_ensemble_average(fold_n, ensemble_predictions, ensemble_labels):
    ensemble_predictions["diagnosis"] = ensemble_labels
    df = pd.DataFrame(ensemble_predictions)
    procs = [Normalize]
    model_names = models_config.keys()
    ensemble_bunch = TabularList.from_df(df=df, cont_names=model_names, procs=procs) \
                                .split_by_rand_pct() \
                                .label_from_df(cols='diagnosis') \
                                .databunch()
    
    learner = tabular_learner(ensemble_bunch, layers=[100, 50], ps=[0.5,0.2], metrics=[quadratic_kappa]) 

    lr = ensemble_config["lr"]
    epoch_n =ensemble_config["epoch_n"]
    learner.fit_one_cycle(
        epoch_n, 
        lr, 
        callbacks=[SaveModelCallback(learner, every='improvement', monitor='valid_loss', name=f'ensemble_{fold_n}')]
    )

def simple_avg_ensemble(valid_predictions, valid_labels):
    valid_predictions = pd.DataFrame.from_dict(valid_predictions)
    valid_predictions['ensemble'] = np.mean(valid_predictions, axis=1)
    valid_predictions["diagnosis"] = valid_labels
    return valid_predictions

### ROC curve, sensitivity and specificity

In [None]:
%matplotlib inline
from sklearn.metrics import roc_curve, auc
from itertools import cycle
from scipy import interpolate
import matplotlib.pyplot as plt
import scikitplot as skplt

def compress_predictions(df):
    # map predictions to 0 to 1, preserving ratio in between
    df.loc[df.ensemble < 0, "ensemble"] = 0
    df.loc[df.ensemble > 4, "ensemble"] = 4
    df = df.div(4)
    return df

def compress_labels(df):
    # map labels to 0 or 1, 1 denoting requiring referral
    df.loc[df.diagnosis < 2, "referral"] = 0
    df.loc[df.diagnosis >= 2, "referral"] = 1
    df.referral = df.referral.astype(int)
    return df

def get_and_show_ROC(df, fold_n):
    predictions = compress_predictions(df).ensemble
    labels = compress_labels(df).referral
    fpr, tpr, threshold = roc_curve(labels, predictions)
    auc_scaler = auc(fpr, tpr)
    print(f"AUC: {auc_scaler}")
    # TPR = sensitivity, FPR = 1 - specificity
    cutoff_df = pd.DataFrame({
        "sensitivity": tpr,
        "specificity": 1 - fpr,
        "threshold": threshold,
        "auc": auc_scaler
    })
    skplt.metrics.plot_roc(labels, pd.DataFrame({"no": 1 - predictions, "yes": predictions}))
    plt.savefig(f'roc_fold_{fold_n}.png')
    plt.show()
    return (cutoff_df, auc_scaler)

### 2-field combined inference

In [None]:
def integrate_2_fields(valid_predictions, valid_labels):
    valid_predictions = pd.DataFrame.from_dict(valid_predictions)
    valid_predictions['ensemble'] = np.mean(valid_predictions, axis=1)
    valid_predictions["diagnosis"] = valid_labels

    splits = np.array_split(valid_predictions, 2)
    splits[0]["second_field"] = list(splits[1]["ensemble"])
    results = splits[0].rename(columns={"ensemble": "first_field"})
    results["diff"] = abs(results["first_field"] - results["second_field"])

    # take simple average
    results['ensemble'] = (results['first_field'] + results['second_field']) / 2

    return results

### Entrypoint

In [None]:
is_integrate_2_fields = True

if __name__ == '__main__':
    # get folds
    import warnings
    warnings.filterwarnings("ignore")
    df = get_df()[0]
    skf = StratifiedKFold(n_splits=meta_config["n_splits"])
    i = 0
    for train_index, val_index in skf.split(df.index, df['diagnosis']):
        # skip to some folds
        # if i != 4:
        #   i += 1
        #   continue
        print(f'------ FOLD {i} ------')
        print("-- TRAINING INDICES --")
        print(train_index)
        print("-- VALIDATION INDICES --")
        print(val_index)
        ensemble_predictions, ensemble_labels, valid_predictions, valid_labels = train_net(i, train_index, val_index)
        # take a simple average for now
        if is_integrate_2_fields:
            valid_predictions_labels = integrate_2_fields(valid_predictions, valid_labels)
            valid_predictions_labels.to_csv(f'models/predictions_fold_{i}')
        else:
            valid_predictions_labels = simple_avg_ensemble(valid_predictions, valid_labels)
        cutoff_df, auc_scaler = get_and_show_ROC(valid_predictions_labels, i)
        cutoff_df.to_csv(f'models/cutoffs_fold_{i}')
        # train_ensemble_average(i, ensemble_predictions, ensemble_labels)
        # evaluate the senstivity and specificity of this fold
        i += 1