Skip to content

MNT support cross 32bit/64bit pickles for HGBT #28074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 15, 2024

Conversation

lorentzenchr
Copy link
Member

Reference Issues/PRs

Fixes #27952.

What does this implement/fix? Explain your changes.

This PR enables to fit and save (pickle dump) an HGBT model on a system with one bitness (e.g. 64 bit) and load and apply the model on another system with different bitness (e.g. 32 bit).

The crucial point are TreePredictor.nodes, an ndarray of PREDICTOR_RECORD_DTYPE. The field feature_idx is of dtype np.intp which is platform dependent.

Any other comments?

A common test for this would be nice.

Copy link

github-actions bot commented Jan 7, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 5e59d7b. Link to the linter CI: here

@ogrisel
Copy link
Member

ogrisel commented Jan 8, 2024

Shall we had a non-regression test for this?

One way would be to commit a few small pickle test files:

  • one HGBDT model fit on a 32-bit host;
  • one HGBDT model fit on a 64-bit host.

The generation of those models should be scripted (to allow reproducibility, e.g. to regenerate those files in case we refactor the internals of those estimators in the future) but the result would be committed in the sklearn.ensemble/_hist_gradient_boostin_tests/ folder or similar.

Then we would write a test to check that loading those pikcles and calling .predict (or even .set_params(warm_start=True, max_iter=est.n_iter_ + 1).fit()) works on all CI hosts (both 32-bit such as pyodide or 64-bit).

@betatim
Copy link
Member

betatim commented Jan 8, 2024

Why does this work? In the sense that if you are on a platform that does not support int64, I'd expect unpickling to fail. Because what type does the unpickling create when it reads the file if the file says "this is a int64" and there isn't one? Can someone explain where my thinking is going wrong?

Agree that having a test would be good, though we'd need to switch platform in between CI runs no? Or maybe we can include a pickle generated on a 32 and a 64 bit platform in the "static" data available to tests? That way you'd avoid having to re-generate it (and switch platforms) at the cost of having to store a binary file in the repo.

edit: Olivier just suggested this :D

@ogrisel
Copy link
Member

ogrisel commented Jan 8, 2024

In the sense that if you are on a platform that does not support int64.

32-bit platforms support loading int64 data, it's not the problem. The problem is that intp is an alias for int64 on 64 bit platform while it is an alias for int32 on 32-bit platforms. So Cython code that uses intp need to be fed platform-specific datastructures as inputs, hence the need for a cast in the Python code before calling the Cython code.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has someone manually checked that this fixes the problem with pyodide? If so +1 for merging on my side (with a changelog entry targeting either 1.4.0 or 1.4.1 as this is a bugfix).

I as commented above, I think it would be great to have non-regression tests for this but as writing such tests in a maintainable way might be more significantly more complex than the fix itself, I am fine with merging without tests for a start and maybe later design a common test for cross-bitness model serialization later.

@betatim
Copy link
Member

betatim commented Jan 8, 2024

Thanks for the explanation!

@lorentzenchr
Copy link
Member Author

lorentzenchr commented Jan 8, 2024

I also prefer to merge without a test. Creating a pickle file once, checking it in via git and using it in a test does also not work as we don't guarantee forward compatibility of saved models.

One take away for me is that it shows a design flaw that we have: authorative Python references tell you to validate inside __init__, but we are used to do that only inside fit. I guess that's one of the reasons (habit) to not validate inside TreePredictor.__init__.

Note that one alternative, a bit orthogonal to this PR, would be to use platform independent dtypes only, i.e. int32 instead of intp (I guess int32 would be enough by far!).

@ogrisel
Copy link
Member

ogrisel commented Jan 9, 2024

Note that one alternative, a bit orthogonal to this PR, would be to use platform independent dtypes only, i.e. int32 instead of intp (I guess int32 would be enough by far!).

Indeed, along with an inline comment to explain why we hardcode the use of int32 in such code.

@lesteve
Copy link
Member

