# Comparison between `error-parity`'s LP solver and a brute-force solver

Out of curiosity, this notebook compares the performance and efficiency of the `error-parity` LP formulation against a baseline brute-force solver.

**NOTE**: this notebook has extra requirements, install them with:
```
pip install "error_parity[dev]"
```

In [None]:
%pip install "error-parity[dev]"

In [2]:
import logging
from itertools import product

import numpy as np
import cvxpy as cp
from scipy.spatial import ConvexHull
from sklearn.metrics import roc_curve

In [3]:
from error_parity import __version__
print(f"Notebook ran using `error-parity=={__version__}`")

Notebook ran using `error-parity==0.3.11`


## Given some data (X, Y, S)

In [4]:
def generate_synthetic_data(n_samples: int, n_groups: int, prevalence: float, seed: int):
    """Helper to generate synthetic features/labels/predictions."""

    # Construct numpy rng
    rng = np.random.default_rng(seed)
    
    # Different levels of gaussian noise per group (to induce some inequality in error rates)
    group_noise = [0.1 + 0.3 * rng.random() / (1+idx) for idx in range(n_groups)]

    # Generate predictions
    assert 0 < prevalence < 1
    y_score = rng.random(size=n_samples)

    # Generate labels
    # - define which samples belong to each group
    # - add different noise levels for each group
    group = rng.integers(low=0, high=n_groups, size=n_samples)
    
    y_true = np.zeros(n_samples)
    for i in range(n_groups):
        group_filter = group == i
        y_true_groupwise = ((
            y_score[group_filter] +
            rng.normal(size=np.sum(group_filter), scale=group_noise[i])
        ) > (1-prevalence)).astype(int)

        y_true[group_filter] = y_true_groupwise

    ### Generate features: just use the sample index
    # As we already have the y_scores, we can construct the features X
    # as the index of each sample, so we can construct a classifier that
    # simply maps this index to our pre-generated predictions for this clf.
    X = np.arange(len(y_true)).reshape((-1, 1))
        
    return X, y_true, y_score, group

In [5]:
N_GROUPS = 2
# N_SAMPLES = 1_000_000
N_SAMPLES = 100_000

SEED = 23

X, y_true, y_score, group = generate_synthetic_data(
    n_samples=N_SAMPLES,
    n_groups=N_GROUPS,
    prevalence=0.25,
    seed=SEED)

In [6]:
actual_prevalence = np.sum(y_true) / len(y_true)
print(f"Actual global prevalence: {actual_prevalence:.1%}")

Actual global prevalence: 27.2%


In [7]:
EPSILON_TOLERANCE = 0.05
# EPSILON_TOLERANCE = 1.0  # best unconstrained classifier
FALSE_POS_COST = 1
FALSE_NEG_COST = 1

---
## Given a trained predictor (that outputs real-valued scores)

In [8]:
# Example predictor that predicts the synthetically produced scores above
predictor = lambda idx: y_score[idx].ravel()

---
---
# Comparing LP vs brute-force solution

## 1. Brute-force solver

In [9]:
from itertools import product
from collections.abc import Iterable
from error_parity.evaluation import eval_accuracy_and_equalized_odds
from tqdm.auto import tqdm

def binarize_predictions(y_score, group_membership, group_thresholds: dict, seed: int = 42):
    """Binarizes score predictions using different group thresholds."""
    # Random number generator
    rng = np.random.default_rng(seed)

    # Results array
    y_pred_binary = np.zeros_like(group_membership, dtype=int)

    for group_key, group_thrs in group_thresholds.items():
        
        # Single threshold provided (no randomization)
        if not isinstance(group_thrs, Iterable):
            low_thr, high_thr = group_thrs, group_thrs
        
        # Two thresholds provided (partial randomization)
        else:
            assert len(group_thrs) == 2, f"Provide exactly 2 thresholds, got {len(group_thrs)}"
            low_thr, high_thr = group_thrs

        # Boolean numpy filter for samples of the current group
        group_filter = group_membership == group_key
        group_score_preds = y_score[group_filter]

        # Below low_thr -> negative pred.
        y_pred_binary[group_filter & (y_score < low_thr)] = 0

        # Above high_thr -> positive pred.
        y_pred_binary[group_filter & (y_score > high_thr)] = 1

        # Between low_thr and high_thr -> random uniform prediction
        if not np.isclose(low_thr, high_thr):
            middle_scores_filter = ((y_score >= low_thr) & (y_score <= high_thr))
            y_pred_binary[group_filter & middle_scores_filter] = rng.integers(
                low=0, high=2, # sampled in [low, high)
                size=np.sum(group_filter & middle_scores_filter),
            )

    # Return binarized predictions
    return y_pred_binary


