In [1]:
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

%reload_ext autoreload
%autoreload 2

import time
import tracemalloc
from typing import Dict, List

import numpy as np
import polars as pl
import polars.selectors as cs
from wsri import somers_d_roc_auc, somers_d_two_pointers, wsri

## Benchmark vs. Monitoring Setup

* **Benchmark**: A fixed reference period (e.g., 2023 data) that we treat as the "gold standard" or baseline distribution of scores and outcomes.
* **Monitoring**: A newer period (e.g., Q1 2024 data) that we want to evaluate for stability relative to the benchmark.

The purpose of **wSRI** here is to ask: *if the score-label relationship in monitoring looks different from benchmark, how much of that difference reflects population shift rather than true deterioration in model rank-ordering?*

---

### Why Weight the Benchmark

Weighting only the **benchmark** population makes sense because:

* **Benchmark as reference**: We want monitoring data to be evaluated *as is*, because it reflects the actual, current population distribution.
* **Benchmark rescaled**: By reweighting benchmark bins to match the marginal distribution of scores in monitoring, we are forcing the benchmark to look like e.g. "what 2023 would have looked like under the 2024 score distribution."
* **Apples-to-apples comparability**: Without reweighting, Somers’ D differences could be confounded by the fact that the monitoring period simply has more/less accounts in certain score ranges (not necessarily because the rank-order power of the model changed).

So the weighting isolates the effect of **conditional rank performance** (score <-> label relationship) rather than **population score mix**.

---

### Why Not Weight Monitoring

If we weighted **monitoring** to benchmark:

* We’d be discarding real, current information about the new population distribution.
* The monitoring Somers’ D would be artificially “distorted back” to the 2023 distribution, making it less diagnostic for drift.
* We’d risk underestimating deterioration if population shift itself is part of what we want to detect.

---

### Covariate Shift

Formally, weighting only benchmark approximates a **covariate-shift correction** problem:

* Assume conditional outcome distributions *P(y|score)* are stable in both periods (or at least what we want to test).
* But the marginal score distribution *P(score)* has shifted.
* Reweighting benchmark to monitoring ensures we compare *P(y|score)* fairly across periods.

This is consistent with the idea of “importance weighting” in domain adaptation:

$$
w(x) = \frac{P_{\text{monitoring}}(x)}{P_{\text{benchmark}}(x)}
$$

where here *x* = score bins.

It makes sense to weight the **benchmark** and not the monitoring because we want to transform the reference into the monitoring population’s shape, ensuring comparability. Monitoring stays unweighted so it reflects reality. This lets the ratio (monitoring Somers’ D / reweighted benchmark Somers’ D) isolate *rank-order stability*, not population differences.


## Weighted Somers’ $D$ via Two-Pointer Sweep

### Variables

#### Inputs

* `scores`: 1D array of predicted scores; shape $(n,)$. Element $s_i \in \mathbb{R}$.
* `labels`: 1D array of binary ground-truth labels; shape $(n,)$. Element $y_i \in \{0,1\}$.
* `weights`: 1D array of non-negative sample weights; shape $(n,)$. Element $w_i \ge 0$.

#### Derived arrays (post-sort, contiguous for Numba)

* `sorted_indices`: permutation that sorts `scores` in non-decreasing order.
* `scores_array`: `scores[sorted_indices]`; ascending.
* `labels_array`: `labels[sorted_indices]`; aligned to `scores_array`.
* `weights_array`: `weights[sorted_indices]`; aligned to `scores_array`.

#### Scalar counts / masses

* `n_total`: integer $= n$, total number of observations.
* `n_pos`: float $= \sum_{i: y_i=1} w_i$, total **positive weight** (“mass”).
* `n_neg`: float $= \sum_{i: y_i=0} w_i$, total **negative weight** (“mass”).

#### Notations

* $P := \{j : y_j = 1\}$ (positive indices), $N := \{i : y_i = 0\}$ (negative indices).
* For an anchor threshold $t$:

  * $D(t) := \{j \in P : s_j < t\}$ (positives **below** $t$, discordant w\.r.t. a negative at $t$).
  * $C(t) := \{j \in P : s_j > t\}$ (positives **above** $t$, concordant w\.r.t. a negative at $t$).
