In [1]:
import matplotlib.pyplot as plt
import scienceplots
import numpy as np
from ipywidgets import interact, widgets, VBox, HBox, Layout, Output

plt.style.use('science')
plt.rcParams['text.usetex'] = False

certainty_estimates = {
    "HIGHLY_UNLIKELY": 0.119,
    "UNLIKELY": 0.269,
    "UNCERTAIN": 0.5,
    "LIKELY": 0.731,
    "HIGHLY_LIKELY": 0.881
}

def softmax_regression_predict(pre_test_probabilities, weights_of_evidence):
    """
    Predict probabilities using a multinomial logistic regression model (softmax regression).
    
    Parameters:
        pre_test_probabilities (list of float): Pre-test probabilities (baseline probabilities before including predictors).
        weights_of_evidence (list of float): Linear predictors (log-odds contributions) for each class.
    
    Returns:
        list of float: Predicted probabilities for each class (summing to 1).
    
    Raises:
        ValueError: If the lengths of inputs don't match or if pre_test_probabilities contain invalid values.
    """
    
    if len(pre_test_probabilities) != len(weights_of_evidence):
        raise ValueError("The lengths of pre_test_probabilities and weights_of_evidence must be the same (equal to the number of diagnoses).")
    
    if not all(0 <= prob <= 1 for prob in pre_test_probabilities):
        raise ValueError("All pre_test_probabilities must be between 0 and 1.")
    
    combined = np.log(pre_test_probabilities) + np.array(weights_of_evidence)
    
    # Compute probabilities using the softmax function
    exp_combined = np.exp(combined - np.max(combined))  # Stability trick to avoid overflow
    probabilities = exp_combined / np.sum(exp_combined)
    
    return probabilities


def stacked_chart_pre_post(pre_test_probabilities, post_testing_probabilities, dx_names, weights_of_evidence, figsize=(6, 6)):
    """Visualize the changes in probabilities with dotted lines connecting the top and bottom of pre- and post-test bars."""
    # Check that input lengths match
    if not (len(pre_test_probabilities) == len(post_testing_probabilities) == len(dx_names) == len(weights_of_evidence)):
        raise ValueError("All input arrays must have the same length.")

    # Normalize the probabilities to ensure they sum to 1
    pre_test_probabilities = np.array(pre_test_probabilities) / np.sum(pre_test_probabilities)
    post_testing_probabilities = np.array(post_testing_probabilities) / np.sum(post_testing_probabilities)

    # Combine the data for stacked bar charts
    data = [pre_test_probabilities, post_testing_probabilities]

    # Labels for the x-axis
    x_labels = ["Before Info", "After Info"]

    # Create the plot
    fig, ax = plt.subplots(figsize=figsize)

    # Stacked bar chart
    x = np.arange(len(data))  # Positions for "Before" and "After"
    bottoms = np.zeros(len(data))  # To track stack heights
    width = 0.12  # Increased bar width for better visibility

    # Store top and bottom coordinates for dotted line drawing
    bar_tops = {i: [] for i in range(len(dx_names))}
    bar_bottoms = {i: [] for i in range(len(dx_names))}

    # Add bars and annotate probabilities
    for i, category in enumerate(dx_names):
        values = [data_point[i] for data_point in data]
        bars = ax.bar(x, values, width=width, bottom=bottoms, label=f"{category} (LogLR: {weights_of_evidence[i]:.2f})")
        for j, (bar, value) in enumerate(zip(bars, values)):
            # Add text inside each bar
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                bar.get_y() + bar.get_height() / 2,
                f"{value:.2f}",
                ha="center",
                va="center",
                fontsize=10,
                color="white"
            )
            # Store the top and bottom of each bar
            bar_tops[i].append(bar.get_y() + bar.get_height())
            bar_bottoms[i].append(bar.get_y())

        bottoms += values

    # Add dotted lines connecting the top and bottom of each category
    for i in range(len(dx_names)):
        ax.plot(
            x,  # x-coordinates for "Before" and "After"
            bar_tops[i],  # y-coordinates for the tops of the bars
            linestyle="--", color="gray", alpha=0.5
        )
        ax.plot(
            x,  # x-coordinates for "Before" and "After"
            bar_bottoms[i],  # y-coordinates for the bottoms of the bars
            linestyle="--", color="gray", alpha=0.5
        )

    # Add labels, title, and legend
    ax.set_xticks(x)
    ax.set_xticklabels(x_labels)
    ax.set_ylabel("Probability")
    ax.set_title("Before-Data vs After-Data Probabilities")

    # Move the legend below the chart
    ax.legend(title="Diagnoses", loc="best", bbox_to_anchor=(0.5, -0.15), ncol=1)

    # Adjust layout to avoid overlap
    plt.tight_layout()
    plt.show()


