-
-
Notifications
You must be signed in to change notification settings - Fork 25.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+2] Add float32 support for Linear Discriminant Analysis #13273
Changes from 7 commits
da6518a
a748478
4586496
f6c6bcc
3223d9c
49c7c38
d8cd7cb
36248e8
d657605
cf02ebb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ | |
from .linear_model.base import LinearClassifierMixin | ||
from .covariance import ledoit_wolf, empirical_covariance, shrunk_covariance | ||
from .utils.multiclass import unique_labels | ||
from .utils import check_array, check_X_y | ||
from .utils import check_array, check_X_y, as_float_array | ||
from .utils.validation import check_is_fitted | ||
from .utils.multiclass import check_classification_targets | ||
from .preprocessing import StandardScaler | ||
|
@@ -427,7 +427,8 @@ def fit(self, X, y): | |
Target values. | ||
""" | ||
# FIXME: Future warning to be removed in 0.23 | ||
X, y = check_X_y(X, y, ensure_min_samples=2, estimator=self) | ||
X = as_float_array(X) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line is necessary to get coefficients of type float32 output when X is of type int32. Removing this line gives coefficients of type float64. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is strange becasue the next There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I see why. So X will be converted to float64 and you expect it to be converted to float 32 because X is of type int32. I would say that there is no real reason for that. @massich what is the mechanism in the other estimators that you modified? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was no mechanism set. I've the vague idea that we talked about it long ago, but I can't recall that we agreed on anything. To me it makes sense that if someone willing creates the data in 32, wants to keep it in 32 bits. But if such individual is there she/he will scream at us, or would had done it already. So I've no objection on defaulting to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An int32 is a "very big" int. Hence, it make sense to convert it to a float64. People who want to save memory with ints use int16. Indeed, an int has a larger span than a float for the same memory cost. |
||
X, y = check_X_y(X, y, ensure_min_samples=2, estimator=self, dtype=[np.float64, np.float32]) | ||
self.classes_ = unique_labels(y) | ||
n_samples, _ = X.shape | ||
n_classes = len(self.classes_) | ||
|
@@ -485,9 +486,10 @@ def fit(self, X, y): | |
raise ValueError("unknown solver {} (valid solvers are 'svd', " | ||
"'lsqr', and 'eigen').".format(self.solver)) | ||
if self.classes_.size == 2: # treat binary case as a special case | ||
self.coef_ = np.array(self.coef_[1, :] - self.coef_[0, :], ndmin=2) | ||
self.coef_ = np.array(self.coef_[1, :] - self.coef_[0, :], ndmin=2, | ||
dtype=X.dtype) | ||
self.intercept_ = np.array(self.intercept_[1] - self.intercept_[0], | ||
ndmin=1) | ||
ndmin=1, dtype=X.dtype) | ||
return self | ||
|
||
def transform(self, X): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
from sklearn.utils.testing import (assert_array_equal, assert_no_warnings, | ||
assert_warns_message) | ||
from sklearn.utils.testing import assert_array_almost_equal | ||
from sklearn.utils.testing import assert_allclose | ||
from sklearn.utils.testing import assert_equal | ||
from sklearn.utils.testing import assert_almost_equal | ||
from sklearn.utils.testing import assert_raises | ||
|
@@ -296,6 +297,29 @@ def test_lda_dimension_warning(n_classes, n_features): | |
assert_warns_message(FutureWarning, future_msg, lda.fit, X, y) | ||
|
||
|
||
@pytest.mark.parametrize("data_type, expected_type", [ | ||
(np.float32, np.float32), (np.float64, np.float64), (np.int32, np.float32), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Convention for int32 is casting to float64. |
||
(np.int64, np.float64)]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find this one more readable: @pytest.mark.parametrize("data_type, expected_type", [
(np.float32, np.float32),
(np.float64, np.float64),
(np.int32, np.float64),
(np.int64, np.float64)
]) |
||
def test_lda_dtype_match(data_type, expected_type): | ||
for (solver, shrinkage) in solver_shrinkage: | ||
clf = LinearDiscriminantAnalysis(solver=solver, shrinkage=shrinkage) | ||
clf.fit(X.astype(data_type), y.astype(data_type)) | ||
assert clf.coef_.dtype == expected_type | ||
|
||
|
||
def test_lda_numeric_consistency_float32_float64(): | ||
for (solver, shrinkage) in solver_shrinkage: | ||
clf_32 = LinearDiscriminantAnalysis(solver=solver, shrinkage=shrinkage) | ||
clf_32.fit(X.astype(np.float32), y.astype(np.float32)) | ||
clf_64 = LinearDiscriminantAnalysis(solver=solver, shrinkage=shrinkage) | ||
clf_64.fit(X.astype(np.float64), y.astype(np.float64)) | ||
|
||
# Check value consistency between types | ||
rtol = 1e-6 | ||
assert_allclose(clf_32.coef_, clf_64.coef_.astype(np.float32), | ||
rtol=rtol) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do not use assert_allclose(clf_32.coef_, clf_64.coef_, rtol=rtol) |
||
|
||
|
||
def test_qda(): | ||
# QDA classification. | ||
# This checks that QDA implements fit and predict and returns | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.