In [1]:
import numpy as np

import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots

from scipy.special import expit

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score as sk_balanced_accuracy_score
from sklearn.model_selection import train_test_split

pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
pio.templates.default = "plotly_white"

rng = np.random.default_rng(42)
np.set_printoptions(precision=4, suppress=True)


In [2]:
def accuracy_score_np(y_true, y_pred, sample_weight=None) -> float:
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    correct = (y_true == y_pred).astype(float)

    if sample_weight is None:
        return float(correct.mean())

    w = np.asarray(sample_weight, dtype=float)
    return float(np.sum(w * correct) / np.sum(w))


def per_class_recall_np(
    y_true,
    y_pred,
    labels=None,
    sample_weight=None,
    zero_division: float = 0.0,
):
    # Per-class recall:
    #   recall_k = (# predicted as k among true k) / (# true k)
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    if labels is None:
        labels = np.unique(y_true)
    labels = np.asarray(labels)

    if sample_weight is None:
        sample_weight = np.ones_like(y_true, dtype=float)
    else:
        sample_weight = np.asarray(sample_weight, dtype=float)

    recalls = np.empty(len(labels), dtype=float)

    for i, cls in enumerate(labels):
        mask = y_true == cls
        denom = float(sample_weight[mask].sum())
        if denom == 0.0:
            recalls[i] = zero_division
        else:
            num = float(sample_weight[mask & (y_pred == cls)].sum())
            recalls[i] = num / denom

    return recalls, labels


def balanced_accuracy_score_np(
    y_true,
    y_pred,
    *,
    labels=None,
    sample_weight=None,
    adjusted: bool = False,
    zero_division: float = 0.0,
) -> float:
    recalls, labels_used = per_class_recall_np(
        y_true,
        y_pred,
        labels=labels,
        sample_weight=sample_weight,
        zero_division=zero_division,
    )
    score = float(np.mean(recalls))

    if not adjusted:
        return score

    n_classes = len(labels_used)
    if n_classes <= 1:
        return 1.0

    chance = 1.0 / n_classes
    return float((score - chance) / (1.0 - chance))


def confusion_matrix_np(y_true, y_pred, labels=None, sample_weight=None):
    # Small confusion-matrix helper (mainly for plotting)
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    if labels is None:
        labels = np.unique(np.concatenate([y_true, y_pred]))
    labels = np.asarray(labels)

    label_to_index = {label: i for i, label in enumerate(labels)}

    true_idx = np.array([label_to_index.get(v, -1) for v in y_true], dtype=int)
    pred_idx = np.array([label_to_index.get(v, -1) for v in y_pred], dtype=int)

    if sample_weight is None:
        sample_weight = np.ones_like(true_idx, dtype=float)
    else:
        sample_weight = np.asarray(sample_weight, dtype=float)

    cm = np.zeros((len(labels), len(labels)), dtype=float)
    valid = (true_idx >= 0) & (pred_idx >= 0)
    np.add.at(cm, (true_idx[valid], pred_idx[valid]), sample_weight[valid])

    return cm, labels


# quick sanity check vs scikit-learn
_y_true = np.array([0, 0, 0, 1, 1, 1])
_y_pred = np.array([0, 0, 1, 0, 1, 1])
print('ours:', balanced_accuracy_score_np(_y_true, _y_pred))
print('sklearn:', sk_balanced_accuracy_score(_y_true, _y_pred))


ours: 0.6666666666666666
sklearn: 0.6666666666666666


In [3]:
n_neg, n_pos = 990, 10

y_true = np.array([0] * n_neg + [1] * n_pos)
y_pred = np.zeros_like(y_true)

acc = accuracy_score_np(y_true, y_pred)
bal = balanced_accuracy_score_np(y_true, y_pred)
bal_adj = balanced_accuracy_score_np(y_true, y_pred, adjusted=True)
recalls, labels = per_class_recall_np(y_true, y_pred)

print(f"accuracy:          {acc:.4f}")
print(f"balanced accuracy: {bal:.4f}")
print(f"adjusted BA:       {bal_adj:.4f}")
print("per-class recall:", dict(zip(labels.tolist(), recalls.tolist())))