# Generalized wrapper function for n diagnoses
def generalized_predictor_wrapper(diagnoses):
    """
    Wrapper function to compute post-test probabilities for all diagnoses using softmax regression.

    Parameters:
        diagnoses (list): List of dictionaries, each containing the details of a diagnosis.
    
    Returns:
        dict: Dictionary containing names, normalized pre-test probabilities, weights of evidence, and post-test probabilities.
    """
    # Gather pre-test probabilities and cumulative weights of evidence
    pre_test_probs = []
    weights_of_evidence = []
    names = []

    for diagnosis in diagnoses:
        name = diagnosis['name']
        pre_test_prob = certainty_estimates[diagnosis['pretest']]
        major_positive = diagnosis['major_positive']
        major_negative = diagnosis['major_negative']
        minor_positive = diagnosis['minor_positive']
        minor_negative = diagnosis['minor_negative']

        # Calculate the cumulative weight of evidence for this diagnosis
        weight_of_evidence = (
            major_positive 
            - (major_negative / 2) 
            + (minor_positive / 3) 
            - (minor_negative / 6)
        )

        # Append data
        pre_test_probs.append(pre_test_prob)
        weights_of_evidence.append(weight_of_evidence)
        names.append(name)

    # Normalize pre-test probabilities
    pre_test_sum = sum(pre_test_probs)
    norm_pre_test_probs = [p / pre_test_sum for p in pre_test_probs]

    # Use softmax regression to calculate post-test probabilities
    predicted_probs = softmax_regression_predict(pre_test_probs, weights_of_evidence)

    # Create result dictionary
    results = {
        'names': names,
        'pre_test_probs': pre_test_probs,
        'norm_pre_test_probs': norm_pre_test_probs,
        'weights_of_evidence': weights_of_evidence,
        'predicted_probs': predicted_probs
    }

    # Display the results
    print(f"{'Diagnosis':<20}{'Pre-Test Prob':<15}{'Norm Pre-Test Prob':<20}{'LR':<15}{'LogLR':<15}{'Post-Test Prob':<15}")
    print("-" * 95)
    for i, name in enumerate(names):
        print(f"{name:<20}{pre_test_probs[i]:<15.2f}{norm_pre_test_probs[i]:<20.2f}{np.exp(weights_of_evidence[i]):<15.2f}{weights_of_evidence[i]:<15.2f}{predicted_probs[i]:<15.2f}")

    return results


