Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Introduce dtype preservation semantics in DistanceMetric objects. #27006

Merged
merged 10 commits into from Aug 10, 2023

Conversation

Micky774
Copy link
Contributor

@Micky774 Micky774 commented Aug 3, 2023

Reference Issues/PRs

What does this implement/fix? Explain your changes.

Preserves dtype when computing distances, under the assumption that the precision of the input data is an implication of preferred precision of output data. Note that accumulation still largely occurs using float64_t with some exceptions.

Any other comments?

Current benchmarks (generated here) suggest that there is no regression in the dense case (dist), and a 10-25% speedup in the sparse case (dist_csr).

Benchmark Plots

428cfb6d-a17e-42cb-b96a-4c9a908fae16

e588e219-e597-4650-af0a-d7eaacfb6c0c

Memory profiling indicates a reduction of memory usage in this script from 763MiB to 382MiB.

cc: @jjerphan @OmarManzoor @thomasjpfan

@github-actions
Copy link

github-actions bot commented Aug 3, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: f8b6027. Link to the linter CI: here

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @Micky774.

We indeed need to add dtype preservation to Cython implementations. Yet, I do not think we need to change the return type of methods to INPUT_DTYPE_t. One of my comments gives some details.

sklearn/metrics/_dist_metrics.pyx.tp Show resolved Hide resolved
doc/whats_new/v1.4.rst Outdated Show resolved Hide resolved
Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was oncerned with the loss of precision for PairwiseDistancesReductions but since the test suite passes I think we can accept or one the other: proceed as you prefer, @Micky774.

@Micky774
Copy link
Contributor Author

Micky774 commented Aug 9, 2023

I was oncerned with the loss of precision for PairwiseDistancesReductions but since the test suite passes I think we can accept or one the other: proceed as you prefer, @Micky774.

I've reintroduced the dtype-dependent signatures. Do you think this warrants a separate changelog entry considering DistanceMetric objects now fully preserve dtype?

@jjerphan
Copy link
Member

jjerphan commented Aug 9, 2023

I don't think this is necessary since DistanceMetric32 aren't public.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @Micky774

@jjerphan jjerphan enabled auto-merge (squash) August 10, 2023 19:09
@jjerphan
Copy link
Member

I merged main again to resolve the unrelated linter failure on the CI.

@jjerphan jjerphan merged commit acf60de into scikit-learn:main Aug 10, 2023
25 checks passed
@Micky774 Micky774 deleted the dm_float32 branch August 10, 2023 22:07
TamaraAtanasoska pushed a commit to TamaraAtanasoska/scikit-learn that referenced this pull request Aug 21, 2023
…s. (scikit-learn#27006)

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
REDVM pushed a commit to REDVM/scikit-learn that referenced this pull request Nov 16, 2023
…s. (scikit-learn#27006)

Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants