Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+2] LogisticRegression convert to float64 (newton-cg) #8835

Merged
merged 26 commits into from Jun 7, 2017

Conversation

@massich
Copy link
Contributor

@massich massich commented May 5, 2017

Reference Issue

Fixes #8769

What does this implement/fix? Explain your changes.

Avoids logistic regression to aggressively cast the data to np.float64 when np.float32 is supplied.

Any other comments?

(only for the newton-cg case)

@massich
Copy link
Contributor Author

@massich massich commented May 5, 2017

@GaelVaroquaux Actually fixing self.coefs_ was straight forward. Where do you wanna go from here?

@massich massich changed the title Is/8769 [WIP] LogisticRegression convert to float64 May 5, 2017
@@ -1281,9 +1287,9 @@ def fit(self, X, y, sample_weight=None):
self.n_iter_ = np.asarray(n_iter_, dtype=np.int32)[:, 0]

if self.multi_class == 'multinomial':
self.coef_ = fold_coefs_[0][0]
self.coef_ = fold_coefs_[0][0].astype(np.float32)

This comment has been minimized.

@massich

massich May 5, 2017
Author Contributor

my bad it should be _dtype

@glemaitre
Copy link
Contributor

@glemaitre glemaitre commented May 5, 2017

You can execute the PEP8 check locally:

bash ./build_tools/travis/flake8_diff.sh

That should be useful in the future.

@massich massich force-pushed the massich:is/8769 branch from 310084b to 14450a2 May 19, 2017
@massich massich changed the title [WIP] LogisticRegression convert to float64 [MRG] LogisticRegression convert to float64 May 19, 2017
_dtype = np.float64
if self.solver in ['newton-cg'] \
and isinstance(X, np.ndarray) and X.dtype in [np.float32]:
_dtype = np.float32

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux May 29, 2017
Member

check_X_y can take a list of acceptable dtypes as a dtype argument. I think that using this feature would be a better way of writing this code. The code would be something like

if self.solver in ['newtown-cg']:
   _dtype = [np.float64, np.float32]
else:
self.coef_ = np.asarray(fold_coefs_)
self.coef_ = np.asarray(fold_coefs_, dtype=_dtype)

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux May 29, 2017
Member

Is the conversion necessary here? In other word, if we get the code right, doesn't coefs_ get returned in the right dtype?

@GaelVaroquaux
Copy link
Member

@GaelVaroquaux GaelVaroquaux commented May 29, 2017

I suspect that the problem isn't really solved: if you look a bit further in the code, you will see that inside 'logistic_regression_path', check_X_y is called again with the np.float64 dtype. And there might be other instances of this problem.

@massich massich changed the title [MRG] LogisticRegression convert to float64 [WIP] LogisticRegression convert to float64 May 30, 2017
Copy link
Contributor Author

@massich massich left a comment

Indeed, logistic_regression_path has a check_array with a np.float64 as a dtype. However, when logistic_regression_path is called with check_input=False, therefore X.dtype remains np.float32. (see here)

Still, w0 starts as an empty list and end up being a np.float64.(see here)

@GaelVaroquaux
Copy link
Member

@GaelVaroquaux GaelVaroquaux commented May 30, 2017

Copy link
Member

@raghavrv raghavrv left a comment

Thanks for the PR!

@@ -1203,7 +1205,12 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Tolerance for stopping criteria must be "
"positive; got (tol=%r)" % self.tol)

X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64,
if self.solver in ['newton-cg']:
_dtype = [np.float64, np.float32]

This comment has been minimized.

@raghavrv

raghavrv Jun 2, 2017
Member

Sorry if I am missing something, but why?

This comment has been minimized.

@massich

massich Jun 2, 2017
Author Contributor

The idea is that previously check_X_y was converting X and y into np.float64. This is fine, if the user passes a list as X, but if a user passes a np.float32 willingly converting it to np.float64 penalizes them in memory and speed.

Therefore, we are trying to keep the data in np.float32 if the user provides the data in such type.

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux Jun 2, 2017
Member

The fact that @raghavrv asks a question tells us that a short comment explaining the logic should probably be useful here.

This comment has been minimized.

@massich

massich Jun 2, 2017
Author Contributor

I think that @raghavrv was more concerned in the fact that we were passing a list rather than forcing one or the other. Once we checked that check_X_y was taking care of it, he was ok with it.

@raghavrv any comments?


for solver in ['newton-cg']:
for multi_class in ['ovr', 'multinomial']:

This comment has been minimized.

@raghavrv

raghavrv Jun 2, 2017
Member

can you remove this new line


def test_dtype_missmatch_to_profile():
# Test that np.float32 input data is not cast to np.float64 when possible

This comment has been minimized.

@raghavrv

raghavrv Jun 2, 2017
Member

