Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PERF Support converting 32-bit matrices directly to liblinear format … #14296

Merged
merged 2 commits into from Jul 31, 2019

Conversation

@alexhenrie
Copy link
Contributor

alexhenrie commented Jul 8, 2019

This decreases the memory required for regression by about 33% on 32-bit dense matrices and by about 42% on 32-bit CSR matrices while not noticeably affecting the running time in any 32-bit or 64-bit case. Direct support for 32-bit matrices is the ideal solution for my group because we only need 32-bit precision and cutting the memory requirement by a third will get us inside our servers' memory limits.

Copy link
Member

jnothman left a comment

Nice! Could you please double check that this code path has test coverage?

T->value = *x;
T->index = j;
++ T;
if (double_precision) {

This comment has been minimized.

Copy link
@jnothman

jnothman Jul 10, 2019

Member

Is the optimiser likely to compile this out of the loop?

This comment has been minimized.

Copy link
@alexhenrie

alexhenrie Jul 10, 2019

Author Contributor

Yes, but even if it doesn't, the CPU's branch predictor will reduce the cost of the if statement to zero.

@jnothman

This comment has been minimized.

Copy link
Member

jnothman commented Jul 10, 2019

Otherwise lgtm

Copy link
Member

rth left a comment

Thanks @alexhenrie I have not reviewed it in detail but the results look nice!

Could you please double check that this code path has test coverage?

Yes, for such refactorings I think we need to be sure that this doesn't change the obtained coefficients and predictions. There are existing tests that check coef_ equality between solvers but we need to ensure,

  • that the used absolute tolerance is sufficient to detect possible issues
  • that those are indeed running on sparse and also 32 bit. Maybe adding the test_dtype_match could be useful. or adding/changing the pytest parametrization of existing tests.
@jnothman

This comment has been minimized.

Copy link
Member

jnothman commented Jul 10, 2019

@alexhenrie

This comment has been minimized.

Copy link
Contributor Author

alexhenrie commented Jul 10, 2019

struct feature_node in linear.h represents the value in double precision, so I would expect a slight loss in precision if the inputs are limited to 32-bit originally.

What file do you want the new tests in?

@alexhenrie alexhenrie force-pushed the alexhenrie:float branch from e69e034 to a789378 Jul 10, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 10, 2019
@alexhenrie alexhenrie force-pushed the alexhenrie:float branch from a789378 to c4114eb Jul 10, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 10, 2019
It is statistically valid for the logistic regression output to have a
higher precision than any single one of its inputs, so I have removed
the tests that limit the output precision to the input precision.
@alexhenrie alexhenrie force-pushed the alexhenrie:float branch from c4114eb to 9ad1a7e Jul 10, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 10, 2019
@alexhenrie alexhenrie force-pushed the alexhenrie:float branch from 9ad1a7e to c38b43b Jul 10, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 10, 2019
@alexhenrie alexhenrie force-pushed the alexhenrie:float branch from c38b43b to 162e190 Jul 10, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 10, 2019
@alexhenrie alexhenrie force-pushed the alexhenrie:float branch from 162e190 to f8087f2 Jul 10, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 10, 2019
@alexhenrie alexhenrie force-pushed the alexhenrie:float branch from f8087f2 to 0e1d075 Jul 10, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 10, 2019
@alexhenrie

This comment has been minimized.

Copy link
Contributor Author

alexhenrie commented Jul 10, 2019

Ah, I see now. test_dtype_match is an existing function that should check that liblinear's output is the same with 32-bit dense input, 32-bit sparse input, 64-bit dense input, or 64-bit sparse input. I just pushed a commit that does exactly that.

test_dtype_match was also checking that the output precision was the same as the input precision. However, it is statistically correct for the logistic regression output to have a higher precision than any single one of its inputs: For example, if the input variables are all binary (0 or 1), it doesn't matter whether the inputs are floats or doubles, the output can still be perfectly calculated to double precision. Therefore I have removed the tests that limit the output precision to the input precision.

The new tests exposed a bug in the 'saga' solver where using sparse input instead of dense input changes the output significantly. I have not figured out how to fix this bug.

Copy link
Member

jnothman left a comment

Sorry, my mistake, liblinear processes in float64 (double), but we are still doing the same cast as before, so I don't think we've changed the fit here, have we?

@jnothman

This comment has been minimized.

Copy link
Member

jnothman commented Jul 11, 2019

The new tests exposed a bug in the 'saga' solver where using sparse input instead of dense input changes the output significantly. I have not figured out how to fix this bug.

Sometimes the fit will be different, but not erroneous. I assume you mean that the sparse handling is wrong?

@alexhenrie

This comment has been minimized.

Copy link
Contributor Author

alexhenrie commented Jul 11, 2019

@jnothman Oh, I see what you mean now. If you have 32-bit inputs, the output before and after this patch is necessarily the same because the only difference is when the conversion to 64-bit happens.

I think there's a bug because according to the new tests, the output of the 'saga' solver varies much more than I would expect. I worked around the problem by increasing the tolerance when using 'saga'. Can you think of a better solution?

@alexhenrie

This comment has been minimized.

Copy link
Contributor Author

alexhenrie commented Jul 11, 2019

Strangely, when I set fit_intercept=False, the 'saga' solver gives consistent output. But I don't know if there is a bug in the intercept calculation or if calculating an intercept just magnifies a problem that comes from somewhere else.

@alexhenrie alexhenrie force-pushed the alexhenrie:float branch from 0e1d075 to dd63bab Jul 11, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 11, 2019
@alexhenrie alexhenrie force-pushed the alexhenrie:float branch from dd63bab to 8005681 Jul 11, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 11, 2019
@alexhenrie alexhenrie force-pushed the alexhenrie:float branch 2 times, most recently from 0e1d075 to aaa44e5 Jul 11, 2019
@alexhenrie

This comment has been minimized.

Copy link
Contributor Author

alexhenrie commented Jul 19, 2019

@jnothman Could you merge these commits now please? :-)

