Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Avoid reference cycles in Tree #2790

Merged
merged 3 commits into from
Jan 25, 2014

Conversation

jnothman
Copy link
Member

Fixes #2787.

This is a WIP because we need to check if wrapping value with a ndarray at predict time is too expensive in time. If so, we can implement our own take using memcpy (which I'd rather over reverting to the version where predict duplicates apply) but for the moment I'm having trouble getting that to work...

This also needs to be tested for any further memory leaks.

@coveralls
Copy link

Coverage Status

Coverage remained the same when pulling 86fbebd on jnothman:tree_without_cycles into f8e7cf1 on scikit-learn:master.

@glouppe
Copy link
Contributor

glouppe commented Jan 24, 2014

This is great! Thanks for figuring this out @jnothman

@ogrisel Could you have a look to confirm that the leak is gone?

@glouppe
Copy link
Contributor

glouppe commented Jan 24, 2014

On my box, using @ogrisel script, I cannot reproduce the leak.

@glouppe
Copy link
Contributor

glouppe commented Jan 24, 2014

Regarding predict, I did a quick benchmark using GradientBoostingRegressor with n_estimators=1000 and I cannot observe any significant performance decrease of predict execution time.

So overall, I am +1 for this. Any second opinion?

Thanks for the patch!

@jnothman
Copy link
Member Author

GradientBoostingRegressor reimplements predict, so I don't think that's a useful assessment! RandomForest might be worth trying.

1000 trees fit on 50 features:

Predicting 1e5 samples, best of 3
At master: 9.2 s
At 86fbebd: 9.48 s

Similar 2-3% time increase for smaller samples. I assume this is not big, but not insignificant..? We can in any case accept this fix and speed up predict in another patch.

@glouppe
Copy link
Contributor

glouppe commented Jan 24, 2014

GradientBoostingRegressor reimplements predict, so I don't think that's a useful assessment!

Whoops :)

Similar 2-3% time increase for smaller samples. I assume this is not big, but not insignificant..? We can in any case accept this fix and speed up predict in another patch.

+1 for that.

@ogrisel
Copy link
Member

ogrisel commented Jan 24, 2014

Weird, I still get the leak when I run the script of #2787 but now the leak is "stabilizing" at the last iteration:

(py27)0 [~/code/scikit-learn (pr/2790)]$ make in 1>/dev/null 2>&1 && python ~/tmp/check_memleak.py
70MB
155MB
234MB
234MB

on master I get:

(py27)0 [~/code/scikit-learn (master)]$ make in 1>/dev/null 2>&1 && python ~/tmp/check_memleak.py
70MB
155MB
235MB
315MB

@ogrisel
Copy link
Member

ogrisel commented Jan 24, 2014

@jnothman does the script run without leaking on your box? I ran it on OSX with a Python 2.7 after a make clean. I also ran the tests on your branch and they pass.

BTW, maybe we could use PyMem_Alloc and friends instead of libc.malloc. This we could use tracemalloc from Python HEAD to track the remaining leak: http://www.python.org/dev/peps/pep-0454/

@ogrisel
Copy link
Member

ogrisel commented Jan 24, 2014

Actually we should stick to libc.malloc for Python < 3.4 and use PyMem_Alloc only for Python 3.4+ as PyMem_Alloc requires the GIL up to Python 3.3: http://www.python.org/dev/peps/pep-0445/#gil-free-pymem-malloc

It might be possible to select the right malloc at compile time with a preprocessor macro in our own .h file depending on the Python version using those constants: http://stackoverflow.com/a/12348545/163740

@GaelVaroquaux
Copy link
Member

Actually we should stick to libc.malloc for Python < 3.4 and use PyMem_Alloc
only for Python 3.4+ as PyMem_Alloc

This is going to add a lot of complexity for little gains, I believe.

@ogrisel
Copy link
Member

ogrisel commented Jan 24, 2014

This is going to add a lot of complexity for little gains, I believe.

Well it would give us very fine control over memory management by making it possible to use the tracemalloc tool that supports differential snapshots to spot leaks along with the traceback of the faulty allocation.

@glouppe
Copy link
Contributor

glouppe commented Jan 24, 2014

Actually we should stick to libc.malloc for Python < 3.4 and use PyMem_Alloc only for Python 3.4+ as PyMem_Alloc

