# Binary classification risk control - Theoretical tests to validate implementation

# Protocol description
We test the theoretical guarantees of risk control in binary classification by using a logistic classifier and synthetic data. The aim is to evaluate the effectiveness of the BinaryClassificationController in maintaining a predefined risk level under different conditions.

Each test case corresponds to a unique set of parameters. We repeat the experiment `n_repeat` times for each combination. The model remains the same across experiments, while the data is resampled for each repetition to account for variability.

Each experiment consists of the following steps:  

- **Calibrate the controller**  
  - We use a **BinaryClassificationController**, which provides a list of lambda values intended to control the risk according to **LTT**.  

- **Verify risk control**  
  - Since the model is a known logistic model, we can compute the **theoretical risk** associated with each lambda value.  
  - We then check whether each lambda value from LTT actually controls the risk.  
  - If a lambda does not meet the risk guarantee, we count **one "error"**.  
  - **Note:** every lambda value must individually control the risk — it is not enough for only some to succeed.  

After repeating the experiment `n_repeat` times, we calculate the **proportion of errors**, which should remain below `delta` = 1 - `confidence_level`.

# Results
The LTT procedure generally controls the risk as expected, with most experiments marked as valid. However, experiments are mainly invalidated when the target level is set too high (e.g. 0.9) or when the calibration set is too small, resulting in an insufficient number of valid thresholds.

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from sklearn.datasets import make_classification
from sklearn.metrics import precision_score, recall_score, accuracy_score
from sklearn.utils import check_random_state
import numpy as np
import matplotlib.pyplot as plt
from mapie.risk_control import precision, accuracy, recall, BinaryClassificationController
from itertools import product
from decimal import Decimal

In [None]:
# Define a simple logistic classifier
class LogisticClassifier:
    """Deterministic sigmoid-based binary classifier."""

    def __init__(self, scale=1.0, threshold=0.5):
        self.scale = scale
        self.threshold = threshold

    def _get_prob(self, x):
        """Probability of class 1 for input x."""
        inf_, sup_ = 0.1, 1.0
        return (sup_ - inf_) / (1 + np.exp(-self.scale * x)) + inf_

    def predict_proba(self, X):
        """Return probabilities [p(y=0), p(y=1)] for each sample in X."""
        probs = np.array([self._get_prob(x) for x in X])
        return np.vstack([1 - probs, probs]).T

    def predict(self, X):
        """Return predicted class labels based on threshold."""
        probs = self.predict_proba(X)[:, 1]
        return (probs >= self.threshold).astype(int)

In [None]:
# Function to generate logistic data
def make_logistic_data(n_samples=200, scale=2.0, random_state=None):
    rng = check_random_state(random_state)
    X = rng.uniform(-3, 3, size=n_samples)
    logits = scale * X
    probs = 1 / (1 + np.exp(-logits))
    y = rng.binomial(1, probs)
    return X, y

In [None]:
N = [100, 5]  # size of the calibration set
risk = [
    {"name": "precision", "risk": precision},
    {"name": "recall", "risk": recall},
    {"name": "accuracy", "risk": accuracy},
]
predict_params =  [np.linspace(0, 0.99, 100), np.array([0.5])]
target_level = [0.1, 0.9]
confidence_level = [0.8, 0.2]

n_repeats = 100
invalid_experiment = False