lr_32 = clone(lr_templ)
lr_32.fit(X_32, y_32)
assert lr_32.coef_.dtype == X_32.dtype
assert lr_32.coef_.dtype in [np.float32, np.float64]

This comment has been minimized.

Copy link
@rth

rth Jul 19, 2019

Member

Why this change? The whole point of adding support of 32bit in other solvers was to ensure that coef_ is of the same dtype as X. We can't remove that constraint for other solvers at least.

This comment has been minimized.

Copy link
@alexhenrie

alexhenrie Jul 19, 2019

Author Contributor

As I explained in #14296 (comment), there is no reason to prevent any solver from using 32-bit input to produce 64-bit output.

lr_32_sparse = clone(lr_templ)
lr_32_sparse.fit(X_sparse_32, y_32)
assert lr_32_sparse.coef_.dtype == X_sparse_32.dtype
assert lr_32_sparse.coef_.dtype in [np.float32, np.float64]

This comment has been minimized.

Copy link
@rth

rth Jul 19, 2019

Member

Same as above

@rth

This comment has been minimized.

Copy link
Member

rth commented Jul 24, 2019

As I explained in #14296 (comment), there is no reason to prevent any solver from using 32-bit input to produce 64-bit output.

That's what we are trying to prevent explicitly as part of #8769, and this would remove the associated tests. The goal is to reduce memory usage and allow faster BLAS operations for those that support AVX* etc when float32 input is provided. So I think tests affecting other solvers should not be modified in this aspect, if those tests don't work for liblinear, you can also create a separate test.

@alexhenrie alexhenrie force-pushed the alexhenrie:float branch from 8d6911c to 2dbfbf4 Jul 24, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 24, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 24, 2019
@alexhenrie

This comment has been minimized.

Copy link
Contributor Author

alexhenrie commented Jul 24, 2019

Fine, I just changed the tests back for all solvers except liblinear. Whatever it takes to get this pull request accepted.

Copy link
Member

jeremiedbb left a comment

liblinear uses double precision. You want to avoid a copy by directly converting float to double during the conversion to liblinear format. Did I understand correctly ?


if solver == 'liblinear' and multi_class == 'multinomial':
return

This comment has been minimized.

Copy link
@jeremiedbb

jeremiedbb Jul 24, 2019

Member

Please use pytest.skip(some informative message) instead of return

This comment has been minimized.

Copy link
@alexhenrie

alexhenrie Jul 24, 2019

Author Contributor

Done.

assert_allclose(lr_32.coef_, lr_64.coef_.astype(np.float32), atol=atol)

if solver == 'saga' and fit_intercept:
# FIXME

This comment has been minimized.

Copy link
@jeremiedbb

jeremiedbb Jul 24, 2019

Member

maybe explain the reason of the FIXME by adding a comment saying that saga does not correctly fit the intercept on sparse data with the default tol and max_iter parameters.

This comment has been minimized.

Copy link
@alexhenrie

alexhenrie Jul 24, 2019

Author Contributor

Done.

*/
static struct feature_node **csr_to_sparse(double *values, int *indices,
int *indptr, int n_samples, int n_features, int n_nonzero, double bias)
static struct feature_node **csr_to_sparse(void *x, int double_precision,

This comment has been minimized.

Copy link
@jeremiedbb

jeremiedbb Jul 24, 2019

Member

maybe use a char* for consistency. Also, since (np.array).data gives a char* we can be specific

This comment has been minimized.

Copy link
@alexhenrie

alexhenrie Jul 24, 2019

Author Contributor

This is exactly the kind of thing void* is for: In no case do we want to read the array data as chars because that would result in gibberish, and it would be a mistake to do so. void* prevents data access until it has been determined whether the array is an array of floats or an array of doubles.

Nonetheless, the machine code generated is the same whether x is defined as void* or char*. So if you want to give up the protection against accidentally accessing the array data as chars, it won't hurt anything to use char*. Let me know and I'll make the change.

@jeremiedbb

This comment has been minimized.

Copy link
Member

jeremiedbb commented Jul 24, 2019

Why did you move Y to last position ? it makes the change log harder to read :)

