Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH multiclass/multinomial newton cholesky for LogisticRegression #28840

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

lorentzenchr
Copy link
Member

@lorentzenchr lorentzenchr commented Apr 15, 2024

Reference Issues/PRs

In a way a follow-up of #24767.

What does this implement/fix? Explain your changes.

This extends the "newton-cholesky" solver of LogisticRegression and LogisticRegressionCV to full multinomial loss. In particular, the full hessian is calculated. This way, this solver does not need to resort to OvR for multiclass targets.

Any other comments?

There are 2 tricky parts:

  1. Some index battle as one usually divides the index of coefficients hierarchically into n_features and n_classes. But in the end, the hessian is a 2-dim matrix - and it is!
  2. The multinomial is over-parameterized for any unpenalized coefficient, so at least for the intercept. We therefore choose the last class intercept as reference and set its intercept value to zero.

Copy link

github-actions bot commented Apr 15, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 5f79028. Link to the linter CI: here

@lorentzenchr lorentzenchr changed the title ENH multiclass newton cholesky for LogisticRegression ENH multiclass/multinomial newton cholesky for LogisticRegression Apr 15, 2024
@lorentzenchr
Copy link
Member Author

Benchmark

As of 1de85b7

X_train.shape = (10000, 75)
sparse.issparse(X_train)=False
n_classes=12
image

import warnings
from pathlib import Path
import numpy as np
from scipy import sparse
from sklearn._loss import HalfMultinomialLoss
from sklearn.compose import ColumnTransformer
from sklearn.datasets import fetch_openml
from sklearn.linear_model import PoissonRegressor, LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer, OneHotEncoder
from sklearn.preprocessing import StandardScaler, KBinsDiscretizer
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model._linear_loss import LinearModelLoss
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split
from time import perf_counter
import pandas as pd



def prepare_data():
    df = fetch_openml(data_id=41214, as_frame=True, parser='auto').frame
    df["Frequency"] = df["ClaimNb"] / df["Exposure"]
    log_scale_transformer = make_pipeline(
        FunctionTransformer(np.log, validate=False), StandardScaler()
    )
    linear_model_preprocessor = ColumnTransformer(
        [
            ("passthrough_numeric", "passthrough", ["BonusMalus"]),
            (
                "binned_numeric",
                KBinsDiscretizer(n_bins=10, subsample=None),
                ["VehAge", "DrivAge"],
            ),
            ("log_scaled_numeric", log_scale_transformer, ["Density"]),
            (
                "onehot_categorical",
                OneHotEncoder(),
                ["VehBrand", "VehPower", "VehGas", "Region", "Area"],
            ),
        ],
        remainder="drop",
    )
    y = df["Frequency"]
    w = df["Exposure"]
    X = linear_model_preprocessor.fit_transform(df)
    return X, y, w


X, y_orig, w = prepare_data()

print("binning the target...")
binner = KBinsDiscretizer(
    n_bins=300, encode="ordinal", strategy="quantile", subsample=int(2e5), random_state=0
)
y = binner.fit_transform(y_orig.to_numpy().reshape(-1, 1)).ravel().astype(float)

X = X.toarray()
X_train, X_test, y_train, y_test, w_train, w_test = train_test_split(
    X, y, w, train_size=10_000, test_size=10_000, random_state=0
)
print(f"{X_train.shape = }")
print(f"{sparse.issparse(X_train)=}")
n_classes = len(np.unique(y_train))
print(f"{n_classes=}")
print("y_train.value_counts() :")
print(pd.Series(y_train).value_counts())


results = []
slow_solvers = set()
loss_sw = np.full_like(y_train, fill_value=(1. / y_train.shape[0]))
alpha = 1e-6  # A bit larger than in the LSMR benchmarks to avoid ConvergenceWarnings
for tol in np.logspace(-1, -10, 10):
    for solver in ["lbfgs", "newton-cg", "newton-cholesky"]:
        if solver in slow_solvers:
            # skip slow solvers to keep the benchmark runtime reasonable
            continue
        tic = perf_counter()
        # with warnings.catch_warnings():
        #     warnings.filterwarnings("ignore", category=ConvergenceWarning)
        clf = LogisticRegression(
            C=1/alpha,
            solver=solver,
            tol=tol,
            max_iter=10_000 if solver=="lbfgs" else 1000,
        ).fit(X_train, y_train)
        toc = perf_counter()
        train_time = toc - tic
        n_iter = clf.n_iter_[0]
        if train_time > 200 or n_iter >= clf.max_iter:
            # skip this solver from now on...
            slow_solvers.add(solver)
        # Look inside _GeneralizedLinearRegressor to check the parameters.
        # Or run once with verbose=1 and compare to the reported loss.
        train_loss = LinearModelLoss(
            base_loss=HalfMultinomialLoss(n_classes=n_classes), fit_intercept=clf.fit_intercept
        ).loss(
            coef=np.c_[clf.coef_, clf.intercept_],
            X=X_train,
            y=y_train,
            l2_reg_strength=alpha / X_train.shape[0],
            sample_weight=loss_sw,
        )
        result = {
            "solver": solver,
            "tol": tol,
            "train_loss": train_loss,
            "train_time": train_time,
            "train_score": clf.score(X_train, y_train),
            "test_score": clf.score(X_test, y_test),
            "n_iter": n_iter,
            "converged": n_iter < clf.max_iter,
        }
        print(result)
        results.append(result)


