# MLP Verifier Training for Router

This notebook trains small MLP models to predict whether LLM or SLM should be used based on confidence scores.


In [None]:
# Required libraries
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from typing import List, Dict, Tuple
from metric_compute import f1_score, normalize_cnli_label
from prepare_router_data import compute_perf

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cpu


In [2]:
# Configuration
DATASETS = ["cnli_short", "coqa_short", "narrative_qa_short", "qasper_short"]
ROUTER_DATA_DIR = "router_data"
MODEL_SAVE_DIR = "mlp_verifier_models"
RESULTS_DIR = "mlp_verifier_results"

# Create directories
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# Cost parameters
C_SLM = 1.0
C_LLM = 20.0
C_ver = 4 * C_SLM
COST_LM1 = C_SLM + C_ver  # 5.0
COST_LM2 = C_SLM + C_ver + C_LLM  # 25.0

# Training parameters
TRAIN_RATIO = 0.8
BATCH_SIZE = 32
NUM_EPOCHS = 100
LEARNING_RATE = 0.001

# Logits thresholds to test
LOGITS_THRESHOLDS = np.arange(0.1, 1.0, 0.1)

print(f"Cost Parameters:")
print(f"  C_SLM: {C_SLM}")
print(f"  C_LLM: {C_LLM}")
print(f"  C_ver: {C_ver}")
print(f"  COST_LM1: {COST_LM1}")
print(f"  COST_LM2: {COST_LM2}")


Cost Parameters:
  C_SLM: 1.0
  C_LLM: 20.0
  C_ver: 4.0
  COST_LM1: 5.0
  COST_LM2: 25.0


In [3]:
# Define MLP model with structure: 32-64-64-64-32
class MLPVerifier(nn.Module):
    def __init__(self, input_dim=1):
        super(MLPVerifier, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.layers(x)

# Dataset class
class ConfidenceDataset(Dataset):
    def __init__(self, confidences, labels):
        self.confidences = torch.FloatTensor(confidences).unsqueeze(1)
        self.labels = torch.FloatTensor(labels).unsqueeze(1)
    
    def __len__(self):
        return len(self.confidences)
    
    def __getitem__(self, idx):
        return self.confidences[idx], self.labels[idx]


In [None]:
def train_mlp_verifier(dataset_name: str, confidences: np.ndarray, labels: np.ndarray) -> MLPVerifier:
    """
    Train MLP verifier for a dataset.
    
    Args:
        dataset_name: Name of the dataset
        confidences: Array of confidence scores
        labels: Array of binary labels (1 if LLM better, 0 if SLM better)
    
    Returns:
        Trained MLP model
    """
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        confidences, labels, test_size=1-TRAIN_RATIO, random_state=42, stratify=labels
    )
    
    # Create datasets and dataloaders
    train_dataset = ConfidenceDataset(X_train, y_train)
    test_dataset = ConfidenceDataset(X_test, y_test)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # Initialize model
    model = MLPVerifier(input_dim=1).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # Training loop
    best_test_acc = 0.0
    patience = 20  # Increased patience to allow more training
    patience_counter = 0
    
    for epoch in range(NUM_EPOCHS):
        # Training
        model.train()
        train_loss = 0.0
        for conf, lab in train_loader:
            conf, lab = conf.to(device), lab.to(device)
            optimizer.zero_grad()
            outputs = model(conf)
            loss = criterion(outputs, lab)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Validation
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for conf, lab in test_loader:
                conf, lab = conf.to(device), lab.to(device)
                outputs = model(conf)
                loss = criterion(outputs, lab)
                test_loss += loss.item()
                predicted = (outputs > 0.5).float()
                total += lab.size(0)
                correct += (predicted == lab).sum().item()
        
        test_acc = correct / total
        
        # Print every epoch for first 20 epochs, then every 10
        if (epoch + 1) <= 20 or (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {train_loss/len(train_loader):.4f}, "
                  f"Test Loss: {test_loss/len(test_loader):.4f}, Test Acc: {test_acc:.4f}, "
                  f"Best: {best_test_acc:.4f}, Patience: {patience_counter}/{patience}")
        
        # Early stopping
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            patience_counter = 0
            if (epoch + 1) <= 20 or (epoch + 1) % 10 == 0:
                print(f"  -> New best test accuracy!")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1} (no improvement for {patience} epochs)")
                break
    
    print(f"\nTraining completed. Best test accuracy: {best_test_acc:.4f}")
    return model


