AutoMix Threshold Router on Log Probability and Entropy

In [1]:
# Rquired libraries
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple

# Configuration

Set the dataset name and other parameters here.

In [53]:
# Configuration

# Dataset name: cnli_short / coqa_short / narrativeqa_short / qasper_short
DATASET_NAME = "qasper_short"

# Router data path
LOGIT_DATA_PATH = os.path.join("logits_entropy", f"{DATASET_NAME}_answerlogits.jsonl")

# Output directory for results
DIR = f"{DATASET_NAME}_outputs"
OUTPUT_DIR = "logits_entropy_threshold_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Cost parameters (for cost-performance analysis)
# Cost model: C = C_SLM + C_ver + w * C_LLM
# where C_ver = C_SLM (verification is performed by SLM)
C_SLM = 1.0  # Cost of small language model (LM1)
C_LLM = 20.0  # Cost of large language model (LM2)
C_ver = 4 * C_SLM  # Cost of verification (performed by SLM)

# Cost when keeping LM1: C = C_SLM + C_ver = 2 * C_SLM
# Cost when routing to LM2: C = C_SLM + C_ver + C_LLM = 2 * C_SLM + C_LLM
COST_LM1 = C_SLM + C_ver
COST_LM2 = C_SLM + C_ver + C_LLM

