In [None]:
# default_exp callbacks

# callbacks
> The callback classes used in network training. Mainly for early stopping techniques.

In [None]:
#export
from sklearn.metrics import roc_auc_score, precision_recall_curve
from sklearn.metrics import auc as calculate_auc
from sklearn.metrics import mean_squared_error
from sklearn.metrics import accuracy_score
import tensorflow as tf
import os
import numpy as np

## Some helper functions

### R2 Score

First we define an `r2_score` between two vectors, which is the squared Pearson correlation, and which in turn is calculated as

$$r = \frac{\sum (x - m_x) (y - m_y)}{\sqrt{\sum (x - m_x)^2 (y - m_y)^2}}$$

In [None]:
#export
from scipy.stats.stats import pearsonr
def r2_score(x,y):
    "Squared Pearson Correlation"
    pcc, _ = pearsonr(x,y)
    return pcc**2

In [None]:
a = np.array([0, 0, 0, 1, 1, 1, 1])
b = np.arange(7)

r, _ = pearsonr(a, b)

np.testing.assert_almost_equal(r**2, r2_score(a, b))

### Precision-AUC Score

*This part is adapted from the official `scikit-learn` [documentation](https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html).*

The `precision_recall_curve` computes the precision-recall pairs for different probability thresholds in binary classification tasks. It takes an array of true values, `y_true`, and another array of estimated probabilities calculated using the decision function.

The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
true positives and ``fp`` the number of false positives. The precision is
intuitively the ability of the classifier not to label as positive a sample
that is negative.

The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
true positives and ``fn`` the number of false negatives. The recall is
intuitively the ability of the classifier to find all the positive samples.

In [None]:
y_true = np.array([0, 0, 1, 1])
y_scores = np.array([0.1, 0.4, 0.35, 0.8])
precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
precision, recall, thresholds

The `calculate_auc` function simply calculates the area under curve

In [None]:
calculate_auc(recall, precision)

In [None]:
#export
def prc_auc_score(y_true, y_score):
    "Precision-Area under curve Score"
    precision, recall, threshold  = precision_recall_curve(y_true, y_score)
    auc = calculate_auc(recall, precision)
    return auc

In [None]:
prc_auc_score(y_true, y_scores)

In [None]:
np.testing.assert_almost_equal(calculate_auc(recall, precision), prc_auc_score(y_true, y_scores))

## Regression

