# Binary classification risk control - Theoretical tests prototype

In [1]:
%reload_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import itertools
from matplotlib import pyplot as plt
from collections import Counter

from mapie.risk_control_draft import BinaryClassificationController

In [4]:
class RandomClassifier:
    def __init__(self, seed=42, threshold=0.5):
        self.seed = seed
        self.threshold = threshold

    def _get_prob(self, x):
        local_seed = hash((x, self.seed)) % (2**32)
        rng = np.random.RandomState(local_seed)
        return np.round(rng.rand(), 2)

    def predict_proba(self, X):
        probs = np.array([self._get_prob(x) for x in X])
        return np.vstack([1 - probs, probs]).T

    def predict(self, X):
        probs = self.predict_proba(X)[:, 1]
        return (probs >= self.threshold).astype(int)

In [26]:
N = 200  # size of the calibration set
p = 0.5  # proportion of positives in the calibration set
metric = "precision"
target_level = 0.98
predict_params = np.linspace(0, 0.99, 1)
confidence_level = 0.1

n_repeats = 100

In [27]:
clf = RandomClassifier()
nb_errors = 0  # number of iterations where the risk is not controlled (i.e., not all the valid thresholds found by LTT are actually valid)
no_valid_params = 0  # number of iterations where LTT finds no valid threshold
nb_valid_params = 0  # total number of valid thresholds LTT finds over all iterations

if metric == "precision":
    if target_level <= p:
        actual_valid_parameters = predict_params
    else:
        actual_valid_parameters = []
elif metric == "recall":
    actual_valid_parameters = predict_params[predict_params <= np.round(1-target_level, 2)]

for _ in range(n_repeats):

    X_calibrate = list(range(1, N+1))
    y_calibrate = [1] * int(p*N) + [0] * (N - int(p*N))
    np.random.shuffle(y_calibrate)

    controller = BinaryClassificationController(
        fitted_binary_classifier=clf,
        metric=metric,
        target_level=target_level,
        confidence_level=confidence_level,
    )
    controller.calibrate(X_calibrate, y_calibrate)
    valid_parameters = controller.valid_thresholds

    nb_valid_params += len(valid_parameters)

    if len(valid_parameters) == 0:
        no_valid_params += 1
    
    if not all(x in actual_valid_parameters for x in valid_parameters):
        nb_errors += 1

print(f"Mean number of valid thresholds found per iteration: {nb_valid_params/n_repeats}")
print(f"Proportion of times LTT finds no valid threshold: {no_valid_params/n_repeats}")
print(f"Proportion of times the risk is not controlled: {nb_errors/n_repeats}")

if nb_errors/n_repeats <= 1 - confidence_level:
    print("Risk controlled")
else:
    print("Risk not controlled")

Mean number of valid thresholds found per iteration: 0.0
Proportion of times LTT finds no valid threshold: 1.0
Proportion of times the risk is not controlled: 0.0
Risk controlled