accuracy:          0.9900
balanced accuracy: 0.5000
adjusted BA:       0.0000
per-class recall: {0: 1.0, 1: 0.0}


In [4]:
cm, cm_labels = confusion_matrix_np(y_true, y_pred)

fig = px.imshow(
    cm,
    text_auto=True,
    color_continuous_scale="Blues",
    x=[f"pred={l}" for l in cm_labels],
    y=[f"true={l}" for l in cm_labels],
)
fig.update_layout(title="Confusion matrix: always predicting class 0")
fig.show()

fig = go.Figure(
    data=[
        go.Bar(
            x=[str(l) for l in labels],
            y=recalls,
            text=[f"{r:.2f}" for r in recalls],
            textposition="auto",
        )
    ]
)
fig.update_layout(
    title="Per-class recall (balanced accuracy is the mean of these)",
    xaxis_title="class",
    yaxis_title="recall",
    yaxis=dict(range=[0, 1]),
)
fig.show()


In [5]:
# A simple probability simulation (overlapping scores + class imbalance)

n_neg, n_pos = 2000, 100

y_true = np.array([0] * n_neg + [1] * n_pos)

# Negatives tend to have lower predicted probabilities, positives higher, but overlapping.
p_neg = rng.beta(2.0, 8.0, size=n_neg)
p_pos = rng.beta(5.0, 5.0, size=n_pos)

proba = np.concatenate([p_neg, p_pos])

# Shuffle together
perm = rng.permutation(len(y_true))
y_true = y_true[perm]
proba = proba[perm]

thresholds = np.linspace(0.0, 1.0, 401)
accs = np.empty_like(thresholds)
bals = np.empty_like(thresholds)

for i, t in enumerate(thresholds):
    y_pred = (proba >= t).astype(int)
    accs[i] = accuracy_score_np(y_true, y_pred)
    bals[i] = balanced_accuracy_score_np(y_true, y_pred)

best_t = float(thresholds[np.argmax(bals)])

fig = go.Figure()
fig.add_trace(go.Scatter(x=thresholds, y=accs, name="accuracy", mode="lines"))
fig.add_trace(go.Scatter(x=thresholds, y=bals, name="balanced accuracy", mode="lines"))
fig.add_vline(x=best_t, line_dash="dash", line_color="black")
fig.update_layout(
    title=f"Accuracy vs balanced accuracy as a function of threshold (best BA at t={best_t:.3f})",
    xaxis_title="threshold t",
    yaxis_title="score",
    yaxis=dict(range=[0, 1]),
)
fig.show()


In [6]:
# Synthetic 2D imbalanced dataset (mild overlap)

n0, n1 = 1200, 80

X0 = rng.normal(loc=(0.0, 0.0), scale=1.0, size=(n0, 2))
X1 = rng.normal(loc=(1.2, 1.2), scale=1.0, size=(n1, 2))

X = np.vstack([X0, X1])
y = np.concatenate([np.zeros(n0, dtype=int), np.ones(n1, dtype=int)])

perm = rng.permutation(len(y))
X, y = X[perm], y[perm]

X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.25, random_state=0, stratify=y
)

fig = px.scatter(
    x=X[:, 0],
    y=X[:, 1],
    color=y.astype(str),
    opacity=0.7,
    title="Synthetic imbalanced dataset",
    labels={"x": "x1", "y": "x2", "color": "class"},
)
fig.show()

print('train class counts:', {0: int((y_train==0).sum()), 1: int((y_train==1).sum())})
print('val class counts:  ', {0: int((y_val==0).sum()), 1: int((y_val==1).sum())})


train class counts: {0: 900, 1: 60}
val class counts:   {0: 300, 1: 20}


In [7]:
def standardize_fit(X):
    mean = X.mean(axis=0)
    std = X.std(axis=0) + 1e-12
    return mean, std


def standardize_transform(X, mean, std):
    return (X - mean) / std


def add_intercept(X):
    return np.c_[np.ones((X.shape[0], 1)), X]


def predict_proba_logreg(X, w):
    Xb = add_intercept(X)
    return expit(Xb @ w)


