Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Expose n_iter_ to BaseLibSVM #21408

Merged
merged 48 commits into from
Dec 21, 2021
Merged

Conversation

jmloyola
Copy link
Member

@jmloyola jmloyola commented Oct 22, 2021

Reference Issues/PRs

Fixes #18928

What does this implement/fix? Explain your changes.

This PR exposes n_iter_ to BaseLibSVM:

  • Add private attribute _num_iter to BaseLibSVM. This will always be a numpy array.
  • Add attribute n_iter_ to BaseLibSVM. When LibSVM optimizes just one model, this is an integer, otherwise, it is a numpy array with the number of iterations for each model.
  • Add documentation of n_iter_ to the classes that inherit from BaseLibSVM (SVR, NuSVR, OneClassSVM, SVC, NuSVC).
  • Add int num_iter to the class Solver in sklearn/svm/src/libsvm/svm.cpp. The Solve method from this class is responsible for carrying out the optimization. In this method, we save the final number of iterations run in optimization.
  • Add int num_iter to the struct decision_function in sklearn/svm/src/libsvm/svm.cpp. Since the svm models do not have direct access to the Solver class, we added the attribute num_iter to decision_function. This struct is returned by the function svm_train_one which the train functions call. The function svm_train_one calls the Solvers, thus having access to the number of iterations.
  • Add int *num_iter to svm_model and svm_csr_model structs in sklearn/svm/src/libsvm/svm.h. Since LibSVM can train several models, depending on the model type and the number of classes, here the attribute is an array of integers.
  • Fix bug in test test_libsvm_iris of sklearn/svm/tests/test_svm.py. Since the fit function returns a different set of attributes than the ones used by the predict function in sklearn/svm/_libsvm.pyx, we explicitly selected the correct ones. Note that this didn't fail previously because the fit_status variable took the place of the svm_type (changing the type of the model).
  • Check the attribute n_iter_ for the classes that inherit from BaseLibSVM in sklearn/utils/estimator_checks.py. For that I changed the tests check_non_transformer_estimators_n_iter and check_transformer_n_iter.
  • Add changelog
  • Test n_iter_ for all the estimators that inherit from BaseLibSVM.

Any other comments?

The pre-commit script removed the trailing white spaces in sklearn/svm/src/libsvm/svm.cpp.

@jmloyola jmloyola changed the title [WIP] Expose n_iter_ to BaseLibSVM [MRG] Expose n_iter_ to BaseLibSVM Oct 22, 2021
@jmloyola
Copy link
Member Author

This PR is ready for review. Any comments or suggestions are more than welcome. 🤓

Though, there is one failing test (test_check_estimator_pairwise > check_estimator > check_non_transformer_estimators_n_iter), but I don't think is related to this PR changes.
I did remove SVC from not_run_check_n_iter in check_non_transformer_estimators_n_iter and that made the test check the SVC(kernel='precomputed') with a non-square matrix X.

This is the traceback of the failing test:

============================================================== FAILURES ==============================================================
___________________________________________________ test_check_estimator_pairwise ____________________________________________________

    def test_check_estimator_pairwise():
        # check that check_estimator() works on estimator with _pairwise
        # kernel or metric
    
        # test precomputed kernel
        est = SVC(kernel="precomputed")
>       check_estimator(est)

sklearn/utils/tests/test_estimator_checks.py:680: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
sklearn/utils/estimator_checks.py:581: in check_estimator
    check(estimator)
sklearn/utils/_testing.py:313: in wrapper
    return fn(*args, **kwargs)
sklearn/utils/estimator_checks.py:3245: in check_non_transformer_estimators_n_iter
    estimator.fit(X, y_)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = SVC(kernel='precomputed', random_state=0)
X = array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
  ...],
       [6.3, 2.5, 5. , 1.9],
       [6.5, 3. , 5.2, 2. ],
       [6.2, 3.4, 5.4, 2.3],
       [5.9, 3. , 5.1, 1.8]])
