In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, widgets, VBox, HBox, Layout, Output, Label

import scienceplots
plt.style.use('science')
plt.rcParams['text.usetex'] = False


def softmax_regression_predict(pre_test_probabilities, weights_of_evidence):
    if len(pre_test_probabilities) != len(weights_of_evidence):
        raise ValueError("Lengths of pre_test_probabilities and weights_of_evidence must match.")
    
    # Convert pre-test probabilities to logs; add the log-likelihood ratio
    combined = np.log(pre_test_probabilities) + np.array(weights_of_evidence)
    
    # Softmax for final probabilities
    exp_combined = np.exp(combined - np.max(combined))  # Stability trick
    probabilities = exp_combined / np.sum(exp_combined)
    
    return probabilities


def stacked_chart_pre_post(pre_test_probabilities, post_testing_probabilities, dx_names, weights_of_evidence, figsize=(6, 6)):
    """Visualize the changes in probabilities with dotted lines connecting the top and bottom of pre- and post-test bars."""
    # Check that input lengths match
    if not (len(pre_test_probabilities) == len(post_testing_probabilities) == len(dx_names) == len(weights_of_evidence)):
        raise ValueError("All input arrays must have the same length.")

    # Normalize the probabilities to ensure they sum to 1
    pre_test_probabilities = np.array(pre_test_probabilities) / np.sum(pre_test_probabilities)
    post_testing_probabilities = np.array(post_testing_probabilities) / np.sum(post_testing_probabilities)

    # Combine the data for stacked bar charts
    data = [pre_test_probabilities, post_testing_probabilities]

    # Labels for the x-axis
    x_labels = ["Before Info", "After Info"]

    # Create the plot
    fig, ax = plt.subplots(figsize=figsize)

    # Stacked bar chart
    x = np.arange(len(data))  # Positions for "Before" and "After"
    bottoms = np.zeros(len(data))  # To track stack heights
    width = 0.12  # Increased bar width for better visibility

    # Store top and bottom coordinates for dotted line drawing
    bar_tops = {i: [] for i in range(len(dx_names))}
    bar_bottoms = {i: [] for i in range(len(dx_names))}

    # Add bars and annotate probabilities
    for i, category in enumerate(dx_names):
        values = [data_point[i] for data_point in data]
        bars = ax.bar(x, values, width=width, bottom=bottoms, label=f"{category} (Cumulative LR: {np.exp(weights_of_evidence[i]):.2f})")
        for j, (bar, value) in enumerate(zip(bars, values)):
            # Add text inside each bar
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                bar.get_y() + bar.get_height() / 2,
                f"{value:.2f}",
                ha="center",
                va="center",
                fontsize=10,
                color="white"
            )
            # Store the top and bottom of each bar
            bar_tops[i].append(bar.get_y() + bar.get_height())
            bar_bottoms[i].append(bar.get_y())

        bottoms += values

    # Add dotted lines connecting the top and bottom of each category
    for i in range(len(dx_names)):
        ax.plot(
            x,  # x-coordinates for "Before" and "After"
            bar_tops[i],  # y-coordinates for the tops of the bars
            linestyle="--", color="gray", alpha=0.5
        )
        ax.plot(
            x,  # x-coordinates for "Before" and "After"
            bar_bottoms[i],  # y-coordinates for the bottoms of the bars
            linestyle="--", color="gray", alpha=0.5
        )

    # Add labels, title, and legend
    ax.set_xticks(x)
    ax.set_xticklabels(x_labels)
    ax.set_ylabel("Probability")
    ax.set_title("Before-Data vs After-Data Probabilities")

    # Move the legend below the chart
    ax.legend(title="Diagnoses", loc="best", bbox_to_anchor=(0.5, -0.15), ncol=1)

    # Adjust layout to avoid overlap
    plt.tight_layout()
    plt.show()


