In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, widgets, VBox, HBox, Layout, Output, Label

import scienceplots
plt.style.use('science')
plt.rcParams['text.usetex'] = False


def softmax_regression_predict(pre_test_probabilities, weights_of_evidence):
    if len(pre_test_probabilities) != len(weights_of_evidence):
        raise ValueError("Lengths of pre_test_probabilities and weights_of_evidence must match.")
    
    # Convert pre-test probabilities to logs; add the log-likelihood ratio
    combined = np.log(pre_test_probabilities) + np.array(weights_of_evidence)
    
    # Softmax for final probabilities
    exp_combined = np.exp(combined - np.max(combined))  # Stability trick
    probabilities = exp_combined / np.sum(exp_combined)
    
    return probabilities


def stacked_chart_pre_post(pre_test_probabilities, post_testing_probabilities, dx_names, weights_of_evidence, figsize=(6, 6)):
    """Visualize the changes in probabilities with dotted lines connecting the top and bottom of pre- and post-test bars."""
    # Check that input lengths match
    if not (len(pre_test_probabilities) == len(post_testing_probabilities) == len(dx_names) == len(weights_of_evidence)):
        raise ValueError("All input arrays must have the same length.")

    # Normalize the probabilities to ensure they sum to 1
    pre_test_probabilities = np.array(pre_test_probabilities) / np.sum(pre_test_probabilities)
    post_testing_probabilities = np.array(post_testing_probabilities) / np.sum(post_testing_probabilities)

    # Combine the data for stacked bar charts
    data = [pre_test_probabilities, post_testing_probabilities]

    # Labels for the x-axis
    x_labels = ["Before Info", "After Info"]

    # Create the plot
    fig, ax = plt.subplots(figsize=figsize)

    # Stacked bar chart
    x = np.arange(len(data))  # Positions for "Before" and "After"
    bottoms = np.zeros(len(data))  # To track stack heights
    width = 0.12  # Increased bar width for better visibility

    # Store top and bottom coordinates for dotted line drawing
    bar_tops = {i: [] for i in range(len(dx_names))}
    bar_bottoms = {i: [] for i in range(len(dx_names))}

    # Add bars and annotate probabilities
    for i, category in enumerate(dx_names):
        values = [data_point[i] for data_point in data]
        bars = ax.bar(x, values, width=width, bottom=bottoms, label=f"{category} (Cumulative LR: {np.exp(weights_of_evidence[i]):.2f})")
        for j, (bar, value) in enumerate(zip(bars, values)):
            # Add text inside each bar
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                bar.get_y() + bar.get_height() / 2,
                f"{value:.2f}",
                ha="center",
                va="center",
                fontsize=10,
                color="white"
            )
            # Store the top and bottom of each bar
            bar_tops[i].append(bar.get_y() + bar.get_height())
            bar_bottoms[i].append(bar.get_y())

        bottoms += values

    # Add dotted lines connecting the top and bottom of each category
    for i in range(len(dx_names)):
        ax.plot(
            x,  # x-coordinates for "Before" and "After"
            bar_tops[i],  # y-coordinates for the tops of the bars
            linestyle="--", color="gray", alpha=0.5
        )
        ax.plot(
            x,  # x-coordinates for "Before" and "After"
            bar_bottoms[i],  # y-coordinates for the bottoms of the bars
            linestyle="--", color="gray", alpha=0.5
        )

    # Add labels, title, and legend
    ax.set_xticks(x)
    ax.set_xticklabels(x_labels)
    ax.set_ylabel("Probability")
    ax.set_title("Before-Data vs After-Data Probabilities")

    # Move the legend below the chart
    ax.legend(title="Diagnoses", loc="best", bbox_to_anchor=(0.5, -0.15), ncol=1)

    # Adjust layout to avoid overlap
    plt.tight_layout()
    plt.show()


