New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+2] Make CD use fused types #6913

Merged
merged 50 commits into from Aug 25, 2016

Conversation

Projects
None yet
5 participants
@yenchenlin
Contributor

yenchenlin commented Jun 21, 2016

According to #5464, current implementation of ElasticNet and Lasso in scikit-learn constrain the input to be np.float64, which is a waste of space.

This PR try to make CD algorithms support fused types when fitting np.float32 dense data and therefore reduce redundant data copy.

  • Make inline helper functions support fused types
  • Make dense CD ElasticNet support fused types
  • Add warning when alpha close to zero and X is np.float32
  • Add tests

**UPDATE 7/7

Here is the memory profiling results when fitting np.float32 data:

  • master

master

  • this branch

float32

**UPDATE 7/12

Here is the memory profiling results when fitting sparse np.float32 data:

  • master

64

  • this branch

32

@agramfort

This comment has been minimized.

Member

agramfort commented Jun 21, 2016

do you expect to merge just this or is it wip?

@yenchenlin

This comment has been minimized.

Contributor

yenchenlin commented Jun 21, 2016

@agramfort Yeah I am thinking of merging just these as a start point.

@yenchenlin yenchenlin changed the title from Make helper functions in cd use fused types to [MRG] Make helper functions in cd use fused types Jun 21, 2016

@jnothman

This comment has been minimized.

Member

jnothman commented Jun 21, 2016

I'm not sure what value there is in merging it separately to something we can benchmark. For instance, you've fused fmax and fsign, but reused the libc fabs which explicitly operates on a double (as opposed to fabsl). Are you sure we benefit from fused implementations of max and sign?

So I think we want to review this as a whole.

@yenchenlin

This comment has been minimized.

Contributor

yenchenlin commented Jun 23, 2016

Thanks @jnothman @agramfort .
Ah yeah you are right, I am working! 💪

@yenchenlin yenchenlin changed the title from [MRG] Make helper functions in cd use fused types to [WIP] Make CD use fused types Jun 23, 2016

@yenchenlin

This comment has been minimized.

Contributor

yenchenlin commented Jul 1, 2016

***Updated 7/7

Currently it is still not working.
It is now working!

Here is my test script:

import numpy as np
from sklearn.linear_model.coordinate_descent import ElasticNet
from sys import argv

@profile
def fit_est():
    clf.fit(X, y)

np.random.seed(5)
X = np.random.rand(2000000, 40)
X = np.float32(X)
y = np.random.rand(2000000)
y = np.float32(y)
T = np.random.rand(5, 40)
T = np.float32(T)

clf = ElasticNet(alpha=1e-7, l1_ratio=1.0, precompute=False)
fit_est()
pred = clf.predict(T)
print pred
# np.dot(R.T, y)
gap += (alpha * l1_norm - const * ddot(
if floating is double:

This comment has been minimized.

@jnothman

jnothman Jul 4, 2016

Member

Are you absolutely certain we can't do this with fused types? What prohibits it?

This comment has been minimized.

@yenchenlin

yenchenlin Jul 4, 2016

Contributor

This algorithm uses lots of C pointer such as <DOUBLE*>, we can't do <floating*>.

This comment has been minimized.

@jnothman

jnothman Jul 4, 2016

Member

You're saying Cython disallows fused type pointers, or typecasts? Is that incapability documented?

Could we use typecasts if we were working with typed memoryviews?

@jnothman

This comment has been minimized.

Member

jnothman commented Jul 4, 2016

Could you use the line-based memory profiling to see where that sharp increase in memory consumption is coming in?

@jnothman

This comment has been minimized.

Member

jnothman commented Jul 4, 2016

Sorry, that was silly; only appropriate if the bad memory usage is in Python code.

fit_intercept and not np.allclose(X_offset, np.zeros(n_features)) or
normalize and not np.allclose(X_scale, np.ones(n_features))):
fit_intercept and not np.allclose(X_offset, np.zeros(n_features))
or normalize and not np.allclose(X_scale, np.ones(n_features))):