# Dynamic widget for n diagnoses
def create_diagnosis_widget():
    """
    Creates an interactive widget for entering diagnosis details and calculating post-test probabilities
    using softmax regression, displaying the stacked chart dynamically.
    
    Returns:
        VBox: A vertical box containing the interactive widget.
    """
    # Define dynamic widgets for n diagnoses
    diagnosis_count = widgets.IntSlider(value=1, min=1, max=10, step=1, description="Diagnoses")
    container = VBox()
    chart_output = Output()  # Output widget to display the chart

    # Function to update the widgets based on the number of diagnoses
    def update_widgets(change):
        if change['name'] == 'value':  # Ensure the change is triggered by the slider
            container.children = []  # Clear the container before updating
            for i in range(change['new']):  # Iterate over the new diagnosis count
                name_widget = widgets.Text(value=f"Diagnosis {i+1}", description="Name:")
                pretest_widget = widgets.Dropdown(
                    options=['HIGHLY_UNLIKELY', 'UNLIKELY', 'UNCERTAIN', 'LIKELY', 'HIGHLY_LIKELY'],
                    value='UNCERTAIN',
                    description="Pre-Test:"
                )
                major_pos_widget = widgets.IntSlider(value=0, min=0, max=10, step=1, description="Major+")
                major_neg_widget = widgets.IntSlider(value=0, min=0, max=10, step=1, description="Major-")
                minor_pos_widget = widgets.IntSlider(value=0, min=0, max=10, step=1, description="Minor+")
                minor_neg_widget = widgets.IntSlider(value=0, min=0, max=10, step=1, description="Minor-")

                # Combine all widgets for one diagnosis
                diagnosis_box = VBox([name_widget, pretest_widget, major_pos_widget, major_neg_widget, minor_pos_widget, minor_neg_widget])
                container.children += (diagnosis_box,)  # Add to the container

    # Observe changes in the slider and update widgets accordingly
    diagnosis_count.observe(update_widgets, names='value')

    # Trigger an initial update
    update_widgets({'name': 'value', 'new': diagnosis_count.value})

    # Function to gather inputs from the widgets
    def gather_inputs():
        diagnoses = []
        for diagnosis_box in container.children:
            name = diagnosis_box.children[0].value
            pretest = diagnosis_box.children[1].value
            major_positive = diagnosis_box.children[2].value
            major_negative = diagnosis_box.children[3].value
            minor_positive = diagnosis_box.children[4].value
            minor_negative = diagnosis_box.children[5].value
            diagnoses.append({
                'name': name,
                'pretest': pretest,
                'major_positive': major_positive,
                'major_negative': major_negative,
                'minor_positive': minor_positive,
                'minor_negative': minor_negative
            })

        # Clear chart output before displaying a new one
        chart_output.clear_output()

        # Compute results and display chart
        with chart_output:
            results = generalized_predictor_wrapper(diagnoses)
            stacked_chart_pre_post(
                results['pre_test_probs'],
                results['predicted_probs'],
                results['names'],
                results['weights_of_evidence']
            )

    # Add a button to trigger the calculation
    calculate_button = widgets.Button(description="Calculate")
    calculate_button.on_click(lambda x: gather_inputs())

    # Combine the slider, container, and calculate button
    return HBox([VBox([diagnosis_count, container, calculate_button]), chart_output])


diagnosis_widget = create_diagnosis_widget()
display(diagnosis_widget)

