This module contains several useful functions that measure the performance of a model.

In [None]:
import numpy as np
import h5py
import os
from pickle import dump, load
from sklearn.metrics import precision_recall_curve, confusion_matrix, roc_curve, auc
import seaborn as sns
from tqdm.notebook import trange, tqdm

import tensorflow as tf

In [None]:
if __name__ == '__main__' and '__file__' not in globals():
  %run Code/Final/DataPreparation.ipynb

In [None]:
'''
INPUT
data: scaled data with input shape that depends on the model
model: an autoencoder which will be used to reconstruct the input data
error_fun: Mean Square Error or Mean Absolute Error
OUTPUT
errors: The reconstruction errors of the input data using the given autoencoder
'''
def find_errors(data, model, error_fun='mse'):

  decoded = model.predict(data)

  # errors = None
  # if error_fun == 'mse':
  #   errors = tf.keras.losses.MSE(data, decoded)
  # elif error_fun == 'mae':
  #   errors = tf.keras.losses.MAE(data, decoded)
  # return errors.numpy()

  if error_fun == 'mse':
    errors = np.mean((data - decoded)**2, axis=(1,2))
  elif error_fun == 'mae':
    errors = np.mean(np.abs(data - decoded), axis=(1,2))
  
  return errors


In [None]:
'''
compute the q-th quantile of the errors which serves as our
threshold to identify anomalies -- any data point that our model
reconstructed with > threshold error will be marked as an outlier
INPUT
data: scaled data with input shape that depends on the model
model: an autoencoder which will be used to reconstruct the input data
quantile: Above which quantile should the reconstruction error be in order to be considered an anomaly
error_fun: Mean Square Error or Mean Absolute Error
OUTPUT
anomaly_idxs: The indices of anomlous data
'''
def find_anomalies(data, model, quantile, error_fun='mse'):
  errors = find_errors(data, model, error_fun)
  thresh = np.quantile(errors, quantile)
  anomaly_idxs = np.where(np.array(errors) >= thresh)[0]
  return anomaly_idxs