def solve_brute_force(
        *,
        predictor,
        tolerance: float,
        data_tuple: float,
        threshold_ticks_step: float = 1e-2,
    ) -> dict:
    """Brute-force solution for equalized odds problem."""

    # Unpack data tuple
    X_feats, y_labels, s_group = data_tuple

    # Generate unique threshold combinations
    unique_groups = np.unique(s_group)
    group_threshold_combinations = product(*[
        ### Deterministic thresholds
        # np.arange(0, 1 + threshold_ticks_step, threshold_ticks_step)

        ### Randomized thresholds (full search)
        [
            (lo_thr, hi_thr)
            for lo_thr, hi_thr in product(
                np.arange(0, 1 + threshold_ticks_step, threshold_ticks_step),
                np.arange(0, 1 + threshold_ticks_step, threshold_ticks_step),
            )
            if lo_thr <= hi_thr
        ]
        for _ in range(N_GROUPS)
    ])

    ### Characterizing the best result
    ### NOTE: "best" is defined as maximizing accuracy constrained by eq_odds <= tolerance

    # Threshold combination of the best result
    best_combi: tuple = None
    
    # Accuracy of the best result
    best_accuracy: float = None
    
    # Constraint violation of the best result
    best_eq_odds_violation: float = None

    # Evaluate all threshold combinations
    num_determ_thrs = np.ceil(1 / threshold_ticks_step) + 1
    total_combinations = int((num_determ_thrs * (num_determ_thrs + 1) / 2) ** len(unique_groups))

    for combi in tqdm(group_threshold_combinations, total=total_combinations):
        thrsh_dict = dict(zip(unique_groups, combi))
        
        # Binarize predictions with this threshold combination
        binarized_preds = binarize_predictions(
            y_score=y_score,
            group_membership=s_group,
            group_thresholds=thrsh_dict,
        )
        
        # Evaluate results
        curr_result = eval_accuracy_and_equalized_odds(
            y_true=y_labels, y_pred_binary=binarized_preds,
            sensitive_attr=s_group,
        )
        
        curr_accuracy, curr_eq_odds_violation = curr_result

        if best_combi is None or (
            best_accuracy < curr_accuracy
            and curr_eq_odds_violation <= tolerance):
            
            # New best found
            best_combi = combi
            best_accuracy = curr_accuracy
            best_eq_odds_violation = curr_eq_odds_violation

    # Return solution that fulfills target tolerance optimally
    return {
        "group_thresholds": best_combi,
        "accuracy": best_accuracy,
        "eq_odds_violation": best_eq_odds_violation,
    }

Run solver:

In [10]:
%%time
brute_force_solution = solve_brute_force(
    predictor=predictor,
    tolerance=EPSILON_TOLERANCE,
    data_tuple=(X, y_true, group),
    threshold_ticks_step=0.1,
)

brute_force_solution

  0%|          | 0/4356 [00:00<?, ?it/s]

CPU times: user 3min 18s, sys: 8.16 s, total: 3min 26s
Wall time: 4min 23s


{'group_thresholds': ((0.7000000000000001, 0.8), (0.7000000000000001, 0.9)),
 'accuracy': 0.80763,
 'eq_odds_violation': 0.04660537497114363}

## 2. LP solver

In [11]:
from error_parity import RelaxedThresholdOptimizer

def solve_lp(predictor, tolerance: float, data_tuple: tuple):
    clf = RelaxedThresholdOptimizer(
        predictor=predictor,
        constraint="equalized_odds",
        tolerance=tolerance,
        max_roc_ticks=None,  # use full precision
        seed=SEED,
    )

    X, y_true, group = data_tuple
    clf.fit(X=X, y=y_true, group=group)
    return clf

In [12]:
%%time
postproc_clf = solve_lp(
    predictor=predictor,
    tolerance=EPSILON_TOLERANCE,
    data_tuple=(X, y_true, group),
)

CPU times: user 104 ms, sys: 4.96 ms, total: 109 ms
Wall time: 108 ms


## Compare accuracy and constraint violation
Assumes `FP_cost == FN_cost == 1.0`.

In [13]:
print(f"Accuracy for dummy constant classifier: {max(np.mean(y_true==label) for label in {0, 1}):.1%}")

Accuracy for dummy constant classifier: 72.8%


Evaluate predictions realized by LP solution.

In [14]:
y_pred_binary_lp = postproc_clf.predict(X, group=group)

lp_acc, lp_eq_odds = eval_accuracy_and_equalized_odds(y_true, y_pred_binary_lp, group)

print(f"Realized LP accuracy: {lp_acc:.1%}")
print(f"Realized LP eq. odds violation: {lp_eq_odds:.1%}\n")

Realized LP accuracy: 82.2%
Realized LP eq. odds violation: 5.0%



Evaluate predictions realized by brute-force solution.

In [15]:
y_pred_binary_brute_force = binarize_predictions(
    y_score=y_score, group_membership=group,
    group_thresholds=dict(zip(range(N_GROUPS), brute_force_solution["group_thresholds"])),
)

bf_acc, bf_eq_odds = eval_accuracy_and_equalized_odds(y_true, y_pred_binary_brute_force, group)

print(f"Realized BF accuracy: {bf_acc:.1%}")
print(f"Realized BF eq. odds violation: {bf_eq_odds:.1%}")

Realized BF accuracy: 80.8%
Realized BF eq. odds violation: 4.7%


**Conclusion:** brute-force solver took over 4 minutes to exhaustively search over 4356 combinations while the LP solver took 114ms to achieve a superior solution (because of the finer search grid).