print(f"Dataset: {DATASET_NAME}")
print(f"Router data path: {LOGIT_DATA_PATH}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"\nCost Parameters:")
print(f"  C_SLM: {C_SLM}")
print(f"  C_LLM: {C_LLM}")
print(f"  C_ver: {C_ver}")
print(f"  Cost when keeping LM1: {COST_LM1}")
print(f"  Cost when routing to LM2: {COST_LM2}")


Dataset: qasper_short
Router data path: logits_entropy/qasper_short_answerlogits.jsonl
Output directory: logits_entropy_threshold_output

Cost Parameters:
  C_SLM: 1.0
  C_LLM: 20.0
  C_ver: 4.0
  Cost when keeping LM1: 5.0
  Cost when routing to LM2: 25.0


# Load Logits and Entropy Data

In [54]:
# Load logit data

if not os.path.exists(LOGIT_DATA_PATH):
    raise FileNotFoundError(f"Logits data not found: {LOGIT_DATA_PATH}")

print(f"[INFO] Loading Logits data from {LOGIT_DATA_PATH}")
df = pd.read_json(LOGIT_DATA_PATH, lines=True, orient="records")

print(f"[INFO] Loaded {len(df)} samples")
print(f"[INFO] Columns: {list(df.columns)}")
print(f"\n[INFO] Sample data:")
df.head()

[INFO] Loading Logits data from logits_entropy/qasper_short_answerlogits.jsonl
[INFO] Loaded 1000 samples
[INFO] Columns: ['generated_text', 'avg_logp', 'avg_entropy_nats', 'avg_entropy_bits']

[INFO] Sample data:


Unnamed: 0,generated_text,avg_logp,avg_entropy_nats,avg_entropy_bits
0,A very small lexicon consisting of 15 positive...,-0.157852,0.323955,0.467368
1,A very small lexicon consisting of 15 positive...,-0.157852,0.323955,0.467368
2,The proposed method performed well even with a...,-0.152505,0.370588,0.534645
3,The relations of Cause and Concession are used...,-0.226092,0.52257,0.753909
4,The relations of Cause and Concession are used...,-0.226092,0.52257,0.753909


In [55]:
# Merge with router data to get performance metrics
ROUTER_DATA_PATH_FULL = os.path.join("router_data", f"router_data_{DATASET_NAME}.jsonl")

if not os.path.exists(ROUTER_DATA_PATH_FULL):
    raise FileNotFoundError(f"Router data not found: {ROUTER_DATA_PATH_FULL}")

print(f"[INFO] Loading router data from {ROUTER_DATA_PATH_FULL}")
router_df = pd.read_json(ROUTER_DATA_PATH_FULL, lines=True, orient="records")

print(f"[INFO] Loaded {len(router_df)} router samples")
print(f"[INFO] Router columns: {list(router_df.columns)}")

# Merge logits/entropy data with router data
# Assuming they are in the same order (same number of samples)
assert len(df) == len(router_df), f"Data length mismatch: {len(df)} vs {len(router_df)}"

# Merge the dataframes
df_merged = router_df.copy()
df_merged["avg_logp"] = df["avg_logp"].values
df_merged["avg_entropy_nats"] = df["avg_entropy_nats"].values

print(f"\n[INFO] Merged data shape: {df_merged.shape}")
print(f"[INFO] Merged columns: {list(df_merged.columns)}")
print(f"\n[INFO] Sample merged data:")
df_merged.head()




[INFO] Loading router data from router_data/router_data_qasper_short.jsonl
[INFO] Loaded 1000 router samples
[INFO] Router columns: ['id', 'dataset', 'gold_output', 'slm_pred', 'llm_pred', 'slm_confidence', 'perf_slm', 'perf_llm']

[INFO] Merged data shape: (1000, 10)
[INFO] Merged columns: ['id', 'dataset', 'gold_output', 'slm_pred', 'llm_pred', 'slm_confidence', 'perf_slm', 'perf_llm', 'avg_logp', 'avg_entropy_nats']

[INFO] Sample merged data:


Unnamed: 0,id,dataset,gold_output,slm_pred,llm_pred,slm_confidence,perf_slm,perf_llm,avg_logp,avg_entropy_nats
0,753990d0b621d390ed58f20c4d9e4f065f0dc672,qasper_short,a vocabulary of positive and negative predicat...,A seed lexicon consisting of 15 positive words...,a small list of 30 Japanese emotion predicates...,0.75,0.357143,0.25641,-0.157852,0.323955
1,753990d0b621d390ed58f20c4d9e4f065f0dc672,qasper_short,seed lexicon consists of positive and negative...,A seed lexicon consisting of 15 positive words...,a list of 15 positive and 15 negative Japanese...,0.75,0.6,0.357143,-0.157852,0.323955
2,9d578ddccc27dd849244d632dd0f6bf27348ad81,qasper_short,Using all data to train: AL -- BiGRU achieved ...,"The proposed method performed well, even with ...",BiGRU trained with AL+CA+CO achieved the highe...,0.75,0.09009,0.105263,-0.152505,0.370588
3,02e4bf719b1a504e385c35c6186742e720bcb281,qasper_short,"based on the relation between events, the sugg...",The answer is through the exploitation of disc...,By exploiting discourse relations: for Cause p...,0.75,0.176471,0.318182,-0.226092,0.52257
4,02e4bf719b1a504e385c35c6186742e720bcb281,qasper_short,cause relation: both events in the relation sh...,The relations are used to propagate polarity t...,By using discourse relations between event pai...,0.75,0.222222,0.355556,-0.226092,0.52257


In [56]:
# Calculate the threshold for the logits and entropy based on the max and min values
# partition into 10 equal-sized bins

# Get min and max values for logits and entropy
logits_min = df_merged["avg_logp"].min()
logits_max = df_merged["avg_logp"].max()
entropy_min = df_merged["avg_entropy_nats"].min()
entropy_max = df_merged["avg_entropy_nats"].max()

# Create 12 evenly spaced thresholds (including min and max), then remove first and last
# to get 10 thresholds between min and max
logits_thresholds = np.linspace(logits_min, logits_max, 11)[1:-1]
entropy_thresholds = np.linspace(entropy_min, entropy_max, 11)[1:-1]

print(f"\nLogits thresholds (9 bins, excluding min and max):")
print(logits_thresholds)
print(f"\nEntropy thresholds (9 bins, excluding min and max):")
print(entropy_thresholds)



Logits thresholds (9 bins, excluding min and max):
[-0.97362576 -0.86590608 -0.7581864  -0.65046671 -0.54274703 -0.43502735
 -0.32730767 -0.21958799 -0.11186831]

Entropy thresholds (9 bins, excluding min and max):
[0.19565922 0.37014548 0.54463174 0.719118   0.89360426 1.06809052
 1.24257678 1.41706304 1.5915493 ]


# Threshold Router Function

In [57]:
# Threshold-based router function using logits and entropy

def apply_logits_entropy_router(df: pd.DataFrame, logits_threshold: float = None, entropy_threshold: float = None) -> Tuple[pd.DataFrame, Dict]:
    """
    Apply threshold-based router using logits and/or entropy to route queries between LM1 and LM2.
    
    Routing logic:
    - If avg_logp >= logits_threshold OR avg_entropy_nats <= entropy_threshold: Keep LM1
    - Otherwise: Route to LM2
    """
    results = []
    lm1_count = 0
    lm2_count = 0
    
    for i, row in df.iterrows():
        avg_logp = float(row["avg_logp"])
        avg_entropy = float(row["avg_entropy_nats"])
        slm_pred = row["slm_pred"]  # LM1's answer
        llm_pred = row["llm_pred"]  # LM2's answer
        
        # Routing decision based on logits and/or entropy
        # High logits (>= threshold) AND low entropy (<= threshold) = keep LM1
        keep_lm1 = True
        
        if logits_threshold is not None:
            keep_lm1 = keep_lm1 and (avg_logp >= logits_threshold)
        
        if entropy_threshold is not None:
            keep_lm1 = keep_lm1 and (avg_entropy <= entropy_threshold)
        
        if keep_lm1:
            final_ans = slm_pred
            action = "keep"  # Keep LM1's answer
            model_used = "lm1"
            lm1_count += 1
            cost = COST_LM1  # C_SLM + C_ver = 5.0
        else:
            final_ans = llm_pred
            action = "route"  # Route to LM2
            model_used = "lm2"
            lm2_count += 1
            cost = COST_LM2  # C_SLM + C_ver + C_LLM = 25.0
        
        # Compute performance for the chosen answer
        perf_chosen = row["perf_slm"] if model_used == "lm1" else row["perf_llm"]
        
        results.append({
            "id": row.get("id", i),
            "dataset": row.get("dataset", DATASET_NAME),
            "avg_logp": avg_logp,
            "avg_entropy_nats": avg_entropy,
            "logits_threshold": logits_threshold,
            "entropy_threshold": entropy_threshold,
            "action": action,
            "model_used": model_used,
            "gold_answer": row.get("gold_output"),
            "final_answer": final_ans,
            "slm_pred": slm_pred,
            "llm_pred": llm_pred,
            "perf_chosen": perf_chosen,
            "perf_slm": row.get("perf_slm"),
            "perf_llm": row.get("perf_llm"),
            "cost": cost,
        })
    
    results_df = pd.DataFrame(results)
    
    # Compute summary statistics
    total = len(results_df)
    lm1_ratio = lm1_count / total
    lm2_ratio = lm2_count / total
    avg_perf = results_df["perf_chosen"].mean()
    avg_perf_slm = results_df["perf_slm"].mean()
    avg_perf_llm = results_df["perf_llm"].mean()
    avg_cost = results_df["cost"].mean()
    total_cost = results_df["cost"].sum()
    
    stats = {
        "logits_threshold": logits_threshold,
        "entropy_threshold": entropy_threshold,
        "total": total,
        "lm1_count": lm1_count,
        "lm2_count": lm2_count,
        "lm1_ratio": lm1_ratio,
        "lm2_ratio": lm2_ratio,
        "avg_perf": avg_perf,
        "avg_perf_slm": avg_perf_slm,
        "avg_perf_llm": avg_perf_llm,
        "avg_cost": avg_cost,
        "total_cost": total_cost,
    }
    
    return results_df, stats


In [None]:
# Test with single thresholds (one at a time)


# Test 1: Only logits threshold
print("="*60)
print("Test 1: Using ONLY logits threshold")
print("="*60)
TEST_LOGITS_THRESHOLD = logits_thresholds[5]  # Middle threshold

results_df_logits, stats_logits = apply_logits_entropy_router(
    df_merged, 
    logits_threshold=TEST_LOGITS_THRESHOLD,
    entropy_threshold=None  # Only using logits
)

print(f"Logits threshold: {stats_logits['logits_threshold']:.4f}")
print(f"Entropy threshold: {stats_logits['entropy_threshold']}")
print(f"LM1 (keep): {stats_logits['lm1_count']} ({stats_logits['lm1_ratio']:.2%})")
print(f"LM2 (route): {stats_logits['lm2_count']} ({stats_logits['lm2_ratio']:.2%})")
print(f"Average performance: {stats_logits['avg_perf']:.4f}")
print(f"Average cost: {stats_logits['avg_cost']:.2f}\n")

# Test 2: Only entropy threshold
print("="*60)
print("Test 2: Using ONLY entropy threshold")
print("="*60)
TEST_ENTROPY_THRESHOLD = entropy_thresholds[5]  # Middle threshold

results_df_entropy, stats_entropy = apply_logits_entropy_router(
    df_merged, 
    logits_threshold=None,  # Only using entropy
    entropy_threshold=TEST_ENTROPY_THRESHOLD
)

print(f"Logits threshold: {stats_entropy['logits_threshold']}")
print(f"Entropy threshold: {stats_entropy['entropy_threshold']:.4f}")
print(f"LM1 (keep): {stats_entropy['lm1_count']} ({stats_entropy['lm1_ratio']:.2%})")
print(f"LM2 (route): {stats_entropy['lm2_count']} ({stats_entropy['lm2_ratio']:.2%})")
print(f"Average performance: {stats_entropy['avg_perf']:.4f}")
print(f"Average cost: {stats_entropy['avg_cost']:.2f}\n")


Test 1: Using ONLY logits threshold
Logits threshold: -0.4350
Entropy threshold: None
LM1 (keep): 907 (90.70%)
LM2 (route): 93 (9.30%)
Average performance: 0.3208
Average cost: 6.86

Test 2: Using ONLY entropy threshold
Logits threshold: None
Entropy threshold: 1.0681
LM1 (keep): 949 (94.90%)
LM2 (route): 51 (5.10%)
Average performance: 0.3108
Average cost: 6.02



# Threshold Sweep for Logits and Entropy

In [None]:
# Threshold sweep: test logits and entropy thresholds separately


print(f"[INFO] Analyzing threshold sweep for {DATASET_NAME}")

# Separate result variables for logits and entropy
logits_results = []
logits_all_results = []
entropy_results = []
entropy_all_results = []

# Test 1: Logits thresholds only (entropy_threshold = None)
for logits_thresh in logits_thresholds:
    results_df, stats = apply_logits_entropy_router(
        df_merged,
        logits_threshold=logits_thresh,
        entropy_threshold=None  # Only using logits
    )
    logits_results.append(stats)
    logits_all_results.append(results_df)
    

# Test 2: Entropy thresholds only (logits_threshold = None)
for entropy_thresh in entropy_thresholds:
    results_df, stats = apply_logits_entropy_router(
        df_merged,
        logits_threshold=None,  # Only using entropy
        entropy_threshold=entropy_thresh
    )
    entropy_results.append(stats)
    entropy_all_results.append(results_df)
    

# Convert to separate DataFrames for easier analysis
logits_sweep_df = pd.DataFrame(logits_results)
entropy_sweep_df = pd.DataFrame(entropy_results)

# Also create a combined DataFrame for backward compatibility
sweep_df = pd.concat([logits_sweep_df, entropy_sweep_df], ignore_index=True)

print(f"\n{'='*60}")
print(f"Threshold Sweep Summary for {DATASET_NAME}")
print(f"{'='*60}")
print("\nLogits-only results:")
print(logits_sweep_df[["logits_threshold", "entropy_threshold", "lm1_ratio", "lm2_ratio", 
                       "avg_perf", "avg_cost", "total_cost"]].to_string(index=False))
print(f"\n{'='*60}")
print("\nEntropy-only results:")
print(entropy_sweep_df[["logits_threshold", "entropy_threshold", "lm1_ratio", "lm2_ratio", 
                        "avg_perf", "avg_cost", "total_cost"]].to_string(index=False))
print(f"{'='*60}\n")


[INFO] Analyzing threshold sweep for qasper_short

Threshold Sweep Summary for qasper_short

Logits-only results:
 logits_threshold entropy_threshold  lm1_ratio  lm2_ratio  avg_perf  avg_cost  total_cost
        -0.973626              None      0.998      0.002  0.302848      5.04      5040.0
        -0.865906              None      0.994      0.006  0.306242      5.12      5120.0
        -0.758186              None      0.988      0.012  0.307260      5.24      5240.0
        -0.650467              None      0.975      0.025  0.308617      5.50      5500.0
        -0.542747              None      0.954      0.046  0.309875      5.92      5920.0
        -0.435027              None      0.907      0.093  0.320810      6.86      6860.0
        -0.327308              None      0.797      0.203  0.337952      9.06      9060.0
        -0.219588              None      0.532      0.468  0.365406     14.36     14360.0
        -0.111868              None      0.195      0.805  0.383237     21.1

  sweep_df = pd.concat([logits_sweep_df, entropy_sweep_df], ignore_index=True)


In [None]:
# Calculate IBC (Incremental Benefit Per Cost) metrics

# Based on the IBC metric defined in the paper:
# IBC_M = (P_M - P_SLM) / (C_M - C_SLM)
# IBC_BASE = (P_LLM - P_SLM) / (C_LLM - C_SLM)
# Δ_IBC(M) = ((IBC_M - IBC_BASE) / IBC_BASE) * 100

# Get baseline performance values (same for both methods)
P_SLM = logits_sweep_df["avg_perf_slm"].iloc[0]  # Average performance of SLM (with verification)
P_LLM = logits_sweep_df["avg_perf_llm"].iloc[0]  # Average performance of LLM (with verification)

# Baseline costs
C_SLM_baseline = COST_LM1  # Cost of SLM method: 5.0
C_LLM_baseline = COST_LM2  # Cost of LLM method: 25.0

# Calculate IBC_BASE: Baseline incremental benefit per cost
IBC_BASE = (P_LLM - P_SLM) / (C_LLM_baseline - C_SLM_baseline)

print(f"\n{'='*60}")
print(f"IBC (Incremental Benefit Per Cost) Baseline")
print(f"{'='*60}")
print(f"P_SLM (SLM average performance): {P_SLM:.4f}")
print(f"P_LLM (LLM average performance): {P_LLM:.4f}")
print(f"C_SLM (SLM cost): {C_SLM_baseline:.2f}")
print(f"C_LLM (LLM cost): {C_LLM_baseline:.2f}")
print(f"IBC_BASE: {IBC_BASE:.6f}")
print(f"{'='*60}\n")

# Calculate IBC for logits thresholds
logits_sweep_df["IBC"] = (logits_sweep_df["avg_perf"] - P_SLM) / (logits_sweep_df["avg_cost"] - C_SLM_baseline)
logits_sweep_df["delta_IBC"] = ((logits_sweep_df["IBC"] - IBC_BASE) / IBC_BASE) * 100

# Calculate IBC for entropy thresholds
entropy_sweep_df["IBC"] = (entropy_sweep_df["avg_perf"] - P_SLM) / (entropy_sweep_df["avg_cost"] - C_SLM_baseline)
entropy_sweep_df["delta_IBC"] = ((entropy_sweep_df["IBC"] - IBC_BASE) / IBC_BASE) * 100

# Update combined DataFrame
sweep_df = pd.concat([logits_sweep_df, entropy_sweep_df], ignore_index=True)

# Display the sweep results with IBC metrics - Logits only
print(f"\n{'='*60}")
print(f"Logits-Only Threshold Sweep with IBC Metrics for {DATASET_NAME}")
print(f"{'='*60}")
display_cols = ["logits_threshold", "entropy_threshold", "lm1_ratio", "lm2_ratio", 
                "avg_perf", "avg_cost", "IBC", "delta_IBC"]
print(logits_sweep_df[display_cols].round(6).to_string(index=False))
print(f"{'='*60}\n")

# Display the sweep results with IBC metrics - Entropy only
print(f"\n{'='*60}")
print(f"Entropy-Only Threshold Sweep with IBC Metrics for {DATASET_NAME}")
print(f"{'='*60}")
print(entropy_sweep_df[display_cols].round(6).to_string(index=False))
print(f"{'='*60}\n")

# Display interpretation
print("IBC Metric Interpretation:")
print(f"  IBC_BASE (baseline): {IBC_BASE:.6f} - Benefit of always using LLM over SLM")
print(f"  Positive Δ_IBC: Method is more cost-effective than baseline")
print(f"  Negative Δ_IBC: Method is less cost-effective than baseline")
print()


# Save separate summaries
logits_summary_path = os.path.join(OUTPUT_DIR, f"logits_only_sweep_summary_{DATASET_NAME}.json")
logits_summary = logits_sweep_df.to_dict(orient="records")
with open(logits_summary_path, "w", encoding="utf-8") as f:
    json.dump(logits_summary, f, ensure_ascii=False, indent=2)

entropy_summary_path = os.path.join(OUTPUT_DIR, f"entropy_only_sweep_summary_{DATASET_NAME}.json")
entropy_summary = entropy_sweep_df.to_dict(orient="records")
with open(entropy_summary_path, "w", encoding="utf-8") as f:
    json.dump(entropy_summary, f, ensure_ascii=False, indent=2)

print(f"Logits-only summary saved to: {logits_summary_path}")
print(f"Entropy-only summary saved to: {entropy_summary_path}\n")




IBC (Incremental Benefit Per Cost) Baseline
P_SLM (SLM average performance): 0.3019
P_LLM (LLM average performance): 0.3853
C_SLM (SLM cost): 5.00
C_LLM (LLM cost): 25.00
IBC_BASE: 0.004166


Logits-Only Threshold Sweep with IBC Metrics for qasper_short
 logits_threshold entropy_threshold  lm1_ratio  lm2_ratio  avg_perf  avg_cost      IBC  delta_IBC
        -0.973626              None      0.998      0.002  0.302848      5.04 0.022727 445.478226
        -0.865906              None      0.994      0.006  0.306242      5.12 0.035865 760.802920
        -0.758186              None      0.988      0.012  0.307260      5.24 0.022175 432.226056
        -0.650467              None      0.975      0.025  0.308617      5.50 0.013358 220.601522
        -0.542747              None      0.954      0.046  0.309875      5.92 0.008626 107.043919
        -0.435027              None      0.907      0.093  0.320810      6.86 0.010146 143.510314
        -0.327308              None      0.797      0.203  

  sweep_df = pd.concat([logits_sweep_df, entropy_sweep_df], ignore_index=True)