In [7]:
def calculate_delta_ibc(p_slm: float, p_llm: float, p_m: float, c_m: float, 
                        c_slm: float = COST_LM1, c_llm: float = COST_LM2) -> float:
    """
    Calculate delta IBC.
    
    IBC_M = (P_M - P_SLM) / (C_M - C_SLM)
    IBC_BASE = (P_LLM - P_SLM) / (C_LLM - C_SLM)
    delta_IBC = ((IBC_M - IBC_BASE) / IBC_BASE) * 100
    """
    if c_m <= c_slm:
        return -100.0
    
    ibc_m = (p_m - p_slm) / (c_m - c_slm)
    ibc_base = (p_llm - p_slm) / (c_llm - c_slm)
    
    if ibc_base == 0:
        return 0.0
    
    delta_ibc = ((ibc_m - ibc_base) / ibc_base) * 100
    return delta_ibc


In [None]:
# Process each dataset
all_results = {}

for dataset_name in DATASETS:
    print(f"\n{'='*60}")
    print(f"Processing {dataset_name}")
    print(f"{'='*60}")
    
    # Load router data
    router_data_path = os.path.join(ROUTER_DATA_DIR, f"router_data_{dataset_name}.jsonl")
    if not os.path.exists(router_data_path):
        print(f"Warning: {router_data_path} not found, skipping...")
        continue
    
    df = pd.read_json(router_data_path, lines=True, orient="records")
    print(f"Loaded {len(df)} samples")
    
    # Create labels: 1 if LLM better, 0 if SLM better
    labels = (df["perf_llm"] > df["perf_slm"]).astype(int).values
    confidences = df["slm_confidence"].values
    
    print(f"Label distribution: {np.bincount(labels)}")
    print(f"Confidence range: [{confidences.min():.2f}, {confidences.max():.2f}]")
    
    # Train MLP verifier
    print("\nTraining MLP verifier...")
    model = train_mlp_verifier(dataset_name, confidences, labels)
    
    # Save model
    model_save_path = os.path.join(MODEL_SAVE_DIR, f"mlp_verifier_{dataset_name}.pth")
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to: {model_save_path}")
    
    # Evaluate on all data with different logits thresholds
    model.eval()
    confidences_tensor = torch.FloatTensor(confidences).unsqueeze(1).to(device)
    
    with torch.no_grad():
        mlp_predictions = model(confidences_tensor).cpu().numpy().flatten()
    
    # Test different logits thresholds
    threshold_results = []
    
    for threshold in LOGITS_THRESHOLDS:
        # Combine MLP prediction with threshold
        # If MLP predicts 1 (LLM better) OR confidence >= threshold, use LLM
        # Otherwise use SLM
        use_llm = (mlp_predictions > 0.5) | (confidences >= threshold)
        
        # Get final predictions and performance
        final_perfs = np.where(use_llm, df["perf_llm"].values, df["perf_slm"].values)
        final_costs = np.where(use_llm, COST_LM2, COST_LM1)
        
        # Calculate metrics
        avg_perf = final_perfs.mean()
        avg_cost = final_costs.mean()
        lm1_count = (~use_llm).sum()
        lm2_count = use_llm.sum()
        
        # Calculate delta IBC
        p_slm = df["perf_slm"].mean()
        p_llm = df["perf_llm"].mean()
        delta_ibc = calculate_delta_ibc(p_slm, p_llm, avg_perf, avg_cost)
        
        threshold_results.append({
            "threshold": float(threshold),
            "avg_perf": float(avg_perf),
            "avg_cost": float(avg_cost),
            "lm1_count": int(lm1_count),
            "lm2_count": int(lm2_count),
            "lm1_ratio": float(lm1_count / len(df)),
            "lm2_ratio": float(lm2_count / len(df)),
            "delta_ibc": float(delta_ibc),
            "p_slm": float(p_slm),
            "p_llm": float(p_llm)
        })
    
    all_results[dataset_name] = threshold_results
    
    # Print summary
    print(f"\n{'='*60}")
    print(f"Results Summary for {dataset_name}")
    print(f"{'='*60}")
    results_df = pd.DataFrame(threshold_results)
    print(results_df[["threshold", "avg_perf", "avg_cost", "lm1_ratio", "lm2_ratio", "delta_ibc"]].to_string(index=False))
    
    # Save results
    results_path = os.path.join(RESULTS_DIR, f"mlp_verifier_results_{dataset_name}.json")
    with open(results_path, 'w', encoding='utf-8') as f:
        json.dump(threshold_results, f, indent=2)
    print(f"\nResults saved to: {results_path}")


