Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX accept uint8 X in hist-gbdt.predict #18410

Merged
merged 3 commits into from Sep 26, 2020

Conversation

NicolasHug
Copy link
Member

Fixes #18408

is_binned = getattr(self, '_in_fit', False)
dtype = [X_BINNED_DTYPE] if is_binned else [X_DTYPE]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in master, X would still be of dtype X_BINNED_DTYPE (pass-through) but is_binned would be False so we'd call _predict_from_numeric_data which can only accept X_DTYPE, hence the failure

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

rng = np.random.RandomState(0)

X = rng.randint(0, 100, size=(10, 2)).astype(np.uint8)
y = rng.randint(0, 2, size=10).astype(np.uint8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does y need to be casted here?

Suggested change
y = rng.randint(0, 2, size=10).astype(np.uint8)
y = rng.randint(0, 2, size=10)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not what was causing the failure on master, but can't hurt to have it anyway?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, lets keep it.

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@NicolasHug
Copy link
Member Author

wanna merge it @lorentzenchr ;) ?

@lorentzenchr lorentzenchr merged commit df61e9e into scikit-learn:master Sep 26, 2020
5 checks passed
@lorentzenchr lorentzenchr changed the title [MRG] Fix accept uint8 X in hist-gbdt.predict FIX accept uint8 X in hist-gbdt.predict Sep 26, 2020
amrcode pushed a commit to amrcode/scikit-learn that referenced this pull request Oct 19, 2020
* fixed dtype issue

* whatsnew

* dont pass lists
jayzed82 pushed a commit to jayzed82/scikit-learn that referenced this pull request Oct 22, 2020
* fixed dtype issue

* whatsnew

* dont pass lists
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Data type mismatch problem when calling HistGradientBoostingClassifier.predict()
4 participants