In [1]:
import numpy as np
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots

pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)

rng = np.random.default_rng(42)


In [2]:
def _as_1d_float_array(x, name):
    x = np.asarray(x, dtype=float).reshape(-1)
    if x.size == 0:
        raise ValueError(f"{name} is empty")
    if np.isnan(x).any():
        raise ValueError(f"{name} contains NaN")
    return x


def _permutation_p_value(perm_stats, observed, alternative="two-sided"):
    perm_stats = np.asarray(perm_stats, dtype=float).reshape(-1)
    observed = float(observed)

    if alternative == "two-sided":
        extreme = np.sum(np.abs(perm_stats) >= abs(observed))
    elif alternative == "greater":
        extreme = np.sum(perm_stats >= observed)
    elif alternative == "less":
        extreme = np.sum(perm_stats <= observed)
    else:
        raise ValueError("alternative must be one of: 'two-sided', 'greater', 'less'")

    # +1 correction to avoid returning exactly 0
    return (extreme + 1) / (perm_stats.size + 1)


def permutation_test_two_sample(
    x,
    y,
    statistic_fn=None,
    alternative="two-sided",
    n_permutations=10_000,
    seed=0,
):
    """Two-sample permutation test (independent samples).

    Parameters
    ----------
    x, y : array-like
        The two groups.
    statistic_fn : callable or None
        Function with signature statistic_fn(x, y) -> float.
        Defaults to difference in means (x.mean() - y.mean()).
    alternative : {'two-sided', 'greater', 'less'}
        Tail(s) used for the p-value.
    n_permutations : int
        Number of Monte Carlo permutations.
    seed : int
        RNG seed for reproducibility.

    Returns
    -------
    result : dict
        Keys: 'stat_obs', 'p_value', 'perm_stats', 'alternative', 'n_permutations'.
    """

    x = _as_1d_float_array(x, "x")
    y = _as_1d_float_array(y, "y")

    if not isinstance(n_permutations, int) or n_permutations <= 0:
        raise ValueError("n_permutations must be a positive integer")

    if statistic_fn is None:
        statistic_fn = lambda a, b: a.mean() - b.mean()

    stat_obs = float(statistic_fn(x, y))

    pooled = np.concatenate([x, y])
    n_x = x.size
    n_total = pooled.size

    rng_local = np.random.default_rng(seed)

    perm_stats = np.empty(n_permutations, dtype=float)
    for i in range(n_permutations):
        idx = rng_local.permutation(n_total)
        x_star = pooled[idx[:n_x]]
        y_star = pooled[idx[n_x:]]
        perm_stats[i] = statistic_fn(x_star, y_star)

    p_value = _permutation_p_value(perm_stats, stat_obs, alternative=alternative)

    return {
        "stat_obs": stat_obs,
        "p_value": p_value,
        "perm_stats": perm_stats,
        "alternative": alternative,
        "n_permutations": n_permutations,
    }


def permutation_test_paired_sign_flip(
    before,
    after,
    statistic_fn=None,
    alternative="two-sided",
    n_permutations=10_000,
    seed=0,
):
    """Paired permutation test using sign flips on within-pair differences.

    Under H0 (no systematic effect), the sign of each pairwise difference is arbitrary.

    Default statistic: mean(after - before).
    """

    before = _as_1d_float_array(before, "before")
    after = _as_1d_float_array(after, "after")

    if before.size != after.size:
        raise ValueError("before and after must have the same length")

    if not isinstance(n_permutations, int) or n_permutations <= 0:
        raise ValueError("n_permutations must be a positive integer")

    d = after - before

    if statistic_fn is None:
        statistic_fn = np.mean

    stat_obs = float(statistic_fn(d))

    rng_local = np.random.default_rng(seed)

    perm_stats = np.empty(n_permutations, dtype=float)
    for i in range(n_permutations):
        signs = rng_local.choice(np.array([-1.0, 1.0]), size=d.size)
        perm_stats[i] = statistic_fn(signs * d)

    p_value = _permutation_p_value(perm_stats, stat_obs, alternative=alternative)

    return {
        "stat_obs": stat_obs,
        "p_value": p_value,
        "perm_stats": perm_stats,
        "alternative": alternative,
        "n_permutations": n_permutations,
    }


In [3]:
n_control = 35
n_treatment = 35

control = rng.lognormal(mean=0.0, sigma=0.8, size=n_control)
treatment = rng.lognormal(mean=0.4, sigma=0.8, size=n_treatment)

{
    "control_mean": control.mean(),
    "treatment_mean": treatment.mean(),
    "control_median": np.median(control),
    "treatment_median": np.median(treatment),
}


