# Permutation Test (Randomization Test)

Permutation tests are **nonparametric** hypothesis tests that build a null distribution directly from your data by **shuffling labels**.
They’re especially useful when you don’t trust parametric assumptions (normality, equal variances), but you *can* assume the data are **exchangeable under the null**.

---

## Learning goals

By the end you should be able to:

- explain what a permutation test is (and what it is not)
- choose a sensible test statistic for your question
- implement a two-sample permutation test from scratch with NumPy
- interpret the permutation distribution and the p-value
- adapt the randomization scheme for **paired** data (sign-flip)

---

## Table of contents

1. What problem does it solve?
2. The core assumption: exchangeability
3. Two-sample permutation test (NumPy from scratch)
4. Visualizing the permutation distribution + p-value
5. Choosing the statistic (mean vs median)
6. Paired designs (sign-flip permutation test)
7. Interpretation, pitfalls, and diagnostics


In [None]:
import numpy as np
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots

pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)

rng = np.random.default_rng(42)


## 1) What problem does it solve?

You often have data like:

- **A/B tests**: did variant B increase average revenue vs A?
- **two conditions**: did a new model reduce latency vs the old one?
- **two groups**: do group 1 and group 2 differ in some outcome?

A permutation test answers:

> **If the null hypothesis were true**, how surprising is the difference (or association) we observed?

Instead of assuming a theoretical sampling distribution (like Student’s t), we **simulate the null** by repeatedly re-labeling the observed data in a way that would be valid under $H_0$.

This makes permutation tests a great fit when:

- sample sizes are small
- distributions are skewed / heavy-tailed
- you want a test for a custom statistic (median difference, trimmed mean, correlation, accuracy, …)

Permutation tests are related to **randomization tests**: if treatment assignment was truly random, the permutation test is (conditionally) exact.


## 2) The core assumption: exchangeability

A permutation test is valid when the data are **exchangeable under the null**.

Two-sample case (independent groups):

- You observe two sets: $x = (x_1, \dots, x_{n_x})$ and $y = (y_1, \dots, y_{n_y})$.
- Null hypothesis (common version):

$$H_0: x \text{ and } y \text{ come from the same distribution}$$

If $H_0$ is true, then the labels “x-group” and “y-group” don’t matter: any reshuffling of the labels is just as plausible.

What can break exchangeability?

- dependence between observations (time series, clustered data)
- confounding in observational data (labels carry information not explained by chance)
- a design with blocks/strata where only certain permutations are allowed

The *randomization scheme* must match the *data collection design*.


## 3) Two-sample permutation test: the recipe

### Choose a test statistic

Pick a scalar statistic that answers your question, for example:

- difference in means: $T(x, y) = \bar{x} - \bar{y}$
- difference in medians
- difference in trimmed means
- KS distance, correlation, etc.

Compute the observed statistic:

$$T_\text{obs} = T(x, y)$$

### Build the permutation (null) distribution

1. Pool the values: $z = (x, y)$
2. Repeatedly **permute** $z$ and split back into two groups of sizes $n_x$ and $n_y$
3. Recompute the statistic each time: $T_1, \dots, T_B$

This yields a Monte Carlo approximation to the null distribution $T \mid H_0$.

### Turn it into a p-value

- two-sided: $p \approx P(|T| \ge |T_\text{obs}| \mid H_0)$
- greater: $p \approx P(T \ge T_\text{obs} \mid H_0)$
- less: $p \approx P(T \le T_\text{obs} \mid H_0)$

A common finite-sample correction (avoids returning exactly 0):