for combination in product(N, risk, predict_params, target_level, confidence_level):
    N, risk, predict_params, target_level, confidence_level = combination
    alpha = float(Decimal("1") - Decimal(str(target_level))) # to avoid floating point issues
    delta = float(Decimal("1") - Decimal(str(confidence_level))) # to avoid floating point issues

    clf = LogisticClassifier(scale=2.0, threshold=0.5)
    nb_errors = 0  # number of iterations where the risk is not controlled (i.e., not all the valid thresholds found by LTT are actually valid)
    total_nb_valid_params = 0

    for _ in range(n_repeats):

        # X_calibrate, y_calibrate = make_classification(
        #     n_samples=N,
        #     n_features=1,
        #     n_informative=1,
        #     n_redundant=0,
        #     n_repeated=0,
        #     n_classes=2,
        #     n_clusters_per_class=1,
        #     weights=[0.5, 0.5],
        #     flip_y=0,
        #     random_state=None
        # )
        # X_calibrate = X_calibrate.squeeze()

        X_calibrate, y_calibrate = make_logistic_data(n_samples=N, scale=2.0, random_state=None)

        controller = BinaryClassificationController(
            predict_function=clf.predict_proba,
            risk=risk["risk"],
            target_level=target_level,
            confidence_level=confidence_level,
        )
        controller._predict_params = predict_params
        controller.calibrate(X_calibrate, y_calibrate)
        valid_parameters = controller.valid_predict_params
        total_nb_valid_params += len(valid_parameters)

        # In the following, we check that all the valid thresholds found by LTT actually control the risk.
        # We sample a large test set and estimate the risk for each valid_parameters using the logistic classifier.
        X_test, y_test = make_classification(
            n_samples=100,
            n_features=1,
            n_informative=1,
            n_redundant=0,
            n_repeated=0,
            n_classes=2,
            n_clusters_per_class=1,
            weights=[0.5, 0.5],
            flip_y=0,
            random_state=None
        )
        X_test = X_test.squeeze()
        probs = clf.predict_proba(X_test)[:, 1]
        
        # If no valid parameters found, risk is not controlled
        if len(valid_parameters) >= 1:
            for lambda_ in valid_parameters:
                y_pred = (probs >= lambda_).astype(int)

                if risk["risk"] == precision:
                    empirical_metric = precision_score(y_test, y_pred, zero_division=0)
                elif risk["risk"] == recall:
                    empirical_metric = recall_score(y_test, y_pred, zero_division=0)
                elif risk["risk"] == accuracy:
                    empirical_metric = accuracy_score(y_test, y_pred)

                # Check if the risk control fails
                if risk["risk"].higher_is_better:
                    if empirical_metric <= target_level:
                        nb_errors += 1
                        break 
                else:
                    if empirical_metric > target_level:
                        nb_errors += 1
                        break

    print(f"\n{N=}, {risk['name']=}, {len(predict_params)=}, {target_level=}, {confidence_level=}")

    print(f"Proportion of times the risk is not controlled: {nb_errors/n_repeats}")
    print(f"Delta: {delta}")
    print(f"Mean number of valid thresholds found per iteration: {int(np.round(total_nb_valid_params/n_repeats))}")

    if nb_errors/n_repeats <= delta:
        print("Valid experiment")
    else:
        print("Invalid experiment")
        invalid_experiment = True

print("\n\n\n")
if invalid_experiment:
    print("Some experiments failed.")
else:
    print("All good!")

# Analysis of the distribution of the supremum of the risk

In [None]:
# Experiment parameters
risk = [
    {"name": "precision", "risk": precision, "risk_score": precision_score},
    {"name": "recall", "risk": recall, "risk_score": recall_score},
    {"name": "accuracy", "risk": accuracy, "risk_score": accuracy_score},
]
target_level = [0.1, 0.9]
confidence_level = [0.8, 0.2]

# Fixed classifier
clf = LogisticClassifier(scale=2.0, threshold=0.5)

# Number of experiment repetitions
n_repeats = 100

# Storage list
results = []

