Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions rehline/_sklearn_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _fit_multiclass(self, X_aug, y, sample_weight=None):
class_pairs = []
for cls_i, cls_j in combinations(self.classes_, 2):
mask = np.isin(y, [cls_i, cls_j])
y_pm = np.where(y[mask] == cls_i, 1, -1).astype(np.float64)
y_pm = np.where(y[mask] == cls_j, 1, -1).astype(np.float64)
sw_sub = sample_weight[mask] if sample_weight is not None else None
tasks.append((X_aug[mask], y_pm, sw_sub))
class_pairs.append((cls_i, cls_j))
Expand Down Expand Up @@ -455,12 +455,12 @@ def predict(self, X):

# discrete vote: score > 0 favors cls_i, score <= 0 favors cls_j
pred = (scores[:, k] > 0).astype(int)
votes[:, i] += pred
votes[:, j] += 1 - pred
votes[:, j] += pred
votes[:, i] += 1 - pred

# continuous confidence: score > 0 means cls_i is more confident
sum_of_confidences[:, i] += scores[:, k]
sum_of_confidences[:, j] -= scores[:, k]
sum_of_confidences[:, j] += scores[:, k]
sum_of_confidences[:, i] -= scores[:, k]

# Monotonically transform to (-1/3, 1/3) to break ties without
# overriding any decision made by a difference of >= 1 vote
Expand Down
320 changes: 317 additions & 3 deletions tests/_test_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,309 @@ def test_decision_function_shapes():
print("\n✓ All decision_function shape tests passed!")


def test_ovo_coef_sign_convention():
"""
Test 5: OvO coefficient sign convention (regression test for the sign bug).

The previous bug assigned cls_i -> +1 and cls_j -> -1 in each OvO subproblem,
which is opposite to sklearn's LabelEncoder convention (cls_i -> -1, cls_j -> +1)
since combinations() always yields sorted pairs (cls_i < cls_j).
This caused every subproblem's coef_ to be fully negated (diff ≈ 2 * |β|).

This test directly checks the sign direction of each OvO subproblem's coef_
via dot product, rather than relying solely on accuracy, so the bug cannot
silently reappear.
"""
print("\n" + "="*60)
print("Test 5: OvO Coefficient Sign Convention")
print("="*60)

np.random.seed(0)
n_samples = 2000
n_features = 6
n_classes = 3
C = 1.0

X, y = make_classification(
n_samples=n_samples,
n_features=n_features,
n_informative=4,
n_redundant=1,
n_classes=n_classes,
class_sep=2.0,
random_state=0
)
scaler = StandardScaler()
X = scaler.fit_transform(X)

# sklearn OvO reference
base_clf = LinearSVC(C=C, loss='hinge', fit_intercept=True,
max_iter=1000000, tol=1e-5, random_state=0)
clf_skl = OneVsOneClassifier(base_clf)
clf_skl.fit(X, y)

# rehline OvO
clf_reh = plq_Ridge_Classifier(
loss={'name': 'svm'}, C=C, multi_class='ovo',
max_iter=1000000, tol=1e-5, verbose=0
)
clf_reh.fit(X, y)

n_estimators = n_classes * (n_classes - 1) // 2
print(f"\n{'Estimator':^12} {'dot(skl,reh)':^16} {'||skl||':^12} {'||reh||':^12} {'sign OK':^10}")
print("-" * 65)

all_positive_dot = True
for k, est in enumerate(clf_skl.estimators_):
coef_skl = est.coef_.flatten()
coef_reh = clf_reh.coef_[k]
dot = np.dot(coef_skl, coef_reh)
norm_skl = np.linalg.norm(coef_skl)
norm_reh = np.linalg.norm(coef_reh)
# If signs agree the dot product is positive; if reversed it is negative.
sign_ok = dot > 0
all_positive_dot = all_positive_dot and sign_ok
print(f"{k:^12d} {dot:^16.4f} {norm_skl:^12.4f} {norm_reh:^12.4f} {'✓' if sign_ok else '❌':^10}")

assert all_positive_dot, \
"OvO coef_ sign convention mismatch: at least one subproblem has reversed sign. " \
"This is the sign-convention bug (cls_i/cls_j label encoding mismatch)."

print("\n✓ OvO sign convention test passed!")


def test_ovo_predict_consistency():
"""
Test 6: OvO predict / decision_function consistency.

Verifies that predict() produces exactly the same result as manually
reconstructing predictions from decision_function() using the voting logic,
ensuring the sign convention in fit and predict are perfectly aligned.
"""
print("\n" + "="*60)
print("Test 6: OvO predict / decision_function Consistency")
print("="*60)

np.random.seed(7)
n_samples = 1500
n_features = 5
n_classes = 4
C = 1.0