This comment has been minimized.

@jnothman

jnothman Jul 5, 2016

Member

I don't see why this is an improvement, or why it's in this PR.

int max_iter, double tol,
def enet_coordinate_descent(np.ndarray[floating, ndim=1] w,
floating alpha, floating beta,
np.ndarray[floating, ndim=2, mode='fortran'] X,

This comment has been minimized.

@jnothman

jnothman Jul 5, 2016

Member

I assume your benchmarking script is using fortran-contiguous data.

This comment has been minimized.

@MechCoder

MechCoder Jul 9, 2016

Member

The parent classes force the contiguity of X to be fortran-contiguous.

def enet_coordinate_descent(np.ndarray[floating, ndim=1] w,
floating alpha, floating beta,
np.ndarray[floating, ndim=2, mode='fortran'] X,
np.ndarray[floating, ndim=1, mode='c'] y,

This comment has been minimized.

@jnothman

jnothman Jul 5, 2016

Member

I'm not sure that we need to support y, alpha, beta, w or tol with lower precision than DOUBLE. Or is consistent type necessary to use BLAS?

cdef double l1_norm
cdef np.ndarray[floating, ndim=1] XtA
if floating is float:

This comment has been minimized.

@jnothman

jnothman Jul 5, 2016

Member

below you use if floating is double. Please be consistent.

# R = y - np.dot(X, w)
for i in range(n_samples):
R[i] = y[i] - ddot(n_features,
<DOUBLE*>(X.data + i * sizeof(DOUBLE)),

This comment has been minimized.

@jnothman

jnothman Jul 5, 2016

Member

I don't get why these lines were written like this in the first place. Surely defining cdef DOUBLE *X_data = X.data above, then doing &X[i] here is clearer. I presume that would work with cdef floating *X_data = X.data too.

1, <DOUBLE*>R.data, 1)
# tmp = (X[:,ii]*R).sum()
tmp = ddot(n_samples,

This comment has been minimized.

@jnothman

jnothman Jul 5, 2016

Member

If these BLAS functions are the reason you use the if/else case, you could locally do ddot if floating is double else sdot or perhaps make a flag variable for if floating is double to make this briefer. Indeed it might be possible to declare a function pointer with floating * parameters above, set it to dot = ddot if floating is double else sdot and see if the C compiler optimises out the fact that the pointer is never changed.

@@ -375,13 +375,23 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
# We expect X and y to be already float64 Fortran ordered when bypassing

This comment has been minimized.

@jnothman

jnothman Jul 5, 2016

Member

Comment needs updating.

@yenchenlin yenchenlin changed the title from [WIP] Make CD use fused types to [WIP] Make dense CD use fused types Jul 7, 2016

@yenchenlin

This comment has been minimized.

Contributor

yenchenlin commented Jul 7, 2016

Hello @jnothman & @MechCoder , thanks alot-alot-alot-alot for your patience and comments.

I've updated the PR description (including new memory profiling results), code, and addressed comments you gave before.

The remaining to-do tasks in my opinion are also listed in the main description of this PR.

@yenchenlin yenchenlin changed the title from [WIP] Make dense CD use fused types to [MRG] Make dense CD use fused types Jul 8, 2016

@yenchenlin

This comment has been minimized.

Contributor

yenchenlin commented Jul 8, 2016

I've added the tests and the user warning for potential non-convergence error when fitting np.float32 data with small alpha.

However, the CI looks weird, any idea?

@yenchenlin

This comment has been minimized.

Contributor

yenchenlin commented Jul 9, 2016

I also complete the sparse case, should I include that into this PR?

@@ -273,6 +273,11 @@ def _set_intercept(self, X_offset, y_offset, X_scale):
"""Set the intercept_
"""
if self.fit_intercept:
if isinstance(self.coef_, np.ndarray):

This comment has been minimized.

@agramfort

agramfort Jul 9, 2016

Member

self.coef_ will always be an ndarray otherwis the line below self.coef_ / X_scale would not work.

This comment has been minimized.

@jnothman

jnothman Jul 13, 2016

Member

I'm confused. Why is this method changing?

This comment has been minimized.

@yenchenlin

yenchenlin Jul 19, 2016

Contributor

It's because when fitting np.float32 data, X_offset, y_offset, and X_scale are of np.float64 if I don't explicitly set their type here, which would in turn makes enet fitting np.float32 data has np.float64 coef_ in the end.

This comment has been minimized.

@MechCoder

MechCoder Jul 26, 2016

Member

Could you clarify why X_offset etc are of dtype np.float64?

In any case, since self.coef_ is always a numpy array, you can get the dtype without doing the if check, right?

This comment has been minimized.

@yenchenlin

yenchenlin Aug 12, 2016

Contributor

hmm ...

Following error will show without this check:

  File "/Users/YenChen/Desktop/Python/scikit-learn/sklearn/linear_model/tests/test_least_angle.py", line 371, in test_multitarget
    estimator.fit(X, Y)
  File "/Users/YenChen/Desktop/Python/scikit-learn/sklearn/linear_model/least_angle.py", line 705, in fit
    self._set_intercept(X_offset, y_offset, X_scale)
  File "/Users/YenChen/Desktop/Python/scikit-learn/sklearn/linear_model/base.py", line 276, in _set_intercept
    dtype = self.coef_.dtype
AttributeError: 'list' object has no attribute 'dtype'

This comment has been minimized.

@MechCoder

MechCoder Aug 13, 2016

Member

This can be done away with, I believe. Look at _preprocess_data and see for places where it is possible that X_offset can be explicitly | implicitly cast to float64

This comment has been minimized.

@yenchenlin

yenchenlin Aug 15, 2016

Contributor

Since error happened at dtype = self.coef_.dtype, with the message:

AttributeError: 'list' object has no attribute 'dtype'

I think self.coef_ is not always a numpy array?

This comment has been minimized.

@MechCoder

MechCoder Aug 15, 2016

Member

Yes, but that is not the issue right now. For example, this works:

a = [1, 2, 3]
b = np.array([1, 2, 3])
a / b
array([ 1.,  1.,  1.])

The issue is that X_offset is being cast to np.float64 in _preprocess_data, which should not happen..

if Xy is not None:
# Xy should be a 1d contiguous array or a 2D C ordered array
Xy = check_array(Xy, dtype=np.float64, order='C', copy=False,
ensure_2d=False)

This comment has been minimized.

@agramfort

agramfort Jul 9, 2016

Member

this is a big code dupe. Just introduce a dtype variable

This comment has been minimized.

@jnothman

jnothman Jul 13, 2016

Member

please address this! and yet doesn't check_array allow for a list of dtypes in order of preference, so that dtype=[np.float64, np.float32] should do the trick?

This comment has been minimized.

@jnothman

jnothman Jul 13, 2016

Member

oh of course, the arrays all need to be consistent, so that wouldn't do the trick.

This comment has been minimized.

@jnothman

jnothman Jul 13, 2016

Member

still, using that for X then using dtype=X.dtype should work?

This comment has been minimized.

@yenchenlin

yenchenlin Jul 18, 2016

Contributor

Yes thanks!

While technically, it's dtype=X.dtype.type.
Again, maybe check_array should support dtype=X.dtype?

@agramfort

This comment has been minimized.

Member

agramfort commented Jul 9, 2016

PR is not huge so yes add the sparse case

@MechCoder

This comment has been minimized.

Member

MechCoder commented Jul 9, 2016

@yenchenlin The travis failure is not spurious. You have references to ATL_sasum and ATL_saxpy which are defined in "atlas_refalias1.h" as ATL_srefasum and ATL_srefaxpy. You need to add ATL_srefasum.h and ATL_srefaxpy.h. You can get them from here https://searchcode.com/codesearch/view/86268628/ and https://searchcode.com/codesearch/view/86268548/. I added them and rebuild the package and it works

@MechCoder

This comment has been minimized.

Member

MechCoder commented Jul 9, 2016

Please add those and ping back. Thanks!

import warnings
ctypedef np.float64_t DOUBLE
ctypedef np.uint32_t UINT32_t
ctypedef floating (*DOT)(int N, floating *X, int incX, floating *Y,

This comment has been minimized.

@jnothman

jnothman Jul 10, 2016

Member

Nice; looks like you're getting the hang of this!

This comment has been minimized.

@MechCoder
cdef AXPY axpy
cdef ASUM asum
if floating is float:

This comment has been minimized.

@jnothman

jnothman Jul 10, 2016

Member

I assume you get the idea that Cython compiles this as if there is no if here; the only if is when detecting the type passed in from Python. Then we hope that the C optimiser recognises these as variables that are constant locally, being assigned values that are constant globally (at linking time) and hence compiles this all to be identical to the former version...

@yenchenlin

This comment has been minimized.

Contributor

yenchenlin commented Aug 24, 2016

@jnothman done!

@jnothman

This comment has been minimized.

Member

jnothman commented Aug 24, 2016

Have you forgotten to push those last changes? whats_new does not appear updated, nor the warning.

@@ -474,7 +474,8 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
warnings.warn('Objective did not converge.' +
' You might want' +
' to increase the number of iterations.' +
' Fitting data with alpha near zero, e.g., 1e-8,' +
' Fitting float32 data with alpha near zero,' +

This comment has been minimized.

@jnothman

jnothman Aug 24, 2016

Member

But this is only relevant if the data is float32, no?

This comment has been minimized.

@yenchenlin

yenchenlin Aug 24, 2016

Contributor

Actually I think fitting with a really small alpha, e.g., 1e-20, even float64 data may not converge.

This comment has been minimized.

@jnothman

jnothman Aug 24, 2016

Member

Sure, so make the warning as relevant and useful as possible to a user that triggers it.

This comment has been minimized.

@yenchenlin

yenchenlin Aug 24, 2016

Contributor

So is simply remove the float32 enough 😛?

This comment has been minimized.

@jnothman

jnothman Aug 24, 2016

Member

Not really, because alpha=1e-8 isn't ordinarily too small for normalized float64. Either remove reference to the alpha value or check appropriate conditions for the message, then it will be much more meaningful message. Also "alpha near zero" would usually be "very small alpha".

This comment has been minimized.

@yenchenlin

yenchenlin Aug 25, 2016

Contributor

I have to admit that I'm not very sure about all the factors that will cause convergence issue, and thus not dare to determine a specific reference value.

Or we can remove reference to the alpha value?

This comment has been minimized.

@jnothman

jnothman Aug 25, 2016

Member

Remove any specific value. Just say near zero

On 25 August 2016 at 23:34, Yen notifications@github.com wrote:

In sklearn/linear_model/coordinate_descent.py
#6913 (comment)
:

@@ -474,7 +474,8 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
warnings.warn('Objective did not converge.' +
' You might want' +
' to increase the number of iterations.' +

  •                      ' Fitting data with alpha near zero, e.g., 1e-8,' +
    
  •                      ' Fitting float32 data with alpha near zero,' +
    

I have to admit that I'm not very sure about all the factors that will
cause convergence issue, and thus not dare to determine a specific
reference value.

Or we can remove reference to the alpha value?


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/scikit-learn/scikit-learn/pull/6913/files/82fdf0962e3c7b0965b54ca137a56ab6d01fc226..c032d3b5820b53fb0717008435a44245cdb746f1#r76242986,
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAEz65p1hO8jA-O7s1Ka5BIexYN_p7qlks5qjZnKgaJpZM4I6SId
.

This comment has been minimized.

@yenchenlin

yenchenlin Aug 25, 2016

Contributor

done!

yenchenlin added some commits Aug 24, 2016

@@ -470,7 +473,10 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
if dual_gap_ > eps_:
warnings.warn('Objective did not converge.' +
' You might want' +
' to increase the number of iterations',
' to increase the number of iterations.' +
' Fitting data with alpha near zero,' +

This comment has been minimized.

@MechCoder

MechCoder Aug 24, 2016

Member

Did we not agree to add this message only when alpha is less than some heuristic value?

This comment has been minimized.

@yenchenlin

yenchenlin Aug 25, 2016

Contributor

It seems hard to determine a heuristic value 😭

There are too many factors.

@MechCoder

This comment has been minimized.

Member

MechCoder commented Aug 24, 2016

Your benchmarks in the description of the Pull Request suggests non-trivial speed gains. Do the speed gains also still hold?

@yenchenlin

This comment has been minimized.

Contributor

yenchenlin commented Aug 25, 2016

@MechCoder yes!

  • float32
    32
  • float64
    64
@MechCoder

This comment has been minimized.

Member

MechCoder commented Aug 25, 2016

Awesome! Merging with master and thanks a lot for your perseverance.

@MechCoder MechCoder merged commit 084ef97 into scikit-learn:master Aug 25, 2016

3 checks passed

ci/circleci Your tests passed on CircleCI!
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@yenchenlin

This comment has been minimized.

Contributor

yenchenlin commented Aug 25, 2016

😭😭 😭

🍻🍻🍻

@MechCoder

This comment has been minimized.

Member

MechCoder commented Aug 25, 2016

It should be worth adding a note to src/cblas/README.txt to let know what changes have to be made to add to call cblas functions internally. Maybe @fabianp can do that?

@jnothman

This comment has been minimized.

Member

jnothman commented Aug 26, 2016

Hurrah! Thanks @fabianp for rescuing this. And to @MechCoder for inviting that saviour. And to @yenchenlin for winning.

@MechCoder

This comment has been minimized.

Member

MechCoder commented Aug 26, 2016

And to you the dark knight? :p

@jnothman

This comment has been minimized.

Member

jnothman commented Sep 6, 2016

Just a heads up that I'm a little concerned that these changes to Lasso changed its behaviour (for float64 data). It seems to have resulted in a test failure at #6717 (comment). For the dummy data I've tried so far, behaviour isn't changed, so this needs more verification.

TomDLT added a commit to TomDLT/scikit-learn that referenced this pull request Oct 3, 2016

[MRG+2] ElasticNet and Lasso now support float32 dtype input (scikit-…
…learn#6913)

ElasticNet and Lasso no longer implicitly convert float32 dtype input to float64 internally.

* Make helper functions in cd use fused types

* Import cblas float functions

* Make enet_coordinate_descent support fused types

* Make dense case work

* Refactor format

* Remove redundant change

* Add cblas files

* Avoid redundant code

* Remove redundant c files and import

* Recover unnecessary change

* Update comment

* Make coef_ type consistent

* Test float32 input

* Add user warning when fitting float32 data with small alpha

* Fix bug

* Change variable to floating type

* Make cd sparse support fused types

* Make CD support fused types when data is sparse

* Add referenced src files

* Avoid duplicated code

* Avoid type casting

* Fix indentation in test

* Avoid type casting in sparse implementation

* Fix indentation

* Fix duplicated intialization code

* Follow PEP8

* Raise tmp precision to double

* Add 64 bit computer check

* Fix test

* Add constraint

* PEP 8

* Make saxpy have the same structure as daxpy

Hopefully this fixes the problems outlined in PR scikit-learn#6913

* Remove wrong hardware test

* Remove dsdot

* Remove redundant asarray

* Add test for fit_intercept

* Make _preprocess_data support other dtypes

* Add concrete value

* Workaround

* Fix error msg

* Move declarartion

* Remove redundant comment

* Add tests

* Test normalize

* Delete warning

* Fix comment

* Add error msg

* Add error msg

* Add what's new

* Fix error msg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment