# Threshold Sensitivity Analysis

This notebook analyzes the effect of the reasoning threshold on the model's performance.
It reads from `metrics_summary.jsonl`.

In [None]:
import json
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Set plot style
sns.set_theme(style="whitegrid")

# Constants
PROJECT_ROOT = Path("../").resolve()
METRICS_FILE = PROJECT_ROOT / "metrics_summary.jsonl"

MODEL_NAME = "Qwen_Qwen3-4B-Thinking-2507"  # Updated model name
CATEGORY = "gsm8k"  # Updated benchmark name

In [None]:
def load_metrics(file_path):
    if not file_path.exists():
        print(f"Warning: {file_path} not found.")
        return []
    data = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data

all_metrics = load_metrics(METRICS_FILE)
print(f"Total records loaded: {len(all_metrics)}")

In [None]:
# Filter data for specific model and category
model_metrics = [
    r for r in all_metrics 
    if r.get("model") == MODEL_NAME 
    and r.get("benchmark") == CATEGORY
    and r.get("sub_category") is None  # Only main benchmark data
]

# Identify baseline (rollout)
baseline_entry = next(
    (r for r in model_metrics if r.get("method") == "rollout"), 
    None
)

if baseline_entry:
    print(f"Baseline (Rollout) found: Acc={baseline_entry.get('accuracy', 0):.2f}%, AvgTokens={baseline_entry.get('avg_token_length', 0):.1f}")
else:
    print("Warning: Baseline (Rollout) not found.")

plot_data = []

for r in model_metrics:
    if r.get("method") == "thinkbrake":
        try:
            t_val = float(r.get("threshold"))
            
            # Calculate Token Reduction
            token_reduction = None
            if baseline_entry:
                base_tokens = baseline_entry.get("avg_token_length", 0)
                curr_tokens = r.get("avg_token_length", 0)
                if base_tokens > 0:
                    token_reduction = (base_tokens - curr_tokens) / base_tokens * 100
            
            entry = {
                "threshold": t_val,
                "accuracy": r.get("accuracy", 0),
                "majority_accuracy": r.get("majority_accuracy"),
                "avg_token_length": r.get("avg_token_length", 0),
                "token_reduction": token_reduction,
                "type": "ThinkBrake"
            }
            
            # Add pass@k if available
            if "pass@k" in r and isinstance(r["pass@k"], dict):
                for k, v in r["pass@k"].items():
                    entry[f"pass@{k}"] = v
            
            plot_data.append(entry)
        except (ValueError, TypeError):
            continue

df = pd.DataFrame(plot_data)
if not df.empty:
    df = df.sort_values("threshold")
    print(df.head())
else:
    print("No ThinkBrake data available for plotting.")

In [None]:
if not df.empty:
    fig, ax1 = plt.subplots(figsize=(10, 6))

    # Plot Accuracy on primary y-axis
    color = 'tab:blue'
    ax1.set_xlabel('Threshold')
    ax1.set_ylabel('Accuracy (%)', color=color)
    line1 = ax1.plot(df['threshold'], df['accuracy'], marker='o', color=color, label='Accuracy')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.grid(True)
    
    # Plot Baseline Accuracy if available
    if baseline_entry:
        base_acc = baseline_entry.get("accuracy", 0)
        line_base = ax1.axhline(y=base_acc, color='gray', linestyle='--', label='Rollout Baseline')

    # Create secondary y-axis for Token Reduction
    if "token_reduction" in df.columns and df["token_reduction"].notna().any():
        ax2 = ax1.twinx()
        color = 'tab:red'
        ax2.set_ylabel('Token Reduction (%)', color=color)
        line2 = ax2.plot(df['threshold'], df['token_reduction'], marker='s', color=color, linestyle='-.', label='Token Reduction')
        ax2.tick_params(axis='y', labelcolor=color)
        ax2.grid(False)
    
    plt.title(f'Threshold Sensitivity: {MODEL_NAME} on {CATEGORY}')
    
    # Combine legends
    lines = line1
    if "line2" in locals():
        lines += line2
    if "line_base" in locals():
        lines.append(line_base)
        
    labels = [l.get_label() for l in lines]
    ax1.legend(lines, labels, loc='best')
    
    plt.show()