def generalized_predictor_wrapper(diagnoses):
    """
    diagnoses = list of dicts:
    [
      { 'name': <str>, 'pretest': <float from 0 to 1>, 'loglr': <float from -5 to +5> },
      ...
    ]
    """
    pre_test_probs = []
    weights_of_evidence = []
    names = []
    
    for diag in diagnoses:
        names.append(diag['name'])
        pre_test_probs.append(diag['pretest'])
        weights_of_evidence.append(diag['loglr'])

    pre_test_sum = sum(pre_test_probs)
    norm_pre_test_probs = [p / pre_test_sum for p in pre_test_probs]

    predicted_probs = softmax_regression_predict(norm_pre_test_probs, weights_of_evidence) # doesn't need to be norm'd
    
    # Display the results
    print(f"{'Diagnosis':<20}{'Pre-Test Pr':<15}{'Norm. Pre-Test':<15}{'Total LR':<15}{'LogLR':<15}{'Post-Test Pr':<15}")
    print("-" * 95)
    for i, name in enumerate(names):
        print(f"{name:<20}{pre_test_probs[i]:<15.2f}{norm_pre_test_probs[i]:<15.2f}{np.exp(weights_of_evidence[i]):<15.2f}{weights_of_evidence[i]:<15.2f}{predicted_probs[i]:<15.2f}")
    
    return {
        'names': names,
        'pre_test_probs': pre_test_probs,
        'weights_of_evidence': weights_of_evidence,
        'predicted_probs': predicted_probs
    }


def create_diagnosis_widget():
    diagnosis_count = widgets.IntSlider(
        value=1, min=1, max=10, step=1, description="Diagnoses"
    )
    container = VBox()
    chart_output = Output()

    def update_widgets(change):
        if change['name'] == 'value':
            container.children = []
            for i in range(change['new']):
                name_widget = widgets.Text(
                    value=f"Diagnosis {i+1}",
                    description="Name:"
                )
                
                # Pre-test probability slider from 0.0 to 1.0, 2-decimal readout
                pretest_widget = widgets.FloatSlider(
                    value=0.5,
                    min=0.0,
                    max=1.0,
                    step=0.01,
                    description="Pre-Test:",
                    readout_format='.2f'
                )

                # Create the LogLR FloatLogSlider
                loglr_widget = widgets.FloatLogSlider(
                    value=1.0,            # Start with LR = 1
                    base=np.e,            # Base e
                    min=-5.0,             # Minimum LogLR
                    max=5.0,              # Maximum LogLR
                    step=0.1,             # Step size
                    description="Total LR:", # Label for the slider
                    readout=True,
                    readout_format=".2f"
                )

                diagnosis_box = VBox([
                    name_widget,
                    pretest_widget,
                    loglr_widget,
                ])
                container.children += (diagnosis_box,)

    diagnosis_count.observe(update_widgets, names='value')
    
    # Force the initial layout
    update_widgets({'name': 'value', 'new': diagnosis_count.value})

    def gather_inputs(_):
        diagnoses = []
        for diagnosis_box in container.children:
            name = diagnosis_box.children[0].value
            pretest = diagnosis_box.children[1].value
            loglr = np.log(diagnosis_box.children[2].value)  # Transform LR back to LogLR

            diagnoses.append({
                'name': name,
                'pretest': pretest,
                'loglr': loglr
            })

        chart_output.clear_output()
        with chart_output:
            results = generalized_predictor_wrapper(diagnoses)
            stacked_chart_pre_post(
                results['pre_test_probs'],
                results['predicted_probs'],
                results['names'],
                results['weights_of_evidence']
            )

    calculate_button = widgets.Button(description="Calculate")
    calculate_button.on_click(gather_inputs)

    return HBox([VBox([diagnosis_count, container, calculate_button]), chart_output])


# Now create and display the revised widget
diagnosis_widget = create_diagnosis_widget()
display(diagnosis_widget)

HBox(children=(VBox(children=(IntSlider(value=1, description='Diagnoses', max=10, min=1), VBox(children=(VBox(…