In [None]:
# Generate plots for each dataset
for dataset_name in DATASETS:
    if dataset_name not in all_results:
        continue
    
    results = all_results[dataset_name]
    thresholds = [r["threshold"] for r in results]
    performances = [r["avg_perf"] for r in results]
    delta_ibcs = [r["delta_ibc"] for r in results]
    
    # Create figure with dual y-axes (wider to accommodate all x-axis labels)
    fig, ax1 = plt.subplots(figsize=(max(12, len(thresholds) * 1.2), 6))
    
    # Left y-axis: Performance
    color1 = 'tab:blue'
    ax1.set_xlabel('Logits Threshold', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Performance (Accuracy/F1)', color=color1, fontsize=12, fontweight='bold')
    line1 = ax1.plot(thresholds, performances, color=color1, marker='o', linewidth=2, 
                     markersize=6, label='Performance')
    ax1.tick_params(axis='y', labelcolor=color1)
    ax1.grid(True, alpha=0.3)
    
    # Set x-axis ticks to show all threshold values
    ax1.set_xticks(thresholds)
    ax1.set_xticklabels([f'{t:.1f}' for t in thresholds], rotation=45, ha='right', fontsize=9)
    
    # Right y-axis: Delta IBC
    ax2 = ax1.twinx()
    color2 = 'tab:red'
    ax2.set_ylabel('Delta IBC (%)', color=color2, fontsize=12, fontweight='bold')
    line2 = ax2.plot(thresholds, delta_ibcs, color=color2, marker='s', linewidth=2, 
                     markersize=6, label='Delta IBC', linestyle='--')
    ax2.tick_params(axis='y', labelcolor=color2)
    
    # Add title
    plt.title(f'{dataset_name.upper()} - MLP Verifier Results\nPerformance vs Delta IBC', 
              fontsize=14, fontweight='bold', pad=20)
    
    # Add legend
    lines = line1 + line2
    labels = [l.get_label() for l in lines]
    ax1.legend(lines, labels, loc='best', fontsize=10)
    
    # Adjust layout
    fig.tight_layout()
    
    # Save plot (both PDF and PNG)
    plot_path_pdf = os.path.join(RESULTS_DIR, f'{dataset_name}_mlp_verifier_plot.pdf')
    plt.savefig(plot_path_pdf, dpi=300, bbox_inches='tight')
    print(f"Plot saved to: {plot_path_pdf}")
    plt.close()


In [None]:
# Create summary table across all datasets
summary_data = []
for dataset_name in DATASETS:
    if dataset_name not in all_results:
        continue
    
    results = all_results[dataset_name]
    for r in results:
        summary_data.append({
            "dataset": dataset_name,
            "threshold": r["threshold"],
            "avg_perf": r["avg_perf"],
            "avg_cost": r["avg_cost"],
            "delta_ibc": r["delta_ibc"],
            "lm1_ratio": r["lm1_ratio"],
            "lm2_ratio": r["lm2_ratio"]
        })

summary_df = pd.DataFrame(summary_data)

# Save summary
summary_path = os.path.join(RESULTS_DIR, "mlp_verifier_summary.csv")
summary_df.to_csv(summary_path, index=False)
print(f"Summary saved to: {summary_path}")

# Display summary
print("\n" + "="*60)
print("Summary Table")
print("="*60)
print(summary_df.to_string(index=False))