This is going to add a lot of complexity for little gains, I believe.

I agree. Let's not make this more complicated than it has already become.

@glouppe
Copy link
Contributor

glouppe commented Jan 24, 2014

@ogrisel This is really odd. Here are the results on my box:

# this branch 
 % python leak.py 
60MB
61MB
61MB
61MB

# master
 % python leak.py                                                      
60MB
140MB
220MB
300MB

return
elif ret == 1:
raise RuntimeError('Cannot resize tree after building is complete')
raise MemoryError()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't you simply check self.locked here and not change the semantics of _resize_c?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, can you comment on why we need this at all? This is internal API and the case this additional test covers should never happen. There is no need to be defensive in my opinion.

@ogrisel
Copy link
Member

ogrisel commented Jan 24, 2014

I re-pulled the tree_without_cycles branch from @jnothman's repo to make sure I had the right code, I still leak some memory on the first 2 iterations of my script:

(py27)0 [~/code/scikit-learn (tree_without_cycles)]$ python ~/tmp/check_memleak.py
70MB
155MB
234MB
234MB
234MB
234MB
234MB
234MB
234MB
234MB
234MB

@arjoly
Copy link
Member

arjoly commented Jan 24, 2014

I got the following results.

ajoly at arnaud-joly-002 in ~/git/scikit-learn on 34c4908! # 0.14.1
(sklearn) ± python leak.py     
70MB
115MB
155MB
195MB
235MB
275MB
275MB
275MB
275MB
275MB
275MB
ajoly at arnaud-joly-002 in ~/git/scikit-learn on master!
(sklearn) ± python leak.py
69MB
154MB
235MB
315MB
395MB
475MB
555MB
635MB
715MB
ajoly at arnaud-joly-002 in ~/git/scikit-learn on 86fbebd!
(sklearn) ± python leak.py               
70MB
155MB
234MB
234MB
234MB
234MB
234MB
234MB
234MB
234MB
234MB

@arjoly
Copy link
Member

arjoly commented Jan 24, 2014

I saw the same behavior for all branch.