In [None]:
#export
class Reg_EarlyStoppingAndPerformance(tf.keras.callbacks.Callback):
    "The callback class used in regression problems."
    def __init__(self, train_data, valid_data, y_scaler, MASK = -1e10, patience=5, criteria = 'val_loss', verbose = 0):
        """
        y_scaler: None, sklearn MinMaxScaler, or StandardScaler
        """
        super(Reg_EarlyStoppingAndPerformance, self).__init__()
        
        assert criteria in ['val_loss', 'val_r2'], 'not support %s ! only %s' % (criteria, ['val_loss', 'val_r2'])
        self.x, self.y  = train_data
        self.x_val, self.y_val = valid_data
        self.y_scaler = y_scaler
        
        self.history = {'loss':[],
                        'val_loss':[],
                        
                        'rmse':[],
                        'val_rmse':[],
                        
                        'r2':[],
                        'val_r2':[],
                        
                        'epoch':[]}
        self.MASK = MASK
        self.patience = patience
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None
        self.criteria = criteria
        self.best_epoch = 0
        self.verbose = verbose
        
    def rmse(self, y_true, y_pred, inner_y_true = True):
        
        if self.y_scaler != None:
            if inner_y_true:
                y_pred = self.y_scaler.inverse_transform(y_pred)
                y_true = self.y_scaler.inverse_transform(y_true)
            else:
                y_pred = self.y_scaler.inverse_transform(y_pred)
       
        N_classes = y_pred.shape[1]
        rmses = []
        for i in range(N_classes):
            y_pred_one_class = y_pred[:,i]
            y_true_one_class = y_true[:, i]
            mask = ~(y_true_one_class == self.MASK)
            mse = mean_squared_error(y_true_one_class[mask], y_pred_one_class[mask])
            rmse = np.sqrt(mse)
            rmses.append(rmse)
        return rmses   
    
    
    def r2(self, y_true, y_pred, inner_y_true = True):
        if self.y_scaler != None:
            if inner_y_true:
                y_pred = self.y_scaler.inverse_transform(y_pred)
                y_true = self.y_scaler.inverse_transform(y_true)
            else:
                y_pred = self.y_scaler.inverse_transform(y_pred)
                
        N_classes = y_pred.shape[1]
        r2s = []
        for i in range(N_classes):
            y_pred_one_class = y_pred[:,i]
            y_true_one_class = y_true[:, i]
            mask = ~(y_true_one_class == self.MASK)
            r2 = r2_score(y_true_one_class[mask], y_pred_one_class[mask])
            r2s.append(r2)
        return r2s   
    
        
    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as infinity.
        if self.criteria == 'val_loss':
            self.best = np.Inf  
        else:
            self.best = -np.Inf

 
        
    def on_epoch_end(self, epoch, logs={}):
        
        y_pred = self.model.predict(self.x)
        rmse_list = self.rmse(self.y, y_pred)
        rmse_mean = np.nanmean(rmse_list)
        
        r2_list = self.r2(self.y, y_pred) 
        r2_mean = np.nanmean(r2_list)
        
        
        y_pred_val = self.model.predict(self.x_val)
        rmse_list_val = self.rmse(self.y_val, y_pred_val)        
        rmse_mean_val = np.nanmean(rmse_list_val)
        
        r2_list_val = self.r2(self.y_val, y_pred_val)       
        r2_mean_val = np.nanmean(r2_list_val)        
        
        self.history['loss'].append(logs.get('loss'))
        self.history['val_loss'].append(logs.get('val_loss'))
        
        self.history['rmse'].append(rmse_mean)
        self.history['val_rmse'].append(rmse_mean_val)
        
        self.history['r2'].append(r2_mean)
        self.history['val_r2'].append(r2_mean_val)        
        
        self.history['epoch'].append(epoch)
        
        
        # logs is a dictionary
        eph = str(epoch+1).zfill(4)   
        loss = '{0:.4f}'.format((logs.get('loss')))
        val_loss = '{0:.4f}'.format((logs.get('val_loss')))
        rmse = '{0:.4f}'.format(rmse_mean)
        rmse_val = '{0:.4f}'.format(rmse_mean_val)
        r2_mean = '{0:.4f}'.format(r2_mean)
        r2_mean_val = '{0:.4f}'.format(r2_mean_val)
        
        if self.verbose:
            print('\repoch: %s, loss: %s - val_loss: %s; rmse: %s - rmse_val: %s;  r2: %s - r2_val: %s' % (eph,
                                                                                                           loss, val_loss, 
                                                                                                           rmse,rmse_val,
                                                                                                           r2_mean,r2_mean_val),
                  end=100*' '+'\n')


        if self.criteria == 'val_loss':
            current = logs.get(self.criteria)
            if current <= self.best:
                self.best = current
                self.wait = 0
                # Record the best weights if current results is better (less).
                self.best_weights = self.model.get_weights()
                self.best_epoch = epoch

            else:
                self.wait += 1
                if self.wait >= self.patience:
                    self.stopped_epoch = epoch
                    self.model.stop_training = True
                    print('\nRestoring model weights from the end of the best epoch.')
                    self.model.set_weights(self.best_weights)    
                    
        else:
            current = np.nanmean(r2_list_val)
            
            if current >= self.best:
                self.best = current
                self.wait = 0
                # Record the best weights if current results is better (less).
                self.best_weights = self.model.get_weights()
                self.best_epoch = epoch

            else:
                self.wait += 1
                if self.wait >= self.patience:
                    self.stopped_epoch = epoch
                    self.model.stop_training = True
                    print('\nRestoring model weights from the end of the best epoch.')
                    self.model.set_weights(self.best_weights)              
    
    def on_train_end(self, logs=None):
        self.model.set_weights(self.best_weights)
        if self.stopped_epoch > 0:
            print('\nEpoch %05d: early stopping' % (self.stopped_epoch + 1))

                
    def evaluate(self, testX, testY):
        """evalulate, return rmse and r2"""
        y_pred = self.model.predict(testX)
        rmse_list = self.rmse(testY, y_pred, inner_y_true = False)
        r2_list = self.r2(testY, y_pred, inner_y_true = False)
        return rmse_list, r2_list

## Classification