and this newline too

@@ -41,12 +41,17 @@ def compute_class_weight(class_weight, classes, y):
# Import error caused by circular imports.
from ..preprocessing import LabelEncoder

if y.dtype == np.float32:
_dtype = np.float32

This comment has been minimized.

@raghavrv

raghavrv Jun 2, 2017
Member

why not _dtype=y.dtype...

is it so you can have y.dtype to be int and weight to be of float?


# Check accuracy consistency
lr_64 = LogisticRegression(solver=solver, multi_class=multi_class)
lr_64.fit(X, Y1)

This comment has been minimized.

@raghavrv

raghavrv Jun 2, 2017
Member

Can you ensure (maybe using astype?) X, Y1 are of float64 before this test? (If in future it is changed, this test will still pass)

def test_dtype_match():
# Test that np.float32 input data is not cast to np.float64 when possible

X_ = np.array(X).astype(np.float32)

This comment has been minimized.

@raghavrv

raghavrv Jun 2, 2017
Member

X_32 = ... astype(32)
X_64 = ... astype(64)

assert_almost_equal(lr_32.coef_, lr_64.coef_.astype(np.float32))


def test_dtype_missmatch_to_profile():

This comment has been minimized.

@raghavrv

raghavrv Jun 2, 2017
Member

This test can be removed...

@@ -608,10 +610,10 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
# and check length
# Otherwise set them to 1 for all examples
if sample_weight is not None:
sample_weight = np.array(sample_weight, dtype=np.float64, order='C')
sample_weight = np.array(sample_weight, dtype=X.dtype, order='C')

This comment has been minimized.

@raghavrv

raghavrv Jun 2, 2017
Member

Should it be y.dtype?

cc: @agramfort

This comment has been minimized.

@massich

massich Jun 2, 2017
Author Contributor

we were discussing with @glemaitre (and @GaelVaroquaux) to force X.dtype and y.dtype to be the same.

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux Jun 2, 2017
Member

Yes, I think the idea should be that the dtype of X conditions the dtype of the computation.

We should be an RFC about this, and include it in the docs.

This comment has been minimized.

@massich

massich Jun 2, 2017
Author Contributor

see #8976


# Check accuracy consistency
lr_64 = LogisticRegression(solver=solver, multi_class=multi_class)
lr_64.fit(X_64, y_64)

This comment has been minimized.

@TomDLT

TomDLT Jun 6, 2017
Member

please add:

assert_equal(lr_64.coef_.dtype, X_64.dtype)

otherwise this test passes when we transform everything to 32 bits

Copy link
Member

@TomDLT TomDLT left a comment

LGTM

@agramfort
Copy link
Member

@agramfort agramfort commented Jun 6, 2017

+1 for MRG if travis is happy

@agramfort agramfort changed the title [MRG] LogisticRegression convert to float64 [MRG+1] LogisticRegression convert to float64 Jun 6, 2017
@TomDLT TomDLT changed the title [MRG+1] LogisticRegression convert to float64 [MRG+2] LogisticRegression convert to float64 Jun 6, 2017
@massich massich changed the title [MRG+2] LogisticRegression convert to float64 [MRG+2] LogisticRegression convert to float64 (newton-cg) Jun 6, 2017
@GaelVaroquaux
Copy link
Member

@GaelVaroquaux GaelVaroquaux commented Jun 6, 2017

@GaelVaroquaux
Copy link
Member

@GaelVaroquaux GaelVaroquaux commented Jun 6, 2017

Anybody has ideas what's wrong with AppVeyor. Cc @ogrisel @lesteve

@GaelVaroquaux
Copy link
Member

@GaelVaroquaux GaelVaroquaux commented Jun 6, 2017

Before we merge, this warrants a whats_new entry.

@agramfort
Copy link
Member

@agramfort agramfort commented Jun 6, 2017

appveyor is not happy :(

@jnothman
Copy link
Member

@jnothman jnothman commented Jun 7, 2017

@jnothman
Copy link
Member

@jnothman jnothman commented Jun 7, 2017

@agramfort
Copy link
Member

@agramfort agramfort commented Jun 7, 2017

all green merging

@agramfort agramfort merged commit 39a4658 into scikit-learn:master Jun 7, 2017
3 of 5 checks passed
3 of 5 checks passed
codecov/patch No report found to compare against
Details
codecov/project No report found to compare against
Details
ci/circleci Your tests passed on CircleCI!
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@GaelVaroquaux
Copy link
Member

@GaelVaroquaux GaelVaroquaux commented Jun 7, 2017

Whats_new entry would have been good :)

@agramfort
Copy link
Member

@agramfort agramfort commented Jun 7, 2017

Sundrique added a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017
…rn#8835)

* Add a test to ensure not changing the input's data type