for rsk, tgt, conf in product(risk, target_level, confidence_level):
    print(f"\nRunning {rsk['name']=}, {tgt=}, {conf=}")

    min_metrics = []
    best_lambdas = []

    for _ in range(n_repeats):
        # Data generation
        # X_calibrate, y_calibrate = make_classification(
        #     n_samples=1000, n_features=1, n_informative=1, n_redundant=0,
        #     n_classes=2, n_clusters_per_class=1, random_state=None
        # )
        # X_calibrate = X_calibrate.squeeze()
        X_calibrate, y_calibrate = make_logistic_data(
            n_samples=1000, scale=2.0, random_state=None
        )

        # X_test, y_test = make_classification(
        #     n_samples=1000, n_features=1, n_informative=1, n_redundant=0,
        #     n_classes=2, n_clusters_per_class=1, random_state=None
        # )
        # X_test = X_test.squeeze()

        X_test, y_test = make_logistic_data(
            n_samples=1000, scale=2.0, random_state=None
        )

        # Calibration
        controller = BinaryClassificationController(
            predict_function=clf.predict_proba,
            risk=rsk["risk"],
            target_level=tgt,
            confidence_level=conf,
        )
        controller._predict_params = np.linspace(0, 0.99, 100)
        controller.calibrate(X_calibrate, y_calibrate)
        valid_parameters = controller.valid_predict_params

        probs = clf.predict_proba(X_test)[:, 1]

        empirical_metric_list = []
        if len(valid_parameters) > 0:
            for lambda_ in valid_parameters:
                y_pred = (probs >= lambda_).astype(int)
                if rsk["risk"] == accuracy:
                    empirical_metric = rsk["risk_score"](y_test, y_pred)
                else:
                    empirical_metric = rsk["risk_score"](y_test, y_pred, zero_division=0)
                empirical_metric_list.append(empirical_metric)

            empirical_metric_list = np.array(empirical_metric_list)
            min_idx = np.argmin(empirical_metric_list)
            min_metrics.append(empirical_metric_list[min_idx])
            best_lambdas.append(valid_parameters[min_idx])
        else:
            min_metrics.append(-1)
            best_lambdas.append(-1)
            best_lambdas.append(-1)
            best_lambdas.append(-1)
            best_lambdas.append(-1)

    results.append({
        "risk": rsk["name"],
        "target": tgt,
        "confidence": conf,
        "min_metrics": np.array(min_metrics),
        "best_lambdas": np.array(best_lambdas)
    })


In [None]:
# Plotting results
risks = ["precision", "recall", "accuracy"]
unique_targets = sorted(list(set(res["target"] for res in results)))
unique_confidences = sorted(list(set(res["confidence"] for res in results)))