$$\hat p = \frac{\#\{\text{as-extreme permutations}\} + 1}{B + 1}$$

Note: with $B$ permutations, the smallest possible p-value is $1/(B+1)$.


In [None]:
def _as_1d_float_array(x, name):
    x = np.asarray(x, dtype=float).reshape(-1)
    if x.size == 0:
        raise ValueError(f"{name} is empty")
    if np.isnan(x).any():
        raise ValueError(f"{name} contains NaN")
    return x


def _permutation_p_value(perm_stats, observed, alternative="two-sided"):
    perm_stats = np.asarray(perm_stats, dtype=float).reshape(-1)
    observed = float(observed)

    if alternative == "two-sided":
        extreme = np.sum(np.abs(perm_stats) >= abs(observed))
    elif alternative == "greater":
        extreme = np.sum(perm_stats >= observed)
    elif alternative == "less":
        extreme = np.sum(perm_stats <= observed)
    else:
        raise ValueError("alternative must be one of: 'two-sided', 'greater', 'less'")

    # +1 correction to avoid returning exactly 0
    return (extreme + 1) / (perm_stats.size + 1)


def permutation_test_two_sample(
    x,
    y,
    statistic_fn=None,
    alternative="two-sided",
    n_permutations=10_000,
    seed=0,
):
    """Two-sample permutation test (independent samples).

    Parameters
    ----------
    x, y : array-like
        The two groups.
    statistic_fn : callable or None
        Function with signature statistic_fn(x, y) -> float.
        Defaults to difference in means (x.mean() - y.mean()).
    alternative : {'two-sided', 'greater', 'less'}
        Tail(s) used for the p-value.
    n_permutations : int
        Number of Monte Carlo permutations.
    seed : int
        RNG seed for reproducibility.

    Returns
    -------
    result : dict
        Keys: 'stat_obs', 'p_value', 'perm_stats', 'alternative', 'n_permutations'.
    """

    x = _as_1d_float_array(x, "x")
    y = _as_1d_float_array(y, "y")

    if not isinstance(n_permutations, int) or n_permutations <= 0:
        raise ValueError("n_permutations must be a positive integer")

    if statistic_fn is None:
        statistic_fn = lambda a, b: a.mean() - b.mean()

    stat_obs = float(statistic_fn(x, y))

    pooled = np.concatenate([x, y])
    n_x = x.size
    n_total = pooled.size

    rng_local = np.random.default_rng(seed)

    perm_stats = np.empty(n_permutations, dtype=float)
    for i in range(n_permutations):
        idx = rng_local.permutation(n_total)
        x_star = pooled[idx[:n_x]]
        y_star = pooled[idx[n_x:]]
        perm_stats[i] = statistic_fn(x_star, y_star)

    p_value = _permutation_p_value(perm_stats, stat_obs, alternative=alternative)

    return {
        "stat_obs": stat_obs,
        "p_value": p_value,
        "perm_stats": perm_stats,
        "alternative": alternative,
        "n_permutations": n_permutations,
    }


def permutation_test_paired_sign_flip(
    before,
    after,
    statistic_fn=None,
    alternative="two-sided",
    n_permutations=10_000,
    seed=0,
):
    """Paired permutation test using sign flips on within-pair differences.

    Under H0 (no systematic effect), the sign of each pairwise difference is arbitrary.

    Default statistic: mean(after - before).
    """

    before = _as_1d_float_array(before, "before")
    after = _as_1d_float_array(after, "after")

    if before.size != after.size:
        raise ValueError("before and after must have the same length")

    if not isinstance(n_permutations, int) or n_permutations <= 0:
        raise ValueError("n_permutations must be a positive integer")

    d = after - before

    if statistic_fn is None:
        statistic_fn = np.mean

    stat_obs = float(statistic_fn(d))

    rng_local = np.random.default_rng(seed)

    perm_stats = np.empty(n_permutations, dtype=float)
    for i in range(n_permutations):
        signs = rng_local.choice(np.array([-1.0, 1.0]), size=d.size)
        perm_stats[i] = statistic_fn(signs * d)

    p_value = _permutation_p_value(perm_stats, stat_obs, alternative=alternative)

    return {
        "stat_obs": stat_obs,
        "p_value": p_value,
        "perm_stats": perm_stats,
        "alternative": alternative,
        "n_permutations": n_permutations,
    }


## 4) Example: A/B test on skewed data

Imagine an A/B test where the metric is **user spend**.

- Spend is usually **right-skewed** (many small values, a few huge ones).
- You might still care about **average** spend, but you may not want to assume normality.

We’ll simulate two groups and test whether treatment increases the mean.


In [None]:
n_control = 35
n_treatment = 35

control = rng.lognormal(mean=0.0, sigma=0.8, size=n_control)
treatment = rng.lognormal(mean=0.4, sigma=0.8, size=n_treatment)

{
    "control_mean": control.mean(),
    "treatment_mean": treatment.mean(),
    "control_median": np.median(control),
    "treatment_median": np.median(treatment),
}


In [None]:
fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Raw values (skewed)", "log1p(values)"),
)