def generalized_predictor_wrapper(diagnoses):
    """
    diagnoses = list of dicts:
    [
      { 'name': <str>, 'pretest': <float from 0 to 1>, 'loglr': <float from -5 to +5> },
      ...
    ]
    """
    pre_test_probs = []
    weights_of_evidence = []
    names = []
    
    for diag in diagnoses:
        names.append(diag['name'])
        pre_test_probs.append(diag['pretest'])
        weights_of_evidence.append(diag['loglr'])

    pre_test_sum = sum(pre_test_probs)
    norm_pre_test_probs = [p / pre_test_sum for p in pre_test_probs]

    predicted_probs = softmax_regression_predict(norm_pre_test_probs, weights_of_evidence) # doesn't need to be norm'd
    
    # Display the results
    print(f"{'Diagnosis':<20}{'Pre-Test Pr':<15}{'Norm. Pre-Test':<15}{'Total LR':<15}{'LogLR':<15}{'Post-Test Pr':<15}")
    print("-" * 95)
    for i, name in enumerate(names):
        print(f"{name:<20}{pre_test_probs[i]:<15.2f}{norm_pre_test_probs[i]:<15.2f}{np.exp(weights_of_evidence[i]):<15.2f}{weights_of_evidence[i]:<15.2f}{predicted_probs[i]:<15.2f}")
    
    return {
        'names': names,
        'pre_test_probs': pre_test_probs,
        'weights_of_evidence': weights_of_evidence,
        'predicted_probs': predicted_probs
    }


def create_diagnosis_widget():
    diagnosis_count = widgets.IntSlider(
        value=1, min=1, max=10, step=1, description="Diagnoses"
    )
    container = VBox()
    chart_output = Output()

    def update_widgets(change):
        if change['name'] == 'value':
            container.children = []
            for i in range(change['new']):
                name_widget = widgets.Text(
                    value=f"Diagnosis {i+1}",
                    description="Name:"
                )
                
                # Pre-test probability slider from 0.0 to 1.0, 2-decimal readout
                pretest_widget = widgets.FloatSlider(
                    value=0.5,
                    min=0.0,
                    max=1.0,
                    step=0.01,
                    description="Pre-Test:",
                    readout_format='.2f'
                )

                # Create the LogLR FloatLogSlider
                loglr_widget = widgets.FloatLogSlider(
                    value=1.0,            # Start with LR = 1
                    base=np.e,            # Base e
                    min=-5.0,             # Minimum LogLR
                    max=5.0,              # Maximum LogLR
                    step=0.1,             # Step size
                    description="Total LR:", # Label for the slider
                    readout=True,
                    readout_format=".2f"
                )

                diagnosis_box = VBox([
                    name_widget,
                    pretest_widget,
                    loglr_widget,
                ])
                container.children += (diagnosis_box,)

    diagnosis_count.observe(update_widgets, names='value')
    
    # Force the initial layout
    update_widgets({'name': 'value', 'new': diagnosis_count.value})

    def gather_inputs(_):
        diagnoses = []
        for diagnosis_box in container.children:
            name = diagnosis_box.children[0].value
            pretest = diagnosis_box.children[1].value
            loglr = np.log(diagnosis_box.children[2].value)  # Transform LR back to LogLR

            diagnoses.append({
                'name': name,
                'pretest': pretest,
                'loglr': loglr
            })

        chart_output.clear_output()
        with chart_output:
            results = generalized_predictor_wrapper(diagnoses)
            stacked_chart_pre_post(
                results['pre_test_probs'],
                results['predicted_probs'],
                results['names'],
                results['weights_of_evidence']
            )

    calculate_button = widgets.Button(description="Calculate")
    calculate_button.on_click(gather_inputs)

    return HBox([VBox([diagnosis_count, container, calculate_button]), chart_output])


# Now create and display the revised widget
diagnosis_widget = create_diagnosis_widget()
display(diagnosis_widget)

