# Tune Classification Thresholds

This notebook helps you sweep thresholds and plot precision/recall per class.

## Setup


In [None]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Add src to path
project_root = Path().resolve().parent
sys.path.insert(0, str(project_root))

from src.eval import evaluate_dataset


## Sweep Thresholds


In [None]:
# Test different thresholds
thresholds = np.arange(0.0, 10.0, 0.5)
results = []

for threshold in thresholds:
    metrics = evaluate_dataset(threshold=threshold)
    if metrics.get("num_samples", 0) > 0:
        results.append({
            "threshold": threshold,
            "accuracy": metrics["overall_accuracy"],
            **metrics["per_class_precision"],
            **{f"recall_{k}": v for k, v in metrics["per_class_recall"].items()}
        })

results_df = pd.DataFrame(results)


## Plot Results


In [None]:
if len(results_df) > 0:
    fig, axes = plt.subplots(2, 1, figsize=(10, 8))
    
    # Plot precision
    axes[0].plot(results_df["threshold"], results_df["recipe"], label="Recipe", marker="o")
    axes[0].plot(results_df["threshold"], results_df["workout"], label="Workout", marker="s")
    axes[0].plot(results_df["threshold"], results_df["quote"], label="Quote", marker="^")
    axes[0].set_xlabel("Threshold")
    axes[0].set_ylabel("Precision")
    axes[0].set_title("Precision vs Threshold")
    axes[0].legend()
    axes[0].grid(True)
    
    # Plot recall
    axes[1].plot(results_df["threshold"], results_df["recall_recipe"], label="Recipe", marker="o")
    axes[1].plot(results_df["threshold"], results_df["recall_workout"], label="Workout", marker="s")
    axes[1].plot(results_df["threshold"], results_df["recall_quote"], label="Quote", marker="^")
    axes[1].set_xlabel("Threshold")
    axes[1].set_ylabel("Recall")
    axes[1].set_title("Recall vs Threshold")
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Find best threshold (highest accuracy)
    best_idx = results_df["accuracy"].idxmax()
    best_threshold = results_df.loc[best_idx, "threshold"]
    print(f"Best threshold: {best_threshold:.2f}")
    print(f"Best accuracy: {results_df.loc[best_idx, 'accuracy']:.3f}")
