# Binary classification risk control - Theoretical tests to validate implementation

# Protocol description
Testing theoretical guarantees of risk control in binary classification using a random classifier and synthetic data.

Each test case looks at a combination of parameters, for which we repeat the experiment `n_repeat` times. The model is the same for all experiments (basically a random classifier), but the data is different each time.

Each experiment consists of the following:
 - We calibrate a BinaryClassificationController. It gives us the list of lambda values that control the risk according to LTT.
 - Because we know that the model is random, we know the theoretical risk associated with each lambda value. So we are able to check if the lambda values given by LTT actually control the risk. If not, we count 1 "error". Note that *each* lambda value should control the risk, not just one of them.

After n_repeat experiments, we compute the proportion of errors, that should be less than delta (1 - confidence_level).

# Results
The risk is controlled in all the test cases. Overall, LTT seems very conservative (to achieve a high percentage of errors, we need to lower the confidence level significantly (0.01) and use only one threshold to avoid the Bonferroni effect). But this is likely due to the model being random, and thus having a lot of variance. It would be interesting to see how this evolves with a better model.

In [80]:
%reload_ext autoreload
%autoreload 2

In [81]:
from sklearn.datasets import make_classification
import numpy as np
from mapie.risk_control import precision, accuracy, recall
from mapie.risk_control_draft import BinaryClassificationController
from itertools import product
from decimal import Decimal

In [82]:
# Using sklearn.dummy.DummyClassifier would be cleaner
class RandomClassifier:
    def __init__(self, seed=42, threshold=0.5):
        self.seed = seed
        self.threshold = threshold

    def _get_prob(self, x):
        local_seed = hash((x, self.seed)) % (2**32)
        rng = np.random.RandomState(local_seed)
        return np.round(rng.rand(), 2)

    def predict_proba(self, X):
        probs = np.array([self._get_prob(x) for x in X])
        return np.vstack([1 - probs, probs]).T

    def predict(self, X):
        probs = self.predict_proba(X)[:, 1]
        return (probs >= self.threshold).astype(int)

In [83]:
N = [100, 5]  # size of the calibration set
risk = [
    {"name": "precision", "risk": precision},
    {"name": "recall", "risk": recall},
    {"name": "accuracy", "risk": accuracy},
]
predict_params =  [np.linspace(0, 0.99, 100), np.array([0.5])]
target_level = [0.1, 0.9]
confidence_level = [0.8, 0.2]

n_repeats = 100

for combination in product(N, risk, predict_params, target_level, confidence_level):
    N, risk, predict_params, target_level, confidence_level = combination
    alpha = float(Decimal("1") - Decimal(str(target_level))) # to avoid floating point issues
    delta = float(Decimal("1") - Decimal(str(confidence_level))) # to avoid floating point issues

    clf = RandomClassifier()
    nb_errors = 0  # number of iterations where the risk is not controlled (i.e., not all the valid thresholds found by LTT are actually valid)
    total_nb_valid_params = 0

    for _ in range(n_repeats):

        X_calibrate, y_calibrate = make_classification(
            n_samples=N,
            n_features=1,
            n_informative=1,
            n_redundant=0,
            n_repeated=0,
            n_classes=2,
            n_clusters_per_class=1,
            weights=[0.5, 0.5],
            flip_y=0,
            random_state=None
        )
        X_calibrate = X_calibrate.squeeze()

        controller = BinaryClassificationController(
            predict_function=clf.predict_proba,
            risk=risk["risk"],
            target_level=target_level,
            confidence_level=confidence_level,
        )
        controller._thresholds = predict_params
        controller.calibrate(X_calibrate, y_calibrate)
        valid_parameters = controller.valid_thresholds
        total_nb_valid_params += len(valid_parameters)

        # In the following, we check that all the valid thresholds found by LTT actually control the risk.
        # Instead of sampling a large test set, we use the fact that we know the theoretical risk of a random classifier.
        # The calculations here are valid only for a balanced data generator.
        if risk["risk"] == precision or risk["risk"] == accuracy:
            if target_level > 0.5 and len(valid_parameters) >= 1:
                nb_errors += 1
        elif risk["risk"] == recall:
            if any(x > alpha for x in valid_parameters) and len(valid_parameters) >= 1:
                nb_errors += 1

    print(f"\n{N=}, {risk['name']=}, {len(predict_params)=}, {target_level=}, {confidence_level=}")

    print(f"Proportion of times the risk is not controlled: {nb_errors/n_repeats}")
    print(f"Delta: {delta}")
    print(f"Mean number of valid thresholds found per iteration: {int(np.round(total_nb_valid_params/n_repeats))}")

    if nb_errors/n_repeats <= delta:
        print("Valid experiment")
    else:
        print("Invalid experiment")


N=100, risk['name']='precision', len(predict_params)=1, target_level=0.5, confidence_level=0.01
Proportion of times the risk is not controlled: 0.0
Delta: 0.99
Mean number of valid thresholds found per iteration: 1
Valid experiment

N=100, risk['name']='precision', len(predict_params)=1, target_level=0.5, confidence_level=0.01
Proportion of times the risk is not controlled: 0.0
Delta: 0.99
Mean number of valid thresholds found per iteration: 1
Valid experiment

N=100, risk['name']='precision', len(predict_params)=1, target_level=0.45, confidence_level=0.01
Proportion of times the risk is not controlled: 0.0
Delta: 0.99
Mean number of valid thresholds found per iteration: 1
Valid experiment

N=100, risk['name']='precision', len(predict_params)=1, target_level=0.45, confidence_level=0.01
Proportion of times the risk is not controlled: 0.0
Delta: 0.99
Mean number of valid thresholds found per iteration: 1
Valid experiment

N=100, risk['name']='precision', len(predict_params)=1, target_le