y = array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., ...2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])
sample_weight = array([], dtype=float64)

    def fit(self, X, y, sample_weight=None):
        """Fit the SVM model according to the given training data.
    
        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features) \
                or (n_samples, n_samples)
            Training vectors, where `n_samples` is the number of samples
            and `n_features` is the number of features.
            For kernel="precomputed", the expected shape of X is
            (n_samples, n_samples).
    
        y : array-like of shape (n_samples,)
            Target values (class labels in classification, real numbers in
            regression).
    
        sample_weight : array-like of shape (n_samples,), default=None
            Per-sample weights. Rescale C per sample. Higher weights
            force the classifier to put more emphasis on these points.
    
        Returns
        -------
        self : object
            Fitted estimator.
    
        Notes
        -----
        If X and y are not C-ordered and contiguous arrays of np.float64 and
        X is not a scipy.sparse.csr_matrix, X and/or y may be copied.
    
        If X is a dense array, then the other methods will not support sparse
        matrices as input.
        """
    
        rnd = check_random_state(self.random_state)
    
        sparse = sp.isspmatrix(X)
        if sparse and self.kernel == "precomputed":
            raise TypeError("Sparse precomputed kernels are not supported.")
        self._sparse = sparse and not callable(self.kernel)
    
        if hasattr(self, "decision_function_shape"):
            if self.decision_function_shape not in ("ovr", "ovo"):
                raise ValueError(
                    "decision_function_shape must be either 'ovr' or 'ovo', "
                    f"got {self.decision_function_shape}."
                )
    
        if callable(self.kernel):
            check_consistent_length(X, y)
        else:
            X, y = self._validate_data(
                X,
                y,
                dtype=np.float64,
                order="C",
                accept_sparse="csr",
                accept_large_sparse=False,
            )
    
        y = self._validate_targets(y)
    
        sample_weight = np.asarray(
            [] if sample_weight is None else sample_weight, dtype=np.float64
        )
        solver_type = LIBSVM_IMPL.index(self._impl)
    
        # input validation
        n_samples = _num_samples(X)
        if solver_type != 2 and n_samples != y.shape[0]:
            raise ValueError(
                "X and y have incompatible shapes.\n"
                + "X has %s samples, but y has %s." % (n_samples, y.shape[0])
            )
    
        if self.kernel == "precomputed" and n_samples != X.shape[1]:
>           raise ValueError(
                "Precomputed matrix must be a square matrix."
                " Input is a {}x{} matrix.".format(X.shape[0], X.shape[1])
            )
E           ValueError: Precomputed matrix must be a square matrix. Input is a 150x4 matrix.

sklearn/svm/_base.py:215: ValueError

@jmloyola
Copy link
Member Author

jmloyola commented Oct 22, 2021

I'm working on fixing that failing test.

[Edited]
I thought of two options to solve this issue:

  • Deprecate the _pairwise attribute of BaseLibSVM as indicated in the sources and finally remove the SVC test from test_check_estimator_pairwise. Please, correct me if this is not right.
  • Add especial case in the test check_non_transformer_estimators_n_iter to consider SCV(kernel="precomputed"). This is transform X using a linear kernel, for example.

Which of these options do you think works best?

@ogrisel
Copy link
Member

ogrisel commented Oct 23, 2021

Deprecate the _pairwise attribute of BaseLibSVM as indicated in the sources

Yes we can do that now (if needed).

and finally remove the SVC test from test_check_estimator_pairwise. Please, correct me if this is not right.

I think we need to keep this test. We still want SVC to work with precomputed pairwise kernels.

@ogrisel
Copy link
Member

ogrisel commented Oct 23, 2021

Add especial case in the test check_non_transformer_estimators_n_iter to consider SCV(kernel="precomputed"). This is transform X using a linear kernel, for example.

I need to have a closer look but I don't understand why it wouldn't be possible to compute a number of iterations when working with precomputed kernels.

@jmloyola
Copy link
Member Author

Deprecate the _pairwise attribute of BaseLibSVM as indicated in the sources

Yes we can do that now (if needed).

Upon further inspection and after talking to @ogrisel in the DataUmbrella Sprint, this won't be necessary right now. It does not affect the test. The deprecation is for the attribute, but the pairwise will remain as more_tags. As you said later, we still need to test the pairwise SVC with a pre-computed kernel.

and finally remove the SVC test from test_check_estimator_pairwise. Please, correct me if this is not right.

I think we need to keep this test. We still want SVC to work with precomputed pairwise kernels.

You are right.

I will work to fix this during the week.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started the review but the trailing spaces change make it hard to review this PR. Please undo them to be able to finalize the review.

sklearn/svm/_libsvm.pyx Outdated Show resolved Hide resolved
sklearn/svm/_libsvm.pyx Outdated Show resolved Hide resolved
sklearn/svm/_libsvm.pyx Outdated Show resolved Hide resolved
sklearn/svm/_libsvm_sparse.pyx Outdated Show resolved Hide resolved
sklearn/svm/src/libsvm/libsvm_helper.c Outdated Show resolved Hide resolved
sklearn/svm/src/libsvm/svm.cpp Outdated Show resolved Hide resolved
sklearn/svm/_libsvm.pyx Outdated Show resolved Hide resolved
Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The number of iterations spent in the optimization routine is indeed a useful information to report. 👍

Here are some comments and suggestions.