Test that np.float32 input data is not cast to np.float64 when using LR + newton-cg

* [WIP] Force X to remain float32. (self.coef_ remains float64 even if X is not)

* [WIP] ensure self.coef_ same type as X

* keep the np.float32 when multi_class='multinomial'

* Avoid hardcoded type for multinomial

* pass flake8

* Ensure that the results in 32bits are the same as in 64

* Address Gael's comments for multi_class=='ovr'

* Add multi_class=='multinominal' to test

* Add support for multi_class=='multinominal'

* prefer float64 to float32

* Force X and y to have the same type

* Revert "Add support for multi_class=='multinominal'"

This reverts commit 4ac33e8.

* remvert more stuff

* clean up some commmented code

* allow class_weight to take advantage of float32

* Add a test where X.dtype is different of y.dtype

* Address @raghavrv comments

* address the rest of @raghavrv's comments

* Revert class_weight

* Avoid copying if dtype matches

* Address alex comment to the cast from inside _multinomial_loss_grad

* address alex comment

* add sparsity test

* Addressed Tom comment of checking that we keep the 64 aswell
dmohns added a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
…rn#8835)

* Add a test to ensure not changing the input's data type

Test that np.float32 input data is not cast to np.float64 when using LR + newton-cg

* [WIP] Force X to remain float32. (self.coef_ remains float64 even if X is not)

* [WIP] ensure self.coef_ same type as X

* keep the np.float32 when multi_class='multinomial'

* Avoid hardcoded type for multinomial

* pass flake8

* Ensure that the results in 32bits are the same as in 64

* Address Gael's comments for multi_class=='ovr'

* Add multi_class=='multinominal' to test

* Add support for multi_class=='multinominal'

* prefer float64 to float32

* Force X and y to have the same type

* Revert "Add support for multi_class=='multinominal'"

This reverts commit 4ac33e8.

* remvert more stuff

* clean up some commmented code

* allow class_weight to take advantage of float32

* Add a test where X.dtype is different of y.dtype

* Address @raghavrv comments

* address the rest of @raghavrv's comments

* Revert class_weight

* Avoid copying if dtype matches

* Address alex comment to the cast from inside _multinomial_loss_grad

* address alex comment

* add sparsity test

* Addressed Tom comment of checking that we keep the 64 aswell
dmohns added a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
…rn#8835)

* Add a test to ensure not changing the input's data type

Test that np.float32 input data is not cast to np.float64 when using LR + newton-cg

* [WIP] Force X to remain float32. (self.coef_ remains float64 even if X is not)

* [WIP] ensure self.coef_ same type as X

* keep the np.float32 when multi_class='multinomial'

* Avoid hardcoded type for multinomial

* pass flake8

* Ensure that the results in 32bits are the same as in 64

* Address Gael's comments for multi_class=='ovr'

* Add multi_class=='multinominal' to test

* Add support for multi_class=='multinominal'

* prefer float64 to float32

* Force X and y to have the same type

* Revert "Add support for multi_class=='multinominal'"

This reverts commit 4ac33e8.

* remvert more stuff

* clean up some commmented code

* allow class_weight to take advantage of float32

* Add a test where X.dtype is different of y.dtype

* Address @raghavrv comments

* address the rest of @raghavrv's comments

* Revert class_weight

* Avoid copying if dtype matches

* Address alex comment to the cast from inside _multinomial_loss_grad

* address alex comment

* add sparsity test

* Addressed Tom comment of checking that we keep the 64 aswell
NelleV added a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017
…rn#8835)

* Add a test to ensure not changing the input's data type

Test that np.float32 input data is not cast to np.float64 when using LR + newton-cg

* [WIP] Force X to remain float32. (self.coef_ remains float64 even if X is not)

* [WIP] ensure self.coef_ same type as X

* keep the np.float32 when multi_class='multinomial'

* Avoid hardcoded type for multinomial

* pass flake8

* Ensure that the results in 32bits are the same as in 64

* Address Gael's comments for multi_class=='ovr'

* Add multi_class=='multinominal' to test

* Add support for multi_class=='multinominal'

* prefer float64 to float32

* Force X and y to have the same type

* Revert "Add support for multi_class=='multinominal'"

This reverts commit 4ac33e8.

* remvert more stuff

* clean up some commmented code

* allow class_weight to take advantage of float32

* Add a test where X.dtype is different of y.dtype

* Address @raghavrv comments

* address the rest of @raghavrv's comments

* Revert class_weight

* Avoid copying if dtype matches

* Address alex comment to the cast from inside _multinomial_loss_grad

* address alex comment

* add sparsity test

* Addressed Tom comment of checking that we keep the 64 aswell
paulha added a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
…rn#8835)

* Add a test to ensure not changing the input's data type