for name, values, color in [
    ("Control", control, "#1f77b4"),
    ("Treatment", treatment, "#ff7f0e"),
]:
    fig.add_trace(
        go.Violin(
            y=values,
            name=name,
            box_visible=True,
            meanline_visible=True,
            points="all",
            jitter=0.25,
            marker_color=color,
            legendgroup=name,
            showlegend=True,
        ),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Violin(
            y=np.log1p(values),
            name=name,
            box_visible=True,
            meanline_visible=True,
            points="all",
            jitter=0.25,
            marker_color=color,
            legendgroup=name,
            showlegend=False,
        ),
        row=1,
        col=2,
    )

fig.update_layout(
    title="A/B data: distribution of outcomes",
    violingap=0.25,
    violinmode="group",
)
fig.update_yaxes(title_text="value", row=1, col=1)
fig.update_yaxes(title_text="log1p(value)", row=1, col=2)
fig

### Run the permutation test (difference in means)

We’ll use:

$$T(x, y) = \bar{x} - \bar{y}$$

Here we interpret $x$ as **treatment** and $y$ as **control**, so a positive statistic means treatment has a higher mean.


In [None]:
stat_mean_diff = lambda x, y: x.mean() - y.mean()

result_mean = permutation_test_two_sample(
    treatment,
    control,
    statistic_fn=stat_mean_diff,
    alternative="two-sided",
    n_permutations=20_000,
    seed=123,
)

result_mean


In [None]:
t_obs = result_mean["stat_obs"]
p_value = result_mean["p_value"]

alpha = 0.05
{
    "observed_mean_diff": t_obs,
    "p_value": p_value,
    "reject_at_0.05": p_value <= alpha,
}


In [None]:
t_perm = result_mean["perm_stats"]
mask_extreme = np.abs(t_perm) >= abs(t_obs)

fig = go.Figure()
fig.add_trace(
    go.Histogram(
        x=t_perm[~mask_extreme],
        nbinsx=60,
        name="not as extreme",
        marker_color="lightgray",
        opacity=0.8,
    )
)
fig.add_trace(
    go.Histogram(
        x=t_perm[mask_extreme],
        nbinsx=60,
        name="as / more extreme",
        marker_color="#d62728",
        opacity=0.85,
    )
)

fig.add_vline(
    x=t_obs,
    line_width=3,
    line_dash="dash",
    line_color="black",
)

fig.update_layout(
    title="Permutation (null) distribution of Δ mean",
    xaxis_title="Δ mean (treatment - control)",
    yaxis_title="count",
    barmode="overlay",
)

fig.add_annotation(
    x=t_obs,
    y=0.98,
    xref="x",
    yref="paper",
    text=f"observed Δ={t_obs:.3f}<br>p={p_value:.4f}",
    showarrow=True,
    arrowhead=2,
    ax=40,
    ay=-30,
)

fig

In [None]:
extreme = np.abs(t_perm) >= abs(t_obs)

k = np.arange(1, t_perm.size + 1)
# +1 / (k+1) is the same finite-sample correction used for the final p-value
p_running = (np.cumsum(extreme) + 1) / (k + 1)

fig = go.Figure()
fig.add_trace(go.Scatter(x=k, y=p_running, mode="lines", name="running p-value"))
fig.add_hline(y=p_value, line_dash="dash", line_color="black")

fig.update_layout(
    title="Monte Carlo p-value convergence",
    xaxis_title="# permutations used",
    yaxis_title="p-value estimate",
)
fig

## 5) Choosing the statistic: mean vs median

Permutation tests let you choose *any* statistic — which is powerful, but also a responsibility.

