# Binary classification risk control - Theoretical tests prototype

In [1]:
%reload_ext autoreload
%autoreload 2

In [8]:
from sklearn.datasets import make_classification
import numpy as np
import itertools
from matplotlib import pyplot as plt

from mapie.risk_control_draft import BinaryClassificationController

In [9]:
class RandomClassifier:
    def __init__(self, seed=42, threshold=0.5):
        self.seed = seed
        self.threshold = threshold

    def _get_prob(self, x):
        local_seed = hash((x, self.seed)) % (2**32)
        rng = np.random.RandomState(local_seed)
        return np.round(rng.rand(), 2)

    def predict_proba(self, X):
        probs = np.array([self._get_prob(x) for x in X])
        return np.vstack([1 - probs, probs]).T

    def predict(self, X):
        probs = self.predict_proba(X)[:, 1]
        return (probs >= self.threshold).astype(int)

In [10]:
N = 100  # size of the calibration set
p = 0.5  # proportion of positives in the calibration set
metric = "recall"
target_level = 0.8
predict_params = np.linspace(0, 0.99, 100)
confidence_level = 0.7

n_repeats = 100

In [11]:
clf = RandomClassifier()
nb_errors = 0  # number of iterations where the risk is not controlled (i.e., not all the valid thresholds found by LTT are actually valid)
no_valid_params = 0  # number of iterations where LTT finds no valid threshold
nb_valid_params = 0  # total number of valid thresholds LTT finds over all iterations

for _ in range(n_repeats):

    X_calibrate, y_calibrate = make_classification(
        n_samples=N,
        n_features=1,
        n_informative=1,
        n_redundant=0,
        n_repeated=0,
        n_classes=2,
        n_clusters_per_class=1,
        weights=[1 - p, p],
        flip_y=0.5,
        random_state=None
    )
    X_calibrate = X_calibrate.squeeze()

    controller = BinaryClassificationController(
        fitted_binary_classifier=clf,
        metric=metric,
        target_level=target_level,
        confidence_level=confidence_level,
    )
    controller._thresholds = predict_params
    controller.calibrate(X_calibrate, y_calibrate)
    valid_parameters = controller.valid_thresholds

    nb_valid_params += len(valid_parameters)

    if len(valid_parameters) == 0:
        no_valid_params += 1

    if metric == "precision" or metric == "accuracy":
        if target_level > p and len(valid_parameters) >= 1:
            nb_errors += 1
    elif metric == "recall":
        if any(x < 0 or x > np.round(1-target_level, 2) for x in valid_parameters) and len(valid_parameters) >= 1:
            nb_errors += 1

print(f"Mean number of valid thresholds found per iteration: {nb_valid_params/n_repeats}")
print(f"Proportion of times LTT finds no valid threshold: {no_valid_params/n_repeats}")
print(f"Proportion of times the risk is not controlled: {nb_errors/n_repeats}")

if nb_errors/n_repeats <= 1 - confidence_level:
    print("Risk controlled")
else:
    print("Risk not controlled")

Mean number of valid thresholds found per iteration: 5.82
Proportion of times LTT finds no valid threshold: 0.0
Proportion of times the risk is not controlled: 0.0
Risk controlled