HBox(children=(VBox(children=(IntSlider(value=1, description='Diagnoses', max=10, min=1), VBox(children=(VBox(…

Here's the intuition for the problem: 

We are asking the LLM (or, the literature) to estimate multiple one diagnosis vs the rest likelihood ratios (ie the likelihood if diagnosis over likelihood if not diagnosis). 

For example, consider the case where there's 3 diagnoses under consideration: It's either A, B, or C (more precisely, let's say it's 1 and only 1 of A, B, or C). This means that the pre-test probabilities must sum to 1. 

If we then gather a new piece of information, we get 3 likelihood ratios

- A vs not A (== B or C). This is LR1
- B vs not B (== A or C). This is LR2
- C vs not C (== A or B). This is LR3 

Because all three LRs refer to the same patient and the same piece of data - they must all be *true at once*.  And, because even after the new piece of information is incorporated, all post-test probabilities must sum to 1. This is called coherence. 

Because we have estimated each one-vs-rest (OVR) likelihood ratio (e.g. LR for A vs not A) independently and imperfectly, when the LRs are applied the resulting post-test probabilities may not sum to 1 - which is called incoherence. Conversely, if we apply all the OVR likelihoods and the resulting post-test probabilities do sum to 1, the LRs are coherent. 

The naive way - and what I initially did - was to apply all the one-vs-rest likelihood ratios to update each pre-test probability, then re-scale the resulting post-test probabilities so that they sum to 1. 

This is problematic because it breaks the Bayes odds–LR identity and therefore distorts the way evidence updates the probabilities. You can see this by considering: 

When an LR of 2 is applied to a situation where the pre-test probability is 50% (1:1 odds), the post-test probability is 66% (2:1 odds) - a 16% probability difference. When an LR of 2 is applied to a situation where the pre-test probability was 90% (9:1 odds), the post-test probability is 94.7% (18:1), a less than 5% difference. 

So, normalizing in probability space (rather than log-odds space) implies that you actually updated with a different LR that differs based on the probability of the diagnosis. The naive probability normalization dilutes strong (probable) diagnoses and inflates weak (improbable) ones.

Consider: 
- Priors (A- 0.50, B- 0.30, C- 0.20); 
- LLM estimated OVR LRs (4, 3, 2).
- Binary posteriors (slices): q=(0.800, 0.563, 0.333); sum s=1.696 == non-coherent
- Renormalize in probability space => P=(0.472, 0.332, 0.197).
- Implied OVR LRs from the bayesian update P: (0.893, 1.158, 0.979) — nowhere near (4, 3, 2)! 

That’s the odds–LR mismatch we're trying to solve

So, an alternative way to solve this 'coupling problem' would be to rescale the estimated LRs in log-odds space - dividing or multiplying them all by some factor such that the set of OVR LRs are consistent (== result in post-test probabilities that sum to 1). This would work (it's whats called an intercept‑only calibration; it fixes a global bias but can’t correct class‑specific errors), but an even better approach is to consider that the existing LRs are some sort of best guess that may follow a normal distribution (ie result from random additive factors pushing the estimate erroneously in either direction), and that we can fit a set of LRs to the guesses that minimizes the squared error. 

Retinkered, with Bayes-coherent OVR LR Solver: 
- Priors (A- 0.50, B- 0.30, C- 0.20); 
- LLM estimated OVR LRs (4, 3, 2).
- Renormalized OVR LRs (1.08, .99, .905)
- Resulting Post-test probabilities => P=(0.52, 0.3, 0.18).

In essence, we are are fitting class “scores” that define a single multiclass model but are estimated noisily; the fitted model is the best one (closest squared error) that implies a coherent set of OVR LRs (ie. a posterior that sums to 1).

That's what the new algorithm does (with an additional L2 regularization step to ensure that extreme values don't make the problem too hard/unstable to numerically solve)


In [10]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, widgets, VBox, HBox, Layout, Output, Label
from dataclasses import dataclass
from typing import Optional, Dict, Any
from scipy.optimize import least_squares
from scipy.special import expit, logsumexp

# Optional styling
import scienceplots
plt.style.use('science')
plt.rcParams['text.usetex'] = False

# ============================================================
#  Bayes‑coherent OVR solver: concepts and data flow
#  -----------------------------------------------------------
#  • priors π_k (pre‑test probabilities) – sum to 1
#  • OVR LRs λ_k  (diagnosis k vs "not k") – possibly inconsistent
#  • We seek class log‑*scores* s_k (roughly, log likelihoods) so that:
#       model_logLR_k = s_k - log(∑_{j≠k} π_j e^{s_j}) + log(1-π_k)
#    matches the input log λ_k in least‑squares sense.
#  • Then posterior P_k ∝ π_k e^{s_k} is Bayes‑coherent by construction.
#  • L2 regularization on s controls numeric stability / overfitting.
# ============================================================

@dataclass
class OVRProjectionResult:
    posterior: np.ndarray                 # Coherent posterior P_k = softmax(log π_k + s_k)
    log_scores: np.ndarray                # s_k with one baseline fixed to 0 (identifiability)
    lr_vs_baseline: np.ndarray            # Common-reference LRs: LR_{k:baseline} = exp(s_k)
    ovr_lr_fitted: np.ndarray             # Coherent OVR LRs implied by the fitted s
    rmse: float                           # RMS residual on log-LR scale (closeness to inputs)
    success: bool
    message: str
    nfev: int
    diagnostics: Dict[str, Any]           # Helpful extras: baseline index, sum_q_init, etc.


def project_ovr_to_coherent(
    priors: np.ndarray,
    ovr_lr: np.ndarray,
    weights: Optional[np.ndarray] = None,
    baseline: Optional[int] = None,
    reg: float = 0.0,
    max_nfev: int = 200,
    ftol: float = 1e-10,
    xtol: float = 1e-10,
    gtol: float = 1e-10,
) -> OVRProjectionResult:
    """
    Project possibly-inconsistent one-vs-rest LRs (λ_k) to a Bayes-coherent
    multiclass model by estimating log-scores s_k that minimize squared errors
    between input log-OVR LRs and model-implied log-OVR LRs with the true
    mixture denominator. Then compute posterior P_k ∝ π_k * exp(s_k).

    Interpretation:
      - s_k acts like log L_k (class-conditional likelihood) up to a global offset.
      - The 'rest' likelihood in OVR is the PRIOR-WEIGHTED mixture of other classes.
      - Regularization 'reg' is ridge on s (Gaussian prior, stabilizer).
    """
    # ---- Input sanitation / normalization ------------------------------------
    pi = np.asarray(priors, dtype=float).copy()
    lam = np.asarray(ovr_lr, dtype=float).copy()
    K = pi.size
    if lam.size != K:
        raise ValueError("ovr_lr and priors must have same length K.")
    if np.any(lam <= 0):
        raise ValueError("All OVR LRs must be > 0 (strictly positive).")

    eps = 1e-12
    pi = np.clip(pi, eps, 1 - eps)        # avoid 0/1 which break logs and logits
    pi = pi / pi.sum()                    # ensure priors sum to 1

    # Fix one class as the log-score baseline to resolve the additive constant
    if baseline is None:
        baseline = int(np.argmax(pi))     # sensible default: most probable prior

    # Optional per-class weights for the least-squares fit (confidence on λ_k)
    if weights is None:
        w = np.ones(K, dtype=float)
    else:
        w = np.asarray(weights, dtype=float)
        if w.size != K:
            raise ValueError("weights must have length K.")
        w = np.clip(w, eps, None)

    # Precompute log priors and helper transforms
    log_pi = np.log(pi)
    log_one_minus_pi = np.log(1.0 - pi)
    log_lambda = np.log(lam)
    logit_pi = log_pi - np.log(1.0 - pi)

    # ---- Initialization -------------------------------------------------------
    # q_init = binary posteriors for each "k vs not-k" (your coherence slices)
    # If inputs are perfectly coherent, ∑ q_init == 1 and s0_full already solves it.
    q_init = expit(logit_pi + log_lambda)

    # From P_k ∝ π_k e^{s_k} ⇒ s_k = log P_k - log π_k (up to a constant). Use q_init.
    s0_full = np.log(q_init) - np.log(pi)   # defined up to an additive constant
    s0_full -= s0_full[baseline]            # fix baseline's s to 0 for identifiability
    keep = np.array([i for i in range(K) if i != baseline], dtype=int)
    theta0 = s0_full[keep]                  # optimizer coordinates (exclude baseline)

    def unpack(theta):
        """Insert the optimized coordinates back into the full s vector (baseline fixed)."""
        s = np.empty(K, dtype=float)
        s[baseline] = 0.0
        s[keep] = theta
        return s

    # ---- Residuals and analytic Jacobian (Gauss–Newton-friendly) -------------
    # Residual for class k:
    #   r_k = sqrt(w_k) * [ log λ_k - ( s_k - log ∑_{j≠k} π_j e^{s_j} + log(1-π_k) ) ]
    # where the middle term is the model's OVR log-LR for k vs rest-mixture.
    def residuals_and_jac(theta):
        s = unpack(theta)

        # Unnormalized log "joint" for class k: z_k = log π_k + s_k
        z = log_pi + s
        # log-sum-exp trick: work in a numerically stable frame
        m = np.max(z)
        ez = np.exp(z - m)                 # proportional to π_k e^{s_k}, rescaled
        S = np.sum(ez)                     # common normalizer (rescaled)
        S_minus = S - ez                   # ∑_{j≠k} π_j e^{s_j} (rescaled)
        S_minus = np.maximum(S_minus, eps) # guard tiny denominators
        logS_minus = np.log(S_minus) + m   # undo rescaling to get true log sum

        # Model-implied OVR log-LR for each class k
        model_loglr = s - logS_minus + log_one_minus_pi

        # Weighted residuals on the log-LR scale (units = "nats")
        r_core = log_lambda - model_loglr
        r = np.sqrt(w) * r_core

        # Jacobian J = ∂r / ∂theta (theta excludes baseline s)
        # For the full s (including baseline) the partials are:
        #   ∂r_k/∂s_k = -sqrt(w_k)
        #   ∂r_k/∂s_j (j ≠ k) = sqrt(w_k) * (π_j e^{s_j}) / (∑_{ℓ≠k} π_ℓ e^{s_ℓ})
        J = np.zeros((K, K), dtype=float)
        np.fill_diagonal(J, -np.sqrt(w))   # diag terms (k wrt s_k)
        frac = ez[None, :] / S_minus[:, None]  # broadcasted fraction for off-diagonals
        frac[np.arange(K), np.arange(K)] = 0.0 # zero out the j=k entries
        J += (np.sqrt(w)[:, None] * frac)
        # Drop the baseline column (its s is fixed)
        J = J[:, keep]

        # Optional L2 (ridge) regularization on theta: adds √reg * theta to residuals,
        # and √reg * I to the Jacobian (Tikhonov → better conditioning).
        if reg > 0.0:
            r = np.concatenate([r, np.sqrt(reg)*theta])
            J = np.vstack([J, np.sqrt(reg)*np.eye(theta.size)])
        return r, J

    # Trust Region Reflective least-squares with analytic Jacobian
    res = least_squares(lambda t: residuals_and_jac(t)[0],
                        theta0,
                        jac=lambda t: residuals_and_jac(t)[1],
                        method="trf",
                        max_nfev=max_nfev, ftol=ftol, xtol=xtol, gtol=gtol)

    # ---- Recover coherent posterior and fitted OVR LRs ------------------------
    theta_hat = res.x
    s_hat = unpack(theta_hat)               # fitted class log-scores (baseline = 0)
    z_hat = log_pi + s_hat                  # log unnormalized posteriors
    logZ = logsumexp(z_hat)                 # common log normalizer
    posterior = np.exp(z_hat - logZ)        # coherent P_k

    # Fitted OVR LRs implied by the model (for reporting / comparison)
    exp_z = np.exp(z_hat)
    S = np.sum(exp_z)
    S_minus_hat = np.maximum(S - exp_z, eps)
    model_loglr_hat = s_hat - np.log(S_minus_hat) + log_one_minus_pi
    ovr_lr_hat = np.exp(model_loglr_hat)

    out = OVRProjectionResult(
        posterior=posterior,
        log_scores=s_hat,
        lr_vs_baseline=np.exp(s_hat),     # LR_{k:baseline}; baseline s=0 ⇒ exp(s_k)
        ovr_lr_fitted=ovr_lr_hat,
        rmse=float(np.sqrt(np.mean(res.fun**2))),
        success=bool(res.success),
        message=res.message,
        nfev=int(res.nfev),
        diagnostics={
            "K": K,
            "baseline": baseline,
            # Coherence diagnostic: if inputs were coherent, this sum would be 1.
            "sum_q_init": float(np.sum(q_init)),
            "priors": pi,
            "input_ovr_lr": lam,
        },
    )
    return out



# ============================================================
#  Visualization: show pre vs post, annotate Input vs Fitted
# ============================================================
def stacked_chart_pre_post(pre_test_probabilities, post_testing_probabilities, dx_names,
                           input_loglr, fitted_ovr_lr=None, figsize=(6, 6)):
    """
    Visualize how the posterior reallocates probability mass.
    Legend shows, per diagnosis:
        "Input OVR LR: <..> | Fitted OVR LR: <..>"
    • input_loglr: what the LLM/literature supplied (possibly inconsistent)
    • fitted_ovr_lr: coherent LRs implied by the projected model (optional)
    """
    if not (len(pre_test_probabilities) == len(post_testing_probabilities) == len(dx_names) == len(input_loglr)):
        raise ValueError("All input arrays must have the same length.")

    pre = np.array(pre_test_probabilities, dtype=float)
    post = np.array(post_testing_probabilities, dtype=float)
    pre = pre / pre.sum()
    post = post / post.sum()

    data = [pre, post]
    x_labels = ["Before Info", "After Info"]

    fig, ax = plt.subplots(figsize=figsize)
    x = np.arange(len(data))
    bottoms = np.zeros(len(data))
    width = 0.12

    # Keep track of bar tops/bottoms to draw dotted connectors (visual funnel)
    bar_tops = {i: [] for i in range(len(dx_names))}
    bar_bottoms = {i: [] for i in range(len(dx_names))}

    for i, category in enumerate(dx_names):
        values = [data_point[i] for data_point in data]
        # Annotate Input vs Fitted OVR LRs side-by-side
        lr_text = f"Input OVR LR: {np.exp(input_loglr[i]):.2f}"
        if fitted_ovr_lr is not None:
            lr_text += f" | Fitted OVR LR: {fitted_ovr_lr[i]:.2f}"
        bars = ax.bar(x, values, width=width, bottom=bottoms,
                      label=f"{category} ({lr_text})")
        for bar, value in zip(bars, values):
            ax.text(bar.get_x() + bar.get_width()/2,
                    bar.get_y() + bar.get_height()/2,
                    f"{value:.2f}",
                    ha="center", va="center", fontsize=10, color="white")
            bar_tops[i].append(bar.get_y() + bar.get_height())
            bar_bottoms[i].append(bar.get_y())
        bottoms += values

    # Dotted lines: where each class "starts" and "ends"
    for i in range(len(dx_names)):
        ax.plot(x, bar_tops[i], linestyle="--", color="gray", alpha=0.5)
        ax.plot(x, bar_bottoms[i], linestyle="--", color="gray", alpha=0.5)

    ax.set_xticks(x)
    ax.set_xticklabels(x_labels)
    ax.set_ylabel("Probability")
    ax.set_title("Before-Data vs After-Data Probabilities")
    ax.legend(title="Diagnoses", loc="best", bbox_to_anchor=(0.5, -0.15), ncol=1)
    plt.tight_layout()
    plt.show()


# ============================================================
#  Wrapper: read inputs, run either OVR projection (recommended)
#  or "common-reference softmax" (assumes shared denominator).
# ============================================================
def generalized_predictor_wrapper(diagnoses, method="ovr_projection", reg=1e-6):
    """
    diagnoses: list of dicts:
      [{ 'name': <str>, 'pretest': float in [0,1], 'loglr': float (log OVR LR) }, ...]

    method:
      - "ovr_projection": Bayes‑coherent OVR projection (mixture-aware least squares)
      - "softmax_scores": treat 'loglr' as log-likelihoods vs a single common reference
                          (i.e., already coherent up to a shared constant)
    """
    # Unpack raw inputs
    names, pre_test_probs, input_loglr = [], [], []
    for d in diagnoses:
        names.append(d['name'])
        pre_test_probs.append(d['pretest'])
        input_loglr.append(d['loglr'])

    # Normalize priors; keep original pre_test_probs for reporting
    pre = np.array(pre_test_probs, dtype=float)
    if pre.sum() <= 0:
        raise ValueError("Sum of pre-test probabilities must be > 0.")
    priors = pre / pre.sum()
    input_loglr = np.array(input_loglr, dtype=float)

    if method == "ovr_projection":
        # Core solver: project inconsistent OVR LRs to a single coherent model
        lam = np.exp(input_loglr)
        res = project_ovr_to_coherent(priors=priors, ovr_lr=lam, reg=reg)

        post = res.posterior
        fitted_ovr = res.ovr_lr_fitted
        fitted_loglr = np.log(fitted_ovr)

        # Basic run diagnostics: optimizer status + coherence check on the init slices
        print(f"OVR projection success: {res.success} (nfev={res.nfev}) | rmse(log-LR)={res.rmse:.4g}")
        print(f"Baseline class index: {res.diagnostics['baseline']} | sum q_init: {res.diagnostics['sum_q_init']:.6f}")

        # Table: show how inputs were adjusted to become coherent
        header = (
            f"{'Diagnosis':<20}{'Pre-Test':<12}{'Norm Pre':<12}"
            f"{'Input OVR LR':<15}{'Fitted OVR LR':<16}{'Fitted LogLR':<15}{'Post-Test':<12}"
        )
        print(header)
        print("-" * len(header))
        for i, name in enumerate(names):
            print(f"{name:<20}"
                  f"{pre[i]:<12.3f}{priors[i]:<12.3f}"
                  f"{np.exp(input_loglr[i]):<15.3f}{fitted_ovr[i]:<16.3f}{fitted_loglr[i]:<15.3f}{post[i]:<12.3f}")

        return {
            'names': names,
            'pre_test_probs': pre,
            'weights_of_evidence': input_loglr,  # keep the original inputs for plotting context
            'predicted_probs': post,
            'fitted_ovr_lr': fitted_ovr,
            'fitted_loglr': fitted_loglr
        }

    elif method == "softmax_scores":
        # Alternate path: assume provided loglr are already common-reference scores
        # (i.e., log L_k minus a shared constant). Then Bayes is just softmax(log π + loglr).
        combined = np.log(priors) + input_loglr
        expc = np.exp(combined - np.max(combined))
        post = expc / np.sum(expc)

        header = (
            f"{'Diagnosis':<20}{'Pre-Test':<12}{'Norm Pre':<12}"
            f"{'Input LR':<12}{'LogLR':<12}{'Post-Test':<12}"
        )
        print(header)
        print("-" * len(header))
        for i, name in enumerate(names):
            print(f"{name:<20}"
                  f"{pre[i]:<12.3f}{priors[i]:<12.3f}"
                  f"{np.exp(input_loglr[i]):<12.3f}{input_loglr[i]:<12.3f}{post[i]:<12.3f}")

        return {
            'names': names,
            'pre_test_probs': pre,
            'weights_of_evidence': input_loglr,
            'predicted_probs': post,
            'fitted_ovr_lr': None,
            'fitted_loglr': None
        }
    else:
        raise ValueError("Unknown method")
    

# ============================================================
#  Notebook widget: quick UI to experiment with priors and LRs
#  • "OVR coherent projection" (recommended) runs the solver
#  • "Common-reference softmax" assumes scores already share a
#    single denominator (no OVR mixture coupling enforced)
# ============================================================
def create_diagnosis_widget():
    diagnosis_count = widgets.IntSlider(value=3, min=1, max=12, step=1, description="Diagnoses")
    method_dd = widgets.Dropdown(options=[('OVR coherent projection','ovr_projection'),
                                          ('Common-reference softmax','softmax_scores')],
                                 value='ovr_projection', description='Method')
    # Ridge strength slider (log-scale). ↑ reg = more shrinkage / stability.
    reg_slider = widgets.FloatLogSlider(value=1e-6, base=10, min=-12, max=-2, step=0.5,
                                        description="Reg (L2)", readout_format=".1e")

    container = VBox()
    chart_output = Output()

    def update_widgets(change):
        if change['name'] == 'value':
            children = []
            for i in range(change['new']):
                name_widget = widgets.Text(value=f"Diagnosis {i+1}", description="Name:")
                pretest_widget = widgets.FloatSlider(value=0.33, min=0.0, max=1.0, step=0.01,
                                                     description="Pre-Test:", readout_format='.2f')
                # Slider shows LR on a log scale; we convert to log() at readout.
                lr_widget = widgets.FloatLogSlider(value=1.0, base=np.e, min=-5.0, max=5.0, step=0.1,
                                                   description="OVR LR:", readout=True, readout_format=".2f")
                children.append(VBox([name_widget, pretest_widget, lr_widget]))
            container.children = tuple(children)

    diagnosis_count.observe(update_widgets, names='value')
    update_widgets({'name': 'value', 'new': diagnosis_count.value})

    def gather_inputs(_):
        # Assemble inputs from the UI
        diagnoses = []
        for box in container.children:
            name = box.children[0].value
            pretest = box.children[1].value
            loglr = np.log(box.children[2].value)  # store log OVR LR
            diagnoses.append({'name': name, 'pretest': pretest, 'loglr': loglr})

        chart_output.clear_output()
        with chart_output:
            # Run chosen method and plot pre vs post; legend shows Input vs Fitted LRs
            results = generalized_predictor_wrapper(diagnoses, method=method_dd.value, reg=reg_slider.value)
            stacked_chart_pre_post(results['pre_test_probs'],
                                   results['predicted_probs'],
                                   results['names'],
                                   results['weights_of_evidence'],
                                   fitted_ovr_lr=results['fitted_ovr_lr'])

    calculate_button = widgets.Button(description="Calculate")
    calculate_button.on_click(gather_inputs)

    left = VBox([diagnosis_count, method_dd, reg_slider, container, calculate_button])
    return HBox([left, chart_output])


# Build & display widget
diagnosis_widget = create_diagnosis_widget()
display(diagnosis_widget)

HBox(children=(VBox(children=(IntSlider(value=3, description='Diagnoses', max=12, min=1), Dropdown(description…