In [1]:
import matplotlib.pyplot as plt
import scienceplots
import numpy as np
plt.style.use('science')
plt.rcParams['text.usetex'] = False
from ipywidgets import interact, widgets, VBox, HBox, Layout, Output

certainty_estimates = {
    "HIGHLY_UNLIKELY": 0.119,
    "UNLIKELY": 0.269,
    "UNCERTAIN": 0.5,
    "LIKELY": 0.731,
    "HIGHLY_LIKELY": 0.881
}

def softmax_regression_predict(pre_test_probabilities, weights_of_evidence):
    """
    Predict probabilities using a multinomial logistic regression model (softmax regression).
    
    Parameters:
        pre_test_probabilities (list of float): Pre-test probabilities (baseline probabilities before including predictors).
        weights_of_evidence (list of float): Linear predictors (log-odds contributions) for each class.
    
    Returns:
        list of float: Predicted probabilities for each class (summing to 1).
    
    Raises:
        ValueError: If the lengths of inputs don't match or if pre_test_probabilities contain invalid values.
    """
    
    if len(pre_test_probabilities) != len(weights_of_evidence):
        raise ValueError("The lengths of pre_test_probabilities and weights_of_evidence must be the same (equal to the number of diagnoses).")
    
    if not all(0 <= prob <= 1 for prob in pre_test_probabilities):
        raise ValueError("All pre_test_probabilities must be between 0 and 1.")
    
    combined = np.log(pre_test_probabilities) + np.array(weights_of_evidence)
    
    # Compute probabilities using the softmax function
    exp_combined = np.exp(combined - np.max(combined))  # Stability trick to avoid overflow
    probabilities = exp_combined / np.sum(exp_combined)
    
    return probabilities

def stacked_chart_pre_post(pre_test_probabilities, post_testing_probabilities, dx_names, weights_of_evidence, figsize=(6, 6)):
    """Visualize the changes in probabilities with dotted lines connecting the top and bottom of pre- and post-test bars."""
    # Check that input lengths match
    if not (len(pre_test_probabilities) == len(post_testing_probabilities) == len(dx_names) == len(weights_of_evidence)):
        raise ValueError("All input arrays must have the same length.")

    # Normalize the probabilities to ensure they sum to 1
    pre_test_probabilities = np.array(pre_test_probabilities) / np.sum(pre_test_probabilities)
    post_testing_probabilities = np.array(post_testing_probabilities) / np.sum(post_testing_probabilities)

    # Combine the data for stacked bar charts
    data = [pre_test_probabilities, post_testing_probabilities]

    # Labels for the x-axis
    x_labels = ["Before Info", "After Info"]

    # Create the plot
    fig, ax = plt.subplots(figsize=figsize)

    # Stacked bar chart
    x = np.arange(len(data))  # Positions for "Before" and "After"
    bottoms = np.zeros(len(data))  # To track stack heights
    width = 0.175  # Increased bar width for better visibility

    # Store top and bottom coordinates for dotted line drawing
    bar_tops = {i: [] for i in range(len(dx_names))}
    bar_bottoms = {i: [] for i in range(len(dx_names))}

    # Add bars and annotate probabilities
    for i, category in enumerate(dx_names):
        values = [data_point[i] for data_point in data]
        bars = ax.bar(x, values, width=width, bottom=bottoms, label=f"{category} (WoE: {weights_of_evidence[i]:.2f})")
        for j, (bar, value) in enumerate(zip(bars, values)):
            # Add text inside each bar
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                bar.get_y() + bar.get_height() / 2,
                f"{value:.2f}",
                ha="center",
                va="center",
                fontsize=10,
                color="white"
            )
            # Store the top and bottom of each bar
            bar_tops[i].append(bar.get_y() + bar.get_height())
            bar_bottoms[i].append(bar.get_y())

        bottoms += values

    # Add dotted lines connecting the top and bottom of each category
    for i in range(len(dx_names)):
        ax.plot(
            x,  # x-coordinates for "Before" and "After"
            bar_tops[i],  # y-coordinates for the tops of the bars
            linestyle="--", color="gray", alpha=0.5
        )
        ax.plot(
            x,  # x-coordinates for "Before" and "After"
            bar_bottoms[i],  # y-coordinates for the bottoms of the bars
            linestyle="--", color="gray", alpha=0.5
        )

    # Add labels, title, and legend
    ax.set_xticks(x)
    ax.set_xticklabels(x_labels)
    ax.set_ylabel("Probability")
    ax.set_title("Before-Data vs After-Data Probabilities")

    # Move the legend below the chart
    ax.legend(title="Diagnoses", loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=4)

    # Adjust layout to avoid overlap
    plt.tight_layout()
    plt.show()