* $\mathbb{1}[\cdot]$: indicator function (1 if condition holds, else 0).


#### Sweep-state (maintained during the pass)

* `discordant_contrib`: float; current $\sum_{j \in P: s_j < t} w_j$, the total positive mass **strictly below** the current anchor threshold $t$.
* `concordant_contrib`: float; current $\sum_{j \in P: s_j > t} w_j$, the total positive mass **strictly above** $t$.
* `numerator`: float; accumulates $\sum_{i \in N} w_i \left(\text{concordant\_contrib} - \text{discordant\_contrib}\right)$.

#### Pointers / indices

* `discordant_index`: integer pointer $r$; smallest index with $s_r \ge t$. All positive indices $< r$ have $s_j < t$ and are counted in `discordant_contrib`.
* `concordant_index`: integer pointer $q$; smallest index with $s_q > t$. All positive indices $\le q-1$ have $s_j \le t$ and have been **removed** from the concordant pool.
* `current_false_score`: float $t$ = score of the current **negative** anchor; i.e., the threshold relative to which positives are partitioned.

---

### Objective

Weighted Somers’ $D$ for binary outcomes is

$$
D \;=\;
\frac{\displaystyle \sum_{i \in N} \sum_{j \in P} w_i w_j \Big(\mathbb{1}[s_j > s_i] - \mathbb{1}[s_j < s_i]\Big)}
{\displaystyle \sum_{i \in N} \sum_{j \in P} w_i w_j}
\;\in [-1,1]
$$

Equivalently, if we anchor on a negative with threshold $t = s_i$, its contribution is

$$
w_i\left(
\sum_{j \in P: s_j > t} w_j
\;-\;
\sum_{j \in P: s_j < t} w_j
\right)
\;=\; w_i\big(\text{concordant mass} - \text{discordant mass}\big)
$$

Ties receive zero contribution because neither $s_j > t$ nor $s_j < t$ is true when $s_j = t$.

---

### Algorithm (two-pointer sweep)

1. **Sort once.** Compute `sorted_indices = argsort(scores)` and build contiguous `scores_array`, `labels_array`, `weights_array`.

2. **Accumulate class masses.**

   * $n_{\text{pos}} = \sum_{y_i=1} w_i$, $n_{\text{neg}} = \sum_{y_i=0} w_i$.
   * If $n_{\text{pos}}=0$ or $n_{\text{neg}}=0$, return $0.0$ (no cross-class pairs).

3. **Initialize sweep state.**

   * `discordant_contrib = 0.0` (no positives yet confirmed below $t$).
   * `concordant_contrib = n_pos` (initially, all positives are potentially $> t$).
   * `discordant_index = 0`, `concordant_index = 0`, `numerator = 0.0`.

4. **Single forward pass.** For each index $i = 0,\dots,n-1$ in score-sorted order:

   * If `labels_array[i] == 0` (a **negative** anchor), set $t = \text{current\_false\_score} = \text{scores\_array}[i]$.
   * **Advance `discordant_index`** while $\text{scores\_array}[r] < t$:

     * If $\text{labels\_array}[r] == 1$, add $\text{weights\_array}[r]$ to `discordant_contrib`.
     * Increment $r=$ `discordant_index`.
     * Result: `discordant_contrib` $= \sum_{j \in P: s_j < t} w_j$.
   * **Advance `concordant_index`** while $\text{scores\_array}[q] \le t$:

     * If $\text{labels\_array}[q] == 1$, subtract $\text{weights\_array}[q]$ from `concordant_contrib`.
     * Increment $q=$ `concordant_index`.
     * Result: `concordant_contrib` $= \sum_{j \in P: s_j > t} w_j$.
   * **Accumulate anchor contribution:**

     $$
     \text{numerator} \;{+}{=}\; \text{weights\_array}[i]\,
     \big(\text{concordant\_contrib} - \text{discordant\_contrib}\big).
     $$

5. **Normalize.**

$$
\frac{\text{numerator}}{n_{\text{neg}} \cdot n_{\text{pos}}}
\;=\;
\frac{\sum_{i \in N} w_i \big(\sum_{j \in P: s_j > s_i} w_j - \sum_{j \in P: s_j < s_i} w_j\big)}
{\left(\sum_{i \in N} w_i\right)\left(\sum_{j \in P} w_j\right)}
$$