In [None]:
#export
class CLA_EarlyStoppingAndPerformance(tf.keras.callbacks.Callback):
    "The callback class used in classification problems."
    def __init__(self, train_data, valid_data, MASK = -1, patience=5, criteria = 'val_loss', metric = 'ROC', last_avf = None, verbose = 0):
        super(CLA_EarlyStoppingAndPerformance, self).__init__()
        
        sp = ['val_loss', 'val_auc']
        assert criteria in sp, 'not support %s ! only %s' % (criteria, sp)
        self.x, self.y  = train_data
        self.x_val, self.y_val = valid_data
        self.last_avf = last_avf
        
        self.history = {'loss':[],
                        'val_loss':[],
                        'auc':[],
                        'val_auc':[],
                        
                        'epoch':[]}
        self.MASK = MASK
        self.patience = patience
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None
        self.criteria = criteria
        self.metric = metric
        self.best_epoch = 0
        self.verbose = verbose
        
    def sigmoid(self, x):
        s = 1/(1+np.exp(-x))
        return s

    
    def roc_auc(self, y_true, y_pred):
        if self.last_avf == None:
            y_pred_logits = self.sigmoid(y_pred)
        else:
            y_pred_logits = y_pred
            
        N_classes = y_pred_logits.shape[1]

        aucs = []
        for i in range(N_classes):
            y_pred_one_class = y_pred_logits[:,i]
            y_true_one_class = y_true[:, i]
            mask = ~(y_true_one_class == self.MASK)
            try:
                if self.metric == 'ROC':
                    auc = roc_auc_score(y_true_one_class[mask], y_pred_one_class[mask]) #ROC_AUC
                elif self.metric == 'PRC': 
                    auc = prc_auc_score(y_true_one_class[mask], y_pred_one_class[mask]) #PRC_AUC
                elif self.metric == 'ACC':
                    auc = accuracy_score(y_true_one_class[mask], np.round(y_pred_one_class[mask])) #ACC
            except:
                auc = np.nan
            aucs.append(auc)
        return aucs  
    
        
        
    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as infinity.
        if self.criteria == 'val_loss':
            self.best = np.Inf  
        else:
            self.best = -np.Inf
            

        
 
        
    def on_epoch_end(self, epoch, logs={}):
        
        y_pred = self.model.predict(self.x)
        roc_list = self.roc_auc(self.y, y_pred)
        roc_mean = np.nanmean(roc_list)
        
        y_pred_val = self.model.predict(self.x_val)
        roc_val_list = self.roc_auc(self.y_val, y_pred_val)        
        roc_val_mean = np.nanmean(roc_val_list)
        
        self.history['loss'].append(logs.get('loss'))
        self.history['val_loss'].append(logs.get('val_loss'))
        self.history['auc'].append(roc_mean)
        self.history['val_auc'].append(roc_val_mean)
        self.history['epoch'].append(epoch)
        
        
        eph = str(epoch+1).zfill(4)        
        loss = '{0:.4f}'.format((logs.get('loss')))
        val_loss = '{0:.4f}'.format((logs.get('val_loss')))
        auc = '{0:.4f}'.format(roc_mean)
        auc_val = '{0:.4f}'.format(roc_val_mean)    
        
        if self.verbose:
            if self.metric == 'ACC':
                print('\repoch: %s, loss: %s - val_loss: %s; acc: %s - val_acc: %s' % (eph,
                                                                                   loss, 
                                                                                   val_loss, 
                                                                                   auc,
                                                                                   auc_val), end=100*' '+'\n')

            else:
                print('\repoch: %s, loss: %s - val_loss: %s; auc: %s - val_auc: %s' % (eph,
                                                                                   loss, 
                                                                                   val_loss, 
                                                                                   auc,
                                                                                   auc_val), end=100*' '+'\n')


        if self.criteria == 'val_loss':
            current = logs.get(self.criteria)
            if current <= self.best:
                self.best = current
                self.wait = 0
                # Record the best weights if current results is better (less).
                self.best_weights = self.model.get_weights()
                self.best_epoch = epoch

            else:
                self.wait += 1
                if self.wait >= self.patience:
                    self.stopped_epoch = epoch
                    self.model.stop_training = True
                    print('\nRestoring model weights from the end of the best epoch.')
                    self.model.set_weights(self.best_weights)    
                    
        else:
            current = roc_val_mean
            if current >= self.best:
                self.best = current
                self.wait = 0
                # Record the best weights if current results is better (less).
                self.best_weights = self.model.get_weights()
                self.best_epoch = epoch

            else:
                self.wait += 1
                if self.wait >= self.patience:
                    self.stopped_epoch = epoch
                    self.model.stop_training = True
                    print('\nRestoring model weights from the end of the best epoch.')
                    self.model.set_weights(self.best_weights)              
    
    def on_train_end(self, logs=None):
        self.model.set_weights(self.best_weights)
        if self.stopped_epoch > 0:
            print('\nEpoch %05d: early stopping' % (self.stopped_epoch + 1))

        
    def evaluate(self, testX, testY):
        
        y_pred = self.model.predict(testX)
        roc_list = self.roc_auc(testY, y_pred)
        return roc_list            

            