sklearn/svm/src/libsvm/LIBSVM_CHANGES Outdated Show resolved Hide resolved
sklearn/svm/_libsvm.pyx Outdated Show resolved Hide resolved
sklearn/svm/_libsvm.pyx Outdated Show resolved Hide resolved
sklearn/svm/_classes.py Outdated Show resolved Hide resolved
sklearn/svm/src/libsvm/svm.cpp Show resolved Hide resolved
sklearn/svm/tests/test_svm.py Outdated Show resolved Hide resolved
sklearn/svm/src/libsvm/svm.h Outdated Show resolved Hide resolved
sklearn/svm/src/libsvm/svm.h Outdated Show resolved Hide resolved
doc/whats_new/v1.1.rst Outdated Show resolved Hide resolved
jmloyola and others added 8 commits October 26, 2021 16:06
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
sklearn/svm/src/libsvm/libsvm_helper.c Outdated Show resolved Hide resolved
sklearn/svm/_libsvm.pyx Outdated Show resolved Hide resolved
jmloyola and others added 2 commits October 28, 2021 11:07
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
@jjerphan
Copy link
Member

This does not LGTM, but it does look very good to me.

@adrinjalali
Copy link
Member

For the performance hit, WDYT @jeremiedbb @ogrisel ?

@jmloyola
Copy link
Member Author

Were you able to look at this @ogrisel, @jeremiedbb? What do you think?

@jmloyola
Copy link
Member Author

jmloyola commented Dec 8, 2021

I re-ran the benchmarks again using the IPython magic to test the time and memory for SVC. I used the same code as before (dense matrix code and sparse matrix code).

When using a dense matrix, we have:

scikit-learn:main jmloyola:add_n_iter_libsvm
fit time 519 µs ± 1.92 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 519 µs ± 3.53 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
fit memory peak memory: 121.02 MiB, increment: 0.38 MiB peak memory: 122.83 MiB, increment: 0.76 MiB
predict time 230 µs ± 549 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 231 µs ± 734 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
predict memory peak memory: 121.06 MiB, increment: 0.00 MiB peak memory: 122.86 MiB, increment: 0.00 MiB

When using a sparse matrix, we have:

scikit-learn:main jmloyola:add_n_iter_libsvm
fit time 1.39 ms ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.37 ms ± 16.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
fit memory peak memory: 120.99 MiB, increment: 0.18 MiB peak memory: 121.02 MiB, increment: 0.21 MiB
predict time 162 µs ± 285 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 168 µs ± 796 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
predict memory peak memory: 121.02 MiB, increment: 0.00 MiB peak memory: 121.02 MiB, increment: 0.00 MiB

I think we don't need to consider the time of the predict method. This part of the code has not changed almost. We only added a model->n_iter = NULL; to the set_model function.

Overall, I think the differences in time and space are negligible. Also, every time I run these benchmarks, we get different times. Do you recommend any other way to benchmark the changes?

@adrinjalali
Copy link
Member

Since it's been a while, would you mind merging the latest main to make sure we haven't got any new issues here?

@ogrisel
Copy link
Member

ogrisel commented Dec 10, 2021

For the performance hit, WDYT @jeremiedbb @ogrisel ?

What performance hit? Reading the benchmark I see no significant change with main. Did I miss something?

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for being late to review this. The code looks good.

I think it would be nice to add a couple of SVM specific tests for n_iter_ in addition to the common test.

