In [1]:
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

%reload_ext autoreload
%autoreload 2

import time
import tracemalloc
from typing import Dict, List

import numpy as np
import polars as pl
import polars.selectors as cs
from wsri import somers_d_roc_auc, somers_d_two_pointers, wsri

## Benchmark vs. Monitoring Setup

* **Benchmark**: A fixed reference period (e.g., 2023 data) that we treat as the "gold standard" or baseline distribution of scores and outcomes.
* **Monitoring**: A newer period (e.g., Q1 2024 data) that we want to evaluate for stability relative to the benchmark.

The purpose of **wSRI** here is to ask: *if the score-label relationship in monitoring looks different from benchmark, how much of that difference reflects population shift rather than true deterioration in model rank-ordering?*

---

### Why Weight the Benchmark

Weighting only the **benchmark** population makes sense because:

* **Benchmark as reference**: We want monitoring data to be evaluated *as is*, because it reflects the actual, current population distribution.
* **Benchmark rescaled**: By reweighting benchmark bins to match the marginal distribution of scores in monitoring, we are forcing the benchmark to look like e.g. "what 2023 would have looked like under the 2024 score distribution."
* **Apples-to-apples comparability**: Without reweighting, Somers’ D differences could be confounded by the fact that the monitoring period simply has more/less accounts in certain score ranges (not necessarily because the rank-order power of the model changed).

So the weighting isolates the effect of **conditional rank performance** (score <-> label relationship) rather than **population score mix**.

---

### Why Not Weight Monitoring

If we weighted **monitoring** to benchmark:

* We’d be discarding real, current information about the new population distribution.
* The monitoring Somers’ D would be artificially “distorted back” to the 2023 distribution, making it less diagnostic for drift.
* We’d risk underestimating deterioration if population shift itself is part of what we want to detect.

---

### Covariate Shift

Formally, weighting only benchmark approximates a **covariate-shift correction** problem:

* Assume conditional outcome distributions *P(y|score)* are stable in both periods (or at least what we want to test).
* But the marginal score distribution *P(score)* has shifted.
* Reweighting benchmark to monitoring ensures we compare *P(y|score)* fairly across periods.

This is consistent with the idea of “importance weighting” in domain adaptation:

$$
w(x) = \frac{P_{\text{monitoring}}(x)}{P_{\text{benchmark}}(x)}
$$

where here *x* = score bins.

It makes sense to weight the **benchmark** and not the monitoring because we want to transform the reference into the monitoring population’s shape, ensuring comparability. Monitoring stays unweighted so it reflects reality. This lets the ratio (monitoring Somers’ D / reweighted benchmark Somers’ D) isolate *rank-order stability*, not population differences.


## Global

In [2]:
rng: np.random.Generator = np.random.default_rng(seed=1227)

sample_sizes: np.typing.NDArray[np.int64] = np.array(
    [5_000, 50_000, 500_000, 5_000_000, 50_000_000]
)

score_column: str = "score"
label_column: str = "label"

## JIT Warmup

In [3]:
# Warm-up for JIT since the first call to a numba JIT function includes a one-time compilation overhead
warmup_sample_size: int = 1_000
benchmark: pl.DataFrame = pl.DataFrame(
    {
        score_column: rng.uniform(0, 1, warmup_sample_size).astype(np.float64),
        label_column: rng.choice([0, 1], warmup_sample_size).astype(np.int8),
    }
).lazy()

monitoring: pl.DataFrame = pl.DataFrame(
    {
        score_column: rng.uniform(0, 1, warmup_sample_size).astype(np.float64),
        label_column: rng.choice([0, 1], warmup_sample_size).astype(np.int8),
    }
).lazy()

wsri_value: float = wsri(
    benchmark=benchmark,
    monitoring=monitoring,
    score_column=score_column,
    label_column=label_column,
    quantiles=20,
    callback=somers_d_two_pointers,
)

print(f"WSRI: {wsri_value}")

WSRI: 4.687046070922117


## Wall Time

In [4]:
wall_time_results: List[Dict[str, float]] = []
data_sets: List[Dict[str, pl.LazyFrame]] = []