{'control_mean': 1.28446912567606,
 'treatment_mean': 1.844945693858993,
 'control_median': 1.0542446701518544,
 'treatment_median': 1.6992762958243015}

In [4]:
fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Raw values (skewed)", "log1p(values)"),
)

for name, values, color in [
    ("Control", control, "#1f77b4"),
    ("Treatment", treatment, "#ff7f0e"),
]:
    fig.add_trace(
        go.Violin(
            y=values,
            name=name,
            box_visible=True,
            meanline_visible=True,
            points="all",
            jitter=0.25,
            marker_color=color,
            legendgroup=name,
            showlegend=True,
        ),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Violin(
            y=np.log1p(values),
            name=name,
            box_visible=True,
            meanline_visible=True,
            points="all",
            jitter=0.25,
            marker_color=color,
            legendgroup=name,
            showlegend=False,
        ),
        row=1,
        col=2,
    )

fig.update_layout(
    title="A/B data: distribution of outcomes",
    violingap=0.25,
    violinmode="group",
)
fig.update_yaxes(title_text="value", row=1, col=1)
fig.update_yaxes(title_text="log1p(value)", row=1, col=2)
fig.show()


In [5]:
stat_mean_diff = lambda x, y: x.mean() - y.mean()

result_mean = permutation_test_two_sample(
    treatment,
    control,
    statistic_fn=stat_mean_diff,
    alternative="two-sided",
    n_permutations=20_000,
    seed=123,
)

result_mean


{'stat_obs': 0.5604765681829329,
 'p_value': 0.01879906004699765,
 'perm_stats': array([-0.3733, -0.3385,  0.252 , ...,  0.0269, -0.2484,  0.0782]),
 'alternative': 'two-sided',
 'n_permutations': 20000}

In [6]:
t_obs = result_mean["stat_obs"]
p_value = result_mean["p_value"]

alpha = 0.05
{
    "observed_mean_diff": t_obs,
    "p_value": p_value,
    "reject_at_0.05": p_value <= alpha,
}


{'observed_mean_diff': 0.5604765681829329,
 'p_value': 0.01879906004699765,
 'reject_at_0.05': True}

In [7]:
t_perm = result_mean["perm_stats"]
mask_extreme = np.abs(t_perm) >= abs(t_obs)

fig = go.Figure()
fig.add_trace(
    go.Histogram(
        x=t_perm[~mask_extreme],
        nbinsx=60,
        name="not as extreme",
        marker_color="lightgray",
        opacity=0.8,
    )
)
fig.add_trace(
    go.Histogram(
        x=t_perm[mask_extreme],
        nbinsx=60,
        name="as / more extreme",
        marker_color="#d62728",
        opacity=0.85,
    )
)

fig.add_vline(
    x=t_obs,
    line_width=3,
    line_dash="dash",
    line_color="black",
)

fig.update_layout(
    title="Permutation (null) distribution of Δ mean",
    xaxis_title="Δ mean (treatment - control)",
    yaxis_title="count",
    barmode="overlay",
)

fig.add_annotation(
    x=t_obs,
    y=0.98,
    xref="x",
    yref="paper",
    text=f"observed Δ={t_obs:.3f}<br>p={p_value:.4f}",
    showarrow=True,
    arrowhead=2,
    ax=40,
    ay=-30,
)

fig.show()


In [8]:
extreme = np.abs(t_perm) >= abs(t_obs)

k = np.arange(1, t_perm.size + 1)
# +1 / (k+1) is the same finite-sample correction used for the final p-value
p_running = (np.cumsum(extreme) + 1) / (k + 1)

fig = go.Figure()
fig.add_trace(go.Scatter(x=k, y=p_running, mode="lines", name="running p-value"))
fig.add_hline(y=p_value, line_dash="dash", line_color="black")

fig.update_layout(
    title="Monte Carlo p-value convergence",
    xaxis_title="# permutations used",
    yaxis_title="p-value estimate",
)
fig.show()


In [9]:
stat_median_diff = lambda x, y: np.median(x) - np.median(y)

result_median = permutation_test_two_sample(
    treatment,
    control,
    statistic_fn=stat_median_diff,
    alternative="two-sided",
    n_permutations=20_000,
    seed=123,
)

{
    "mean_diff": {"stat_obs": result_mean["stat_obs"], "p_value": result_mean["p_value"]},
    "median_diff": {"stat_obs": result_median["stat_obs"], "p_value": result_median["p_value"]},
}


{'mean_diff': {'stat_obs': 0.5604765681829329, 'p_value': 0.01879906004699765},
 'median_diff': {'stat_obs': 0.6450316256724471,
  'p_value': 0.011899405029748513}}