X, y = make_classification(
n_samples=n_samples,
n_features=n_features,
n_informative=4,
n_redundant=0,
n_classes=n_classes,
class_sep=1.5,
random_state=7
)
scaler = StandardScaler()
X = scaler.fit_transform(X)

clf = plq_Ridge_Classifier(
loss={'name': 'svm'}, C=C, multi_class='ovo',
max_iter=1000000, tol=1e-5, verbose=0
)
clf.fit(X, y)

# Predictions from predict()
y_pred = clf.predict(X)

# Manually reconstruct predictions from decision_function (mirrors predict internals)
scores = clf.decision_function(X)
n_cls = len(clf.classes_)
votes = np.zeros((n_samples, n_cls))
confidences = np.zeros((n_samples, n_cls))
for k, (_, _, cls_i, cls_j) in enumerate(clf.estimators_):
i = np.where(clf.classes_ == cls_i)[0][0]
j = np.where(clf.classes_ == cls_j)[0][0]
pred = (scores[:, k] > 0).astype(int)
votes[:, j] += pred
votes[:, i] += 1 - pred
confidences[:, j] += scores[:, k]
confidences[:, i] -= scores[:, k]
transformed = confidences / (3 * (np.abs(confidences) + 1))
y_manual = clf.classes_[np.argmax(votes + transformed, axis=1)]

n_disagree = np.sum(y_pred != y_manual)
print(f"Disagreements between predict() and manual reconstruction: {n_disagree}")

assert n_disagree == 0, \
f"predict() and decision_function() are inconsistent: {n_disagree} samples disagree. " \
"This indicates a mismatch between the sign convention in fit and predict."

print("✓ OvO predict / decision_function consistency test passed!")


def test_ovo_fit_intercept_false():
"""
Test 7: OvO with fit_intercept=False — correct coef_ shape and accuracy.

Ensures that disabling the intercept still produces the correct coef_ shape,
sets intercept_ to all zeros, and matches sklearn's solution.
"""
print("\n" + "="*60)
print("Test 7: OvO with fit_intercept=False")
print("="*60)

np.random.seed(13)
n_samples = 2000
n_features = 6
n_classes = 3
C = 1.0

X, y = make_classification(
n_samples=n_samples,
n_features=n_features,
n_informative=4,
n_redundant=1,
n_classes=n_classes,
class_sep=2.0,
random_state=13
)
scaler = StandardScaler()
X = scaler.fit_transform(X)

# sklearn OvO, no intercept
base_clf = LinearSVC(C=C, loss='hinge', fit_intercept=False,
max_iter=1000000, tol=1e-5, random_state=13)
clf_skl = OneVsOneClassifier(base_clf)
clf_skl.fit(X, y)

# rehline OvO, no intercept
clf_reh = plq_Ridge_Classifier(
loss={'name': 'svm'}, C=C, multi_class='ovo',
fit_intercept=False, max_iter=1000000, tol=1e-5, verbose=0
)
clf_reh.fit(X, y)

n_estimators = n_classes * (n_classes - 1) // 2

# Shape checks
assert clf_reh.coef_.shape == (n_estimators, n_features), \
f"Expected coef_ shape ({n_estimators}, {n_features}), got {clf_reh.coef_.shape}"
assert np.all(clf_reh.intercept_ == 0.0), \
"intercept_ should be all zeros when fit_intercept=False"

# Accuracy checks
max_diff = 0
for k, est in enumerate(clf_skl.estimators_):
diff = np.max(np.abs(est.coef_.flatten() - clf_reh.coef_[k]))
max_diff = max(max_diff, diff)
print(f"Estimator {k}: max coef diff = {diff:.6e}")

print(f"Overall max coef diff: {max_diff:.6e}")
assert max_diff <= 1e-3, \
f"fit_intercept=False OvO coef_ diff {max_diff:.6e} > 1e-3"

print("✓ OvO fit_intercept=False test passed!")


def test_multiclass_invalid_multi_class():
"""
Test 8: Invalid multi_class parameter should raise ValueError.

Ensures that passing an unrecognised multi_class value causes fit() to raise
a clear ValueError rather than silently failing or producing wrong results.
"""
print("\n" + "="*60)
print("Test 8: Invalid multi_class Parameter")
print("="*60)

np.random.seed(42)
X = np.random.randn(200, 4)
y = np.random.randint(0, 3, 200)

clf = plq_Ridge_Classifier(
loss={'name': 'svm'}, C=1.0, multi_class='invalid_option'
)

raised = False
try:
clf.fit(X, y)
except ValueError as e:
raised = True
print(f"ValueError raised as expected: {e}")

assert raised, "Expected ValueError for invalid multi_class parameter, but none was raised."
print("✓ Invalid multi_class parameter test passed!")


def test_ovo_more_classes():
"""
Test 9: OvO correctness with 5 classes (10 subproblems).

Verifies that the number of subproblems, coef_ shape, and coefficient
accuracy are all correct when the number of classes grows, guarding against
errors in the combinatorial subproblem construction logic.
"""
print("\n" + "="*60)
print("Test 9: OvO with 5 Classes (10 subproblems)")
print("="*60)