for sample_size in sample_sizes:
    print(f"Sample size: {sample_size}")

    benchmark: pl.LazyFrame = pl.DataFrame(
        {
            score_column: rng.uniform(0, 1, sample_size).astype(np.float64),
            label_column: rng.choice([0, 1], sample_size).astype(np.int8),
        }
    ).lazy()

    monitoring: pl.LazyFrame = pl.DataFrame(
        {
            score_column: rng.uniform(0, 1, sample_size).astype(np.float64),
            label_column: rng.choice([0, 1], sample_size).astype(np.int8),
        }
    ).lazy()

    data_sets.append({"benchmark": benchmark, "monitoring": monitoring})

    start_time_two_pointers: float = time.perf_counter()
    wsri_two_pointers: float = wsri(
        benchmark=benchmark,
        monitoring=monitoring,
        score_column=score_column,
        label_column=label_column,
        quantiles=20,
        callback=somers_d_two_pointers,
    )
    time_two_pointers: float = time.perf_counter() - start_time_two_pointers

    start_time_roc_auc: float = time.perf_counter()
    wsri_roc_auc: float = wsri(
        benchmark=benchmark,
        monitoring=monitoring,
        score_column=score_column,
        label_column=label_column,
        quantiles=20,
        callback=somers_d_roc_auc,
    )
    time_roc_auc: float = time.perf_counter() - start_time_roc_auc

    wall_time_results.append(
        {
            "sample_size": sample_size,
            "wsri_two_pointers": wsri_two_pointers,
            "wsri_roc_auc": wsri_roc_auc,
            "time_two_pointers": time_two_pointers,
            "time_roc_auc": time_roc_auc,
        }
    )

Sample size: 5000
Sample size: 50000
Sample size: 500000
Sample size: 5000000
Sample size: 50000000


## Memory

In [5]:
memory_results: List[Dict[str, float]] = []

for sample_size, data_set in zip(sample_sizes, data_sets):
    print(f"Sample size: {sample_size}")

    benchmark: pl.LazyFrame = data_set["benchmark"]
    monitoring: pl.LazyFrame = data_set["monitoring"]

    tracemalloc.start()
    _ = wsri(
        benchmark=benchmark,
        monitoring=monitoring,
        score_column=score_column,
        label_column=label_column,
        quantiles=20,
        callback=somers_d_two_pointers,
    )
    _, peak_mem_two_pointers = tracemalloc.get_traced_memory()
    tracemalloc.stop()

    tracemalloc.start()
    _ = wsri(
        benchmark=benchmark,
        monitoring=monitoring,
        score_column=score_column,
        label_column=label_column,
        quantiles=20,
        callback=somers_d_roc_auc,
    )
    _, peak_mem_roc_auc = tracemalloc.get_traced_memory()
    tracemalloc.stop()

    memory_results.append(
        {
            "sample_size": sample_size,
            "peak_mem_two_pointers_mb": peak_mem_two_pointers / 10**6,
            "peak_mem_roc_auc_mb": peak_mem_roc_auc / 10**6,
        }
    )

Sample size: 5000
Sample size: 50000
Sample size: 500000
Sample size: 5000000
Sample size: 50000000


## Combine Results


In [6]:
wall_time_data: pl.DataFrame = pl.from_dicts(wall_time_results)
memory_data: pl.DataFrame = pl.from_dicts(memory_results)

combined_results: pl.DataFrame = wall_time_data.join(
    other=memory_data, on="sample_size", how="inner"
).with_columns(
    (pl.col("time_roc_auc") / pl.col("time_two_pointers")).alias(
        "speedup_two_pointers_vs_roc_auc"
    ),
    (pl.col("peak_mem_roc_auc_mb") / pl.col("peak_mem_two_pointers_mb")).alias(
        "memory_two_pointers_vs_roc_auc"
    ),
)

combined_results.with_columns(cs.float().round(4))

sample_size,wsri_two_pointers,wsri_roc_auc,time_two_pointers,time_roc_auc,peak_mem_two_pointers_mb,peak_mem_roc_auc_mb,speedup_two_pointers_vs_roc_auc,memory_two_pointers_vs_roc_auc
i64,f64,f64,f64,f64,f64,f64,f64,f64
5000,-0.3375,-0.3375,0.0068,0.0083,0.1305,0.502,1.2172,3.8457
50000,1.5518,1.5518,0.0219,0.0259,1.2534,4.5766,1.1857,3.6515
500000,-0.3074,-0.3074,0.2069,0.2441,12.5036,45.0766,1.18,3.6051
5000000,0.868,0.868,2.1829,2.7142,125.0035,450.0768,1.2434,3.6005
50000000,-0.529,-0.529,26.1408,31.7893,1250.0041,4500.0769,1.2161,3.6
