Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Changed implementation of Birch.predict to use pairwise_distances… #16149

Merged
merged 14 commits into from
Jan 29, 2020

Conversation

alexshacked
Copy link
Contributor

@alexshacked alexshacked commented Jan 17, 2020

Fixes #16027

Attempting to reduce memory footprint of Birch.predict. Please see solution description at issue
#16027 (comment)

Benchmark script to be added soon

@alexshacked
Copy link
Contributor Author

alexshacked commented Jan 19, 2020

Benchmark script + output
/> python performance_16027.py SCIKIT-LEARN_HOME

performance_16027.txt
perf.log

Bellow examples of running the benchmark on several sample sizes.
The initial memory size is caused by loading the samples. The peak memory size is caused by runninng the algorithm. The delta between peak and inital memory is similar for all sample sizes indicating that the chunking algorithm works as expected and the memory increase is not a function of the samples Matrix size.

/>python performance_16027.py /Users/ashacked/dev/python/scikit-learn
1. 2020-01-19 20:05:01.659986  Generating 1000000 samples
2. 2020-01-19 20:05:01.969912  Profiling Birch clustering algorithm on 1000000 samples
3. 2020-01-19 20:05:37.562192  Analyzing profiler results
initial memory[MB]:  129.5
peak memory[MB]:     162.3
final memory[MB]:    137.7
4. 2020-01-19 20:05:37.569448  End. Full profiler output at: /tmp/perf.log

/>python performance_16027.py /Users/ashacked/dev/python/scikit-learn
1. 2020-01-19 20:08:02.074358  Generating 2000000 samples
2. 2020-01-19 20:08:02.704236  Profiling Birch clustering algorithm on 2000000 samples
3. 2020-01-19 20:09:10.495158  Analyzing profiler results
initial memory[MB]:  177.4
peak memory[MB]:     230.6
final memory[MB]:    175.8
4. 2020-01-19 20:09:10.510405  End. Full profiler output at: /tmp/perf.log

/>python performance_16027.py /Users/ashacked/dev/python/scikit-learn
1. 2020-01-19 20:11:16.651208  Generating 4000000 samples
2. 2020-01-19 20:11:17.990072  Profiling Birch clustering algorithm on 4000000 samples
3. 2020-01-19 20:13:39.360781  Analyzing profiler results
initial memory[MB]:  272.4
peak memory[MB]:     343.9
final memory[MB]:    267.3
4. 2020-01-19 20:13:39.388682  End. Full profiler output at: /tmp/perf.log

/>python performance_16027.py /Users/ashacked/dev/python/scikit-learn
1. 2020-01-19 20:14:55.223833  Generating 8000000 samples
2. 2020-01-19 20:14:57.969499  Profiling Birch clustering algorithm on 8000000 samples
3. 2020-01-19 20:19:36.029547  Analyzing profiler results
initial memory[MB]:  426.3
peak memory[MB]:     443.2
final memory[MB]:    416.1
4. 2020-01-19 20:19:36.082528  End. Full profiler output at: /tmp/perf.log

Is there a location in the scikit-learn source code tree where I can put the performance test?

Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @alexshacked. I don't think we need to store the performance tests in the codebase. Your reports in the PR discussion should be enough.

@alexshacked alexshacked changed the title [WIP] Changed implementation of Birch.predict to use pairwise_distances… [MRG] Changed implementation of Birch.predict to use pairwise_distances… Jan 20, 2020
Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. Thanks @alexshacked !

@alexshacked
Copy link
Contributor Author

alexshacked commented Jan 20, 2020

You're right @jeremiedbb. On large input size, the increase in memory doesen't go beyond the 1GB chunk size

python performance_16027.py /Users/ashacked/dev/python/scikit-learn
1. 2020-01-20 15:15:51.555499  Generating 4000000 samples
2. 2020-01-20 15:15:52.980514  Profiling Birch clustering algorithm on 4000000 samples
3. 2020-01-20 15:18:05.623137  Analyzing profiler results
initial memory[MB]:  267.4
peak memory[MB]:     1262.4
final memory[MB]:    318.3
4. 2020-01-20 15:18:05.625402  End. Full profiler output at: /tmp/perf.log

python performance_16027.py /Users/ashacked/dev/python/scikit-learn
1. 2020-01-20 15:23:11.744886  Generating 12000000 samples
2. 2020-01-20 15:23:15.761165  Profiling Birch clustering algorithm on 12000000 samples
3. 2020-01-20 15:30:25.432834  Analyzing profiler results
initial memory[MB]:  746.0
peak memory[MB]:     1560.2
final memory[MB]:    637.3
4. 2020-01-20 15:30:25.435119  End. Full profiler output at: /tmp/perf.log

@jnothman
Copy link
Member

Should we be passing Y_norm_squared as a parameter for euclidean_distances to avoid recomputing them?

@alexshacked
Copy link
Contributor Author

Right @jnothman. I get it. No need to calculate YY for each chunk.
Done

@@ -579,10 +580,9 @@ def predict(self, X):
"""
X = check_array(X, accept_sparse='csr')
self._check_fit(X)
reduced_distance = safe_sparse_dot(X, self.subcluster_centers_.T)
reduced_distance *= -2
reduced_distance += self._subcluster_norms
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant that we could pass these sub cluster norms into pairwise_distances_argmin

Copy link
Contributor Author

@alexshacked alexshacked Jan 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this ?

kwargs = { 'Y_norm_squared': self._subcluster_norms}
return self.subcluster_labels_[
                pairwise_distances_argmin(X, self.subcluster_centers_, metric_kwargs = kwargs)
            ]

@alexshacked
Copy link
Contributor Author

@jnothman _subcluster_norms is passed to pairwise_distances_argmin. pairwise.py is not changed by this PR

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!!

Please add an |Efficiency| entry to the change log at doc/whats_new/v0.23.rst. Like the other entries there, please reference this pull request with :pr: and credit yourself (and other contributors if applicable) with :user:

@thomasjpfan thomasjpfan merged commit 7d10be4 into scikit-learn:master Jan 29, 2020
@thomasjpfan
Copy link
Member

Thank you for the PR @alexshacked !

thomasjpfan pushed a commit to thomasjpfan/scikit-learn that referenced this pull request Feb 22, 2020
panpiort8 pushed a commit to panpiort8/scikit-learn that referenced this pull request Mar 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Risk of MemoryError when using Birch clustering algorithm with large datasets and a simple solution
4 participants