@alexhenrie

This comment has been minimized.

Copy link
Contributor Author

alexhenrie commented Jul 24, 2019

liblinear uses double precision. You want to avoid a copy by directly converting float to double during the conversion to liblinear format. Did I understand correctly ?

Yes, that is correct. The space wasted by having to represent the input datapoints as doubles instead of floats puts us over our server's memory limit.

@alexhenrie

This comment has been minimized.

Copy link
Contributor Author

alexhenrie commented Jul 24, 2019

Why did you move Y to last position ? it makes the change log harder to read :)

I wanted all of the variables describing X to be next to each other. I'll change it back if you prefer it the other way, although it won't make the diff much more readable.

@alexhenrie alexhenrie force-pushed the alexhenrie:float branch from 2dbfbf4 to c34946b Jul 24, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 24, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 24, 2019
# and that the output is approximately the same no matter the input format.

if solver == 'liblinear' and multi_class == 'multinomial':
pytest.skip('liblinear does not support multinomial classes')

This comment has been minimized.

Copy link
@jnothman

jnothman Jul 24, 2019

Member
Suggested change
pytest.skip('liblinear does not support multinomial classes')
pytest.skip('liblinear does not support multinomial logistic')

This comment has been minimized.

Copy link
@alexhenrie

alexhenrie Jul 25, 2019

Author Contributor

Done.

@alexhenrie alexhenrie force-pushed the alexhenrie:float branch from c34946b to 4ac70e3 Jul 25, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 25, 2019
alexhenrie added a commit to alexhenrie/scikit-learn that referenced this pull request Jul 25, 2019
@alexhenrie

This comment has been minimized.

Copy link
Contributor Author

alexhenrie commented Jul 29, 2019

@jeremiedbb Have I answered your concerns to your satisfaction?

@jeremiedbb

This comment has been minimized.

Copy link
Member

jeremiedbb commented Jul 29, 2019

Sorry for the delay. You're right that void* is a better practice. But numpy decided to use char*. Maybe there's a good reason, I've no idea. So I still think we should stick to char*. Otherwise it makes the code harder to follow: char* converted to void* converted to double*... :)

@alexhenrie

This comment has been minimized.

Copy link
Contributor Author

alexhenrie commented Jul 29, 2019

@jeremiedbb Okay, I just changed the two void * to char *. Anything else?

@alexhenrie alexhenrie force-pushed the alexhenrie:float branch from 4ac70e3 to ea02d9c Jul 29, 2019
@jeremiedbb jeremiedbb merged commit 583e18f into scikit-learn:master Jul 31, 2019
14 of 15 checks passed
14 of 15 checks passed
ci/circleci: doc-min-dependencies Your tests failed on CircleCI
Details
LGTM analysis: JavaScript No code changes detected
Details
LGTM analysis: C/C++ No new or fixed alerts
Details
LGTM analysis: Python No new or fixed alerts
Details
ci/circleci: deploy Your tests passed on CircleCI!
Details
ci/circleci: doc Your tests passed on CircleCI!
Details
ci/circleci: lint Your tests passed on CircleCI!
Details
scikit-learn.scikit-learn Build #20190729.69 succeeded
Details
scikit-learn.scikit-learn (Linux py35_conda_openblas) Linux py35_conda_openblas succeeded
Details
scikit-learn.scikit-learn (Linux py35_ubuntu_atlas) Linux py35_ubuntu_atlas succeeded
Details
scikit-learn.scikit-learn (Linux pylatest_conda_mkl_pandas) Linux pylatest_conda_mkl_pandas succeeded
Details
scikit-learn.scikit-learn (Linux32 py35_ubuntu_atlas_32bit) Linux32 py35_ubuntu_atlas_32bit succeeded
Details
scikit-learn.scikit-learn (Windows py35_pip_openblas_32bit) Windows py35_pip_openblas_32bit succeeded
Details
scikit-learn.scikit-learn (Windows py37_conda_mkl) Windows py37_conda_mkl succeeded
Details
scikit-learn.scikit-learn (macOS pylatest_conda_mkl) macOS pylatest_conda_mkl succeeded
Details
@jeremiedbb

This comment has been minimized.

Copy link
Member

jeremiedbb commented Jul 31, 2019

Thanks @alexhenrie !

@alexhenrie

This comment has been minimized.

Copy link
Contributor Author

alexhenrie commented Jul 31, 2019

Thank you so much! This is an enormous help to me and my colleagues :)

@alexhenrie alexhenrie deleted the alexhenrie:float branch Jan 17, 2020
@alexhenrie alexhenrie restored the alexhenrie:float branch Jan 17, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants
You can’t perform that action at this time.