# Binary classification risk control - Theoretical tests prototype

In [20]:
%reload_ext autoreload
%autoreload 2

In [21]:
from sklearn.datasets import make_classification
import numpy as np
from mapie.risk_control import precision, accuracy, recall
from mapie.risk_control_draft import BinaryClassificationController
from itertools import product

In [22]:
# Using sklearn.dummy.DummyClassifier would be clearer
class RandomClassifier:
    def __init__(self, seed=42, threshold=0.5):
        self.seed = seed
        self.threshold = threshold

    def _get_prob(self, x):
        local_seed = hash((x, self.seed)) % (2**32)
        rng = np.random.RandomState(local_seed)
        return np.round(rng.rand(), 2)

    def predict_proba(self, X):
        probs = np.array([self._get_prob(x) for x in X])
        return np.vstack([1 - probs, probs]).T

    def predict(self, X):
        probs = self.predict_proba(X)[:, 1]
        return (probs >= self.threshold).astype(int)

In [38]:
N = [100, 5]  # size of the calibration set
risk = [
    {"name": "precision", "risk": precision},
    {"name": "recall", "risk": recall},
    {"name": "accuracy", "risk": accuracy},
]
predict_params = [np.linspace(0, 0.99, 100), np.array([0.5])]
target_level = [0.1, 0.9]
confidence_level = [0.8, 0.2]

n_repeats = 100
invalid_experiments = []

for i, combination in enumerate(product(N, risk, predict_params, target_level, confidence_level)):
    N, risk, predict_params, target_level, confidence_level = combination

    clf = RandomClassifier()
    nb_errors = 0  # number of iterations where the risk is not controlled (i.e., not all the valid thresholds found by LTT are actually valid)

    for _ in range(n_repeats):

        X_calibrate, y_calibrate = make_classification(
            n_samples=N,
            n_features=1,
            n_informative=1,
            n_redundant=0,
            n_repeated=0,
            n_classes=2,
            n_clusters_per_class=1,
            weights=[0.5, 0.5],
            flip_y=0,
            random_state=None
        )
        X_calibrate = X_calibrate.squeeze()

        controller = BinaryClassificationController(
            predict_function=clf.predict_proba,
            risk=risk["risk"],
            target_level=target_level,
            confidence_level=confidence_level,
        )
        controller._thresholds = predict_params
        controller.calibrate(X_calibrate, y_calibrate)
        valid_parameters = controller.valid_thresholds

        # The following works because the data is balanced
        if risk["risk"] == precision or risk["risk"] == accuracy:
            if target_level > 0.5 and len(valid_parameters) >= 1:
                nb_errors += 1
        elif risk["risk"] == recall:
            if any(x < 0 or x > np.round(1-target_level, 2) for x in valid_parameters) and len(valid_parameters) >= 1:
                nb_errors += 1

    print(f"Proportion of times the risk is not controlled: {nb_errors/n_repeats}")
    print(f"Risk level: {1-confidence_level}")

    if nb_errors/n_repeats <= 1 - confidence_level:
        #print("Valid experiment")
        pass
    else:
        print("Unvalid experiment")
        print(f"{N=} {risk['name']=} {predict_params=} {target_level=} {confidence_level=}")
        invalid_experiments.append(i)

Proportion of times the risk is not controlled: 0.0
Risk level: 0.19999999999999996
Proportion of times the risk is not controlled: 0.0
Risk level: 0.8
Proportion of times the risk is not controlled: 1.0
Risk level: 0.19999999999999996
Unvalid experiment
N=100 risk['name']='precision' predict_params=array([0.  , 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1 ,
       0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2 , 0.21,
       0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3 , 0.31, 0.32,
       0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4 , 0.41, 0.42, 0.43,
       0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5 , 0.51, 0.52, 0.53, 0.54,
       0.55, 0.56, 0.57, 0.58, 0.59, 0.6 , 0.61, 0.62, 0.63, 0.64, 0.65,
       0.66, 0.67, 0.68, 0.69, 0.7 , 0.71, 0.72, 0.73, 0.74, 0.75, 0.76,
       0.77, 0.78, 0.79, 0.8 , 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87,
       0.88, 0.89, 0.9 , 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98,
       0.99]) target_level=0.9 confidence_

In [39]:
print(invalid_experiments)

[2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 30, 38, 42, 46, 47]


In [35]:
[2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 30, 38, 42, 46, 47]

[2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 30, 38, 42, 46, 47]

In [40]:
print(i)

47


In [41]:
print(len(invalid_experiments))

17
