In [1]:
# Using plotly plot difference distributions.


from typing import Any, List, Literal

import numpy as np
import plotly.graph_objects as go
from IPython.display import DisplayHandle, display
from ipywidgets import widgets
from ipywidgets.widgets.widget_string import LabelStyle
from numpy.typing import NDArray
from plotly.subplots import make_subplots
from scipy.stats import norm

Alternative = Literal["smaller", "larger", "two-sided"]
ALTERNATIVES: List[Alternative] = ["smaller", "larger", "two-sided"]

TOMATO = (250, 97, 104)
MUSTARD = (255, 210, 46)


class ExperimentEffectViewer:
    """
    Initially the plots are between ± 3 or the ate if it is
    outside this range but then default to the selection from the user.
    """

    def __init__(
        self,
        ate: float = -0.5,
        standard_error: float = 0.5,
        alpha: float = 0.05,
        alternative: Alternative = "smaller",
        plot_x_min: float = -3,
        plot_x_max: float = 3,
        n_points: int = 1000,
        min_ate: float = -5,
        max_ate: float = 5,
        min_standard_error: float = 0.000,
        max_standard_error: float = 1,
        fig_height: int = 600,
        critical_opacity: float = 0.4,
        confidence_opacity: float = 0.4,
    ):
        self.ate = ate
        self.standard_error = standard_error
        self.alpha = alpha
        self.alternative = alternative
        self.plot_x_min = plot_x_min
        self.plot_x_max = plot_x_max
        self.n_points = n_points
        self.min_ate = min_ate
        self.max_ate = max_ate
        self.min_standard_error = min_standard_error
        self.max_standard_error = max_standard_error
        self.fig_height = fig_height
        self.critical_opacity = critical_opacity
        self.confidence_opacity = confidence_opacity

        self.validate()
        self._init_controls_and_fig()

    def validate(self):
        if self.alternative not in ALTERNATIVES:
            raise ValueError(f"Alternative must be one of {ALTERNATIVES}")
        if self.alpha < 0 or self.alpha > 0.5:
            raise ValueError("alpha must be between 0 and 1")
        if self.standard_error < 0:
            raise ValueError("standard error must be positive")

    @property
    def n_sides(self) -> Literal[1, 2]:
        return 2 if self.alternative == "two-sided" else 1

    @property
    def t_value(self) -> float:
        return self.ate / self.standard_error

    @property
    def p_value(self) -> float:
        """Probability of observing a t-value as extreme as the one observed"""
        return self.n_sides * norm.sf(abs(self.t_value))

    @property
    def is_significant(self) -> bool:
        return self.p_value < self.alpha

    @property
    def min_x_std(self) -> float:
        return min(-3, self.t_value - 1)

    @property
    def max_x_std(self) -> float:
        return max(3, self.t_value + 1)

    @property
    def min_x(self) -> float:
        return min(self.plot_x_min, self.ate - self.standard_error)

    @property
    def max_x(self) -> float:
        return max(self.plot_x_max, self.ate + self.standard_error)

    @property
    def x_std(self) -> NDArray[np.float64]:
        return np.linspace(self.min_x_std, self.max_x_std, self.n_points)

    @property
    def y_std(self) -> NDArray[np.float64]:
        return norm.pdf(self.x_std, loc=0, scale=1)

    @property
    def x(self) -> NDArray[np.float64]:
        return np.linspace(self.min_x, self.max_x, self.n_points)

    @property
    def y(self) -> NDArray[np.float64]:
        return norm.pdf(self.x, loc=0, scale=self.standard_error)

    @property
    def critical_value_left_std(self) -> float:
        return norm.ppf(self.alpha / self.n_sides)

    @property
    def critical_value_right_std(self) -> float:
        return norm.ppf(1 - self.alpha / self.n_sides)

    @property
    def critical_value_left(self) -> float:
        return self.critical_value_left_std * self.standard_error

    @property
    def critical_value_right(self) -> float:
        return self.critical_value_right_std * self.standard_error

    @property
    def confidence_lower_std(self) -> float:
        lower = self.t_value + self.critical_value_left_std
        if self.alternative == "smaller":
            # extend to end of chart on left.
            return min(-100, lower)
        return lower

    @property
    def confidence_upper_std(self) -> float:
        upper = self.t_value + self.critical_value_right_std
        if self.alternative == "larger":
            # extend to end of chart on right.
            return max(100, upper)
        return upper

    @property
    def confidence_lower(self) -> float:
        lower = self.ate + self.critical_value_left
        if self.alternative == "smaller":
            # extend to end of chart on left.
            return min(-100, lower)
        return lower

    @property
    def confidence_upper(self) -> float:
        upper = self.ate + self.critical_value_right
        if self.alternative == "larger":
            # extend to end of chart on right.
            return max(100, upper)
        return upper

    @property
    def ate_height_std(self) -> float:
        return max(self.y_std) / 2

    @property
    def ate_height(self) -> float:
        return max(self.y) / 2

    @property
    def x_lower_fill_std(self) -> NDArray[np.float64]:
        return np.linspace(self.min_x_std, self.critical_value_left_std, 100)

    @property
    def y_lower_fill_std(self) -> NDArray[np.float64]:
        return norm.pdf(self.x_lower_fill_std, loc=0, scale=1)

    @property
    def x_upper_fill_std(self) -> NDArray[np.float64]:
        return np.linspace(self.critical_value_right_std, self.max_x_std, 100)

    @property
    def y_upper_fill_std(self) -> NDArray[np.float64]:
        return norm.pdf(self.x_upper_fill_std, loc=0, scale=1)

    @property
    def x_lower_fill(self) -> NDArray[np.float64]:
        return np.linspace(self.min_x, self.critical_value_left, 100)

    @property
    def y_lower_fill(self) -> NDArray[np.float64]:
        return norm.pdf(self.x_lower_fill, loc=0, scale=self.standard_error)

    @property
    def x_upper_fill(self) -> NDArray[np.float64]:
        return np.linspace(self.critical_value_right, self.max_x, 100)

    @property
    def y_upper_fill(self) -> NDArray[np.float64]:
        return norm.pdf(self.x_upper_fill, loc=0, scale=self.standard_error)

    def _alpha_slider(self) -> widgets.FloatSlider:
        return widgets.FloatSlider(
            value=self.alpha,
            min=0,
            max=0.5,
            step=0.0001,
            description="alpha: ",
            continuous_update=True,
            layout=widgets.Layout(width="80%"),
        )

    def _alternative_dropdown(self) -> widgets.Dropdown:
        return widgets.Dropdown(
            options=ALTERNATIVES,
            value=self.alternative,
            description="Alternative: ",
            layout=widgets.Layout(width="80%"),
        )

    def _ate_slider(self) -> widgets.FloatSlider:
        return widgets.FloatSlider(
            value=self.ate,
            min=self.min_ate,
            max=self.max_ate,
            step=0.001,
            description="ATE",
            continuous_update=True,
            layout=widgets.Layout(width="80%"),
        )

    def _standard_error_slider(self) -> widgets.FloatSlider:
        return widgets.FloatSlider(
            value=self.standard_error,
            min=self.min_standard_error,
            max=self.max_standard_error,
            step=0.001,
            description="SE",
            continuous_update=True,
            layout=widgets.Layout(width="80%"),
        )

    def _x_range_slider(self) -> widgets.FloatRangeSlider:
        return widgets.FloatRangeSlider(
            value=[self.plot_x_min, self.plot_x_max],
            min=2 * self.plot_x_min,
            max=2 * self.plot_x_max,
            step=0.01,
            description="X range",
            continuous_update=True,
            layout=widgets.Layout(width="80%"),
        )

    def _p_value_label_widget(self) -> widgets.Label:
        return widgets.Label(value=f"p-value: {self.p_value:.4f}")

    def _is_significant_label_widget(self) -> widgets.Label:
        "Label is red if significant and blue if not"
        return widgets.Label(
            value=f"Significant: {self.is_significant}",
            style=LabelStyle(text_color="red" if self.is_significant else "blue"),
        )

    def _init_controls(self) -> None:
        self.alpha_slider = self._alpha_slider()
        self.alternative_dropdown = self._alternative_dropdown()
        self.ate_slider = self._ate_slider()
        self.standard_error_slider = self._standard_error_slider()
        self.x_range_slider = self._x_range_slider()
        self.p_value_label_widget = self._p_value_label_widget()
        self.is_significant_label_widget = self._is_significant_label_widget()

        self.controls_container = widgets.VBox(
            [
                self.alpha_slider,
                self.alternative_dropdown,
                self.ate_slider,
                self.standard_error_slider,
                self.x_range_slider,
                self.p_value_label_widget,
                self.is_significant_label_widget,
            ],
            layout=widgets.Layout(width="80%"),
        )

    def _link_controls_to_fig(self) -> None:
        self.alpha_slider.observe(self.update_fig)
        self.alternative_dropdown.observe(self.update_fig)
        self.ate_slider.observe(self.update_fig)
        self.standard_error_slider.observe(self.update_fig)
        self.x_range_slider.observe(self.update_fig)

    def _init_controls_and_fig(self) -> None:
        self._init_controls()
        self._init_fig()
        self._link_controls_to_fig()
        self.update_fig(None)
        self.container = widgets.VBox([self.controls_container, self.fig])

    def _init_fig(self) -> None:
        # Create the figure with titles and axis labels
        fig = go.FigureWidget(
            make_subplots(
                rows=2,
                cols=1,
                subplot_titles=(
                    "Normalised Mean Difference Distribution",
                    "Mean Difference Distribution",
                ),
            )
        )
        fig.layout.height = self.fig_height
        fig.layout.xaxis.title.text = "Scaled mean difference (t-value)"
        fig.layout.xaxis2.title.text = "mean difference"
        fig.layout.yaxis.title.text = "density"
        fig.layout.yaxis2.title.text = "density"

        # Add the two curves
        fig.add_trace(
            go.Scatter(
                x=self.x_std,
                y=self.y_std,
                mode="lines",
                name="Standard Normal Distribution",
            ),
            row=1,
            col=1,
        )
        fig.add_trace(
            go.Scatter(
                x=self.x,
                y=self.y,
                mode="lines",
                name="Mean difference distribution",
            ),
            row=2,
            col=1,
        )

        # Add critical areas plots
        fig.add_trace(
            go.Scatter(
                x=self.x_lower_fill_std,
                y=self.y_lower_fill_std,
                fill="tozeroy",
                mode="none",
                fillcolor="rgb({},{},{},{})".format(*TOMATO, self.critical_opacity),
                # show on legend as Critical Region
                name="Critical Region",
                showlegend=self.alternative != "larger",
                opacity=0.5,
            ),
            row=1,
            col=1,
        )

        fig.add_trace(
            go.Scatter(
                x=self.x_upper_fill_std,
                y=self.y_upper_fill_std,
                fill="tozeroy",
                mode="none",
                fillcolor="rgb({},{},{},{})".format(*TOMATO, self.critical_opacity),
                name="Critical Region",
                # don't show legend
                showlegend=self.alternative == "larger",
                opacity=0.5,
            ),
            row=1,
            col=1,
        )

        fig.add_trace(
            go.Scatter(
                x=self.x_lower_fill,
                y=self.y_lower_fill,
                fill="tozeroy",
                mode="none",
                fillcolor="rgb({},{},{},{})".format(*TOMATO, self.critical_opacity),
                name="Critical Region",
                showlegend=False,
                opacity=0.5,
            ),
            row=2,
            col=1,
        )

        fig.add_trace(
            go.Scatter(
                x=self.x_upper_fill,
                y=self.y_upper_fill,
                fill="tozeroy",
                mode="none",
                fillcolor="rgb({},{},{},{})".format(*TOMATO, self.critical_opacity),
                name="Critical Region",
                showlegend=False,
                opacity=0.5,
            ),
            row=2,
            col=1,
        )

        # Add the confidence intervals
        fig.add_shape(
            type="rect",
            xref="x",
            yref="y",
            x0=self.confidence_lower_std,
            y0=0,
            x1=self.confidence_upper_std,
            y1=self.ate_height_std,
            fillcolor="rgb({},{},{},{})".format(*MUSTARD, self.confidence_opacity),
            opacity=0.5,
            line=dict(width=0),
            name="Confidence Interval",
            row=1,
            col=1,
        )

        fig.add_shape(
            type="rect",
            xref="x",
            yref="y",
            x0=self.confidence_lower,
            y0=0,
            x1=self.confidence_upper,
            y1=self.ate_height,
            fillcolor="rgb({},{},{},{})".format(*MUSTARD, self.confidence_opacity),
            opacity=0.5,
            line=dict(width=0),
            name="Confidence Interval",
            row=2,
            col=1,
        )

        # add t_value ate with shapes
        fig.add_shape(
            type="line",
            x0=self.t_value,
            y0=0,
            x1=self.t_value,
            y1=self.ate_height_std,
            line=dict(color="black", width=2, dash="dash"),
            name="t-value",
            row=1,
            col=1,
        )
        fig.add_shape(
            type="line",
            x0=self.ate,
            y0=0,
            x1=self.ate,
            y1=self.ate_height,
            line=dict(color="black", width=2, dash="dash"),
            name="ATE",
            row=2,
            col=1,
        )

        self.fig = fig

    def update_fig(self, change: Any) -> None:
        # get reference to plot elements
        fig = self.fig
        normal_dist_plot = fig.data[0]
        mean_diff_plot = fig.data[1]
        left_critical_std_plot = fig.data[2]
        right_critical_std_plot = fig.data[3]
        left_critical_plot = fig.data[4]
        right_critical_plot = fig.data[5]
        confidence_interval_std_plot = fig.layout.shapes[0]
        confidence_interval_plot = fig.layout.shapes[1]
        t_value_line_plot = fig.layout.shapes[2]
        ate_line_plot = fig.layout.shapes[3]

        # get and set slider values
        self.ate = self.ate_slider.value
        self.standard_error = self.standard_error_slider.value
        self.alpha = self.alpha_slider.value
        self.alternative = self.alternative_dropdown.value
        self.plot_x_min, self.plot_x_max = self.x_range_slider.value

        # set labels
        self.p_value_label_widget.value = f"p-value: {self.p_value:.4f}"
        self.is_significant_label_widget.value = f"Significant: {self.is_significant}"
        self.is_significant_label_widget.style = LabelStyle(
            text_color="red" if self.is_significant else "blue"
        )

        # only view the critical region for correct alternative
        if self.alternative == "smaller":
            left_critical_std_plot.visible = left_critical_plot.visible = True
            right_critical_std_plot.visible = right_critical_plot.visible = False
        elif self.alternative == "larger":
            left_critical_std_plot.visible = left_critical_plot.visible = False
            right_critical_std_plot.visible = right_critical_plot.visible = True
        else:
            left_critical_std_plot.visible = left_critical_plot.visible = True
            right_critical_std_plot.visible = right_critical_plot.visible = True

        normal_dist_plot.x = self.x_std
        normal_dist_plot.y = self.y_std
        mean_diff_plot.x = self.x
        mean_diff_plot.y = self.y
        left_critical_std_plot.x = self.x_lower_fill_std
        left_critical_std_plot.y = self.y_lower_fill_std
        right_critical_std_plot.x = self.x_upper_fill_std
        right_critical_std_plot.y = self.y_upper_fill_std
        left_critical_plot.x = self.x_lower_fill
        left_critical_plot.y = self.y_lower_fill
        right_critical_plot.x = self.x_upper_fill
        right_critical_plot.y = self.y_upper_fill
        confidence_interval_std_plot.x0 = self.confidence_lower_std
        confidence_interval_std_plot.x1 = self.confidence_upper_std
        confidence_interval_plot.x0 = self.confidence_lower
        confidence_interval_plot.x1 = self.confidence_upper
        t_value_line_plot.x0 = t_value_line_plot.x1 = self.t_value
        t_value_line_plot.y0 = 0
        t_value_line_plot.y1 = self.ate_height_std
        ate_line_plot.x0 = ate_line_plot.x1 = self.ate
        ate_line_plot.y0 = 0
        ate_line_plot.y1 = self.ate_height
        fig.layout.xaxis.range = [self.min_x_std, self.max_x_std]
        fig.layout.xaxis2.range = [self.min_x, self.max_x]

    def show(self) -> DisplayHandle:
        return display(self.container)

    def display(self) -> DisplayHandle:
        return display(self.container)

In [2]:
experiment_effect_viewer = ExperimentEffectViewer()
experiment_effect_viewer.display()


VBox(children=(VBox(children=(FloatSlider(value=0.05, description='alpha: ', layout=Layout(width='80%'), max=0…