---

### Why pointers never reset

* Scores are processed in **non-decreasing** order. Let $t_k$ and $t_{k+1}$ be successive negative anchors; then $t_{k+1} \ge t_k$.
* As $t$ increases:

  * $D(t)$ (positives with $s_j < t$) is **monotone non-decreasing**; once a positive enters $D(t)$, it never leaves. This is exactly what `discordant_index` tracks using “<”.
  * $C(t)$ (positives with $s_j > t$) is **monotone non-increasing**; once a positive is no longer $> t$ (i.e., $s_j \le t$), it never re-enters. This is what `concordant_index` enforces using “$\le$”.
* Hence both pointers move forward at most $n$ steps total; no resets are necessary.

---

### Tie handling

* The discordant pointer uses **strict** “<” and the concordant pointer uses “$\le$”.
* For $s_j = t$:

  * Not counted into `discordant_contrib` (not < $t$).
  * Removed from `concordant_contrib` (no longer $>\,t$).
* Net effect: ties contribute $0$ to the numerator, matching the Somers’ $D$ definition.

---

### Correctness invariants (maintained at each negative anchor)

For the current threshold $t$:

$$
\text{discordant\_contrib} = \sum_{j \in P: s_j < t} w_j,
\qquad
\text{concordant\_contrib} = \sum_{j \in P: s_j > t} w_j,
$$

and the instantaneous contribution is

$$
w_i\big(\text{concordant\_contrib} - \text{discordant\_contrib}\big).
$$

---

### Complexity and numerical notes

* **Time:** $O(n \log n)$ for sorting + $O(n)$ for the sweep; each element is touched by each pointer at most once.
* **Space:** $O(n)$ for sorted views; in-place reordering avoided to keep inputs unchanged.
* **Numerics:** Use 64-bit floats for accumulators (`concordant_contrib`, `discordant_contrib`, `numerator`) to reduce rounding error on large weights. Guard for $n_{\text{pos}}=0$ or $n_{\text{neg}}=0$.


## Global

In [2]:
rng: np.random.Generator = np.random.default_rng(seed=1227)

sample_sizes: np.typing.NDArray[np.int64] = np.array(
    [5_000, 50_000, 500_000, 5_000_000, 50_000_000]
)

score_column: str = "score"
label_column: str = "label"

## JIT Warmup

In [3]:
# Warm-up for JIT since the first call to a numba JIT function includes a one-time compilation overhead
warmup_sample_size: int = 1_000
benchmark: pl.DataFrame = pl.DataFrame(
    {
        score_column: rng.uniform(0, 1, warmup_sample_size).astype(np.float64),
        label_column: rng.choice([0, 1], warmup_sample_size).astype(np.int8),
    }
).lazy()

monitoring: pl.DataFrame = pl.DataFrame(
    {
        score_column: rng.uniform(0, 1, warmup_sample_size).astype(np.float64),
        label_column: rng.choice([0, 1], warmup_sample_size).astype(np.int8),
    }
).lazy()

wsri_value: float = wsri(
    benchmark=benchmark,
    monitoring=monitoring,
    score_column=score_column,
    label_column=label_column,
    quantiles=20,
    callback=somers_d_two_pointers,
)

print(f"WSRI: {wsri_value}")

WSRI: 4.687046070922117


## Wall Time

In [4]:
wall_time_results: List[Dict[str, float]] = []
data_sets: List[Dict[str, pl.LazyFrame]] = []