- The **mean** is sensitive to outliers (which might be exactly what you care about for revenue).
- The **median** is more robust (focuses on the “typical” user).

Different questions → different statistics.

Below we compute both, using the same permutation idea.


In [None]:
stat_median_diff = lambda x, y: np.median(x) - np.median(y)

result_median = permutation_test_two_sample(
    treatment,
    control,
    statistic_fn=stat_median_diff,
    alternative="two-sided",
    n_permutations=20_000,
    seed=123,
)

{
    "mean_diff": {"stat_obs": result_mean["stat_obs"], "p_value": result_mean["p_value"]},
    "median_diff": {"stat_obs": result_median["stat_obs"], "p_value": result_median["p_value"]},
}


In [None]:
labels = ["mean diff", "median diff"]
p_vals = [result_mean["p_value"], result_median["p_value"]]

fig = go.Figure(
    go.Bar(
        x=labels,
        y=p_vals,
        text=[f"{p:.4f}" for p in p_vals],
        textposition="outside",
    )
)
fig.add_hline(y=0.05, line_dash="dash", line_color="black")
fig.update_layout(
    title="p-values depend on the chosen statistic",
    yaxis_title="p-value",
)
fig.update_yaxes(range=[0, max(0.06, max(p_vals) * 1.2)])
fig

In [None]:
t_obs_med = result_median["stat_obs"]
p_med = result_median["p_value"]
t_perm_med = result_median["perm_stats"]
mask_extreme_med = np.abs(t_perm_med) >= abs(t_obs_med)

fig = go.Figure()
fig.add_trace(
    go.Histogram(
        x=t_perm_med[~mask_extreme_med],
        nbinsx=60,
        name="not as extreme",
        marker_color="lightgray",
        opacity=0.8,
    )
)
fig.add_trace(
    go.Histogram(
        x=t_perm_med[mask_extreme_med],
        nbinsx=60,
        name="as / more extreme",
        marker_color="#d62728",
        opacity=0.85,
    )
)

fig.add_vline(x=t_obs_med, line_width=3, line_dash="dash", line_color="black")

fig.update_layout(
    title="Permutation (null) distribution of Δ median",
    xaxis_title="Δ median (treatment - control)",
    yaxis_title="count",
    barmode="overlay",
)

fig.add_annotation(
    x=t_obs_med,
    y=0.98,
    xref="x",
    yref="paper",
    text=f"observed Δ={t_obs_med:.3f}<br>p={p_med:.4f}",
    showarrow=True,
    arrowhead=2,
    ax=40,
    ay=-30,
)

fig

## 6) Paired designs: sign-flip permutation test

If your data are **paired** (same user before/after, matched pairs, repeated measures), you can’t shuffle labels across all observations.

A common paired null is:

$$H_0: \text{the treatment has no systematic effect}$$

Let differences be $d_i = \text{after}_i - \text{before}_i$.

Under $H_0$, the sign of each $d_i$ is arbitrary (positive or negative is equally likely), so we create the null distribution by **randomly flipping signs**:

- draw random signs $s_i \in \{-1, +1\}$
- compute $T(s \odot d)$, e.g. the mean

This is the right permutation scheme for paired data.


In [None]:
n = 30
before = rng.normal(50, 10, size=n)
# Treatment tends to increase the metric by ~3 on average, but with noise
after = before + rng.normal(3.0, 8.0, size=n)

paired_result = permutation_test_paired_sign_flip(
    before,
    after,
    alternative="two-sided",
    n_permutations=20_000,
    seed=321,
)

paired_result


In [None]:
diff = after - before

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Paired observations", "Within-pair differences"),
)

fig.add_trace(
    go.Scatter(x=before, y=after, mode="markers", name="pairs"),
    row=1,
    col=1,
)

lo = min(before.min(), after.min())
hi = max(before.max(), after.max())
line = np.linspace(lo, hi, 100)
fig.add_trace(
    go.Scatter(x=line, y=line, mode="lines", line=dict(dash="dash"), name="y=x"),
    row=1,
    col=1,
)

fig.add_trace(go.Histogram(x=diff, nbinsx=30, name="after - before"), row=1, col=2)