HBox(children=(VBox(children=(IntSlider(value=1, description='Diagnoses', max=10, min=1), VBox(children=(VBox(…

In [2]:
import matplotlib.pyplot as plt
import scienceplots
import numpy as np
from ipywidgets import interact, widgets, VBox, HBox, Layout, Output
from dataclasses import dataclass
from typing import Optional, Dict, Any
from scipy.optimize import least_squares
from scipy.special import expit, logsumexp

plt.style.use('science')
plt.rcParams['text.usetex'] = False

# --------------------------------------------------------------------
# Qualitative pretest → numeric prior (edit these if you recalibrate)
# --------------------------------------------------------------------
certainty_estimates = {
    "HIGHLY_UNLIKELY": 0.119,
    "UNLIKELY":        0.269,
    "UNCERTAIN":       0.500,
    "LIKELY":          0.731,
    "HIGHLY_LIKELY":   0.881
}

# ====================================================================
# Bayes‑coherent OVR projection (mixture‑aware coupling)
#   Inputs:
#     priors π_k  (sum to 1)
#     ovr_lr λ_k  (LR for k vs "not k")  -- here λ_k = exp(weight_of_evidence_k)
#   Output:
#     coherent posterior P_k ∝ π_k * exp(s_k)
#     fitted OVR LRs implied by the model
# ====================================================================

@dataclass
class OVRProjectionResult:
    posterior: np.ndarray                 # coherent P_k
    log_scores: np.ndarray                # s_k (baseline fixed to 0)
    lr_vs_baseline: np.ndarray            # LR_{k:baseline} = exp(s_k)
    ovr_lr_fitted: np.ndarray             # coherent OVR LRs implied by s
    rmse: float                           # sqrt(mean(residual^2)) on log-LR scale
    success: bool
    message: str
    nfev: int
    diagnostics: Dict[str, Any]

def _project_ovr_to_coherent(
    priors: np.ndarray,
    ovr_lr: np.ndarray,
    weights: Optional[np.ndarray] = None,
    baseline: Optional[int] = None,
    reg: float = 1e-6,
    max_nfev: int = 200,
    ftol: float = 1e-10,
    xtol: float = 1e-10,
    gtol: float = 1e-10,
) -> OVRProjectionResult:
    pi = np.asarray(priors, dtype=float).copy()
    lam = np.asarray(ovr_lr, dtype=float).copy()
    K = pi.size
    if K < 2:
        raise ValueError("Need at least 2 diagnoses for OVR.")
    if lam.size != K:
        raise ValueError("ovr_lr and priors must have the same length.")
    if np.any(lam <= 0):
        raise ValueError("All OVR LRs must be > 0.")

    eps = 1e-12
    pi = np.clip(pi, eps, 1 - eps)
    pi = pi / pi.sum()
    if baseline is None:
        baseline = int(np.argmax(pi))

    if weights is None:
        w = np.ones(K, dtype=float)
    else:
        w = np.asarray(weights, dtype=float)
        if w.size != K:
            raise ValueError("weights must have length K.")
        w = np.clip(w, eps, None)

    log_pi = np.log(pi)
    log_one_minus_pi = np.log(1.0 - pi)
    log_lambda = np.log(lam)
    logit_pi = log_pi - np.log(1.0 - pi)

    # Initialize with per-binary posteriors (the “slices”); exact if coherent
    q_init = expit(logit_pi + log_lambda)
    s0_full = np.log(q_init) - np.log(pi)      # s up to a constant
    s0_full -= s0_full[baseline]               # fix baseline to 0
    keep = np.array([i for i in range(K) if i != baseline], dtype=int)
    theta0 = s0_full[keep]

    def unpack(theta):
        s = np.empty(K, dtype=float)
        s[baseline] = 0.0
        s[keep] = theta
        return s

    def residuals_and_jac(theta):
        s = unpack(theta)
        z = log_pi + s
        m = np.max(z)
        ez = np.exp(z - m)                 # ∝ π_k e^{s_k}
        S = np.sum(ez)
        S_minus = np.maximum(S - ez, eps)  # ∑_{j≠k} π_j e^{s_j}
        logS_minus = np.log(S_minus) + m

        model_loglr = s - logS_minus + log_one_minus_pi      # OVR model log-LR
        r = np.sqrt(w) * (log_lambda - model_loglr)

        # Analytic Jacobian
        J = np.zeros((K, K), dtype=float)
        np.fill_diagonal(J, -np.sqrt(w))
        frac = ez[None, :] / S_minus[:, None]
        frac[np.arange(K), np.arange(K)] = 0.0
        J += (np.sqrt(w)[:, None] * frac)
        J = J[:, keep]

        if reg > 0.0:
            r = np.concatenate([r, np.sqrt(reg)*theta])
            J = np.vstack([J, np.sqrt(reg)*np.eye(theta.size)])
        return r, J

    res = least_squares(lambda t: residuals_and_jac(t)[0],
                        theta0,
                        jac=lambda t: residuals_and_jac(t)[1],
                        method="trf",
                        max_nfev=max_nfev, ftol=ftol, xtol=xtol, gtol=gtol)

    theta_hat = res.x
    s_hat = unpack(theta_hat)
    z_hat = log_pi + s_hat
    posterior = np.exp(z_hat - logsumexp(z_hat))

    exp_z = np.exp(z_hat)
    S = np.sum(exp_z)
    S_minus_hat = np.maximum(S - exp_z, eps)
    model_loglr_hat = s_hat - np.log(S_minus_hat) + log_one_minus_pi
    ovr_lr_hat = np.exp(model_loglr_hat)

    return OVRProjectionResult(
        posterior=posterior,
        log_scores=s_hat,
        lr_vs_baseline=np.exp(s_hat),
        ovr_lr_fitted=ovr_lr_hat,
        rmse=float(np.sqrt(np.mean(res.fun**2))),
        success=bool(res.success),
        message=res.message,
        nfev=int(res.nfev),
        diagnostics={
            "K": K,
            "baseline": baseline,
            "sum_q_init": float(np.sum(q_init)),  # =1 if inputs coherent
            "priors": pi,
            "input_ovr_lr": lam,
        },
    )

# -------------------------------------------------------------
# Visualization: show Input vs Fitted OVR LR, pre vs post bars
# -------------------------------------------------------------
def _stacked_chart_pre_post(pre_test_probabilities, post_testing_probabilities, dx_names,
                            input_loglr, fitted_ovr_lr=None, figsize=(6, 6)):
    if not (len(pre_test_probabilities) == len(post_testing_probabilities) == len(dx_names) == len(input_loglr)):
        raise ValueError("All input arrays must have the same length.")

    pre = np.array(pre_test_probabilities, dtype=float)
    post = np.array(post_testing_probabilities, dtype=float)
    pre = pre / pre.sum()
    post = post / post.sum()

    data = [pre, post]
    x_labels = ["Before Info", "After Info"]

    fig, ax = plt.subplots(figsize=figsize)
    x = np.arange(len(data))
    bottoms = np.zeros(len(data))
    width = 0.12

    bar_tops = {i: [] for i in range(len(dx_names))}
    bar_bottoms = {i: [] for i in range(len(dx_names))}

    for i, category in enumerate(dx_names):
        values = [data_point[i] for data_point in data]
        legend_txt = f"Input OVR LR: {np.exp(input_loglr[i]):.2f}"
        if fitted_ovr_lr is not None:
            legend_txt += f" | Fitted OVR LR: {fitted_ovr_lr[i]:.2f}"
        bars = ax.bar(x, values, width=width, bottom=bottoms,
                      label=f"{category} ({legend_txt})")
        for bar, value in zip(bars, values):
            ax.text(bar.get_x() + bar.get_width()/2,
                    bar.get_y() + bar.get_height()/2,
                    f"{value:.2f}",
                    ha="center", va="center", fontsize=10, color="white")
            bar_tops[i].append(bar.get_y() + bar.get_height())
            bar_bottoms[i].append(bar.get_y())
        bottoms += values

    for i in range(len(dx_names)):
        ax.plot(x, bar_tops[i], linestyle="--", color="gray", alpha=0.5)
        ax.plot(x, bar_bottoms[i], linestyle="--", color="gray", alpha=0.5)

    ax.set_xticks(x)
    ax.set_xticklabels(x_labels)
    ax.set_ylabel("Probability")
    ax.set_title("Before-Data vs After-Data Probabilities")
    ax.legend(title="Diagnoses", loc="best", bbox_to_anchor=(0.5, -0.15), ncol=1)
    plt.tight_layout()
    plt.show()

# ====================================================================
# Your app‑specific mapping from qualitative finding counts to weights
#   Here we interpret 'weight_of_evidence' as log OVR LR for class k.
#   Edit the coefficients if you recalibrate those qualitative bins.
# ====================================================================
def _compute_weight_of_evidence(major_positive, major_negative, minor_positive, minor_negative):
    # Example linear rule you supplied (units: log‑LR for k vs not‑k)
    return (
        major_positive
        - (major_negative / 2)
        + (minor_positive / 3)
        - (minor_negative / 6)
    )

# ====================================================================
# Wrapper (OVR projection): assembles inputs, runs solver, prints table
# ====================================================================
def generalized_predictor_wrapper(diagnoses, reg=1e-6):
    """
    diagnoses: list of dicts like
      {'name': str, 'pretest': one of certainty_estimates keys,
       'major_positive': int, 'major_negative': int, 'minor_positive': int, 'minor_negative': int}
    """
    names, pre_test_probs, loglr_input = [], [], []

    # Build priors and per‑class cumulative log OVR LR
    for d in diagnoses:
        names.append(d['name'])
        pre_test_probs.append(certainty_estimates[d['pretest']])
        loglr_input.append(_compute_weight_of_evidence(
            d['major_positive'], d['major_negative'], d['minor_positive'], d['minor_negative']
        ))

    pre = np.array(pre_test_probs, dtype=float)
    priors = pre / np.sum(pre)
    loglr_input = np.array(loglr_input, dtype=float)
    lam_input = np.exp(loglr_input)

    # Project to a single coherent model
    res = _project_ovr_to_coherent(priors=priors, ovr_lr=lam_input, reg=reg)

    post = res.posterior
    fitted_ovr = res.ovr_lr_fitted
    fitted_loglr = np.log(fitted_ovr)

    # Table: Input vs Fitted OVR LR + Fitted LogLR
    header = (
        f"{'Diagnosis':<20}{'Pre-Test':<12}{'Norm Pre':<12}"
        f"{'Input OVR LR':<15}{'Fitted OVR LR':<16}{'Fitted LogLR':<15}{'Post-Test':<12}"
    )
    print(f"OVR projection success: {res.success} (nfev={res.nfev}) | rmse(log-LR)={res.rmse:.4g}")
    print(f"Baseline class index: {res.diagnostics['baseline']} | sum q_init: {res.diagnostics['sum_q_init']:.6f}")
    print(header)
    print("-" * len(header))
    for i, name in enumerate(names):
        print(f"{name:<20}"
              f"{pre[i]:<12.3f}{priors[i]:<12.3f}"
              f"{lam_input[i]:<15.3f}{fitted_ovr[i]:<16.3f}{fitted_loglr[i]:<15.3f}{post[i]:<12.3f}")

    return {
        'names': names,
        'pre_test_probs': pre,                # raw priors (not yet normalized)
        'norm_pre_test_probs': priors,        # normalized priors
        'weights_of_evidence': loglr_input,   # input log OVR LR (for legend)
        'predicted_probs': post,
        'fitted_ovr_lr': fitted_ovr,
        'fitted_loglr': fitted_loglr
    }

# ====================================================================
# Interactive widget
# ====================================================================
def create_diagnosis_widget():
    # OVR needs at least 2 classes
    diagnosis_count = widgets.IntSlider(value=3, min=2, max=12, step=1, description="Diagnoses")
    reg_slider = widgets.FloatLogSlider(value=1e-6, base=10, min=-12, max=-2, step=0.5,
                                        description="Reg (L2)", readout_format=".1e")

    container = VBox()
    chart_output = Output()

    def update_widgets(change):
        if change['name'] == 'value':
            children = []
            for i in range(change['new']):
                name_widget = widgets.Text(value=f"Diagnosis {i+1}", description="Name:")
                pretest_widget = widgets.Dropdown(
                    options=['HIGHLY_UNLIKELY', 'UNLIKELY', 'UNCERTAIN', 'LIKELY', 'HIGHLY_LIKELY'],
                    value='UNCERTAIN', description="Pre-Test:"
                )
                major_pos_widget = widgets.IntSlider(value=0, min=0, max=10, step=1, description="Major+")
                major_neg_widget = widgets.IntSlider(value=0, min=0, max=10, step=1, description="Major-")
                minor_pos_widget = widgets.IntSlider(value=0, min=0, max=10, step=1, description="Minor+")
                minor_neg_widget = widgets.IntSlider(value=0, min=0, max=10, step=1, description="Minor-")
                children.append(VBox([name_widget, pretest_widget,
                                      major_pos_widget, major_neg_widget,
                                      minor_pos_widget, minor_neg_widget]))
            container.children = tuple(children)

    diagnosis_count.observe(update_widgets, names='value')
    update_widgets({'name': 'value', 'new': diagnosis_count.value})

    def gather_inputs(_=None):
        diagnoses = []
        for box in container.children:
            diagnoses.append({
                'name':            box.children[0].value,
                'pretest':         box.children[1].value,
                'major_positive':  box.children[2].value,
                'major_negative':  box.children[3].value,
                'minor_positive':  box.children[4].value,
                'minor_negative':  box.children[5].value
            })

        chart_output.clear_output()
        with chart_output:
            results = generalized_predictor_wrapper(diagnoses, reg=float(reg_slider.value))
            _stacked_chart_pre_post(
                results['pre_test_probs'],
                results['predicted_probs'],
                results['names'],
                results['weights_of_evidence'],     # input log OVR LR (for legend)
                fitted_ovr_lr=results['fitted_ovr_lr']
            )

    calculate_button = widgets.Button(description="Calculate")
    calculate_button.on_click(gather_inputs)

    left = VBox([diagnosis_count, reg_slider, container, calculate_button])
    return HBox([left, chart_output])


diagnosis_widget = create_diagnosis_widget()
display(diagnosis_widget)

HBox(children=(VBox(children=(IntSlider(value=3, description='Diagnoses', max=12, min=2), FloatLogSlider(value…