It looks like that tree from 0.14.1 consumes less memory than master. :(

@ogrisel
Copy link
Member

ogrisel commented Jan 24, 2014

Ok so @arjoly you have increasing yet stabilizing memory usage on 0.14.1 as well. So the leak might be fixed in this branch in the end. Python might be refusing to free some RSS for some other reason.

@arjoly
Copy link
Member

arjoly commented Jan 24, 2014

I saw the same behavior for all branch.

I wasn't synced with master :-(
I corrected the bench.

@ogrisel
Copy link
Member

ogrisel commented Jan 24, 2014

@arjoly On master you have the leak: it never stabilizes. On the other branch (0.14.1 and @jnothman's fix) you reach stability.

@ogrisel
Copy link
Member

ogrisel commented Jan 24, 2014

Here is a new version of my script that tracks leaking objects (without references):

import gc
import os
import psutil
import objgraph
import numpy as np
from sklearn.tree import ExtraTreeRegressor

X = np.random.normal(size=(100, 50))
Y = np.random.normal(size=(100, int(5e4)))

p = psutil.Process(os.getpid())

initially_without_ref = objgraph.get_leaking_objects()

def print_mem():
    print("{:.0f}MB".format(p.get_memory_info().rss / 1e6))
    currently_without_ref = objgraph.get_leaking_objects()
    print([o for o in currently_without_ref
             if o not in initially_without_ref])

print_mem()

for i in range(3):
    et = ExtraTreeRegressor(max_features=1).fit(X, Y)
    del et
    gc.collect()
    print_mem()

Now the results:

  • on this branch:
73MB
[]
159MB
[<listiterator object at 0x10d744e10>, {140694180998368: [140694180488784]}]
159MB
[<listiterator object at 0x10d744e10>, {140694180998368: [140694180488784]}]
159MB
[<listiterator object at 0x10d744e10>, {140694180998368: [140694180488784]}]

I don't know what 140694180998368 is but it's there just once, so it might be an artifact of the functions calls or the loop. The listiterator is not a real leak, it's expected when we enter the loop.

  • on master, we are leaking sklearn._tree.Tree instances:
73MB
[]
158MB
[<listiterator object at 0x1044cce10>, <sklearn.tree._tree.Tree object at 0x1061c7650>, {140266067200112: [140266066904800]}]
238MB
[<listiterator object at 0x1044cce10>, <sklearn.tree._tree.Tree object at 0x1061c7650>, {140266067200112: [140266066904800]}, <sklearn.tree._tree.Tree object at 0x1061c7710>]
318MB
[<listiterator object at 0x1044cce10>, <sklearn.tree._tree.Tree object at 0x1061c7650>, {140266067200112: [140266066904800]}, <sklearn.tree._tree.Tree object at 0x1061c7710>, <sklearn.tree._tree.Tree object at 0x1061c77d0>]
  • on 0.14.1:
74MB
[]
119MB
[<listiterator object at 0x1087b6990>, {140731809630416: [140731782125488]}]
159MB
[<listiterator object at 0x1087b6990>, {140731809630416: [140731782125488]}]
159MB
[<listiterator object at 0x1087b6990>, {140731809630416: [140731782125488]}]

So in the end this branch seems to correctly fix the regression introduced in master.

@arjoly
Copy link
Member

arjoly commented Jan 24, 2014

Could someone explain me why we can't do as before and wrapped the c array into a numpy array at the last moment?

@glouppe
Copy link
Contributor

glouppe commented Jan 24, 2014

@ogrisel Thanks for the check. So in conclusions, this PR correctly fixes the issue.

self.nodes = NULL
self.locked = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new attribute would deserve an inline comment to explain the motivation and how it is meant to be used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove it (see my comment above).

@ogrisel
Copy link
Member

ogrisel commented Jan 24, 2014

@ogrisel Thanks for the check. So in conclusions, this PR correctly fixes the issue.

Yes but before merging I think this PR should at least have more inline comments to explain the reference counting magic that happens under the hood.


# XXX using (size_t)(-1) is ugly, but SIZE_MAX is not available in C89
# (i.e., older MSVC).
cdef int _resize_c(self, SIZE_t capacity=<SIZE_t>(-1)) nogil:
"""Guts of _resize. Returns 0 for success, -1 for error."""
if self.locked:
return 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if 0 is success and -1 is error -- what is 1?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to get rid of locked, but will leave a note in there for anyone
who tries to add post-building tree expansion. Comments coming.

On 25 January 2014 02:17, Peter Prettenhofer notifications@github.comwrote:

In sklearn/tree/_tree.pyx:

 # XXX using (size_t)(-1) is ugly, but SIZE_MAX is not available in C89
 # (i.e., older MSVC).
 cdef int _resize_c(self, SIZE_t capacity=<SIZE_t>(-1)) nogil:
     """Guts of _resize. Returns 0 for success, -1 for error."""
  •    if self.locked:
    
  •        return 1
    

if 0 is success and -1 is error -- what is 1?


Reply to this email directly or view it on GitHubhttps://github.com//pull/2790/files#r9150939
.

@coveralls
Copy link

Coverage Status

Coverage remained the same when pulling 87194f7 on jnothman:tree_without_cycles into f203953 on scikit-learn:master.

@glouppe
Copy link
Contributor

glouppe commented Jan 25, 2014

Awesome, thanks for your work Joel! I am merging this.

glouppe added a commit that referenced this pull request Jan 25, 2014
[MRG] Avoid reference cycles in Tree
@glouppe glouppe merged commit e8fdfa6 into scikit-learn:master Jan 25, 2014
@jnothman
Copy link
Member Author

Great! Let's hope we've ironed out everything :)

On the topic of which, implementing Tree.predict without _get_value_ndarray
(by reimplementing np.take, basically) doesn't appear to make it faster.

On 26 January 2014 00:15, Gilles Louppe notifications@github.com wrote:

Merged #2790 #2790.


Reply to this email directly or view it on GitHubhttps://github.com//pull/2790
.

@larsmans
Copy link
Member

Sorry for not being in time for a review, but this looks like an elegant solution. Thanks!

@arjoly
Copy link
Member

arjoly commented Jan 25, 2014

Thanks for the fix !!!

@ogrisel
Copy link
Member

ogrisel commented Jan 26, 2014

Thanks for the fix @jnothman.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Regression] Memory Leak in decision trees
8 participants