fig.update_xaxes(title_text="before", row=1, col=1)
fig.update_yaxes(title_text="after", row=1, col=1)
fig.update_xaxes(title_text="after - before", row=1, col=2)
fig.update_yaxes(title_text="count", row=1, col=2)

fig.update_layout(title="Paired data view")
fig

In [None]:
t_obs_p = paired_result["stat_obs"]
p_p = paired_result["p_value"]
t_perm_p = paired_result["perm_stats"]
mask_extreme_p = np.abs(t_perm_p) >= abs(t_obs_p)

fig = go.Figure()
fig.add_trace(
    go.Histogram(
        x=t_perm_p[~mask_extreme_p],
        nbinsx=60,
        name="not as extreme",
        marker_color="lightgray",
        opacity=0.8,
    )
)
fig.add_trace(
    go.Histogram(
        x=t_perm_p[mask_extreme_p],
        nbinsx=60,
        name="as / more extreme",
        marker_color="#d62728",
        opacity=0.85,
    )
)

fig.add_vline(x=t_obs_p, line_width=3, line_dash="dash", line_color="black")

fig.update_layout(
    title="Sign-flip permutation distribution (paired mean difference)",
    xaxis_title="mean(after - before)",
    yaxis_title="count",
    barmode="overlay",
)

fig.add_annotation(
    x=t_obs_p,
    y=0.98,
    xref="x",
    yref="paper",
    text=f"observed={t_obs_p:.3f}<br>p={p_p:.4f}",
    showarrow=True,
    arrowhead=2,
    ax=40,
    ay=-30,
)

fig

## 7) Interpretation: what the result means

A permutation test gives you:

- an **observed statistic** (e.g., $\bar{x} - \bar{y}$)
- a **null distribution** of that statistic (generated by re-labeling)
- a **p-value**: how often a null world produces a statistic at least as extreme

### What the p-value means

If $p = 0.02$ (two-sided), a precise reading is:

> *Assuming the null hypothesis and the permutation scheme are valid*, only about **2%** of random labelings would produce a difference at least as extreme as the one we observed.

### What the p-value does NOT mean

- it is **not** $P(H_0 \mid \text{data})$
- it is **not** the probability the result happened “by chance” in some vague sense
- it does **not** measure effect size (always report the effect itself)

### Decision rule

Choose a significance level $\alpha$ (commonly 0.05).

- if $p \le \alpha$: reject $H_0$ (evidence against the null)
- if $p > \alpha$: fail to reject $H_0$ (not enough evidence)

Failing to reject is not the same as “proving no effect” — it may just mean the test is underpowered.


## Pitfalls + diagnostics

- **Match the permutation to the design**: paired data needs sign-flips (or within-pair swaps); blocked experiments need restricted permutations.
- **Exchangeability is the real assumption**: permutation tests aren’t automatically valid for observational data with confounding.
- **Pick the statistic before peeking**: changing the statistic after seeing the data is p-hacking.
- **Monte Carlo error**: two runs with different seeds can give slightly different p-values; increase `n_permutations` to reduce noise.
- **p-value resolution**: with `B` permutations, the smallest possible p-value is `1/(B+1)`.
- **Report more than p**: include the observed effect (and ideally uncertainty / practical significance).


## Exercises

1. Implement a **permutation test for correlation**: keep `x` fixed and permute `y`, using $T=\mathrm{corr}(x,y)$.
2. Implement a **stratified** two-sample permutation test where labels are only permuted *within* strata.
3. Compare mean-difference permutation tests on:
   - normal data
   - heavy-tailed data
   What changes?
4. Increase `n_permutations` and plot how the running p-value stabilizes.


## References

- Fisher: *The Design of Experiments* (randomization tests)
- Good: *Permutation, Parametric and Bootstrap Tests of Hypotheses*
- Ernst (2004): “Permutation Methods: A Basis for Exact Inference”
- Manly: *Randomization, Bootstrap and Monte Carlo Methods in Biology*
- Efron & Tibshirani: *An Introduction to the Bootstrap* (related resampling perspective)