# Generalized wrapper function for n diagnoses
def generalized_predictor_wrapper(diagnoses):
    """
    Wrapper function to compute post-test probabilities for all diagnoses using softmax regression.

    Parameters:
        diagnoses (list): List of dictionaries, each containing the details of a diagnosis.
    
    Returns:
        dict: Dictionary containing names, normalized pre-test probabilities, weights of evidence, and post-test probabilities.
    """
    # Gather pre-test probabilities and cumulative weights of evidence
    pre_test_probs = []
    weights_of_evidence = []
    names = []

    for diagnosis in diagnoses:
        name = diagnosis['name']
        pre_test_prob = certainty_estimates[diagnosis['pretest']]
        major_positive = diagnosis['major_positive']
        major_negative = diagnosis['major_negative']
        minor_positive = diagnosis['minor_positive']
        minor_negative = diagnosis['minor_negative']

        # Calculate the cumulative weight of evidence for this diagnosis
        weight_of_evidence = (
            major_positive 
            - (major_negative / 2) 
            + (minor_positive / 3) 
            - (minor_negative / 6)
        )

        # Append data
        pre_test_probs.append(pre_test_prob)
        weights_of_evidence.append(weight_of_evidence)
        names.append(name)

    # Normalize pre-test probabilities
    pre_test_sum = sum(pre_test_probs)
    norm_pre_test_probs = [p / pre_test_sum for p in pre_test_probs]

    # Use softmax regression to calculate post-test probabilities
    predicted_probs = softmax_regression_predict(pre_test_probs, weights_of_evidence)

    # Create result dictionary
    results = {
        'names': names,
        'pre_test_probs': pre_test_probs,
        'norm_pre_test_probs': norm_pre_test_probs,
        'weights_of_evidence': weights_of_evidence,
        'predicted_probs': predicted_probs
    }

    # Display the results
    print(f"{'Diagnosis':<15}{'Pre-Test Prob':<15}{'Norm Pre-Test Prob':<20}{'WoE':<15}{'Post-Test Prob':<15}")
    print("-" * 80)
    for i, name in enumerate(names):
        print(f"{name:<15}{pre_test_probs[i]:<15.3f}{norm_pre_test_probs[i]:<20.3f}{weights_of_evidence[i]:<15.3f}{predicted_probs[i]:<15.3f}")

    return results


# Dynamic widget for n diagnoses
def create_diagnosis_widget():
    """
    Creates an interactive widget for entering diagnosis details and calculating post-test probabilities
    using softmax regression, displaying the stacked chart dynamically.
    
    Returns:
        VBox: A vertical box containing the interactive widget.
    """
    # Define dynamic widgets for n diagnoses
    diagnosis_count = widgets.IntSlider(value=1, min=1, max=10, step=1, description="Diagnoses")
    container = VBox()
    chart_output = Output()  # Output widget to display the chart

    # Function to update the widgets based on the number of diagnoses
    def update_widgets(change):
        if change['name'] == 'value':  # Ensure the change is triggered by the slider
            container.children = []  # Clear the container before updating
            for i in range(change['new']):  # Iterate over the new diagnosis count
                name_widget = widgets.Text(value=f"Diagnosis {i+1}", description="Name:")
                pretest_widget = widgets.Dropdown(
                    options=['HIGHLY_UNLIKELY', 'UNLIKELY', 'UNCERTAIN', 'LIKELY', 'HIGHLY_LIKELY'],
                    value='UNCERTAIN',
                    description="Pre-Test:"
                )
                major_pos_widget = widgets.IntSlider(value=0, min=0, max=10, step=1, description="Major+")
                major_neg_widget = widgets.IntSlider(value=0, min=0, max=10, step=1, description="Major-")
                minor_pos_widget = widgets.IntSlider(value=0, min=0, max=10, step=1, description="Minor+")
                minor_neg_widget = widgets.IntSlider(value=0, min=0, max=10, step=1, description="Minor-")

                # Combine all widgets for one diagnosis
                diagnosis_box = VBox([name_widget, pretest_widget, major_pos_widget, major_neg_widget, minor_pos_widget, minor_neg_widget])
                container.children += (diagnosis_box,)  # Add to the container

    # Observe changes in the slider and update widgets accordingly
    diagnosis_count.observe(update_widgets, names='value')

    # Trigger an initial update
    update_widgets({'name': 'value', 'new': diagnosis_count.value})

    # Function to gather inputs from the widgets
    def gather_inputs():
        diagnoses = []
        for diagnosis_box in container.children:
            name = diagnosis_box.children[0].value
            pretest = diagnosis_box.children[1].value
            major_positive = diagnosis_box.children[2].value
            major_negative = diagnosis_box.children[3].value
            minor_positive = diagnosis_box.children[4].value
            minor_negative = diagnosis_box.children[5].value
            diagnoses.append({
                'name': name,
                'pretest': pretest,
                'major_positive': major_positive,
                'major_negative': major_negative,
                'minor_positive': minor_positive,
                'minor_negative': minor_negative
            })

        # Clear chart output before displaying a new one
        chart_output.clear_output()

        # Compute results and display chart
        with chart_output:
            results = generalized_predictor_wrapper(diagnoses)
            stacked_chart_pre_post(
                results['pre_test_probs'],
                results['predicted_probs'],
                results['names'],
                results['weights_of_evidence']
            )

    # Add a button to trigger the calculation
    calculate_button = widgets.Button(description="Calculate")
    calculate_button.on_click(lambda x: gather_inputs())

    # Combine the slider, container, and calculate button
    return HBox([VBox([diagnosis_count, container, calculate_button]), chart_output])


diagnosis_widget = create_diagnosis_widget()
display(diagnosis_widget)

HBox(children=(VBox(children=(IntSlider(value=1, description='Diagnoses', max=10, min=1), VBox(children=(VBox(…