Test that np.float32 input data is not cast to np.float64 when using LR + newton-cg

* [WIP] Force X to remain float32. (self.coef_ remains float64 even if X is not)

* [WIP] ensure self.coef_ same type as X

* keep the np.float32 when multi_class='multinomial'

* Avoid hardcoded type for multinomial

* pass flake8

* Ensure that the results in 32bits are the same as in 64

* Address Gael's comments for multi_class=='ovr'

* Add multi_class=='multinominal' to test

* Add support for multi_class=='multinominal'

* prefer float64 to float32

* Force X and y to have the same type

* Revert "Add support for multi_class=='multinominal'"

This reverts commit 4ac33e8.

* remvert more stuff

* clean up some commmented code

* allow class_weight to take advantage of float32

* Add a test where X.dtype is different of y.dtype

* Address @raghavrv comments

* address the rest of @raghavrv's comments

* Revert class_weight

* Avoid copying if dtype matches

* Address alex comment to the cast from inside _multinomial_loss_grad

* address alex comment

* add sparsity test

* Addressed Tom comment of checking that we keep the 64 aswell
AishwaryaRK added a commit to AishwaryaRK/scikit-learn that referenced this pull request Aug 29, 2017
…rn#8835)

* Add a test to ensure not changing the input's data type

Test that np.float32 input data is not cast to np.float64 when using LR + newton-cg

* [WIP] Force X to remain float32. (self.coef_ remains float64 even if X is not)

* [WIP] ensure self.coef_ same type as X

* keep the np.float32 when multi_class='multinomial'

* Avoid hardcoded type for multinomial

* pass flake8

* Ensure that the results in 32bits are the same as in 64

* Address Gael's comments for multi_class=='ovr'

* Add multi_class=='multinominal' to test

* Add support for multi_class=='multinominal'

* prefer float64 to float32

* Force X and y to have the same type

* Revert "Add support for multi_class=='multinominal'"

This reverts commit 4ac33e8.

* remvert more stuff

* clean up some commmented code

* allow class_weight to take advantage of float32

* Add a test where X.dtype is different of y.dtype

* Address @raghavrv comments

* address the rest of @raghavrv's comments

* Revert class_weight

* Avoid copying if dtype matches

* Address alex comment to the cast from inside _multinomial_loss_grad

* address alex comment

* add sparsity test

* Addressed Tom comment of checking that we keep the 64 aswell
maskani-moh added a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
…rn#8835)

* Add a test to ensure not changing the input's data type

Test that np.float32 input data is not cast to np.float64 when using LR + newton-cg

* [WIP] Force X to remain float32. (self.coef_ remains float64 even if X is not)

* [WIP] ensure self.coef_ same type as X

* keep the np.float32 when multi_class='multinomial'

* Avoid hardcoded type for multinomial

* pass flake8

* Ensure that the results in 32bits are the same as in 64

* Address Gael's comments for multi_class=='ovr'

* Add multi_class=='multinominal' to test

* Add support for multi_class=='multinominal'

* prefer float64 to float32

* Force X and y to have the same type

* Revert "Add support for multi_class=='multinominal'"

This reverts commit 4ac33e8.

* remvert more stuff

* clean up some commmented code

* allow class_weight to take advantage of float32

* Add a test where X.dtype is different of y.dtype

* Address @raghavrv comments

* address the rest of @raghavrv's comments

* Revert class_weight

* Avoid copying if dtype matches

* Address alex comment to the cast from inside _multinomial_loss_grad

* address alex comment

* add sparsity test

* Addressed Tom comment of checking that we keep the 64 aswell
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
…rn#8835)

* Add a test to ensure not changing the input's data type

Test that np.float32 input data is not cast to np.float64 when using LR + newton-cg

* [WIP] Force X to remain float32. (self.coef_ remains float64 even if X is not)

* [WIP] ensure self.coef_ same type as X

* keep the np.float32 when multi_class='multinomial'

* Avoid hardcoded type for multinomial

* pass flake8

* Ensure that the results in 32bits are the same as in 64

* Address Gael's comments for multi_class=='ovr'

* Add multi_class=='multinominal' to test

* Add support for multi_class=='multinominal'

* prefer float64 to float32

* Force X and y to have the same type

* Revert "Add support for multi_class=='multinominal'"

This reverts commit 4ac33e8.

* remvert more stuff

* clean up some commmented code

* allow class_weight to take advantage of float32

* Add a test where X.dtype is different of y.dtype

* Address @raghavrv comments

* address the rest of @raghavrv's comments

* Revert class_weight

* Avoid copying if dtype matches

* Address alex comment to the cast from inside _multinomial_loss_grad

* address alex comment

* add sparsity test

* Addressed Tom comment of checking that we keep the 64 aswell
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

8 participants
You can’t perform that action at this time.