For instance checking the shape of n_iter_ when fitting SVC or NuSVC on the iris dataset: the shape should be (3,) (3 * 2 // 2) while fitting on a binary classification problem, (possibly derived from iris) the shape should be (1,).

I was also wondering if we couldn't check some properties about the values them selves, for instance the impact of changing C or gamma or the dataset on the values of n_iter_, for instance for a toy binary classification problem by having a look at: https://en.wikipedia.org/wiki/Sequential_minimal_optimization#Optimization_problem

For instance it seems intuitive that the following toy dataset could converge in a single step whatever the value of C:

>>> SVC(kernel="linear").fit([[-1], [1]], [-1, 1]).n_iter_
array([1], dtype=int32)

but I am not 100% sure this is guaranteed.

Alternatively looking at the iris dataset: https://scikit-learn.org/stable/auto_examples/decomposition/plot_pca_vs_lda.html, it seems fair that the 0 vs 1 (setosa vs versicolor) and (0 vs 2) (setosa vs virgina) subproblems are much easier than (1 vs 2) (versicolor vs virginica) and therefore the number of iterations of the last subproblem should be significiantly higher which is what we observe for various different kernels and different values of C as long as C is not to low (regularization is not to high):

>>> from sklearn.datasets import load_iris
>>> from sklearn.svm import SVC
>>> X, y = load_iris(return_X_y=True)
>>> SVC(kernel="linear", C=1).fit(X, y).n_iter_
array([13,  3, 31], dtype=int32)
>>> SVC(kernel="linear", C=100).fit(X, y).n_iter_
array([ 13,   3, 206], dtype=int32)
>>> SVC(kernel="linear", C=0.1).fit(X, y).n_iter_
array([17,  5, 26], dtype=int32)
>>> SVC(kernel="linear", C=0.001).fit(X, y).n_iter_
array([50, 50, 50], dtype=int32)

We could therefore write and test such as:

X, y = load_iris(return_X_y=True)
n_iter_iris = SVC(kernel="linear", C=100).fit(X, y).n_iter_

# Looking at the iris dataset:
# https://scikit-learn.org/stable/auto_examples/decomposition/plot_pca_vs_lda.html
# one expects that the 0 vs 1 (setosa vs versicolor) and (0 vs 2) (setosa vs virgina)
# subproblems are much easier than (1 vs 2) (versicolor vs virginica) and therefore
# the number of iterations of the last subproblem should be significiantly higher.
n_iter_0v1, n_iter_0v2, n_iter_1v2 = n_iter_iris
assert n_iter_0v1 < n_iter_1v2
assert n_iter_0v2 < n_iter_1v2

Other than the testing coverage that can be improved as suggested above, LGTM.

@ogrisel
Copy link
Member

ogrisel commented Dec 10, 2021

I merged main to make sure that everything is in order.

@ogrisel
Copy link
Member

ogrisel commented Dec 10, 2021

Another test idea: we could check the consistency of the n_iter_ values for the iris 3-way multiclass problem and the 3 n_iter_ values obtained by training 3 binary classifiers for the 0v1, 0v2 and 1v2 subproblems constructed manually by fancying indexing the iris dataset manually on y as follows:

mask_0v1 = np.isin(y, [0, 1])  # setosa vs versicolor
n_iter_0v1 = SVC(kernel="linear", C=100).fit(X[mask_0v1], y[mask_0v1]).n_iter_
assert n_iter_0v1.shape == (1,)

mask_0v2 = np.isin(y, [0, 2])  # setosa vs virginica
n_iter_0v2 = SVC(kernel="linear", C=100).fit(X[mask_0v2], y[mask_0v2]).n_iter_
assert n_iter_0v2.shape == (1,)

mask_1v2 = np.isin(y, [1, 2])  # versicolor vs virginica
n_iter_1v2 = SVC(kernel="linear", C=100).fit(X[mask_1v2], y[mask_1v2]).n_iter_
assert n_iter_1v2.shape == (1,)

# setosa is much easier to separate from the other 2 classes
assert np.all(n_iter_1v2 > n_iter_0v1)
assert  np.all(n_iter_1v2 > n_iter_0v2)

# check the consistency of the one-vs-one multiclass setting:
n_iter_iris = SVC(kernel="linear", C=100).fit(X, y).n_iter_
assert n_iter_iris.shape == (3,)  # (3 * (3 - 1)) // 2
assert_array_equal(
    n_iter_iris, np.concatenate([n_iter_0v1, n_iter_0v2, n_iter_1v2])
)

@ogrisel
Copy link
Member

ogrisel commented Dec 10, 2021

Also note, the fact that over-regularized models have a number of iterations that match the number of samples per class can probably be explained because in that regime, all samples are support vectors and the optimization problem is probably degenerate as all dual coefficients hit the regularization constraint making the problem trivial to optimize with one-iteration per-sample before reaching the KKT convergence criterion:

>>> C = 0.001
>>> SVC(kernel="linear", C=C).fit(X, y).n_iter_
array([50, 50, 50], dtype=int32)
>>> svc.n_support_
array([50, 50, 50], dtype=int32)
>>> np.allclose(np.abs(svc.dual_coef_), C)

@jmloyola
Copy link
Member Author

Thanks for the review @ogrisel.
I added the coverage tests and realized they didn't pass. I had to convert np.int64 to int. Now they do 🤓.

Now, I'll work on the rest of the tests.

@@ -1059,16 +1105,17 @@ def test_svc_bad_kernel():
svc.fit(X, Y)


def test_timeout():
def test_libsvm_convergence_warnings():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the name of the test so that it is consistent with the test for liblinear (test_linear_svm_convergence_warnings)

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you very much for the follow-up!

@ogrisel ogrisel merged commit efdc92d into scikit-learn:main Dec 21, 2021
@ogrisel
Copy link
Member

ogrisel commented Dec 21, 2021

Note that #21408 (comment) could have been turned into a test but this is good enough the way it is (unless you are interested in a new PR with this new test).

@jmloyola
Copy link
Member Author

Thanks so much, @ogrisel, @jjerphan, and @adrinjalali for the review. 🤓

I will keep working on a new PR to add more tests. I hadn't had too much time lately.

@jmloyola jmloyola deleted the add_n_iter_libsvm branch December 21, 2021 19:26
venkyyuvy pushed a commit to venkyyuvy/scikit-learn that referenced this pull request Jan 1, 2022

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
mathijs02 pushed a commit to mathijs02/scikit-learn that referenced this pull request Dec 27, 2022

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Get the number of iterations in SVR
5 participants