for sample_size in sample_sizes:
    print(f"Sample size: {sample_size}")

    benchmark: pl.LazyFrame = pl.DataFrame(
        {
            score_column: rng.uniform(0, 1, sample_size).astype(np.float64),
            label_column: rng.choice([0, 1], sample_size).astype(np.int8),
        }
    ).lazy()

    monitoring: pl.LazyFrame = pl.DataFrame(
        {
            score_column: rng.uniform(0, 1, sample_size).astype(np.float64),
            label_column: rng.choice([0, 1], sample_size).astype(np.int8),
        }
    ).lazy()

    data_sets.append({"benchmark": benchmark, "monitoring": monitoring})

    start_time_two_pointers: float = time.perf_counter()
    wsri_two_pointers: float = wsri(
        benchmark=benchmark,
        monitoring=monitoring,
        score_column=score_column,
        label_column=label_column,
        quantiles=20,
        callback=somers_d_two_pointers,
    )
    time_two_pointers: float = time.perf_counter() - start_time_two_pointers

    start_time_roc_auc: float = time.perf_counter()
    wsri_roc_auc: float = wsri(
        benchmark=benchmark,
        monitoring=monitoring,
        score_column=score_column,
        label_column=label_column,
        quantiles=20,
        callback=somers_d_roc_auc,
    )
    time_roc_auc: float = time.perf_counter() - start_time_roc_auc

    wall_time_results.append(
        {
            "sample_size": sample_size,
            "wsri_two_pointers": wsri_two_pointers,
            "wsri_roc_auc": wsri_roc_auc,
            "time_two_pointers": time_two_pointers,
            "time_roc_auc": time_roc_auc,
        }
    )

Sample size: 5000
Sample size: 50000
Sample size: 500000
Sample size: 5000000
Sample size: 50000000


## Memory

In [5]:
memory_results: List[Dict[str, float]] = []

for sample_size, data_set in zip(sample_sizes, data_sets):
    print(f"Sample size: {sample_size}")

    benchmark: pl.LazyFrame = data_set["benchmark"]
    monitoring: pl.LazyFrame = data_set["monitoring"]

    tracemalloc.start()
    _ = wsri(
        benchmark=benchmark,
        monitoring=monitoring,
        score_column=score_column,
        label_column=label_column,
        quantiles=20,
        callback=somers_d_two_pointers,
    )
    _, peak_mem_two_pointers = tracemalloc.get_traced_memory()
    tracemalloc.stop()

    tracemalloc.start()
    _ = wsri(
        benchmark=benchmark,
        monitoring=monitoring,
        score_column=score_column,
        label_column=label_column,
        quantiles=20,
        callback=somers_d_roc_auc,
    )
    _, peak_mem_roc_auc = tracemalloc.get_traced_memory()
    tracemalloc.stop()

    memory_results.append(
        {
            "sample_size": sample_size,
            "peak_mem_two_pointers_mb": peak_mem_two_pointers / 10**6,
            "peak_mem_roc_auc_mb": peak_mem_roc_auc / 10**6,
        }
    )

Sample size: 5000
Sample size: 50000
Sample size: 500000
Sample size: 5000000
Sample size: 50000000


## Combine Results


In [6]:
wall_time_data: pl.DataFrame = pl.from_dicts(wall_time_results)
memory_data: pl.DataFrame = pl.from_dicts(memory_results)

combined_results: pl.DataFrame = wall_time_data.join(
    other=memory_data, on="sample_size", how="inner"
).with_columns(
    (pl.col("time_roc_auc") / pl.col("time_two_pointers")).alias(
        "speedup_two_pointers_vs_roc_auc"
    ),
    (pl.col("peak_mem_roc_auc_mb") / pl.col("peak_mem_two_pointers_mb")).alias(
        "memory_two_pointers_vs_roc_auc"
    ),
)

combined_results.with_columns(cs.float().round(4))

sample_size,wsri_two_pointers,wsri_roc_auc,time_two_pointers,time_roc_auc,peak_mem_two_pointers_mb,peak_mem_roc_auc_mb,speedup_two_pointers_vs_roc_auc,memory_two_pointers_vs_roc_auc
i64,f64,f64,f64,f64,f64,f64,f64,f64
5000,-0.3375,-0.3375,0.0068,0.0083,0.1305,0.502,1.2172,3.8457
50000,1.5518,1.5518,0.0219,0.0259,1.2534,4.5766,1.1857,3.6515
500000,-0.3074,-0.3074,0.2069,0.2441,12.5036,45.0766,1.18,3.6051
5000000,0.868,0.868,2.1829,2.7142,125.0035,450.0768,1.2434,3.6005
50000000,-0.529,-0.529,26.1408,31.7893,1250.0041,4500.0769,1.2161,3.6
