# Threshold moving
- Maximize the positive predictive value (aka precision)
- While also maximizing true positive rate (aka sensitivity or recall)
- While minimizing the number of removed samples

In [1]:
import sys

sys.path.append("..")

In [2]:
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve

from phenobase.pylib.binary_metrics import Metrics
from phenobase.pylib import util

In [3]:
CSV = Path("..") / "data" / "score.csv"
OUT = Path("..") / "data" / "thresholds.csv"

DF = pd.read_csv(CSV)
# DF.head()

In [4]:
CHECKPOINTS = DF["checkpoint"].unique()
# CHECKPOINTS

In [5]:
TRAITS = DF["trait"].unique()
# TRAITS

In [6]:
@dataclass
class Score:
    score: float
    threshold: float
    checkpoint: str
    trait: str

In [7]:
def find_best_score(metrics, recall_limit=0.7):
    best = (0.0, 0.0)
    for threshold in np.arange(0.01, 1.0, 0.01):
        metrics.filter_y(thresh_lo=threshold, thresh_hi=threshold)
        if (
            metrics.precision >= best[0] 
            and metrics.recall >= recall_limit
            and metrics.total > 0
        ):
            best = metrics.precision, threshold
    return best

In [8]:
scores = defaultdict(list)
for trait in TRAITS:
    for checkpoint in CHECKPOINTS:
        df = DF.loc[(DF["checkpoint"] == checkpoint) & (DF["trait"] == trait)]
        if len(df) > 0:
            best = Metrics(y_true=df["y_true"], y_pred=df["y_pred"])
            score = find_best_score(best)
            scores[trait].append(
                Score(
                    score=score[0],
                    threshold=score[1],
                    checkpoint=checkpoint,
                    trait=trait,
                )
            )
scores = {k: sorted(v, key=lambda b: b.score, reverse=True) for k, v in scores.items()}

In [10]:
for trait in TRAITS:
    for i in range(5):
        score = scores[trait][i]
        print(trait, score)
        df = DF.loc[(DF["checkpoint"] == score.checkpoint) & (DF["trait"] == score.trait)]
        metrics = Metrics(y_true=df["y_true"], y_pred=df["y_pred"])
        metrics.filter_y(thresh_lo=score.threshold, thresh_hi=score.threshold)
        metrics.display_matrix()
    print()

flowers Score(score=0.8095238095238095, threshold=0.89, checkpoint='data/tuned/effnet_528_flowers_prec_wt/checkpoint-19897', trait='flowers')
tp =  306    fn =  130
fp =   72    tn =  404
total =  912
flowers Score(score=0.7994858611825193, threshold=0.98, checkpoint='data/tuned/effnet_528_flowers_prec_wt/checkpoint-20200', trait='flowers')
tp =  311    fn =  125
fp =   78    tn =  398
total =  912
flowers Score(score=0.7989690721649485, threshold=0.73, checkpoint='data/tuned/effnet_528_flowers_prec_nowt/checkpoint-19897', trait='flowers')
tp =  310    fn =  126
fp =   78    tn =  398
total =  912
flowers Score(score=0.7979274611398963, threshold=0.99, checkpoint='data/tuned/effnet_528_flowers_prec_nowt/checkpoint-19998', trait='flowers')
tp =  308    fn =  128
fp =   78    tn =  398
total =  912
flowers Score(score=0.7974025974025974, threshold=0.98, checkpoint='data/tuned/effnet_528_flowers_prec_wt/checkpoint-19998', trait='flowers')
tp =  307    fn =  129
fp =   78    tn =  398
tota