results = pd.DataFrame.from_records(results)
filepath = Path().resolve() / "bench_multinomial_logistic_regression_mtpl_dense_newton_cholesky.csv"
results.to_csv(filepath)


import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt


filepath = Path().resolve() / "bench_multinomial_logistic_regression_mtpl_dense_newton_cholesky.csv"

results = pd.read_csv(filepath)
results["suboptimality"] = results["train_loss"] - results["train_loss"].min() + 1e-16

fig, axes = plt.subplots(ncols=2, figsize=(8*2, 6))
for label, group in results.groupby("solver"):
    group.sort_values("tol").plot(
        x="n_iter", y="suboptimality", loglog=True, marker="o", label=label, ax=axes[0]
    )
axes[0].set_ylabel("suboptimality")
axes[0].set_title("Suboptimality by iterations")

for label, group in results.groupby("solver"):
    group.sort_values("tol").plot(
        x="train_time", y="suboptimality", loglog=True, marker="o", label=label, ax=axes[1]
    )
axes[1].set_ylabel("suboptimality")
axes[1].set_title("Suboptimality by time")
plt.show()

@lorentzenchr lorentzenchr added this to the 1.5 milestone Apr 18, 2024
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lorentzenchr, this is a very interesting PR. Here is a first pass of feedback.

nitpick: I think Hessian should always be capitalized in the docstrings and comments.

sklearn/linear_model/tests/test_logistic.py Outdated Show resolved Hide resolved
sklearn/linear_model/tests/test_logistic.py Show resolved Hide resolved
@@ -426,77 +450,83 @@ def gradient_hessian(
gradient : ndarray of shape coef.shape
The gradient of the loss.

hessian : ndarray
hessian : ndarray of shape (n_dof, n_dof) or \
(n_classes, n_dof, n_dof, n_classes)
Hessian matrix.

hessian_warning : bool
True if pointwise hessian has more than half of its elements non-positive.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
True if pointwise hessian has more than half of its elements non-positive.
True if pointwise Hessian has more than 25% of its elements non-positive.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Within linear_loss.py, hessian is mostly written with lowercase h.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use the upper case notation when we speak in English about the mathematical concept while keeping the lower case notation when mentioning specific Python variables in our code.

sklearn/linear_model/_linear_loss.py Outdated Show resolved Hide resolved
sklearn/linear_model/_linear_loss.py Show resolved Hide resolved
@lorentzenchr
Copy link
Member Author

nitpick: I think Hessian should always be capitalized in the docstrings and comments.

That's right, but not the standard in our code base. If you wish to correct that, I propose a separate PR.

@ogrisel
Copy link
Member

ogrisel commented Apr 25, 2024

That's right, but not the standard in our code base. If you wish to correct that, I propose a separate PR.

I think we can just make sure that we don't propagate this error in new docstrings / comments and use a follow-up PR to fix existing docstrings/comments that are not logically related to the scope of this PR.

# While a dedicated Cython routine could exploit the symmetry, it is very hard to
# beat BLAS GEMM, even thought the latter cannot exploit the symmetry, unless one
# pays the price of a taking square roots and implements
# sqrtWX = sqrt(W)[: None] * X
Copy link
Member

@ogrisel ogrisel Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that exploiting symmetry is not the only reason why a dedicated sandwich product kernel would make sense.

This line above would trigger and read/write round trip between RAM and CPU of the size of X (when X is too large to fit in CPU cache which is typically the case of interest). When n_samples >> n_features, a dedicated fused sandwich product kernel would only have to:

  • read n_samples * (n_feature + 1) / 2 from RAM;
  • write n_features ** 2 to RAM.

while what you propose would:

  • read n_samples * (n_feature + 1) from RAM, # sqrtWX = sqrt(W)[: None] * X
  • write n_samples * n_feature to RAM, # sqrtWX = sqrt(W)[: None] * X
  • read n_samples * n_feature / 2 from RAM, # np.dot(sqrtWX.T, sqrtWX)
  • write n_features ** 2 to RAM. # np.dot(sqrtWX.T, sqrtWX)

Assuming that this kernel is memory bound, I would expect a ~3x speed-up from the fused kernel over the 2-step numpy code.

The problem is that writing an efficient blocked sandwich product kernel in Cython with OpenMP threading and hardware adapted SIMD vector instructions is far from trivial.

For CPU, https://github.com/Quantco/tabmat already presumably does that.

For GPU, something like https://github.com/openai/triton/ might be able to do it in a vendor agnostic way.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can remove some of that comment. I wanted to stress 2 facts:

  • this is the cpu bottleneck
  • Replacing it by some self written BLAS like function is a ludicrous undertaking (even tabmat is only faster when there are categoricals!). GEMM might be the algo where most human time was spent writing and optimizing it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s not get off-topic too much.

@jjerphan jjerphan self-requested a review April 29, 2024 15:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants