# Binary classification risk control - Theoretical tests to validate implementation

# Protocol description
We test the theoretical guarantees of risk control in binary classification by using a logistic classifier and synthetic data. The aim is to evaluate the effectiveness of the BinaryClassificationController in maintaining a predefined risk level under different conditions.

Each test case corresponds to a unique set of parameters. We repeat the experiment `n_repeat` times for each combination. The model remains the same across experiments, while the data is resampled for each repetition to account for variability.

Each experiment consists of the following steps:  

- **Calibrate the controller**  
  - We use a **BinaryClassificationController**, which provides a list of lambda values intended to control the risk according to **LTT**.  

- **Verify risk control**  
  - Since the model is a known logistic model, we can compute the **theoretical risk** associated with each lambda value.  
  - We then check whether each lambda value from LTT actually controls the risk.  
  - If a lambda does not meet the risk guarantee, we count **one "error"**.  
  - **Note:** *every* lambda value must individually control the risk — it is not enough for only some to succeed.  

After repeating the experiment `n_repeat` times, we calculate the **proportion of errors**, which should remain below `delta` = 1 - `confidence_level`.

# Results
The risk is not controlled in all the test cases.

In [29]:
%reload_ext autoreload
%autoreload 2

In [30]:
from sklearn.datasets import make_classification
import numpy as np
from mapie.risk_control import precision, accuracy, recall, BinaryClassificationController
from itertools import product
from decimal import Decimal

In [31]:
# Define a simple logistic classifier
class LogisticClassifier:
    """Deterministic sigmoid-based binary classifier."""

    def __init__(self, scale=1.0, threshold=0.5):
        self.scale = scale
        self.threshold = threshold

    def _get_prob(self, x):
        """Probability of class 1 for input x."""
        return 1 / (1 + np.exp(-self.scale * x))

    def predict_proba(self, X):
        """Return probabilities [p(y=0), p(y=1)] for each sample in X."""
        probs = np.array([self._get_prob(x) for x in X])
        return np.vstack([1 - probs, probs]).T

    def predict(self, X):
        """Return predicted class labels based on threshold."""
        probs = self.predict_proba(X)[:, 1]
        return (probs >= self.threshold).astype(int)

In [32]:
N = [100, 5]  # size of the calibration set
risk = [
    {"name": "precision", "risk": precision},
    {"name": "recall", "risk": recall},
    {"name": "accuracy", "risk": accuracy},
]
predict_params =  [np.linspace(0, 0.99, 100), np.array([0.5])]
target_level = [0.1, 0.9]
confidence_level = [0.8, 0.2]

n_repeats = 100
invalid_experiment = False

for combination in product(N, risk, predict_params, target_level, confidence_level):
    N, risk, predict_params, target_level, confidence_level = combination
    alpha = float(Decimal("1") - Decimal(str(target_level))) # to avoid floating point issues
    delta = float(Decimal("1") - Decimal(str(confidence_level))) # to avoid floating point issues

    clf = LogisticClassifier(scale=1.0, threshold=0.5)
    nb_errors = 0  # number of iterations where the risk is not controlled (i.e., not all the valid thresholds found by LTT are actually valid)
    total_nb_valid_params = 0

    for _ in range(n_repeats):

        X_calibrate, y_calibrate = make_classification(
            n_samples=N,
            n_features=1,
            n_informative=1,
            n_redundant=0,
            n_repeated=0,
            n_classes=2,
            n_clusters_per_class=1,
            weights=[0.5, 0.5],
            flip_y=0,
            random_state=None
        )
        X_calibrate = X_calibrate.squeeze()

        controller = BinaryClassificationController(
            predict_function=clf.predict_proba,
            risk=risk["risk"],
            target_level=target_level,
            confidence_level=confidence_level,
        )
        controller._predict_params = predict_params
        controller.calibrate(X_calibrate, y_calibrate)
        valid_parameters = controller.valid_predict_params
        total_nb_valid_params += len(valid_parameters)

        # In the following, we check that all the valid thresholds found by LTT actually control the risk.
        # Instead of sampling a large test set, we use the fact that we know the theoretical risk of a random classifier.
        # The calculations here are valid only for a balanced data generator.
        if risk["risk"] == precision or risk["risk"] == accuracy:
            if target_level > 0.5 and len(valid_parameters) >= 1:
                nb_errors += 1
        elif risk["risk"] == recall:
            if any(x > alpha for x in valid_parameters) and len(valid_parameters) >= 1:
                nb_errors += 1

    print(f"\n{N=}, {risk['name']=}, {len(predict_params)=}, {target_level=}, {confidence_level=}")

    print(f"Proportion of times the risk is not controlled: {nb_errors/n_repeats}")
    print(f"Delta: {delta}")
    print(f"Mean number of valid thresholds found per iteration: {int(np.round(total_nb_valid_params/n_repeats))}")

    if nb_errors/n_repeats <= delta:
        print("Valid experiment")
    else:
        print("Invalid experiment")
        invalid_experiment = True

print("\n\n\n")
if invalid_experiment:
    print("Some experiments failed.")
else:
    print("All good!")


N=100, risk['name']='precision', len(predict_params)=100, target_level=0.1, confidence_level=0.8
Proportion of times the risk is not controlled: 0.0
Delta: 0.2
Mean number of valid thresholds found per iteration: 85
Valid experiment

N=100, risk['name']='precision', len(predict_params)=100, target_level=0.1, confidence_level=0.2
Proportion of times the risk is not controlled: 0.0
Delta: 0.8
Mean number of valid thresholds found per iteration: 85
Valid experiment

N=100, risk['name']='precision', len(predict_params)=100, target_level=0.9, confidence_level=0.8
Proportion of times the risk is not controlled: 0.0
Delta: 0.2
Mean number of valid thresholds found per iteration: 0
Valid experiment

N=100, risk['name']='precision', len(predict_params)=100, target_level=0.9, confidence_level=0.2
Proportion of times the risk is not controlled: 0.6
Delta: 0.8
Mean number of valid thresholds found per iteration: 10
Valid experiment

N=100, risk['name']='precision', len(predict_params)=1, target_l