lesteve commented Jan 9, 2024

Note that we added tests for the trees in #21552. Something similar could be done here.

About common tests, I don't think this is so easy for 32bit/64bit but it could be doable for little-endian vs big-endian, which could surface some of these issues, I had an old branch about this but I would need to revive it https://github.com/lesteve/scikit-learn/tree/test-common-cross-endianness-pickle ...

@ogrisel ogrisel added this to the 1.4 milestone Jan 9, 2024
@ogrisel
Copy link
Member

ogrisel commented Jan 9, 2024

In particular:

https://github.com/scikit-learn/scikit-learn/pull/21552/files#diff-e0c534f299457278700ee68add18acf0d4b2a1dd9391e1bcc49f3fe1b472fe67R2278-R2297

@lorentzenchr I can help to push such a non-regression test to your PR if you don't have the time to do it yourself.

@lorentzenchr
Copy link
Member Author

@lorentzenchr I can help to push such a non-regression test to your PR if you don't have the time to do it yourself.

I won't have time for it. I would prefer a common test for it - if possible - and merge here without dedicated test.

@ogrisel
Copy link
Member

ogrisel commented Jan 9, 2024

I won't have time for it. I would prefer a common test for it - if possible - and merge here without dedicated test.

It's hard to write a common test that does not involve the maintenance overhead of storing and managing platform specific pickle files somewhere.

The strategy used by @lesteve above on the other-hand does not require storing any pickle file anywhere but requires an estimator-specific monkeypatch during the test execution. The test itself is complex to write but significantly lower maintenance.

@lorentzenchr
Copy link
Member Author

@ogrisel I'd highly appreciate your help with such a test.

@lesteve
Copy link
Member

lesteve commented Jan 9, 2024

I can try to take a look at adding a test, since this fix is useful in a Pyodide context!

@lorentzenchr out of interest did you work on this, because you had a use case where the fix was needed, or mostly because you were curious?

@lorentzenchr
Copy link
Member Author

Just curious and trying to remove bugs that matter.

@lesteve
Copy link
Member

lesteve commented Jan 12, 2024

Here is what I have done:

  • added a test strongly inspired from the tree one in Support cross 32bit/64bit pickles for decision tree #21552
  • the test showed that the fix was not working, the reason is that __init__ is not called when loading a pickle I think. I put the dtype conversion in __setstate__ which is the proper place
  • tweaked changelog to target 1.4, not sure what the status of the release is sorry ... feel free to change this

@ogrisel
Copy link
Member

ogrisel commented Jan 12, 2024

tweaked changelog to target 1.4, not sure what the status of the release is sorry ... feel free to change this

I think we should target 1.4.0. I will do a review today.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (just a few suggestions to make the comments more explicit for people not to familiar with the interaction of numpy dtypes and platform bitness).

lesteve and others added 2 commits January 12, 2024 11:51
…oosting.py

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
…oosting.py

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
lesteve and others added 2 commits January 12, 2024 11:54
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Let's include it in 1.4.0. Thanks

Comment on lines +138 to +142
# TODO: consider always using platform agnostic dtypes for fitted
# estimator attributes. For this particular estimator, this would
# mean replacing the intp field of PREDICTOR_RECORD_DTYPE by an int32
# field. Ideally this should be done consistently throughout
# scikit-learn along with a common test.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@jeremiedbb jeremiedbb merged commit 7f131e0 into scikit-learn:main Jan 15, 2024
jeremiedbb pushed a commit to jeremiedbb/scikit-learn that referenced this pull request Jan 17, 2024
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
jeremiedbb pushed a commit that referenced this pull request Jan 17, 2024
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@ogrisel ogrisel deleted the mnt_32_64_pickle_hgbt branch January 30, 2024 10:52
glemaitre pushed a commit to glemaitre/scikit-learn that referenced this pull request Feb 10, 2024
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

HistGradientBoosting pickle portability between 64bit and 32bit arch
5 participants