In [10]:
labels = ["mean diff", "median diff"]
p_vals = [result_mean["p_value"], result_median["p_value"]]

fig = go.Figure(
    go.Bar(
        x=labels,
        y=p_vals,
        text=[f"{p:.4f}" for p in p_vals],
        textposition="outside",
    )
)
fig.add_hline(y=0.05, line_dash="dash", line_color="black")
fig.update_layout(
    title="p-values depend on the chosen statistic",
    yaxis_title="p-value",
)
fig.update_yaxes(range=[0, max(0.06, max(p_vals) * 1.2)])
fig.show()


In [11]:
t_obs_med = result_median["stat_obs"]
p_med = result_median["p_value"]
t_perm_med = result_median["perm_stats"]
mask_extreme_med = np.abs(t_perm_med) >= abs(t_obs_med)

fig = go.Figure()
fig.add_trace(
    go.Histogram(
        x=t_perm_med[~mask_extreme_med],
        nbinsx=60,
        name="not as extreme",
        marker_color="lightgray",
        opacity=0.8,
    )
)
fig.add_trace(
    go.Histogram(
        x=t_perm_med[mask_extreme_med],
        nbinsx=60,
        name="as / more extreme",
        marker_color="#d62728",
        opacity=0.85,
    )
)

fig.add_vline(x=t_obs_med, line_width=3, line_dash="dash", line_color="black")

fig.update_layout(
    title="Permutation (null) distribution of Δ median",
    xaxis_title="Δ median (treatment - control)",
    yaxis_title="count",
    barmode="overlay",
)

fig.add_annotation(
    x=t_obs_med,
    y=0.98,
    xref="x",
    yref="paper",
    text=f"observed Δ={t_obs_med:.3f}<br>p={p_med:.4f}",
    showarrow=True,
    arrowhead=2,
    ax=40,
    ay=-30,
)

fig.show()


In [12]:
n = 30
before = rng.normal(50, 10, size=n)
# Treatment tends to increase the metric by ~3 on average, but with noise
after = before + rng.normal(3.0, 8.0, size=n)

paired_result = permutation_test_paired_sign_flip(
    before,
    after,
    alternative="two-sided",
    n_permutations=20_000,
    seed=321,
)

paired_result


{'stat_obs': 2.5770377629999737,
 'p_value': 0.05504724763761812,
 'perm_stats': array([-1.2227, -0.5949, -0.2218, ...,  2.3792,  1.6762,  0.5078]),
 'alternative': 'two-sided',
 'n_permutations': 20000}

In [13]:
diff = after - before

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Paired observations", "Within-pair differences"),
)

fig.add_trace(
    go.Scatter(x=before, y=after, mode="markers", name="pairs"),
    row=1,
    col=1,
)

lo = min(before.min(), after.min())
hi = max(before.max(), after.max())
line = np.linspace(lo, hi, 100)
fig.add_trace(
    go.Scatter(x=line, y=line, mode="lines", line=dict(dash="dash"), name="y=x"),
    row=1,
    col=1,
)

fig.add_trace(go.Histogram(x=diff, nbinsx=30, name="after - before"), row=1, col=2)

fig.update_xaxes(title_text="before", row=1, col=1)
fig.update_yaxes(title_text="after", row=1, col=1)
fig.update_xaxes(title_text="after - before", row=1, col=2)
fig.update_yaxes(title_text="count", row=1, col=2)

fig.update_layout(title="Paired data view")
fig.show()


In [14]:
t_obs_p = paired_result["stat_obs"]
p_p = paired_result["p_value"]
t_perm_p = paired_result["perm_stats"]
mask_extreme_p = np.abs(t_perm_p) >= abs(t_obs_p)

fig = go.Figure()
fig.add_trace(
    go.Histogram(
        x=t_perm_p[~mask_extreme_p],
        nbinsx=60,
        name="not as extreme",
        marker_color="lightgray",
        opacity=0.8,
    )
)
fig.add_trace(
    go.Histogram(
        x=t_perm_p[mask_extreme_p],
        nbinsx=60,
        name="as / more extreme",
        marker_color="#d62728",
        opacity=0.85,
    )
)

fig.add_vline(x=t_obs_p, line_width=3, line_dash="dash", line_color="black")

fig.update_layout(
    title="Sign-flip permutation distribution (paired mean difference)",
    xaxis_title="mean(after - before)",
    yaxis_title="count",
    barmode="overlay",
)

fig.add_annotation(
    x=t_obs_p,
    y=0.98,
    xref="x",
    yref="paper",
    text=f"observed={t_obs_p:.3f}<br>p={p_p:.4f}",
    showarrow=True,
    arrowhead=2,
    ax=40,
    ay=-30,
)

fig.show()
