In [None]:
import copy
import warnings
import numpy as np
from sklearn.metrics import auc, roc_auc_score, precision_recall_curve

# Metric definitions

## Basic metrics

In [None]:
class BasicMetrics:

    def __init__(self, df, preds):
        self.event_observed = copy.deepcopy(df.event_observed)
        self.preds = copy.deepcopy(preds)


    def get_n_rp(self):
        n_rp = self.event_observed.sum()
        assert 0 <= n_rp, f"{n_rp=}"
        return n_rp


    def get_n_rn(self):
        n_rn = (~self.event_observed).sum()
        assert 0 <= n_rn, f"{n_rn=}"
        return n_rn


    def get_n_tp(self):
        n_tp = (self.preds * self.event_observed).sum()
        assert 0 <= n_tp, f"{n_tp=}"
        return n_tp


    def get_n_fp(self):
        n_fp = (self.preds * ~self.event_observed).sum()
        assert 0 <= n_fp, f"{n_fp=}"
        return n_fp


    def get_n_tn(self):
        n_tn = ((1 - self.preds) * ~self.event_observed).sum()
        assert 0 <= n_tn, f"{n_tn=}"
        return n_tn


    def get_n_fn(self):
        n_fn = ((1 - self.preds) * self.event_observed).sum()
        assert 0 <= n_fn, f"{n_fn=}"
        return n_fn


    def get_n_pp(self):
        n_pp = self.preds.sum()
        assert 0 <= n_pp, f"{n_pp=}"
        return n_pp


    def get_n_np(self):
        n_np = (1 - self.preds).sum()
        assert 0 <= n_np, f"{n_np=}"
        return n_np


    def get_mse(self):
        preds = copy.deepcopy(self.preds)
        mse = ((self.event_observed - preds)**2).mean()
        assert 0 <= mse <= 1, f"{mse=}"
        return mse


    def get_inverted_mse(self):
        inverted_mse = 1 - self.get_mse()
        return inverted_mse

## Advanced metrics

In [None]:
class AdvancedMetrics:

    def __init__(self, stats_classifier, df_generator, part, adjust=True):
        self.stats_classifier = stats_classifier
        self.df_generator = df_generator
        self.part = part

        dfs = df_generator(horizon=stats_classifier.horizon, adjust=adjust)
        self.df = dfs[part]

        self.preds = self.stats_classifier.get_prediction(self.df)

        self.basic_metrics = BasicMetrics(
            df=self.df,
            preds=self.preds,
        )


    def _auroc_model(self):
        if (~self.df.event_observed).all() or self.df.event_observed.all():
            return 1.

        probs = copy.deepcopy(self.preds)
        auroc = roc_auc_score(self.df.event_observed, probs)

        assert 0 <= auroc <= 1, f"{auroc=}"
        return auroc


    def _auprc_model(self):
        probs = copy.deepcopy(self.preds)

        with warnings.catch_warnings():
            warnings.filterwarnings(action='ignore', category=UserWarning)
            precision, recall, _ = precision_recall_curve(self.df.event_observed, probs)

        recall[np.isnan(recall)] = 0
        auprc = auc(recall, precision)

        assert 0 <= auprc <= 1 + 1e-5, f"{auprc=}"
        return auprc


    def _acc_model(self):
        n_tp = self.basic_metrics.get_n_tp()
        n_tn = self.basic_metrics.get_n_tn()

        n_rp = self.basic_metrics.get_n_rp()
        n_rn = self.basic_metrics.get_n_rn()

        if n_rp + n_rn == 0:
            return 0

        acc = (n_tp + n_tn) / (n_rp + n_rn)

        assert n_tp + n_tn <= n_rp + n_rn
        assert 0 <= acc <= 1, f"{acc=}"
        return acc


    def _tpr_model(self):
        n_tp = self.basic_metrics.get_n_tp()

        if n_tp == 0:
            return 0

        n_rp = self.basic_metrics.get_n_rp()
        tpr = n_tp / (n_rp + 1e-9)

        assert n_tp <= n_rp
        assert 0 <= tpr <= 1, f"{tpr=}"
        return tpr


    def _tnr_model(self):
        n_tn = self.basic_metrics.get_n_tn()

        if n_tn == 0:
            return 0

        n_rn = self.basic_metrics.get_n_rn()
        tnr = n_tn / (n_rn + 1e-9)

        assert n_tn <= n_rn
        assert 0 <= tnr <= 1, f"{tnr=}"
        return tnr


    def _baa_model(self):
        tpr = self._tpr_model()
        tnr = self._tnr_model()
        return (tpr + tnr) / 2


    def _you_model(self):
        tpr = self._tpr_model()
        tnr = self._tnr_model()
        you = abs(tpr + tnr - 1)

        assert 0 <= you <= 1, f"{you=}"
        return you


    def _pre_model(self):
        n_tp = self.basic_metrics.get_n_tp()

        if n_tp == 0:
            return 0

        n_pp = self.basic_metrics.get_n_pp()
        pre = n_tp / (n_pp + 1e-9)

        assert n_tp <= n_pp
        assert 0 <= pre <= 1, f"{pre=}"
        return pre


    def _fβ_model(self, β):
        pre = self._pre_model()
        tpr = self._tpr_model()
        fβ = (1 + β**2) * pre * tpr / (β**2 * pre + tpr + 1e-9)

        assert 0 <= fβ <= 1, f"{fβ=}"
        return fβ


    def _metric_names(self):
        metric_names = ["F2", "Precision", "AUPRC", "Accuracy", "Balanced accuracy", "Youden", "Sensitivity", "F1", "F_0.5", "AUROC", "Specificity", "1 - MSE"]
        return metric_names


    def __call__(self):
        metrics = {
            "F2": self._fβ_model(β=2),
            "Precision": self._pre_model(),
            "AUPRC": self._auprc_model(),
            "Accuracy": self._acc_model(),
            "Balanced accuracy": self._baa_model(),
            "Youden": self._you_model(),
            "Sensitivity": self._tpr_model(),
            "F1": self._fβ_model(β=1),
            "F_0.5": self._fβ_model(β=.5),
            "AUROC": self._auroc_model(),
            "Specificity": self._tnr_model(),
            "1 - MSE": self.basic_metrics.get_inverted_mse(),
        }

        assert list(metrics.keys()) == self._metric_names(), "Metrics returned and expected are out of sync."
        return metrics