for risk_name in risks:
    res_risk = [r for r in results if r["risk"] == risk_name]

    fig, axes = plt.subplots(2, 4, figsize=(14, 7), sharex=False)
    fig.suptitle(f"{risk_name.capitalize()} — Risk Control Visualization", fontsize=14, fontweight="bold")

    # Updated explanatory legend text (under title)
    legend_text = (
        "Shaded area: upper (1−δ)-tail of sup R(λ). "
        "Risk controlled if at least the (1−δ) percentile lies below α."
    )
    fig.text(0.5, 0.93, legend_text, ha="center", va="bottom", fontsize=9, color="gray")

    idx = 0
    for t in unique_targets:
        for c in unique_confidences:
            ax_top = axes[0, idx]
            ax_bottom = axes[1, idx]

            res_conf = [r for r in res_risk if r["target"] == t and r["confidence"] == c]
            if len(res_conf) == 0:
                # Always show an empty top plot with legend
                ax_top.hist([], bins=30, color="steelblue", alpha=0.7, edgecolor="white")
                ax_top.axvline(1 - t, color="red", linestyle="--", linewidth=2)
                ax_top.axvspan(0, 0, color="salmon", alpha=0.3)  # empty shaded
                ax_top.set_title(f"target={t}, conf={c}", fontsize=10)
                ax_top.set_ylabel("Freq.")
                ax_top.legend(
                    ["$\\alpha = 1 - target$", "upper δ-tail (uncontrolled)"],
                    loc="upper center", frameon=False, fontsize=8
                )

                # Empty bottom plot
                ax_bottom.hist([], bins=30, color="darkorange", alpha=0.7, edgecolor="white")
                ax_bottom.set_xlabel("λ (threshold)")
                ax_bottom.set_ylabel("Freq.")
                idx += 1
                continue

            r = res_conf[0]
            sup_r = 1 - r["min_metrics"]
            lambdas = r["best_lambdas"]

            clean_sup_r = sup_r[~np.isnan(sup_r)]
            if clean_sup_r.size > 0:
                ax_top.hist(clean_sup_r, bins=30, alpha=0.7, color="steelblue", edgecolor="white")
                alpha_line = 1 - t
                ax_top.axvline(alpha_line, color="red", linestyle="--", linewidth=2)
                delta = 1 - c
                perc = (1 - delta) * 100
                threshold = np.percentile(clean_sup_r, perc)
                ax_top.axvspan(threshold, np.max(clean_sup_r), color="salmon", alpha=0.3)
            else:
                # Show empty histogram if no data
                ax_top.hist([], bins=30, color="steelblue", alpha=0.7, edgecolor="white")
                ax_top.axvline(1 - t, color="red", linestyle="--", linewidth=2)
                ax_top.axvspan(0, 0, color="salmon", alpha=0.3)

            ax_top.set_title(f"target={t}, conf={c}", fontsize=10)
            ax_top.set_ylabel("Freq.")
            ax_top.legend(
                ["$\\alpha = 1 - target$", "upper (1-δ)-tail"],
                loc="upper center", frameon=False, fontsize=8
            )

            # Bottom: λ histogram
            clean_lambda = lambdas[~np.isnan(lambdas)]
            if clean_lambda.size > 0:
                ax_bottom.hist(clean_lambda, bins=30, alpha=0.7, color="darkorange", edgecolor="white")
            else:
                ax_bottom.hist([], bins=30, color="darkorange", alpha=0.7, edgecolor="white")

            ax_bottom.set_xlabel("λ (threshold)")
            ax_bottom.set_ylabel("Freq.")
            idx += 1

    # Hide unused subplots
    for j in range(idx, 4):
        axes[0, j].axis("off")
        axes[1, j].axis("off")

    plt.tight_layout(rect=[0, 0, 1, 0.9])
    plt.show()


In [None]:
# Generate data
X_calibrate, y_calibrate = make_classification(
    n_samples=10000,
    n_features=1,
    n_informative=1,
    n_redundant=0,
    n_repeated=0,
    n_classes=2,
    n_clusters_per_class=1,
    weights=[0.5, 0.5],
    flip_y=0,
    random_state=None
)
X_calibrate = X_calibrate.squeeze()

# Plot
plt.figure(figsize=(8, 3))
plt.scatter(X_calibrate[y_calibrate == 0], np.zeros_like(X_calibrate[y_calibrate == 0]),
            color="blue", alpha=0.7, label="Class 0")
plt.scatter(X_calibrate[y_calibrate == 1], np.ones_like(X_calibrate[y_calibrate == 1]),
            color="red", alpha=0.7, label="Class 1")

# Small jitter for visibility
plt.yticks([0, 1], ["Class 0", "Class 1"])
plt.xlabel("Feature value (X)")
plt.title("1D Binary Classification Dataset")
plt.legend()
plt.grid(True, linestyle="--", alpha=0.5)
plt.tight_layout()
plt.show()


In [None]:
X, y = make_logistic_data(n_samples=1000, scale=2.0, random_state=None)
clf = LogisticClassifier(scale=2.0, threshold=0.5)

order = np.argsort(X)
X_sorted = X[order]
probs_sorted = clf._get_prob(X_sorted)


plt.figure(figsize=(7, 3))
plt.scatter(X, y, c=y, cmap="coolwarm", alpha=0.5, edgecolor="none", label="Samples")
plt.plot(X_sorted, probs_sorted, color="black", linewidth=2)
plt.show()