def log_loss_binary(y, p, sample_weight=None, eps: float = 1e-12) -> float:
    y = np.asarray(y)
    p = np.clip(np.asarray(p), eps, 1.0 - eps)

    per_sample = -(y * np.log(p) + (1.0 - y) * np.log(1.0 - p))

    if sample_weight is None:
        return float(per_sample.mean())

    w = np.asarray(sample_weight, dtype=float)
    return float(np.sum(w * per_sample) / np.sum(w))


def fit_logreg_gd(
    X_train,
    y_train,
    X_val,
    y_val,
    *,
    lr: float = 0.2,
    n_epochs: int = 400,
    l2: float = 1e-2,
    sample_weight=None,
):
    # Binary logistic regression with (optional) sample weights + early stopping on val BA
    Xb = add_intercept(X_train)
    n, d = Xb.shape

    if sample_weight is None:
        sample_weight = np.ones(n, dtype=float)
    else:
        sample_weight = np.asarray(sample_weight, dtype=float)

    sw_sum = float(sample_weight.sum())
    w = np.zeros(d, dtype=float)

    history = {
        "train_loss": [],
        "val_acc": [],
        "val_bal_acc": [],
    }

    best = {
        "epoch": -1,
        "val_bal_acc": -np.inf,
        "w": w.copy(),
    }

    for epoch in range(n_epochs):
        # forward + gradient on train
        p_train = expit(Xb @ w)
        grad = (Xb.T @ (sample_weight * (p_train - y_train))) / sw_sum
        grad[1:] += l2 * w[1:]

        w = w - lr * grad

        # metrics
        p_train = expit(Xb @ w)
        train_loss = log_loss_binary(y_train, p_train, sample_weight=sample_weight) + 0.5 * l2 * float(
            np.sum(w[1:] ** 2)
        )

        p_val = predict_proba_logreg(X_val, w)
        y_val_hat = (p_val >= 0.5).astype(int)

        val_acc = accuracy_score_np(y_val, y_val_hat)
        val_bal_acc = balanced_accuracy_score_np(y_val, y_val_hat)

        history["train_loss"].append(train_loss)
        history["val_acc"].append(val_acc)
        history["val_bal_acc"].append(val_bal_acc)

        if val_bal_acc > best["val_bal_acc"]:
            best = {"epoch": epoch, "val_bal_acc": val_bal_acc, "w": w.copy()}

    return best["w"], history, best


In [8]:
# Standardize features (important for GD stability)
mean, std = standardize_fit(X_train)
X_train_s = standardize_transform(X_train, mean, std)
X_val_s = standardize_transform(X_val, mean, std)

# Unweighted training
w_unw, hist_unw, best_unw = fit_logreg_gd(X_train_s, y_train, X_val_s, y_val)

# Balanced class weights: each class gets ~50% of total weight
n_train = len(y_train)
n_pos = int((y_train == 1).sum())
n_neg = int((y_train == 0).sum())

w_pos = n_train / (2.0 * n_pos)
w_neg = n_train / (2.0 * n_neg)
sw_bal = np.where(y_train == 1, w_pos, w_neg)

w_wt, hist_wt, best_wt = fit_logreg_gd(X_train_s, y_train, X_val_s, y_val, sample_weight=sw_bal)

print('best epoch (unweighted):', best_unw['epoch'], 'val BA:', f"{best_unw['val_bal_acc']:.4f}")
print('best epoch (weighted):  ', best_wt['epoch'], 'val BA:', f"{best_wt['val_bal_acc']:.4f}")


best epoch (unweighted): 179 val BA: 0.5250
best epoch (weighted):   53 val BA: 0.8367


In [9]:
epochs = np.arange(1, len(hist_unw["train_loss"]) + 1)

fig = make_subplots(
    rows=1,
    cols=3,
    subplot_titles=("Train log loss", "Validation accuracy", "Validation balanced accuracy"),
)

for name, hist in [("unweighted", hist_unw), ("class-weighted", hist_wt)]:
    fig.add_trace(
        go.Scatter(x=epochs, y=hist["train_loss"], name=f"{name} loss", mode="lines"),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Scatter(x=epochs, y=hist["val_acc"], name=f"{name} acc", mode="lines"),
        row=1,
        col=2,
    )
    fig.add_trace(
        go.Scatter(x=epochs, y=hist["val_bal_acc"], name=f"{name} BA", mode="lines"),
        row=1,
        col=3,
    )

