In [None]:
from dataclasses import dataclass
from typing import Literal, Optional, Dict
import math
from scipy.stats import beta as beta_dist
from scipy.stats import betabinom
from scipy.stats import norm
import numpy as np
from scipy import stats
from typing import Dict, Any, Tuple
from scipy import stats

import matplotlib.pyplot as plt

@dataclass
class SSBCResult:
    alpha_target: float
    alpha_corrected: float
    u_star: int
    n: int
    satisfied_mass: float   # probability that coverage >= target
    mode: Literal["beta", "beta-binomial"]
    details: Dict


def ssbc_correct(
    alpha_target: float,
    n: int,
    delta: float,
    *,
    mode: Literal["beta", "beta-binomial"] = "beta",
    m: Optional[int] = None,
    bracket_width: Optional[int] = None,  # Δ in Algorithm 1
) -> SSBCResult:
    """
    Small-Sample Beta Correction (SSBC), corrected acceptance rule.
    
    Find the largest α' = u/(n+1) ≤ α_target such that:
    P(Coverage(α') ≥ 1 - α_target) ≥ 1 - δ
    
    where Coverage(α') ~ Beta(n+1-u, u) for infinite test window.
    
    Args:
        alpha_target: Target miscoverage rate
        n: Calibration set size
        delta: Risk tolerance (PAC parameter)
        mode: "beta" for infinite test window, "beta-binomial" for finite
        m: Test window size (for beta-binomial mode)
        bracket_width: Search radius around initial guess (default: min(n, 10))
    """
    if not (0.0 < alpha_target < 1.0):
        raise ValueError("alpha_target must be in (0,1).")
    if n < 1:
        raise ValueError("n must be >= 1.")
    if not (0.0 < delta < 1.0):
        raise ValueError("delta must be in (0,1).")
    if mode not in ("beta", "beta-binomial"):
        raise ValueError("mode must be 'beta' or 'beta-binomial'.")
    
    # Maximum u to search (α' must be ≤ α_target)
    u_max = min(n, math.floor(alpha_target * (n + 1)))
    target_coverage = 1 - alpha_target
    
    # Initial guess for u using normal approximation to Beta distribution
    # We want P(Beta(n+1-u, u) >= target_coverage) ≈ 1-δ
    # Using normal approximation: u ≈ u_target - z_δ * sqrt(u_target)
    # where u_target = (n+1)*α_target and z_δ = Φ^(-1)(1-δ)
    u_target = (n + 1) * alpha_target
    z_delta = norm.ppf(1 - delta)  # quantile function (inverse CDF)
    u_star_guess = max(1, math.floor(u_target - z_delta * math.sqrt(u_target)))
    
    # Clamp to valid range
    u_star_guess = min(u_max, u_star_guess)
    
    # Bracket width (Δ in Algorithm 1)
    if bracket_width is None:
        # Adaptive bracket: wider for small n, scales with √n for large n
        # For large n, the uncertainty scales as √u_target ~ (n*α)^(1/2)
        bracket_width = max(5, min(int(2 * z_delta * math.sqrt(u_target)), n // 10))
        bracket_width = min(bracket_width, 100)  # cap at 100 for efficiency
    
    # Search bounds - ensure we don't go outside [1, u_max]
    u_min = max(1, u_star_guess - bracket_width)
    u_search_max = min(u_max, u_star_guess + bracket_width)
    
    # If the guess is way off (e.g., guess > u_max), fall back to full search
    if u_min > u_search_max:
        u_min = 1
        u_search_max = u_max
    
    if mode == "beta-binomial":
        m_eval = m if m is not None else n
        if m_eval < 1:
            raise ValueError("m must be >= 1 for beta-binomial mode.")
        k_thresh = math.ceil(target_coverage * m_eval)
    
    u_star = None
    mass_star = None
    
    # Search from u_min up to u_search_max to find the largest u that satisfies the condition
    # Keep updating u_star as we find larger values that work
    search_log = []
    for u in range(u_min, u_search_max + 1):
        # When we calibrate at α' = u/(n+1), coverage follows:
        a = n + 1 - u  # first parameter
        b = u          # second parameter
        alpha_prime = u / (n + 1)
        
        if mode == "beta":
            # P(Coverage ≥ target_coverage) where Coverage ~ Beta(a, b)
            # Using: P(X >= x) = 1 - CDF(x) for continuous distributions
            ptail = 1 - beta_dist.cdf(target_coverage, a, b)
        else:
            # P(X ≥ k_thresh) where X ~ BetaBinomial(m, a, b)
            ptail = betabinom.sf(k_thresh - 1, m_eval, a, b)
        
        passes = ptail >= 1 - delta
        search_log.append({
            'u': u,
            'alpha_prime': alpha_prime,
            'a': a,
            'b': b,
            'ptail': ptail,
            'threshold': 1 - delta,
            'passes': passes
        })
        
        # Accept if probability is high enough - keep updating to find the largest
        if passes:
            u_star = u
            mass_star = ptail
    
    # If nothing passes, fall back to u=1 (most conservative)
    if u_star is None:
        u_star = 1
        a = n + 1 - u_star
        b = u_star
        mass_star = (1 - beta_dist.cdf(target_coverage, a, b)
                     if mode == "beta" else
                     betabinom.sf(k_thresh - 1, (m if m else n), a, b))
    
    alpha_corrected = u_star / (n + 1)
    
    return SSBCResult(
        alpha_target=alpha_target,
        alpha_corrected=alpha_corrected,
        u_star=u_star,
        n=n,
        satisfied_mass=mass_star,
        mode=mode,
        details=dict(
            u_max=u_max,
            u_star_guess=u_star_guess,
            search_range=(u_min, u_search_max),
            bracket_width=bracket_width,
            delta=delta,
            m=(m if (mode == "beta-binomial") else None),
            acceptance_rule="P(Coverage >= target) >= 1-delta",
            search_log=search_log,
        ),
    )

class BinaryClassifierSimulator:
    def __init__(self, p_class1, beta_params_class0, beta_params_class1, seed=None):
        """
        Simulate binary classification data with probabilities from Beta distributions.
        
        Parameters:
        -----------
        p_class1 : float
            Probability of drawing class 1 (class imbalance parameter)
        beta_params_class0 : tuple (a, b)
            Beta distribution parameters for p(class=1) when true label is 0
        beta_params_class1 : tuple (a, b)
            Beta distribution parameters for p(class=1) when true label is 1
        seed : int, optional
            Random seed for reproducibility
        """
        self.p_class1 = p_class1
        self.p_class0 = 1.0 - p_class1
        self.a0, self.b0 = beta_params_class0
        self.a1, self.b1 = beta_params_class1
        self.rng = np.random.default_rng(seed)
    
    def generate(self, n_samples):
        """
        Generate n_samples of (label, p(class=0), p(class=1))
        
        Parameters:
        -----------
        n_samples : int
            Number of samples to generate
            
        Returns:
        --------
        labels : np.ndarray, shape (n_samples,)
            True binary labels (0 or 1)
        probs : np.ndarray, shape (n_samples, 2)
            Classification probabilities [p(class=0), p(class=1)]
        """
        # Draw true labels according to class distribution
        labels = self.rng.choice([0, 1], size=n_samples, p=[self.p_class0, self.p_class1])
        
        # Initialize probability array
        probs = np.zeros((n_samples, 2))
        
        # For each label, draw classification probability from appropriate Beta
        for i, label in enumerate(labels):
            if label == 0:
                # True label is 0: sample p(class=1) from Beta(a0, b0)
                p_class1 = self.rng.beta(self.a0, self.b0)
            else:
                # True label is 1: sample p(class=1) from Beta(a1, b1)
                p_class1 = self.rng.beta(self.a1, self.b1)
            
            probs[i, 1] = p_class1  # p(class=1)
            probs[i, 0] = 1.0 - p_class1  # p(class=0)
        
        return labels, probs

def clopper_pearson_intervals(labels, confidence=0.95):
    """
    Compute Clopper-Pearson (exact binomial) confidence intervals for class prevalences.
    
    Parameters:
    -----------
    labels : np.ndarray
        Binary labels (0 or 1)
    confidence : float, default=0.95
        Confidence level (e.g., 0.95 for 95% CI)
        
    Returns:
    --------
    intervals : dict
        Dictionary with keys 0 and 1, each containing:
        - 'count': number of samples in this class
        - 'proportion': observed proportion
        - 'lower': lower bound of CI
        - 'upper': upper bound of CI
    """
    alpha = 1 - confidence
    n_total = len(labels)
    
    intervals = {}
    
    for label in [0, 1]:
        count = np.sum(labels == label)
        proportion = count / n_total
        
        # Clopper-Pearson uses Beta distribution quantiles
        # Lower bound: Beta(count, n-count+1) at alpha/2
        # Upper bound: Beta(count+1, n-count) at 1-alpha/2
        
        if count == 0:
            lower = 0.0
            upper = stats.beta.ppf(1 - alpha/2, count + 1, n_total - count)
        elif count == n_total:
            lower = stats.beta.ppf(alpha/2, count, n_total - count + 1)
            upper = 1.0
        else:
            lower = stats.beta.ppf(alpha/2, count, n_total - count + 1)
            upper = stats.beta.ppf(1 - alpha/2, count + 1, n_total - count)
        
        intervals[label] = {
            'count': count,
            'proportion': proportion,
            'lower': lower,
            'upper': upper
        }
    
    return intervals

def split_by_class(labels, probs):
    """
    Split calibration data by true class for Mondrian conformal prediction.
    
    Parameters:
    -----------
    labels : np.ndarray, shape (n,)
        True binary labels (0 or 1)
    probs : np.ndarray, shape (n, 2)
        Classification probabilities [P(class=0), P(class=1)]
        
    Returns:
    --------
    class_data : dict
        Dictionary with keys 0 and 1, each containing:
        - 'labels': labels for this class (all same value)
        - 'probs': probabilities for samples in this class
        - 'indices': original indices (for tracking)
        - 'n': number of samples in this class
    """
    class_data = {}
    
    for label in [0, 1]:
        mask = labels == label
        indices = np.where(mask)[0]
        
        class_data[label] = {
            'labels': labels[mask],
            'probs': probs[mask],
            'indices': indices,
            'n': np.sum(mask)
        }
    
    return class_data


In [None]:
# --- quick sanity check ---
if __name__ == "__main__":
    # Example: alpha=0.10, n=50, delta=0.10 (90% PAC)
    print("="*60)
    print("n=50 test:")
    print("="*60)
    result_beta = ssbc_correct(0.10, n=50, delta=0.10, mode="beta")
    print("Beta mode:")
    print(f"  alpha_target:    {result_beta.alpha_target}")
    print(f"  alpha_corrected: {result_beta.alpha_corrected:.6f}")
    print(f"  u_star:          {result_beta.u_star}")
    print(f"  u_star_guess:    {result_beta.details['u_star_guess']}")
    print(f"  search_range:    {result_beta.details['search_range']}")
    print(f"  satisfied_mass:  {result_beta.satisfied_mass:.6f}")
    print(f"\n  Search log:")
    for entry in result_beta.details['search_log']:
        status = "✓ PASS" if entry['passes'] else "✗ FAIL"
        print(f"    u={entry['u']}: α'={entry['alpha_prime']:.4f}, "
              f"Beta({entry['a']},{entry['b']}), "
              f"P(C≥0.9)={entry['ptail']:.4f} {status}")
    
    print("\n" + "="*60)
    print("n=5000 test:")
    print("="*60)
    result_large = ssbc_correct(0.10, n=50000, delta=0.10, mode="beta", bracket_width=15)
    print("Beta mode:")
    print(f"  alpha_target:    {result_large.alpha_target}")
    print(f"  alpha_corrected: {result_large.alpha_corrected:.6f}")
    print(f"  u_star:          {result_large.u_star}")
    print(f"  u_star_guess:    {result_large.details['u_star_guess']}")
    print(f"  search_range:    {result_large.details['search_range']}")
    print(f"  satisfied_mass:  {result_large.satisfied_mass:.6f}")
    print(f"  searched {len(result_large.details['search_log'])} values instead of 50")
    print(f"\n  Search log :")
    log = result_large.details['search_log']
    for entry in log:
        status = "✓ PASS" if entry['passes'] else "✗ FAIL"
        print(f"    u={entry['u']}: α'={entry['alpha_prime']:.4f}, "
              f"Beta({entry['a']},{entry['b']}), "
              f"P(C≥0.9)={entry['ptail']:.4f} {status}")

In [None]:


# Test it
if __name__ == "__main__":
    # Simulate imbalanced data: 10% positive class
    # Class 0: Beta(2, 8) → mean p(class=1) = 0.2 (low scores, correct)
    # Class 1: Beta(8, 2) → mean p(class=1) = 0.8 (high scores, correct)
    
    sim = BinaryClassifierSimulator(
        p_class1=0.10,
        beta_params_class0=(2, 8),
        beta_params_class1=(8, 2),
        seed=42
    )
    
    labels, probs = sim.generate(n_samples=50)
    
    print(f"Shape of labels: {labels.shape}")
    print(f"Shape of probs: {probs.shape}")
    print(f"\nFirst 5 samples:")
    for i in range(5):
        print(f"Label={labels[i]}, P(0)={probs[i,0]:.3f}, P(1)={probs[i,1]:.3f}")
    
    print(f"\nClass balance: {np.bincount(labels)}")
    print(f"Mean P(1) when true=0: {probs[labels==0, 1].mean():.3f}")
    print(f"Mean P(1) when true=1: {probs[labels==1, 1].mean():.3f}")

In [None]:

# Test it with the simulator
if __name__ == "__main__":
    
    # Generate data
    sim = BinaryClassifierSimulator(
        p_class1=0.10,
        beta_params_class0=(2, 8),
        beta_params_class1=(8, 2),
        seed=42
    )
    
    labels, probs = sim.generate(n_samples=100)
    
    # Compute Clopper-Pearson intervals
    intervals = clopper_pearson_intervals(labels, confidence=0.95)
    
    print("Clopper-Pearson 95% Confidence Intervals:")
    print("=" * 60)
    for label in [0, 1]:
        info = intervals[label]
        print(f"\nClass {label}:")
        print(f"  Count:      {info['count']}/{len(labels)}")
        print(f"  Proportion: {info['proportion']:.4f}")
        print(f"  95% CI:     [{info['lower']:.4f}, {info['upper']:.4f}]")
    
    # Test edge cases
    print("\n" + "=" * 60)
    print("Edge case tests:")
    
    # All zeros
    labels_zeros = np.zeros(50, dtype=int)
    intervals_zeros = clopper_pearson_intervals(labels_zeros)
    print(f"\nAll class 0: Class 1 interval = [{intervals_zeros[1]['lower']:.4f}, {intervals_zeros[1]['upper']:.4f}]")
    
    # All ones
    labels_ones = np.ones(50, dtype=int)
    intervals_ones = clopper_pearson_intervals(labels_ones)
    print(f"All class 1: Class 0 interval = [{intervals_ones[0]['lower']:.4f}, {intervals_ones[0]['upper']:.4f}]")

In [None]:
def mondrian_conformal_calibrate(class_data, alpha_target, delta, mode="beta", m=None):
    """
    Perform Mondrian (per-class) conformal calibration with SSBC correction.
    
    For each class, compute:
    1. Nonconformity scores: s(x, y) = 1 - P(y|x)
    2. SSBC-corrected alpha for PAC guarantee
    3. Conformal quantile threshold
    4. Singleton error rate bounds via PAC guarantee
    
    Then evaluate prediction set sizes on calibration data PER CLASS and MARGINALLY.
    
    Parameters:
    -----------
    class_data : dict
        Output from split_by_class()
    alpha_target : float or dict
        Target miscoverage rate for each class
        If float: same for both classes
        If dict: {0: α0, 1: α1} for per-class control
    delta : float or dict
        PAC risk tolerance for each class
        If float: same for both classes
        If dict: {0: δ0, 1: δ1} for per-class control
    mode : str
        "beta" (infinite test) or "beta-binomial" (finite test)
    m : int, optional
        Test window size for beta-binomial mode
        
    Returns:
    --------
    calibration_result : dict
        Dictionary with keys 0 and 1, each containing calibration info
        
    prediction_stats : dict
        Dictionary with keys:
        - 0, 1: per-class statistics (conditioned on true label)
        - 'marginal': overall statistics (ignoring true labels)
    """
    # Handle scalar or dict inputs for alpha and delta
    if isinstance(alpha_target, (int, float)):
        alpha_dict = {0: alpha_target, 1: alpha_target}
    else:
        alpha_dict = alpha_target
    
    if isinstance(delta, (int, float)):
        delta_dict = {0: delta, 1: delta}
    else:
        delta_dict = delta
    
    calibration_result = {}
    
    # Step 1: Calibrate per class
    for label in [0, 1]:
        data = class_data[label]
        n = data['n']
        alpha_class = alpha_dict[label]
        delta_class = delta_dict[label]
        
        if n == 0:
            calibration_result[label] = {
                'n': 0,
                'alpha_target': alpha_class,
                'alpha_corrected': None,
                'delta': delta_class,
                'threshold': None,
                'scores': np.array([]),
                'ssbc_result': None,
                'error': 'No calibration samples for this class'
            }
            continue
        
        # Compute nonconformity scores: s(x, y) = 1 - P(y|x)
        true_class_probs = data['probs'][:, label]
        scores = 1.0 - true_class_probs
        
        # Apply SSBC to get corrected alpha
        ssbc_result = ssbc_correct(
            alpha_target=alpha_class,
            n=n,
            delta=delta_class,
            mode=mode,
            m=m
        )
        
        alpha_corrected = ssbc_result.alpha_corrected
        
        # Compute conformal quantile threshold
        k = int(np.ceil((n + 1) * (1 - alpha_corrected)))
        k = min(k, n)
        
        sorted_scores = np.sort(scores)
        threshold = sorted_scores[k - 1] if k > 0 else sorted_scores[0]
        
        calibration_result[label] = {
            'n': n,
            'alpha_target': alpha_class,
            'alpha_corrected': alpha_corrected,
            'delta': delta_class,
            'threshold': threshold,
            'scores': sorted_scores,
            'ssbc_result': ssbc_result,
            'k': k
        }
    
    # Step 2: Evaluate prediction sets
    if calibration_result[0].get('threshold') is None or calibration_result[1].get('threshold') is None:
        return calibration_result, {
            'error': 'Cannot compute prediction sets - missing threshold for at least one class'
        }
    
    threshold_0 = calibration_result[0]['threshold']
    threshold_1 = calibration_result[1]['threshold']
    
    # Helper to compute Clopper-Pearson interval
    def cp_interval(count, total, confidence=0.95):
        alpha = 1 - confidence
        proportion = count / total if total > 0 else 0.0
        
        if total == 0:
            return {
                'count': count,
                'proportion': 0.0,
                'lower': 0.0,
                'upper': 0.0
            }
        
        if count == 0:
            lower = 0.0
            upper = stats.beta.ppf(1 - alpha/2, count + 1, total - count)
        elif count == total:
            lower = stats.beta.ppf(alpha/2, count, total - count + 1)
            upper = 1.0
        else:
            lower = stats.beta.ppf(alpha/2, count, total - count + 1)
            upper = stats.beta.ppf(1 - alpha/2, count + 1, total - count)
        
        return {
            'count': count,
            'proportion': proportion,
            'lower': lower,
            'upper': upper
        }
    
    prediction_stats = {}
    
    # Step 2a: Evaluate per true class
    for true_label in [0, 1]:
        data = class_data[true_label]
        n_class = data['n']
        
        if n_class == 0:
            prediction_stats[true_label] = {
                'n_class': 0,
                'error': 'No samples in this class'
            }
            continue
        
        probs = data['probs']
        prediction_sets = []
        
        for i in range(n_class):
            score_0 = 1.0 - probs[i, 0]
            score_1 = 1.0 - probs[i, 1]
            
            pred_set = []
            if score_0 <= threshold_0:
                pred_set.append(0)
            if score_1 <= threshold_1:
                pred_set.append(1)
            
            prediction_sets.append(pred_set)
        
        # Count set sizes and correctness
        n_abstentions = sum(len(ps) == 0 for ps in prediction_sets)
        n_doublets = sum(len(ps) == 2 for ps in prediction_sets)
        
        n_singletons_correct = sum(ps == [true_label] for ps in prediction_sets)
        n_singletons_incorrect = sum(len(ps) == 1 and true_label not in ps for ps in prediction_sets)
        n_singletons_total = n_singletons_correct + n_singletons_incorrect
        
        # PAC bounds
        n_escalations = n_doublets + n_abstentions
        
        if n_escalations > 0 and n_singletons_total > 0:
            rho = n_singletons_total / n_escalations
            kappa = n_abstentions / n_escalations
            alpha_singlet_bound = alpha_dict[true_label] * (1 + 1/rho) - kappa/rho
            alpha_singlet_observed = n_singletons_incorrect / n_singletons_total if n_singletons_total > 0 else 0.0
        else:
            rho = None
            kappa = None
            alpha_singlet_bound = None
            alpha_singlet_observed = None
        
        prediction_stats[true_label] = {
            'n_class': n_class,
            'abstentions': cp_interval(n_abstentions, n_class),
            'singletons': cp_interval(n_singletons_total, n_class),
            'singletons_correct': cp_interval(n_singletons_correct, n_class),
            'singletons_incorrect': cp_interval(n_singletons_incorrect, n_class),
            'doublets': cp_interval(n_doublets, n_class),
            'prediction_sets': prediction_sets,
            'pac_bounds': {
                'rho': rho,
                'kappa': kappa,
                'alpha_singlet_bound': alpha_singlet_bound,
                'alpha_singlet_observed': alpha_singlet_observed,
                'n_singletons': n_singletons_total,
                'n_escalations': n_escalations
            }
        }
    
    # Step 2b: MARGINAL ANALYSIS (ignoring true labels)
    # Reconstruct full dataset
    all_labels = np.concatenate([class_data[0]['labels'], class_data[1]['labels']])
    all_probs = np.concatenate([class_data[0]['probs'], class_data[1]['probs']], axis=0)
    all_indices = np.concatenate([class_data[0]['indices'], class_data[1]['indices']])
    
    # Sort back to original order
    sort_idx = np.argsort(all_indices)
    all_labels = all_labels[sort_idx]
    all_probs = all_probs[sort_idx]
    
    n_total = len(all_labels)
    
    # Compute prediction sets for all samples
    all_prediction_sets = []
    for i in range(n_total):
        score_0 = 1.0 - all_probs[i, 0]
        score_1 = 1.0 - all_probs[i, 1]
        
        pred_set = []
        if score_0 <= threshold_0:
            pred_set.append(0)
        if score_1 <= threshold_1:
            pred_set.append(1)
        
        all_prediction_sets.append(pred_set)
    
    # Count overall set sizes
    n_abstentions_total = sum(len(ps) == 0 for ps in all_prediction_sets)
    n_singletons_total = sum(len(ps) == 1 for ps in all_prediction_sets)
    n_doublets_total = sum(len(ps) == 2 for ps in all_prediction_sets)
    
    # Break down singletons by predicted class
    n_singletons_pred_0 = sum(ps == [0] for ps in all_prediction_sets)
    n_singletons_pred_1 = sum(ps == [1] for ps in all_prediction_sets)
    
    # Compute overall coverage
    n_covered = sum(all_labels[i] in all_prediction_sets[i] for i in range(n_total))
    coverage = n_covered / n_total
    
    # Compute errors on singletons
    singleton_mask = [len(ps) == 1 for ps in all_prediction_sets]
    n_singletons_covered = sum(
        all_labels[i] in all_prediction_sets[i] 
        for i in range(n_total) 
        if singleton_mask[i]
    )
    n_singletons_errors = n_singletons_total - n_singletons_covered
    
    # Overall PAC bounds (using weighted average of alphas for interpretation)
    # Note: This is approximate since we have different α per class
    n_escalations_total = n_doublets_total + n_abstentions_total
    
    if n_escalations_total > 0 and n_singletons_total > 0:
        rho_marginal = n_singletons_total / n_escalations_total
        kappa_marginal = n_abstentions_total / n_escalations_total
        
        # Weighted average alpha (by class size)
        n_0 = class_data[0]['n']
        n_1 = class_data[1]['n']
        alpha_weighted = (n_0 * alpha_dict[0] + n_1 * alpha_dict[1]) / (n_0 + n_1)
        
        alpha_singlet_bound_marginal = alpha_weighted * (1 + 1/rho_marginal) - kappa_marginal/rho_marginal
        alpha_singlet_observed_marginal = n_singletons_errors / n_singletons_total
    else:
        rho_marginal = None
        kappa_marginal = None
        alpha_weighted = None
        alpha_singlet_bound_marginal = None
        alpha_singlet_observed_marginal = None
    
    prediction_stats['marginal'] = {
        'n_total': n_total,
        'coverage': {
            'count': n_covered,
            'rate': coverage,
            'ci_95': cp_interval(n_covered, n_total)
        },
        'abstentions': cp_interval(n_abstentions_total, n_total),
        'singletons': {
            **cp_interval(n_singletons_total, n_total),
            'pred_0': n_singletons_pred_0,
            'pred_1': n_singletons_pred_1,
            'errors': n_singletons_errors
        },
        'doublets': cp_interval(n_doublets_total, n_total),
        'prediction_sets': all_prediction_sets,
        'pac_bounds': {
            'rho': rho_marginal,
            'kappa': kappa_marginal,
            'alpha_weighted': alpha_weighted,
            'alpha_singlet_bound': alpha_singlet_bound_marginal,
            'alpha_singlet_observed': alpha_singlet_observed_marginal,
            'n_singletons': n_singletons_total,
            'n_escalations': n_escalations_total
        }
    }
    
    return calibration_result, prediction_stats

In [None]:
def report_prediction_stats(prediction_stats: Dict[Any, Any],
                            calibration_result: Dict[Any, Any],
                            verbose: bool = True) -> Dict[str, Any]:
    """
    Pretty/robust summary for Mondrian conformal prediction stats.

    Tolerates multiple schema shapes:
      - dicts with 'rate'/'ci_95' or 'proportion'/'lower'/'upper'
      - raw ints for counts (e.g., marginal['singletons']['pred_0'] = 339)
      - per-class singleton correct/incorrect either nested under 'singletons'
        OR as top-level aliases 'singletons_correct' / 'singletons_incorrect'.

    Also computes Clopper–Pearson CIs when missing, and splits marginal
    singleton errors by predicted class.
    """

    # ------------------------- helpers -------------------------
    def cp_interval(count: int, total: int, confidence: float = 0.95) -> Dict[str, float]:
        """Clopper–Pearson exact CI using SciPy."""
        alpha = 1 - confidence
        count = int(count)
        total = int(total)
        if total <= 0:
            return {'count': count, 'proportion': 0.0, 'lower': 0.0, 'upper': 0.0}
        p = count / total
        if count == 0:
            lower = 0.0
            upper = stats.beta.ppf(1 - alpha/2, 1, total)
        elif count == total:
            lower = stats.beta.ppf(alpha/2, total, 1)
            upper = 1.0
        else:
            lower = stats.beta.ppf(alpha/2, count, total - count + 1)
            upper = stats.beta.ppf(1 - alpha/2, count + 1, total - count)
        return {'count': count, 'proportion': float(p), 'lower': float(lower), 'upper': float(upper)}

    def as_dict(x: Any) -> Dict[str, Any]:
        return x if isinstance(x, dict) else {}

    def get_count(x: Any, default: int = 0) -> int:
        if isinstance(x, dict):
            return int(x.get('count', default))
        if isinstance(x, (int,)):
            return int(x)
        return default

    def get_rate(x: Any, default: float = 0.0) -> float:
        if isinstance(x, dict):
            if 'rate' in x: return float(x['rate'])
            if 'proportion' in x: return float(x['proportion'])
            # Some producers put only count; caller must compute rate.
            return default
        if isinstance(x, (float,)):
            return float(x)
        return default

    def get_ci_tuple(x: Any) -> Tuple[float, float]:
        if not isinstance(x, dict):
            return (0.0, 0.0)
        if 'ci_95' in x and isinstance(x['ci_95'], (tuple, list)) and len(x['ci_95']) == 2:
            return float(x['ci_95'][0]), float(x['ci_95'][1])
        lo = x.get('lower', 0.0)
        hi = x.get('upper', 0.0)
        return float(lo), float(hi)

    def ensure_ci(d: Dict[str, Any], count: int, total: int) -> Tuple[float, float, float]:
        """
        Return (rate, lo, hi). If d already has rate/CI, use them; else compute CP from count/total.
        """
        r = get_rate(d, default=None)
        lo, hi = get_ci_tuple(d)
        if r is None or (lo == 0.0 and hi == 0.0 and (count > 0 or total > 0)):
            ci = cp_interval(count, total)
            return ci['proportion'], ci['lower'], ci['upper']
        return float(r), float(lo), float(hi)

    def pct(x: float) -> str:
        return f"{x:6.2%}"

    summary: Dict[str, Any] = {}

    if verbose:
        print("=" * 80)
        print("PREDICTION SET STATISTICS (All rates with 95% Clopper–Pearson CIs)")
        print("=" * 80)

    # ----------------- per-class (conditioned on Y) -----------------
    for class_label in [0, 1]:
        if class_label not in prediction_stats:
            continue
        cls = prediction_stats[class_label]

        if isinstance(cls, dict) and 'error' in cls:
            if verbose:
                print(f"\nClass {class_label}: {cls['error']}")
            summary[class_label] = {'error': cls['error']}
            continue

        n = int(cls.get('n', cls.get('n_class', 0)))
        alpha_target = cls.get('alpha_target',
                               calibration_result.get(class_label, {}).get('alpha_target', None))
        delta = cls.get('delta',
                        calibration_result.get(class_label, {}).get('delta', None))

        abst = as_dict(cls.get('abstentions', {}))
        sing = as_dict(cls.get('singletons', {}))
        # Accept both nested and flat aliases
        sing_corr = as_dict(sing.get('correct', cls.get('singletons_correct', {})))
        sing_inc  = as_dict(sing.get('incorrect', cls.get('singletons_incorrect', {})))
        doub = as_dict(cls.get('doublets', {}))
        pac  = as_dict(cls.get('pac_bounds', {}))

        # Counts
        abst_count = get_count(abst)
        sing_count = get_count(sing)
        sing_corr_count = get_count(sing_corr)
        sing_inc_count  = get_count(sing_inc)
        doub_count = get_count(doub)

        # Ensure rates/CIs (fallback to CP if missing)
        abst_rate, abst_lo, abst_hi = ensure_ci(abst, abst_count, n)
        sing_rate, sing_lo, sing_hi = ensure_ci(sing, sing_count, n)
        sing_corr_rate, sing_corr_lo, sing_corr_hi = ensure_ci(sing_corr, sing_corr_count, n)
        sing_inc_rate,  sing_inc_lo,  sing_inc_hi  = ensure_ci(sing_inc,  sing_inc_count,  n)
        doub_rate, doub_lo, doub_hi = ensure_ci(doub, doub_count, n)

        # P(error | singleton, Y=class)
        err_given_single_ci = cp_interval(sing_inc_count, sing_count if sing_count > 0 else 1)

        class_summary = {
            'n': n,
            'alpha_target': alpha_target,
            'delta': delta,
            'abstentions': {
                'count': abst_count,
                'rate': abst_rate,
                'ci_95': (abst_lo, abst_hi)
            },
            'singletons': {
                'count': sing_count,
                'rate': sing_rate,
                'ci_95': (sing_lo, sing_hi),
                'correct': {
                    'count': sing_corr_count,
                    'rate': sing_corr_rate,
                    'ci_95': (sing_corr_lo, sing_corr_hi)
                },
                'incorrect': {
                    'count': sing_inc_count,
                    'rate': sing_inc_rate,
                    'ci_95': (sing_inc_lo, sing_inc_hi)
                },
                'error_given_singleton': {
                    'count': sing_inc_count,
                    'denom': sing_count,
                    'rate': err_given_single_ci['proportion'],
                    'ci_95': (err_given_single_ci['lower'], err_given_single_ci['upper'])
                }
            },
            'doublets': {
                'count': doub_count,
                'rate': doub_rate,
                'ci_95': (doub_lo, doub_hi)
            },
            'pac_bounds': pac
        }
        summary[class_label] = class_summary

        if verbose:
            print(f"\n{'='*80}")
            print(f"CLASS {class_label} (Conditioned on True Label = {class_label})")
            print(f"{'='*80}")
            alpha_str = f"{alpha_target:.3f}" if alpha_target is not None else "n/a"
            delta_str = f"{delta:.3f}" if delta is not None else "n/a"
            print(f"  n={n}, α_target={alpha_str}, δ={delta_str}")

            print("\nPrediction Set Breakdown:")
            print(f"  Abstentions:  {abst_count:4d} / {n:4d} = {pct(abst_rate)}  "
                  f"95% CI: [{abst_lo:.4f}, {abst_hi:.4f}]")
            print(f"  Singletons:   {sing_count:4d} / {n:4d} = {pct(sing_rate)}  "
                  f"95% CI: [{sing_lo:.4f}, {sing_hi:.4f}]")
            print(f"    ├─ Correct:   {sing_corr_count:4d} / {n:4d} = {pct(sing_corr_rate)}  "
                  f"95% CI: [{sing_corr_lo:.4f}, {sing_corr_hi:.4f}]")
            print(f"    └─ Incorrect: {sing_inc_count:4d} / {n:4d} = {pct(sing_inc_rate)}  "
                  f"95% CI: [{sing_inc_lo:.4f}, {sing_inc_hi:.4f}]")

            print(f"  Singleton error | Y={class_label}: "
                  f"{sing_inc_count:4d} / {sing_count:4d} = {pct(err_given_single_ci['proportion'])}  "
                  f"95% CI: [{err_given_single_ci['lower']:.4f}, {err_given_single_ci['upper']:.4f}]")

            print(f"\n  Doublets:     {doub_count:4d} / {n:4d} = {pct(doub_rate)}  "
                  f"95% CI: [{doub_lo:.4f}, {doub_hi:.4f}]")

            if pac and pac.get('rho', None) is not None:
                print(f"\n  PAC Singleton Error Rate (δ={delta_str}):")
                print(f"    ρ = {pac.get('rho', 0):.3f}, κ = {pac.get('kappa', 0):.3f}")
                if 'alpha_singlet_bound' in pac and 'alpha_singlet_observed' in pac:
                    bound = float(pac['alpha_singlet_bound'])
                    observed = float(pac['alpha_singlet_observed'])
                    ok = '✓' if observed <= bound else '✗'
                    print(f"    α'_bound:    {bound:.4f}")
                    print(f"    α'_observed: {observed:.4f} {ok}")

    # ----------------- marginal / deployment view -----------------
    if 'marginal' in prediction_stats:
        marg = prediction_stats['marginal']
        n_total = int(marg['n_total'])

        cov    = as_dict(marg.get('coverage', {}))
        abst_m = as_dict(marg.get('abstentions', {}))
        sing_m = as_dict(marg.get('singletons', {}))
        doub_m = as_dict(marg.get('doublets', {}))
        pac_m  = as_dict(marg.get('pac_bounds', {}))

        cov_count = get_count(cov)
        abst_m_count = get_count(abst_m)
        sing_total = get_count(sing_m)
        doub_m_count = get_count(doub_m)

        cov_rate, cov_lo, cov_hi = ensure_ci(cov, cov_count, n_total)
        abst_m_rate, abst_m_lo, abst_m_hi = ensure_ci(abst_m, abst_m_count, n_total)
        sing_m_rate, sing_m_lo, sing_m_hi = ensure_ci(sing_m, sing_total, n_total)
        doub_m_rate, doub_m_lo, doub_m_hi = ensure_ci(doub_m, doub_m_count, n_total)

        # pred_0 / pred_1 may be dicts or ints (counts)
        raw_s0 = sing_m.get('pred_0', 0)
        raw_s1 = sing_m.get('pred_1', 0)
        s0_count = get_count(raw_s0)
        s1_count = get_count(raw_s1)
        # prefer provided rate/CI, else compute off n_total
        if isinstance(raw_s0, dict):
            s0_rate, s0_lo, s0_hi = ensure_ci(raw_s0, s0_count, n_total)
        else:
            s0_ci = cp_interval(s0_count, n_total)
            s0_rate, s0_lo, s0_hi = s0_ci['proportion'], s0_ci['lower'], s0_ci['upper']
        if isinstance(raw_s1, dict):
            s1_rate, s1_lo, s1_hi = ensure_ci(raw_s1, s1_count, n_total)
        else:
            s1_ci = cp_interval(s1_count, n_total)
            s1_rate, s1_lo, s1_hi = s1_ci['proportion'], s1_ci['lower'], s1_ci['upper']

        # overall singleton errors (dict or int). Denominator should be sing_total.
        raw_s_err = sing_m.get('errors', 0)
        s_err_count = get_count(raw_s_err)
        if isinstance(raw_s_err, dict):
            s_err_rate, s_err_lo, s_err_hi = ensure_ci(raw_s_err, s_err_count, sing_total if sing_total > 0 else 1)
        else:
            se_ci = cp_interval(s_err_count, sing_total if sing_total > 0 else 1)
            s_err_rate, s_err_lo, s_err_hi = se_ci['proportion'], se_ci['lower'], se_ci['upper']

        # errors by predicted class via per-class incorrect singletons
        # (pred 0 errors happen when Y=1 singleton is wrong; pred 1 errors when Y=0 singleton is wrong)
        err_pred0_count = int(prediction_stats.get(1, {}).get('singletons', {})
                              .get('incorrect', prediction_stats.get(1, {}).get('singletons_incorrect', {}))
                              .get('count', 0))
        err_pred1_count = int(prediction_stats.get(0, {}).get('singletons', {})
                              .get('incorrect', prediction_stats.get(0, {}).get('singletons_incorrect', {}))
                              .get('count', 0))
        pred0_err_ci = cp_interval(err_pred0_count, s0_count if s0_count > 0 else 1)
        pred1_err_ci = cp_interval(err_pred1_count, s1_count if s1_count > 0 else 1)

        marginal_summary = {
            'n_total': n_total,
            'coverage': {
                'count': cov_count,
                'rate': cov_rate,
                'ci_95': (cov_lo, cov_hi)
            },
            'abstentions': {
                'count': abst_m_count,
                'rate': abst_m_rate,
                'ci_95': (abst_m_lo, abst_m_hi)
            },
            'singletons': {
                'count': sing_total,
                'rate': sing_m_rate,
                'ci_95': (sing_m_lo, sing_m_hi),
                'pred_0': {'count': s0_count, 'rate': s0_rate, 'ci_95': (s0_lo, s0_hi)},
                'pred_1': {'count': s1_count, 'rate': s1_rate, 'ci_95': (s1_lo, s1_hi)},
                'errors':  {'count': s_err_count, 'rate': s_err_rate, 'ci_95': (s_err_lo, s_err_hi)},
                'errors_by_pred': {
                    'pred_0': {
                        'count': err_pred0_count,
                        'denom': s0_count,
                        'rate': pred0_err_ci['proportion'],
                        'ci_95': (pred0_err_ci['lower'], pred0_err_ci['upper'])
                    },
                    'pred_1': {
                        'count': err_pred1_count,
                        'denom': s1_count,
                        'rate': pred1_err_ci['proportion'],
                        'ci_95': (pred1_err_ci['lower'], pred1_err_ci['upper'])
                    }
                }
            },
            'doublets': {
                'count': doub_m_count,
                'rate': doub_m_rate,
                'ci_95': (doub_m_lo, doub_m_hi)
            },
            'pac_bounds': pac_m
        }
        summary['marginal'] = marginal_summary

        if verbose:
            print(f"\n{'='*80}")
            print("MARGINAL ANALYSIS (Deployment View - Ignores True Labels)")
            print(f"{'='*80}")
            print(f"  Total samples: {n_total}")

            print("\nOverall Coverage:")
            print(f"  Covered: {cov_count:4d} / {n_total:4d} = {pct(cov_rate)}  "
                  f"95% CI: [{cov_lo:.4f}, {cov_hi:.4f}]")

            print("\nPrediction Set Distribution:")
            print(f"  Abstentions: {abst_m_count:4d} / {n_total:4d} = {pct(abst_m_rate)}  "
                  f"95% CI: [{abst_m_lo:.4f}, {abst_m_hi:.4f}]")
            print(f"  Singletons:  {sing_total:4d} / {n_total:4d} = {pct(sing_m_rate)}  "
                  f"95% CI: [{sing_m_lo:.4f}, {sing_m_hi:.4f}]")
            print(f"    ├─ Pred 0: {s0_count:4d} / {n_total:4d} = {pct(s0_rate)}  "
                  f"95% CI: [{s0_lo:.4f}, {s0_hi:.4f}]")
            print(f"    ├─ Pred 1: {s1_count:4d} / {n_total:4d} = {pct(s1_rate)}  "
                  f"95% CI: [{s1_lo:.4f}, {s1_hi:.4f}]")
            print(f"    ├─ Errors (overall): {s_err_count:4d} / {sing_total:4d} = {pct(s_err_rate)}  "
                  f"95% CI: [{s_err_lo:.4f}, {s_err_hi:.4f}]")
            print(f"    ├─ Pred 0 errors:    {err_pred0_count:4d} / {s0_count:4d} = {pct(pred0_err_ci['proportion'])}  "
                  f"95% CI: [{pred0_err_ci['lower']:.4f}, {pred0_err_ci['upper']:.4f}]")
            print(f"    └─ Pred 1 errors:    {err_pred1_count:4d} / {s1_count:4d} = {pct(pred1_err_ci['proportion'])}  "
                  f"95% CI: [{pred1_err_ci['lower']:.4f}, {pred1_err_ci['upper']:.4f}]")

            print(f"  Doublets:    {doub_m_count:4d} / {n_total:4d} = {pct(doub_m_rate)}  "
                  f"95% CI: [{doub_m_lo:.4f}, {doub_m_hi:.4f}]")

            if pac_m and pac_m.get('rho', None) is not None:
                aw = pac_m.get('alpha_weighted', None)
                aw_str = f"{float(aw):.3f}" if aw is not None else "n/a"
                print(f"\n  Overall PAC Bounds (weighted α={aw_str}):")
                print(f"    ρ = {pac_m.get('rho', 0):.3f}, κ = {pac_m.get('kappa', 0):.3f}")
                if 'alpha_singlet_bound' in pac_m and 'alpha_singlet_observed' in pac_m:
                    bound = float(pac_m['alpha_singlet_bound'])
                    observed = float(pac_m['alpha_singlet_observed'])
                    ok = '✓' if observed <= bound else '✗'
                    print(f"    α'_bound:    {bound:.4f}")
                    print(f"    α'_observed: {observed:.4f} {ok}")

                n_escalations = int(pac_m.get('n_escalations',
                                              doub_m_count + abst_m_count))
                print(f"\n  Deployment Decision Mix:")
                print(f"    Automate: {sing_total} singletons ({sing_m_rate:.1%})")
                print(f"    Escalate: {n_escalations} doublets+abstentions "
                      f"({n_escalations / n_total:.1%})")

    return summary


# Test it
if __name__ == "__main__":
    sim = BinaryClassifierSimulator(
        p_class1=0.50,
        beta_params_class0=(2, 7),  # Class 0: low P(1) scores
        beta_params_class1=(7, 2),  # Class 1: high P(1) scores
    )
    
    labels, probs = sim.generate(n_samples=1000)
    
    # Split by class
    class_data = split_by_class(labels, probs)

    plt.hist( class_data[0]['probs'][:,1], bins=np.linspace(0,1,50))
    plt.hist( class_data[1]['probs'][:,1], bins=np.linspace(0,1,50), alpha=0.5)
    plt.show()
    
    print("Mondrian Split Summary:")
    print("=" * 60)
    for label in [0, 1]:
        data = class_data[label]
        print(f"\nClass {label}:")
        print(f"  n = {data['n']}")
        print(f"  Indices (first 5): {data['indices'][:5]}")
        print(f"  Labels unique: {np.unique(data['labels'])}")
        print(f"  Mean P({label}): {data['probs'][:, label].mean():.3f}")
        print(f"  Mean P({1-label}): {data['probs'][:, 1-label].mean():.3f}")
    
    # Verify split is complete and non-overlapping
    total_samples = sum(class_data[i]['n'] for i in [0, 1])
    print(f"\nTotal samples after split: {total_samples} (should be {len(labels)})")
    
    # Show example samples from each class
    print("\n" + "=" * 60)
    print("Example samples from each class:")
    for label in [0, 1]:
        data = class_data[label]
        print(f"\nClass {label} (first 3 samples):")
        for i in range(min(3, data['n'])):
            print(f"  Sample {data['indices'][i]}: "
                  f"Label={data['labels'][i]}, "
                  f"P(0)={data['probs'][i, 0]:.3f}, "
                  f"P(1)={data['probs'][i, 1]:.3f}")

    print()
    print()
    print()
    print()
        
    cal_result, pred_stats = mondrian_conformal_calibrate(
        class_data=class_data,
        alpha_target={0: 0.10, 1: 0.10},
        delta={0: 0.10, 1: 0.10},
        mode="beta"
    )
    
    # Generate report with ALL Clopper-Pearson CIs
    summary = report_prediction_stats(pred_stats, cal_result, verbose=True)

In [None]:
delta_0 = np.arange(0.05, 0.20, 0.05)
delta_1 = np.arange(0.05, 0.20, 0.05)
alpha_0 = np.arange(0.05, 0.20, 0.05) 
alpha_1 = np.arange(0.05, 0.20, 0.05) 

In [None]:
import itertools
import numpy as np
import pandas as pd

# --- your libs must already expose these:
# from your_module import mondrian_conformal_calibrate, report_prediction_stats

def sweep_hyperparams_and_collect(
    class_data,
    alpha_0, delta_0,
    alpha_1, delta_1,
    mode="beta",
    extra_metrics=None,
    quiet=True,
):
    """
    Sweep (a0,d0,a1,d1), run mondrian_conformal_calibrate + report_prediction_stats,
    and return a tidy DataFrame with hyperparams + selected metrics.
    """
    rows = []
    combos = list(itertools.product(alpha_0, delta_0, alpha_1, delta_1))
    for (a0, d0, a1, d1) in combos:
        if not quiet:
            print(f"a0={a0:.3f}, d0={d0:.3f}, a1={a1:.3f}, d1={d1:.3f}")
        cal_result, pred_stats = mondrian_conformal_calibrate(
            class_data=class_data,
            alpha_target={0: float(a0), 1: float(a1)},
            delta={0: float(d0), 1: float(d1)},
            mode=mode
        )
        summary = report_prediction_stats(pred_stats, cal_result, verbose=False)

        # robust getter
        def g(d, *keys, default=None):
            cur = d
            for k in keys:
                if not isinstance(cur, dict) or k not in cur:
                    return default
                cur = cur[k]
            return cur

        n_total   = int(g(summary, 'marginal', 'n_total', default=0) or 0)
        cov       = float(g(summary, 'marginal', 'coverage', 'rate', default=0.0) or 0.0)
        sing_rate = float(g(summary, 'marginal', 'singletons', 'rate', default=0.0) or 0.0)
        sing_cnt  = int(g(summary, 'marginal', 'singletons', 'count', default=0) or 0)
        abst_cnt  = int(g(summary, 'marginal', 'abstentions', 'count', default=0) or 0)
        doub_cnt  = int(g(summary, 'marginal', 'doublets', 'count', default=0) or 0)
        esc_rate  = (abst_cnt + doub_cnt) / float(n_total if n_total else 1)

        err_all   = float(g(summary, 'marginal', 'singletons', 'errors', 'rate', default=0.0) or 0.0)
        err_p0    = float(g(summary, 'marginal', 'singletons', 'errors_by_pred', 'pred_0', 'rate', default=0.0) or 0.0)
        err_p1    = float(g(summary, 'marginal', 'singletons', 'errors_by_pred', 'pred_1', 'rate', default=0.0) or 0.0)

        err_y0    = float(g(summary, 0, 'singletons', 'error_given_singleton', 'rate', default=0.0) or 0.0)
        err_y1    = float(g(summary, 1, 'singletons', 'error_given_singleton', 'rate', default=0.0) or 0.0)

        row = {
            'a0': float(a0), 'd0': float(d0),
            'a1': float(a1), 'd1': float(d1),
            'cov': cov,
            'sing_rate': sing_rate,
            'err_all': err_all,
            'err_pred0': err_p0,
            'err_pred1': err_p1,
            'err_y0': err_y0,
            'err_y1': err_y1,
            'esc_rate': esc_rate,
            'n_total': int(n_total),
            'sing_count': int(sing_cnt),
            # handy extras for hover:
            'm_abst': abst_cnt, 'm_doublets': doub_cnt,
        }

        if extra_metrics:
            for name, fn in extra_metrics.items():
                try:
                    row[name] = fn(summary)
                except Exception:
                    row[name] = np.nan

        rows.append(row)

    df = pd.DataFrame(rows)
    return df.sort_values(['a0','d0','a1','d1'], kind='mergesort').reset_index(drop=True)


# ------------------------ Plotly interactive chart ------------------------

def plot_parallel_coordinates_plotly(
    df,
    columns=None,
    color='err_all',
    color_continuous_scale=None,
    title="Mondrian sweep – interactive parallel coordinates",
    height=600,
    base_opacity=0.9,        # opacity of the active / selected lines
    unselected_opacity=0.06  # fade level for unselected lines
):
    import plotly.express as px

    if columns is None:
        default_cols = ['a0','d0','a1','d1',
                        'cov','sing_rate',
                        'err_all','err_pred0','err_pred1',
                        'err_y0','err_y1',
                        'esc_rate']
        columns = [c for c in default_cols if c in df.columns]

    fig = px.parallel_coordinates(
        df,
        dimensions=columns,
        color=color if color in df.columns else None,
        color_continuous_scale=color_continuous_scale or px.colors.sequential.Blugrn,
        labels={c: c for c in columns},
    )

    # Maximize contrast between selected and unselected lines
    if fig.data:
        # Fade unselected lines
        fig.data[0].unselected.update(line=dict(color=f"rgba(1,1,1,{float(unselected_opacity)})"))
        # Make selected lines as opaque as possible (can’t change width)
        # We can only set the *trace* line color/scale; bump overall alpha by
        # moving the colorbar to fully opaque range.
        # No direct line.opacity for parcoords; use colorscale without alpha.
        pass

    fig.update_layout(
        title=title,
        height=height,
        margin=dict(l=40, r=40, t=60, b=40),
        plot_bgcolor="white",
        paper_bgcolor="white",
        font=dict(size=14),
        uirevision=True  # keep user brushing across updates
    )

    # Make axis labels and ranges more readable
    fig.update_traces(
        labelfont=dict(size=14),
        rangefont=dict(size=12),
        tickfont=dict(size=12)
    )

    # Optional: title for colorbar if we’re coloring by a column
    if color in df.columns and fig.data and getattr(fig.data[0], "line", None):
        if getattr(fig.data[0].line, "colorbar", None) is not None:
            fig.data[0].line.colorbar.title = color

    return fig



# -------- convenience wrapper: run sweep + show plotly figure --------

def sweep_and_plot_parallel_plotly(
    class_data,
    delta_0, delta_1,
    alpha_0, alpha_1,
    mode="beta",
    extra_metrics=None,
    color='err_all',
    color_continuous_scale=None,
    title=None,
    height=600,
):
    df = sweep_hyperparams_and_collect(
        class_data=class_data,
        alpha_0=alpha_0, delta_0=delta_0,
        alpha_1=alpha_1, delta_1=delta_1,
        mode=mode,
        extra_metrics=extra_metrics,
        quiet=True
    )
    
    fig = plot_parallel_coordinates_plotly(
        df,
        color=color,
        color_continuous_scale=color_continuous_scale,
        title=title,
        height=height
    )
    return df, fig


In [None]:
delta_0 = np.arange(0.05, 0.20, 0.05)
delta_1 = np.arange(0.05, 0.20, 0.05)
alpha_0 = np.arange(0.05, 0.20, 0.05)
alpha_1 = np.arange(0.05, 0.20, 0.05)

df, fig = sweep_and_plot_parallel_plotly(
    class_data=class_data,
    delta_0=delta_0, delta_1=delta_1,
    alpha_0=alpha_0, alpha_1=alpha_1,
    mode="beta",
    color='a1',  # try 'esc_rate' or 'cov' too
)

# in a notebook:
fig.show()
