In [None]:
import pandas as pd
from pyprojroot import here
from sklearn.metrics import precision_recall_curve, auc
import matplotlib.pyplot as plt

from colors import *

from nutils import bootstrap_auc

import warnings
warnings.filterwarnings('ignore')

In [None]:
fig, ax = plt.subplots(figsize=(4,4))

auc_values = {}

for path in (here() / 'data/processed/matrices/prob').glob('*lgbm-11-*.csv'):
    
    TARGET = path.stem.split('-')[0]
    MODEL = path.stem.split('-')[1]
    ORIGIN = path.stem.split('-')[2]
    FS = path.stem.split('-')[3]
    HPO = path.stem.split('-')[4]
    
    pred_name = f'{TARGET}-{MODEL}-{ORIGIN}-{FS}-{HPO}'
    true_name = f'{TARGET}'

    if TARGET=='cri':
        continue
    
    y_true = pd.read_csv(here() / f'data/processed/true_matrices/{true_name}.csv', index_col='Datetime')
    y_pred = pd.read_csv(here() / f'data/processed/matrices/prob/{pred_name}.csv', index_col='Datetime')

    df = pd.concat([y_true, y_pred], axis=1).dropna()
    y_true = df.iloc[:,0]
    y_pred = df.iloc[:,1]

    precision, recall, _ = precision_recall_curve(y_true, y_pred, pos_label=1)
    _auc = auc(recall, precision)

    lb, ub = bootstrap_auc(y_true, y_pred)

    ls = '-'
    GUESS = 'Pred'
    if MODEL=='guess':
        GUESS = 'Guess'
        ls = (0, (1, 5))
    
    ax.plot(recall, precision, 
            label=f"{TARGET.capitalize()} {GUESS} {_auc:.2f} ({lb:.2f}-{ub:.2f})", 
            color=c[TARGET],
            ls=ls)
    
    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    #ax.legend(frameon=False, bbox_to_anchor=(1,1));

    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    
    auc_values[pred_name] = _auc
    
    # Get, reorder and set labels
    handles, labels = plt.gca().get_legend_handles_labels()
    handles_labels = list(zip(handles, labels))
    handles_labels.sort(key=lambda hl: hl[1])
    handles, labels = zip(*handles_labels)
    ax.legend(handles, labels, frameon=False, bbox_to_anchor=(1, 0.5))



plt.savefig(here() / 'output/plots/prcurve.png', dpi=300, bbox_inches='tight')