Skip to content

Commit

Permalink
Merge pull request #337 from rasbt/float32
Browse files Browse the repository at this point in the history
Allow lower precision array types in plot_decision_regions function
  • Loading branch information
rasbt committed Mar 11, 2018
2 parents eb0f665 + 2153211 commit f3850ba
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/sources/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ The CHANGELOG for the current development version is available at
##### Bug Fixes

- Fixed issue when class labels were provided to the `EnsembleVoteClassifier` when `refit` was set to `false`. ([#322](https://github.com/rasbt/mlxtend/issues/322))
- Allow arrays with 16-bit and 32-bit precision in `plot_decision_regions` function. ([#337](https://github.com/rasbt/mlxtend/issues/337))



Expand Down
2 changes: 1 addition & 1 deletion mlxtend/plotting/decision_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def plot_decision_regions(X, y, clf,
if dim > 2:
for feature_idx in filler_feature_values:
X_predict[:, feature_idx] = filler_feature_values[feature_idx]
Z = clf.predict(X_predict)
Z = clf.predict(X_predict.astype(X.dtype))
Z = Z.reshape(xx.shape)
# Plot decisoin region
ax.contourf(xx, yy, Z,
Expand Down
4 changes: 2 additions & 2 deletions mlxtend/utils/checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ def check_Xy(X, y, y_int=True):
if not isinstance(y, np.ndarray):
raise ValueError('y must be a NumPy array. Found %s' % type(y))

if y_int and not np.issubdtype(y.dtype, np.integer):
if 'int' not in str(y.dtype):
raise ValueError('y must be an integer array. Found %s. '
'Try passing the array as y.astype(np.integer)'
% y.dtype)

if X.dtype not in (np.float, np.int):
if not ('float' in str(X.dtype) or 'int' in str(X.dtype)):
raise ValueError('X must be an integer or float array. Found %s.'
% X.dtype)

Expand Down
8 changes: 8 additions & 0 deletions mlxtend/utils/tests/test_checking_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def test_invalid_type_X():
y)


def test_float16_X():
check_Xy(X.astype(np.float16), y)


def test_float16_y():
check_Xy(X, y.astype(np.int16))


def test_invalid_type_y():
expect = "y must be a NumPy array. Found <class 'list'>"
if (sys.version_info < (3, 0)):
Expand Down

0 comments on commit f3850ba

Please sign in to comment.