np.random.seed(99)
n_samples = 3000
n_features = 8
n_classes = 5
C = 1.0
n_estimators = n_classes * (n_classes - 1) // 2 # 10

X, y = make_classification(
n_samples=n_samples,
n_features=n_features,
n_informative=6,
n_redundant=1,
n_classes=n_classes,
class_sep=1.5,
random_state=99
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=99, stratify=y
)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# sklearn
base_clf = LinearSVC(C=C, loss='hinge', fit_intercept=True,
max_iter=1000000, tol=1e-5, random_state=99)
clf_skl = OneVsOneClassifier(base_clf)
clf_skl.fit(X_train, y_train)
acc_skl = accuracy_score(y_test, clf_skl.predict(X_test))

# rehline
clf_reh = plq_Ridge_Classifier(
loss={'name': 'svm'}, C=C, multi_class='ovo',
max_iter=1000000, tol=1e-5, verbose=0
)
clf_reh.fit(X_train, y_train)
acc_reh = accuracy_score(y_test, clf_reh.predict(X_test))

# 形状检查
assert clf_reh.coef_.shape == (n_estimators, n_features), \
f"Expected coef_ shape ({n_estimators}, {n_features}), got {clf_reh.coef_.shape}"
assert clf_reh.intercept_.shape == (n_estimators,), \
f"Expected intercept_ shape ({n_estimators},), got {clf_reh.intercept_.shape}"
assert len(clf_reh.estimators_) == n_estimators, \
f"Expected {n_estimators} estimators, got {len(clf_reh.estimators_)}"

# 精度检查
max_diff = 0
for k, est in enumerate(clf_skl.estimators_):
diff = np.max(np.abs(est.coef_.flatten() - clf_reh.coef_[k]))
max_diff = max(max_diff, diff)
print(f"5-class OvO: {n_estimators} subproblems, max coef diff = {max_diff:.6e}")
print(f"Accuracy: sklearn={acc_skl:.4f}, rehline={acc_reh:.4f}")

assert max_diff <= 1e-3, \
f"5-class OvO coef_ diff {max_diff:.6e} > 1e-3"

print("✓ OvO 5-class test passed!")
return acc_skl, acc_reh, max_diff


if __name__ == "__main__":
print("\n" + "="*70)
print("MULTI-CLASS CLASSIFICATION TEST SUITE")
Expand All @@ -391,14 +694,25 @@ def test_decision_function_shapes():
acc_skl_ovr, acc_reh_ovr, diff_ovr = test_multiclass_ovr_vs_sklearn()
acc_skl_ovo, acc_reh_ovo, diff_ovo = test_multiclass_ovo_vs_sklearn()
test_decision_function_shapes()

test_ovo_coef_sign_convention()
test_ovo_predict_consistency()
test_ovo_fit_intercept_false()
test_multiclass_invalid_multi_class()
acc_skl_ovo5, acc_reh_ovo5, diff_ovo5 = test_ovo_more_classes()

print("\n" + "="*70)
print("TEST SUMMARY")
print("="*70)
print(f"{'Test':^30} {'sklearn acc':^12} {'rehline acc':^12} {'max coef diff':^15}")
print("-" * 70)
print(f"{'Binary Classification':^30} {acc_skl_bin:^12.4f} {acc_reh_bin:^12.4f} {diff_bin:^15.2e}")
print(f"{'OvR Multi-class':^30} {acc_skl_ovr:^12.4f} {acc_reh_ovr:^12.4f} {diff_ovr:^15.2e}")
print(f"{'OvO Multi-class':^30} {acc_skl_ovo:^12.4f} {acc_reh_ovo:^12.4f} {diff_ovo:^15.2e}")
print(f"{'OvO Multi-class (3cls)':^30} {acc_skl_ovo:^12.4f} {acc_reh_ovo:^12.4f} {diff_ovo:^15.2e}")
print(f"{'OvO Multi-class (5cls)':^30} {acc_skl_ovo5:^12.4f} {acc_reh_ovo5:^12.4f} {diff_ovo5:^15.2e}")
print(f"{'Decision Func Shapes':^30} {'—':^12} {'—':^12} {'—':^15}")
print(f"{'OvO Sign Convention':^30} {'—':^12} {'—':^12} {'—':^15}")
print(f"{'OvO Predict Consistency':^30} {'—':^12} {'—':^12} {'—':^15}")
print(f"{'OvO No Intercept':^30} {'—':^12} {'—':^12} {'—':^15}")
print(f"{'Invalid multi_class':^30} {'—':^12} {'—':^12} {'—':^15}")
print("="*70)
print("\n✓ All tests passed successfully!")
print("\n✓ All 9 tests passed successfully!")