# Threshold moving
- Maximize the positive predictive value (aka precision)
- While also maximizing true positive rate (aka sensitivity or recall)
- While minimizing the number of removed samples

In [1]:
import sys

sys.path.append("..")

In [2]:
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve

from phenobase.pylib.binary_metrics import Metrics
from phenobase.pylib import util

In [3]:
CSV = Path("..") / "data" / "score.csv"
OUT = Path("..") / "data" / "thresholds.csv"

DF = pd.read_csv(CSV)
# DF.head()

In [4]:
CHECKPOINTS = DF["checkpoint"].unique()
# CHECKPOINTS

In [5]:
TRAITS = DF["trait"].unique()
# TRAITS

In [6]:
def move_threshold(trait, checkpoint):
    df = DF.loc[(DF["checkpoint"] == checkpoint) & (DF["trait"] == trait)]
    if len(df) == 0:
        return
    precision, recall, thresholds = precision_recall_curve(df["y_true"], df["y_pred"])
    f_scores = (2.0 * precision * recall) / (precision + recall)
    idx = np.argmax(f_scores)
    f_score = f_scores[idx]
    threshold = thresholds[idx]
    return f_score, threshold, checkpoint

In [7]:
bests = defaultdict(list)
for trait in TRAITS:
    for checkpoint in CHECKPOINTS:
        best = move_threshold(trait, checkpoint)
        if best is not None:
            bests[trait].append(best)
bests = {k: sorted(v, key=lambda t: t, reverse=True) for k, v in bests.items()}

In [11]:
for trait in TRAITS:
    for i in range(3):
        print(trait, bests[trait][i])
    print()

flowers (0.7904967602591791, 0.525034368038178, 'data/tuned/effnet_528_flowers_prec_wt/checkpoint-20099')
flowers (0.7887931034482758, 0.0252350941300392, 'data/tuned/effnet_528_flowers_prec_wt/checkpoint-19998')
flowers (0.7844925883694414, 0.107626773416996, 'data/tuned/effnet_528_flowers_prec_wt/checkpoint-19897')

fruits (0.7028688524590163, 3.3582030027901e-06, 'data/tuned/effnet_528_fruits_prec_wt/checkpoint-14880')
fruits (0.7010752688172043, 0.00024449560442, 'data/tuned/effnet_528_fruits_prec_wt/checkpoint-18414')
fruits (0.6998916576381365, 0.0007000703481025, 'data/tuned/effnet_528_fruits_prec_wt/checkpoint-18600')

leaves (0.9880715705765407, 0.0002770601131487, 'data/tuned/effnet_528_leaves_prec_wt/checkpoint-23084')
leaves (0.9880478087649401, 0.0023480465169996, 'data/tuned/effnet_528_leaves_prec_wt/checkpoint-22852')
leaves (0.987389659520807, 0.923058807849884, 'data/tuned/vit_384_lg_all_prec_wt/checkpoint-36630')



In [9]:
# no_skill = len(df["y_true"][df["y_true"]==1]) / len(df["y_true"])

# plt.plot([0,1], [no_skill, no_skill], linestyle='--', label='No Skill')
# plt.plot(recall, precision, marker='.', label='Logistic')
# plt.scatter(recall[idx], precision[idx], marker='o', color='black', label='Best')
# plt.xlabel('Recall')
# plt.ylabel('Precision')
# plt.legend()
# plt.show()