fig.update_layout(height=350, width=1100, title="Training curves (early stopping uses validation BA)")
fig.update_yaxes(range=[0, 1], row=1, col=2)
fig.update_yaxes(range=[0, 1], row=1, col=3)
fig.show()


In [10]:
def best_threshold_for_balanced_accuracy(y_true, proba, thresholds):
    best = {"t": None, "ba": -np.inf}
    for t in thresholds:
        y_pred = (proba >= t).astype(int)
        ba = balanced_accuracy_score_np(y_true, y_pred)
        if ba > best["ba"]:
            best = {"t": float(t), "ba": float(ba)}
    return best


thresholds = np.linspace(0.0, 1.0, 401)

p_unw = predict_proba_logreg(X_val_s, w_unw)
p_wt = predict_proba_logreg(X_val_s, w_wt)

best_t_unw = best_threshold_for_balanced_accuracy(y_val, p_unw, thresholds)
best_t_wt = best_threshold_for_balanced_accuracy(y_val, p_wt, thresholds)

print('best threshold (unweighted):', best_t_unw)
print('best threshold (weighted):  ', best_t_wt)

# Visualize BA(t)
ba_unw = [balanced_accuracy_score_np(y_val, (p_unw >= t).astype(int)) for t in thresholds]
ba_wt = [balanced_accuracy_score_np(y_val, (p_wt >= t).astype(int)) for t in thresholds]

fig = go.Figure()
fig.add_trace(go.Scatter(x=thresholds, y=ba_unw, name="unweighted", mode="lines"))
fig.add_trace(go.Scatter(x=thresholds, y=ba_wt, name="class-weighted", mode="lines"))
fig.add_vline(x=best_t_unw["t"], line_dash="dash", line_color="#1f77b4")
fig.add_vline(x=best_t_wt["t"], line_dash="dash", line_color="#ff7f0e")
fig.update_layout(
    title="Validation balanced accuracy as a function of the decision threshold",
    xaxis_title="threshold t",
    yaxis_title="balanced accuracy",
    yaxis=dict(range=[0, 1]),
)
fig.show()


best threshold (unweighted): {'t': 0.1525, 'ba': 0.8433333333333334}
best threshold (weighted):   {'t': 0.62, 'ba': 0.8416666666666667}


In [11]:
def summarize_threshold(y_true, proba, t):
    y_pred = (proba >= t).astype(int)
    acc = accuracy_score_np(y_true, y_pred)
    ba = balanced_accuracy_score_np(y_true, y_pred)
    recalls, labels = per_class_recall_np(y_true, y_pred)
    cm, _ = confusion_matrix_np(y_true, y_pred, labels=np.array([0, 1]))
    return {
        "t": float(t),
        "acc": float(acc),
        "ba": float(ba),
        "recalls": dict(zip(labels.tolist(), recalls.tolist())),
        "cm": cm,
    }


summaries = {
    "unweighted @0.5": summarize_threshold(y_val, p_unw, 0.5),
    "unweighted @t*": summarize_threshold(y_val, p_unw, best_t_unw["t"]),
    "weighted @0.5": summarize_threshold(y_val, p_wt, 0.5),
    "weighted @t*": summarize_threshold(y_val, p_wt, best_t_wt["t"]),
}

for k, v in summaries.items():
    print(k, {"t": v["t"], "acc": v["acc"], "ba": v["ba"], "recalls": v["recalls"]})

# Confusion matrices (2x2): rows=methods, cols=threshold choice
fig = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=(
        "Unweighted @0.5",
        "Unweighted @t*",
        "Weighted @0.5",
        "Weighted @t*",
    ),
)

items = [
    (1, 1, summaries["unweighted @0.5"]),
    (1, 2, summaries["unweighted @t*"]),
    (2, 1, summaries["weighted @0.5"]),
    (2, 2, summaries["weighted @t*"]),
]

for r, c, s in items:
    cm = s["cm"]
    fig.add_trace(
        go.Heatmap(
            z=cm,
            x=["pred=0", "pred=1"],
            y=["true=0", "true=1"],
            colorscale="Blues",
            showscale=False,
            text=cm.astype(int),
            texttemplate="%{text}",
        ),
        row=r,
        col=c,
    )