In [None]:
'''
(Not used in the final model)
This is for a specific recording (data) with specific quality_file. It doesn't take into consideration historical data
leftovers: How many data points do not have a corresponding quality label (obsolete)
'''
def performance(data, leftovers, model, quantile, quality_file, path, error_fun='mse', snippet_len=125):

  errors = find_errors(data, model, error_fun)
  thresh = np.quantile(errors, quantile)

  # idxs = predicted
  anomaly_idxs = np.where(np.array(errors) >= thresh)[0]
  normal_idxs = np.where(np.array(errors) < thresh)[0] 

  file_signal = h5py.File(quality_file, 'r')
  signal_quality = file_signal['data'][()]
  file_signal.close()

  signal_quality = signal_quality.flatten()
  if leftovers != 0:
    signal_quality = signal_quality[data.shape[1]//125-1:-leftovers//125]
  else:
    signal_quality = signal_quality[data.shape[1]//125-1:]

  # indexes = real
  anomaly_indices = np.where(signal_quality == 0)[0]
  normal_indices = np.where(signal_quality == 1)[0]

  # positive = low quality = 0
  # negatice = normal quality = 1

  tp = len(np.intersect1d(anomaly_indices, anomaly_idxs))
  fp = len(np.intersect1d(normal_indices, anomaly_idxs))
  tn = len(np.intersect1d(normal_indices, normal_idxs))
  fn = len(np.intersect1d(anomaly_indices, normal_idxs))

  recall = tp/(tp+fn)
  precision = tp/(tp+fp)
  accuracy = (tp+tn)/(len(normal_indices) + len(anomaly_indices))

  print('accuracy = %.2f \nprecision = %.2f \nrecall = %.2f' % (accuracy, precision, recall))

  return accuracy, precision, recall

In [None]:
# (Not used in the final model)
# It runs the function performance over all recordings in a folder
def performance_per_test_file(model, quantile=0.90, path=r'Data/fetal_quality_assessment', error_fun='mse', snippet_len=125):

  filenames = os.listdir(path + '/test')

  # filename -> (tpr,tnr)
  perform = dict()

  for name in tqdm(filenames):
    print(name)
    quality_file = name[:-12] + '.quality'

    data, leftovers = create_dataset_from_file(path + '/test/' + name, snippet_len, sliding_window=True, trim_zeros=False, return_leftovers=True)

    accuracy, precision, recall = performance(data, leftovers, model, quantile, path + '/quality_files/' + quality_file, path, error_fun, snippet_len)

    perform[name] = accuracy, precision, recall

  return perform


In [None]:
'''
Present confusion matrix
INPUT
y_true: the true labels of the data (low quality -> 1, high quality -> 0)
y_likelihoods: The likelihood a sample is good or bad quality (a number between 0 and 1)
thr: The threshold above which the likelihood should be to consider a sample of bad quality (tunable)
''' 
def confusion(y_true, y_likelihoods, thr=0.5):

  
  good_indices_predicted = np.where(y_likelihoods < thr)[0]
  bad_indixes_predicted = np.where(y_likelihoods > thr)[0]

  y_pred = np.zeros(y_true.size)
  y_pred[bad_indixes_predicted] = 1

  LABELS = ['Good quality', 'Bad quality']

  conf_matrix = confusion_matrix(y_true, y_pred)
  plt.figure(figsize=(16, 10))
  sns.heatmap(conf_matrix, xticklabels=LABELS, yticklabels=LABELS, annot=True, fmt="d");
  plt.title("Confusion matrix")
  plt.ylabel('True class')
  plt.xlabel('Predicted class')
  plt.show()

In [None]:
'''
Present scatter plot where each sample is a point
INPUT
y_true: the true labels of the data (low quality -> 1, high quality -> 0)
y_likelihoods: The likelihood a sample is good or bad quality (a number between 0 and 1)
''' 
def scatter(y_true, y_likelihoods):

  good_indices = np.where(y_true == 0)[0]
  bad_indices = np.where(y_true == 1)[0]
  good_points = y_likelihoods[good_indices]
  bad_points = y_likelihoods[bad_indices]

  plt.figure(figsize=(16,10))
  plt.scatter(good_indices, good_points, label='hign quality', c='b', marker='.')
  plt.scatter(bad_indices, bad_points, label='low quality', c='r', marker='.')
  plt.ylabel("Likelihood")
  plt.xlabel("Data point index")

  plt.legend()
  plt.show()

In [None]:
'''
Present the precision-recall curve, in order to assess model and tune the threshold
INPUT
y_true: the true labels of the data (low quality -> 1, high quality -> 0)
y_likelihoods: The likelihood a sample is good or bad quality (a number between 0 and 1)
''' 
def precision_recall_curv(y_true, y_likelihoods):

  precision_rt, recall_rt, threshold_rt = precision_recall_curve(y_true, y_likelihoods)
  percentiles = np.argsort(threshold_rt) * 100. / (len(threshold_rt) - 1)

  # find intersection
  idx = np.argwhere(np.diff(np.sign(precision_rt - recall_rt))).flatten()
  x = percentiles[idx[0]]
  y = precision_rt[idx[0]]
  print('(x,y)=(%.2f, %.2f)' % (x,y))

  plt.figure(figsize=(16,10))
  plt.plot(percentiles, precision_rt[1:], label="Precision", linewidth=2)
  plt.plot(percentiles, recall_rt[1:], label="Recall", linewidth=2)
  plt.title('Precision and recall for different threshold values')
  plt.xlabel('Threshold')
  plt.ylabel('Precision/Recall')
  plt.legend()
  plt.show()

In [None]:
'''
Present the ROC curve and compute the AUC to assess the model
INPUT
y_true: the true labels of the data (low quality -> 1, high quality -> 0)
y_likelihoods: The likelihood a sample is good or bad quality (a number between 0 and 1)
''' 
def roc_auc(y_true, y_likelihoods):
  false_pos_rate, true_pos_rate, thresholds = roc_curve(y_true, y_likelihoods)
  roc_auc = auc(false_pos_rate, true_pos_rate,)

  plt.figure(figsize=(12, 12))
  plt.plot(false_pos_rate, true_pos_rate, linewidth=5, label='AUC = %0.3f'% roc_auc)
  plt.plot([0,1],[0,1], linewidth=5)
  plt.xlim([-0.01, 1])
  plt.ylim([0, 1.01])
  plt.legend(loc='lower right')
  plt.title('Receiver operating characteristic curve (ROC)')
  plt.ylabel('True Positive Rate')
  plt.xlabel('False Positive Rate')
  plt.show()


In [None]:
'''
Calculate precision, recall and accuracy
INPUT
y_true: the true labels of the data (low quality -> 1, high quality -> 0)
y_likelihoods: The likelihood a sample is good or bad quality (a number between 0 and 1)
thr: The threshold above which the likelihood should be to consider a sample of bad quality (tunable)
''' 
def precision_recall_metrics(y_true, y_likelihoods, thr):

   
  good_indices_predicted = np.where(y_likelihoods < thr)[0]
  bad_indixes_predicted = np.where(y_likelihoods > thr)[0]

  y_pred = np.zeros(y_true.size)
  y_pred[bad_indixes_predicted] = 1

  total = y_true.size
  real_pos = np.where(y_true == 1)[0]
  real_neg = np.where(y_true == 0)[0]
  pred_pos = np.where(y_pred == 1)[0]
  pred_neg = np.where(y_pred == 0)[0]

  tp = np.intersect1d(real_pos, pred_pos).size
  tn = np.intersect1d(real_neg, pred_neg).size
  fp = np.intersect1d(real_neg, pred_pos).size
  fn = np.intersect1d(real_pos, pred_neg).size

  precision = tp/(tp+fp)
  recall = tp/(tp+fn)
  accuracy = (tp+tn)/total

  print('Precision = %.4f \nRecall = %.4f \nAccuracy = %.4f' % (precision, recall, accuracy))

Previous versions of performance metrics functions

In [None]:
# def scatter_plot(model, file_path, quality_path, quantile=None, snippet_len=125, scaler_path='Models/ALL_scaler.pkl'):

#   scaler = load(open(scaler_path, 'rb'))
#   d, leftovers = create_dataset_from_file(file_path, snippet_len, sliding_window=True, trim_zeros=False, return_leftovers=True)
#   d = scale_data(d, scaler, 125)
#   errors = find_errors(d, model)
#   segment_errors=errors[:, -125:]
#   mean_errors = np.mean(segment_errors, axis=1)

#   file_signal = h5py.File(quality_path, 'r')
#   signal_quality = file_signal['data'][()]
#   file_signal.close()

#   signal_quality = remove_nan(signal_quality)
#   signal_quality = signal_quality.flatten()

#   if leftovers!=0:
#     signal_quality = signal_quality[snippet_len//125-1:-leftovers//125]
#   else:
#     signal_quality = signal_quality[snippet_len//125-1:]

#   bad_indices = np.where(signal_quality == 0)[0]
#   good_indices = np.where(signal_quality == 1)[0]

#   good_points = mean_errors[good_indices]
#   bad_points = mean_errors[bad_indices]

#   plt.figure(figsize=(16,10))
#   # plt.xlim(6000, 8000)
#   # plt.ylim(0, 100)
#   plt.scatter(good_indices, good_points, label='hign quality', c='b')
#   plt.scatter(bad_indices, bad_points, label='low quality', c='r')
#   plt.title("Reconstruction error")
#   plt.ylabel("Reconstruction error")
#   plt.xlabel("Data point index")

#   if quantile is not None:
#     thr = np.quantile(mean_errors, quantile)
#     plt.hlines(thr, 0, len(signal_quality), colors="g", zorder=100, label='Threshold')

#   plt.legend()
#   plt.show()





# def precision_recall(model, file_path, quality_path, snippet_len=125, scaler_path='Models/ALL_scaler.pkl'):

#   scaler = load(open(scaler_path, 'rb'))
#   d, leftovers = create_dataset_from_file(file_path, snippet_len, sliding_window=True, trim_zeros=False, return_leftovers=True)
#   d = scale_data(d, scaler, 125)
#   errors = find_errors(d, model)
#   segment_errors=errors[:, -125:]
#   mean_errors = np.mean(segment_errors, axis=1)

#   file_signal = h5py.File(quality_path, 'r')
#   signal_quality = file_signal['data'][()]
#   file_signal.close()

#   signal_quality = remove_nan(signal_quality)
#   signal_quality = signal_quality.flatten()

#   if leftovers!=0:
#     signal_quality = signal_quality[snippet_len//125-1:-leftovers//125]
#   else:
#     signal_quality = signal_quality[snippet_len//125-1:]

#   y_true = np.logical_xor(signal_quality, np.ones(len(signal_quality)))  # because we consider the bad ones positives
#   precision_rt, recall_rt, threshold_rt = precision_recall_curve(y_true, mean_errors)
#   percentiles = np.argsort(threshold_rt) * 100. / (len(threshold_rt) - 1)
#   plt.figure(figsize=(16,10))
#   plt.plot(percentiles, precision_rt[1:], label="Precision",linewidth=2)
#   plt.plot(percentiles, recall_rt[1:], label="Recall",linewidth=2)
#   plt.title('Precision and recall for different threshold values')
#   plt.xlabel('Threshold')
#   plt.ylabel('Precision/Recall')
#   plt.legend()
#   plt.show()




# def confusion(model, file_path, quality_path, quantile, snippet_len=125, scaler_path='Models/ALL_scaler.pkl'):

#   scaler = load(open(scaler_path, 'rb'))
#   d, leftovers = create_dataset_from_file(file_path, snippet_len, sliding_window=True, trim_zeros=False, return_leftovers=True)
#   d = scale_data(d, scaler, 125)
#   errors = find_errors(d, model)
#   segment_errors=errors[:, -125:]
#   mean_errors = np.mean(segment_errors, axis=1)

#   file_signal = h5py.File(quality_path, 'r')
#   signal_quality = file_signal['data'][()]
#   file_signal.close()

#   signal_quality = remove_nan(signal_quality)
#   signal_quality = signal_quality.flatten()

#   if leftovers!=0:
#     signal_quality = signal_quality[snippet_len//125-1:-leftovers//125]
#   else:
#     signal_quality = signal_quality[snippet_len//125-1:]

#   thresh = np.quantile(mean_errors, quantile)
#   anomaly_idxs = np.where(np.array(mean_errors) >= thresh)[0]
#   y_pred = np.zeros(len(signal_quality))
#   y_pred[anomaly_idxs] = 1
#   y_true = np.logical_xor(signal_quality, np.ones(len(signal_quality)))  # because we consider the bad ones positives
#   LABELS = ['Good quality', 'Bad quality']

#   conf_matrix = confusion_matrix(y_true, y_pred)
#   plt.figure(figsize=(16, 10))
#   sns.heatmap(conf_matrix, xticklabels=LABELS, yticklabels=LABELS, annot=True, fmt="d");
#   plt.title("Confusion matrix")
#   plt.ylabel('True class')
#   plt.xlabel('Predicted class')
#   plt.show()



# def roc_auc(model, file_path, quality_path, snippet_len=125, scaler_path='Models/ALL_scaler.pkl'):

#   scaler = load(open(scaler_path, 'rb'))
#   d, leftovers = create_dataset_from_file(file_path, snippet_len, sliding_window=True, trim_zeros=False, return_leftovers=True)
#   d = scale_data(d, scaler, 125)
#   errors = find_errors(d, model)
#   segment_errors=errors[:, -125:]
#   mean_errors = np.mean(segment_errors, axis=1)

#   file_signal = h5py.File(quality_path, 'r')
#   signal_quality = file_signal['data'][()]
#   file_signal.close()

#   signal_quality = remove_nan(signal_quality)
#   signal_quality = signal_quality.flatten()

#   if leftovers!=0:
#     signal_quality = signal_quality[snippet_len//125-1:-leftovers//125]
#   else:
#     signal_quality = signal_quality[snippet_len//125-1:]

#   y_true = np.logical_xor(signal_quality, np.ones(len(signal_quality)))  # because we consider the bad ones positives

#   false_pos_rate, true_pos_rate, thresholds = roc_curve(y_true, mean_errors)
#   roc_auc = auc(false_pos_rate, true_pos_rate,)

#   plt.figure(figsize=(12, 12))
#   plt.plot(false_pos_rate, true_pos_rate, linewidth=5, label='AUC = %0.3f'% roc_auc)
#   plt.plot([0,1],[0,1], linewidth=5)
#   plt.xlim([-0.01, 1])
#   plt.ylim([0, 1.01])
#   plt.legend(loc='lower right')
#   plt.title('Receiver operating characteristic curve (ROC)')
#   plt.ylabel('True Positive Rate')
#   plt.xlabel('False Positive Rate')
#   plt.show()