fig.update_layout(height=650, width=900, title="Validation confusion matrices")
fig.show()


unweighted @0.5 {'t': 0.5, 'acc': 0.940625, 'ba': 0.525, 'recalls': {0: 1.0, 1: 0.05}}
unweighted @t* {'t': 0.1525, 'acc': 0.88125, 'ba': 0.8433333333333334, 'recalls': {0: 0.8866666666666667, 1: 0.8}}
weighted @0.5 {'t': 0.5, 'acc': 0.7375, 'ba': 0.8366666666666667, 'recalls': {0: 0.7233333333333334, 1: 0.95}}
weighted @t* {'t': 0.62, 'acc': 0.834375, 'ba': 0.8416666666666667, 'recalls': {0: 0.8333333333333334, 1: 0.85}}


In [12]:
# Decision boundary visualization (in original feature space)

def decision_boundary_figure(X_val, y_val, w, mean, std, threshold: float, title: str):
    x1_min, x1_max = X_val[:, 0].min() - 1.0, X_val[:, 0].max() + 1.0
    x2_min, x2_max = X_val[:, 1].min() - 1.0, X_val[:, 1].max() + 1.0

    xs = np.linspace(x1_min, x1_max, 200)
    ys = np.linspace(x2_min, x2_max, 200)
    xx, yy = np.meshgrid(xs, ys)
    grid = np.c_[xx.ravel(), yy.ravel()]
    grid_s = standardize_transform(grid, mean, std)

    p = predict_proba_logreg(grid_s, w).reshape(xx.shape)

    fig = go.Figure()

    fig.add_trace(
        go.Contour(
            x=xs,
            y=ys,
            z=p,
            contours=dict(start=threshold, end=threshold, size=1, coloring="lines"),
            line=dict(color="black", width=3),
            showscale=False,
            name="decision boundary",
        )
    )

    fig.add_trace(
        go.Scatter(
            x=X_val[:, 0],
            y=X_val[:, 1],
            mode="markers",
            marker=dict(
                size=6,
                color=y_val,
                colorscale=[[0, "#1f77b4"], [1, "#d62728"]],
                opacity=0.7,
                line=dict(width=0),
            ),
            name="validation points",
        )
    )

    fig.update_layout(
        title=title,
        xaxis_title="x1",
        yaxis_title="x2",
        height=450,
        width=500,
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
    )

    return fig


fig1 = decision_boundary_figure(
    X_val,
    y_val,
    w_unw,
    mean,
    std,
    threshold=best_t_unw["t"],
    title=f"Unweighted logistic regression (threshold t*={best_t_unw['t']:.2f})",
)
fig2 = decision_boundary_figure(
    X_val,
    y_val,
    w_wt,
    mean,
    std,
    threshold=best_t_wt["t"],
    title=f"Class-weighted logistic regression (threshold t*={best_t_wt['t']:.2f})",
)

fig = make_subplots(rows=1, cols=2, subplot_titles=(fig1.layout.title.text, fig2.layout.title.text))
for tr in fig1.data:
    fig.add_trace(tr, row=1, col=1)
for tr in fig2.data:
    fig.add_trace(tr, row=1, col=2)

fig.update_layout(height=450, width=1050, title="Decision boundary tuned for balanced accuracy")
fig.update_xaxes(title_text="x1", row=1, col=1)
fig.update_yaxes(title_text="x2", row=1, col=1)
fig.update_xaxes(title_text="x1", row=1, col=2)
fig.update_yaxes(title_text="x2", row=1, col=2)
fig.show()


In [13]:
# scikit-learn comparison on the same dataset

clf_unw = LogisticRegression(max_iter=2000)
clf_wt = LogisticRegression(max_iter=2000, class_weight="balanced")

clf_unw.fit(X_train, y_train)
clf_wt.fit(X_train, y_train)

pred_unw = clf_unw.predict(X_val)
pred_wt = clf_wt.predict(X_val)

print('sklearn unweighted BA:', sk_balanced_accuracy_score(y_val, pred_unw))
print('sklearn weighted BA:  ', sk_balanced_accuracy_score(y_val, pred_wt))


sklearn unweighted BA: 0.5483333333333333
sklearn weighted BA:   0.8183333333333334
