Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Massive memory usage by parallel RandomForestClassifier #936

Closed
jni opened this issue Jul 6, 2012 · 32 comments
Closed

Massive memory usage by parallel RandomForestClassifier #936

jni opened this issue Jul 6, 2012 · 32 comments

Comments

@jni
Copy link

jni commented Jul 6, 2012

I think this will be hard to fix without swapping out joblib (or maybe even the GIL ;), but basically the amount of memory used by RandomForestClassifier is exorbitant for n_jobs > 1. In my case, I have a dataset of about 1GB (300,000 samples by 415 features by 64-bit float), but doing fit() on a RandomForestClassifier having n_jobs=16 results in 45GB of memory being used.

Does anyone have any ideas or is this hopeless without moving everything to C?

@glouppe
Copy link
Contributor

glouppe commented Jul 7, 2012

The problem is that two copies of X (X and X_argsorted) are made for each job.

You cant circumvent that by putting X into shared memory. I did that in one my branches:
https://github.com/glouppe/scikit-learn/blob/cytomine/sklearn/ensemble/forest.py#L274

I used that module: https://bitbucket.org/cleemesser/numpy-sharedmem/issue/2/sharedmemory_sysvso-not-added-correctly-to

It was not put into master because of this additional depedency though.

@ogrisel
Copy link
Member

ogrisel commented Jul 7, 2012

To use shared memory you can memory map your input set with joblib:

from sklearn.externals import joblib

filename = '/tmp/dataset.joblib'
joblib.dump(np.asfortranarray(X), filename)
X = joblib.load(filename, mmap_mode='c')

IIRC the random forest model need a fortran layout data to work efficiently hence the call to np.asfortranarray before serialization on the disk.

@ogrisel
Copy link
Member

ogrisel commented Jul 7, 2012

BTW @glouppe is the above strategy works as expected it would be great to make the RandomForest/ExtraTrees* classes able to do it automatically using a mmap_folder=None parameter. If provided, memmapping is used of n_jobs > 1.

@ogrisel
Copy link
Member

ogrisel commented Jul 9, 2012

@jni any news on this? Have you tried any of the afore-mentionned solutions? If that work for you we should devise a way to make it simpler to implement or at least better documented.

@jni
Copy link
Author

jni commented Jul 9, 2012

Haven't tried it, busy weekend — I'll do it today! Thanks!

@jni
Copy link
Author

jni commented Jul 9, 2012

Ok, two failures to report.

First, I tried to combine @glouppe's code with @ogrisel's joblib modification. This crashed and burned and anyway didn't seem to much affect memory usage: it was up to 20GB before it crashed. I've made a gist with the diff with scikit-learn 0.11.X and the error for njobs=2.

I then tried @glouppe's cytomine branch directly, after installing sharedmem, but this also failed for some unknown reason.

... Any ideas?

@amueller
Copy link
Member

amueller commented Jul 9, 2012

@jni have you tried playing with "min_density"? That can really affect memory usage (also cpu usage though in a non-linear way).

@jni
Copy link
Author

jni commented Jul 9, 2012

@amueller, I just tried min_density=0.001 and it still goes to 45GB on my dataset. By my understanding, the original data gets replicated to each process before any parameters like min_density take effect, so those parameters will only minimally affect the total memory footprint, which is dominated by njobs * (dataset_size) * c, with c in [2, 3]. (2 according to the above comments, but closer to 3 empirically. ;)

@jni
Copy link
Author

jni commented Jul 9, 2012

I stand corrected: usage is >50GB with njobs=8 and min_density=1.0. But it looks like I can't make it any lower.

@amueller
Copy link
Member

amueller commented Jul 9, 2012

hm ok. was just an idea. don't know where the additional copy comes from (2 instead of 3). You did take care of the memory layout, right?

@amueller
Copy link
Member

amueller commented Jul 9, 2012

I meant if you use fortran or c ordering. IIRC the forests want fortran ordering, so if you provide c-ordered arrays, they'll make a copy.

@jni
Copy link
Author

jni commented Jul 9, 2012

I thought that's what you might have meant, so I ran rf.fit(np.asfortranarray(...)) instead, but same result (about 3x). =\

@glouppe
Copy link
Contributor

glouppe commented Jul 10, 2012

For my branch to work, you need

  1. to use bootstrap=False in your forests

  2. Increase SHMMAX if you intend to share a large block of memory.

  • On OSX:
    sudo sysctl -w kern.sysv.shmmax=YOUR_VALUE
    sudo sysctl -w kern.sysv.shmall=YOUR_VALUE
  • On Linux
    sudo sysctl -w kernel.shmmax=YOUR_VALUE
    sudo sysctl -w kernel.shmall=YOUR_VALUE
    sudo sysctl -p /etc/sysctl.conf

Hope this helps!

@jni
Copy link
Author

jni commented Jul 10, 2012

Thanks, @glouppe! This'll help me but I'm disappointed it's of limited use if it won't make it into the scikit proper... It seems to me that if I can get it to work, this can be an optional dependency... I often use the following pattern:

try:
    import sharedmem as shm
    shm_available = True
except ImportError:
    logging.warning('sharedmem library is not available')
    shm_available = False

Otherwise, I would still be interested in getting joblib persistence to work...

Secondly, this may explain the close-to-3x memory usage when my data is copied, since it's not float32. It's probably a good idea to coerce the data within BaseForest before it is copied by multiprocessing, only to be coerced inside each tree. This'll bring the memory usage way down, and speed things up too I'm sure.

I'll run these experiments this afternoon and report back. Thanks everyone!

@ogrisel
Copy link
Member

ogrisel commented Jul 10, 2012

@jni could you please report the error message you get with memory mapped file solution, or better tell me if this is the same as the following? https://gist.github.com/3084146

If so, I will try to give it a deeper look by trying to reproduce it in a pure joblib context, outside of scikit-learn. Maybe @GaelVaroquaux has an idea on the cause of the problem.

@jni
Copy link
Author

jni commented Jul 10, 2012

@ogrisel, yes, it's the same error as you pointed out.

@ogrisel
Copy link
Member

ogrisel commented Jul 10, 2012

Alright I'll try and see if there is an easy solution to fix this problem tonight unless @GaelVaroquaux or someone else does it first.

@ogrisel
Copy link
Member

ogrisel commented Jul 10, 2012

@jni are you running unix (linux or OSX)? If so maybe just putting the data into the right memory layout before calling the fit method of the random forest might work thanks to the copy on write semantics of the unix fork that backs the multiprocessing module of the standard library on those platforms:

#!/usr/bin/env python

import numpy as np
from sklearn.datasets.samples_generator import make_classification
from sklearn.externals import joblib
from sklearn.ensemble import RandomForestClassifier

print "generating dataset"
X, y = make_classification(n_samples=100000, n_features=500)

print "put data in the right layout"
X = np.asarray(X, dtype=np.float32, order='F')

print "fitting random forest:"
clf = RandomForestClassifier(n_estimators=100, n_jobs=2)
print clf.fit(X, y).score(X, y)

Can you tell us if it solves your issue?

Thanks to @larsmans for the heads up on COW unix forks.

@glouppe
Copy link
Contributor

glouppe commented Jul 10, 2012

@jni In the end, does this work with my branch? I put this together a few months ago and it indeed solved my problems (as long as bootstrap=False). At the time I remember that I added additional checks to avoid coerce data several times (i.e only once for the whole forest).

@jni
Copy link
Author

jni commented Jul 10, 2012

@ogrisel, I'm on Linux (Fedora Core 16). I'm aware of the idea of copy-on-write in Unix fork(), but in my experience I have never been able to capitalise on it in Python. I believe we're running into the problem detailed here, namely, objects might not change but the Python interpreter is changing the metadata of an object (e.g. the reference count), which results in the whole object getting copied.

To illustrate; some setup:

import numpy as np
from ray import classify # this is my own library
dat5 = classify.load_training_data_from_disk('training/multi-channel-graph-e05-5.trdat.h5')
from sklearn.ensemble import RandomForestClassifier
# using @glouppe's branch
features = np.asarray(dat5[0], dtype=np.float32, order='F')
labels = np.asarray(dat5[1][:, 0], dtype=np.float32)
features.shape
# (299351, 415)
float(features.nbytes) / 2**30
# 0.46279529109597206
labels.shape
np.unique(labels)
# array([-1.,  1.], dtype=float32)

Now we try with and without @glouppe's shared memory. If COW was working, there should be no difference in memory usage. But!

rf = RandomForestClassifier(100, max_depth=20, n_jobs=16, shared=False, bootstrap=False)
rf = rf.fit(features, labels)
# about 1GB/process
rf = RandomForestClassifier(100, max_depth=20, n_jobs=16, shared=True, bootstrap=False)
rf = rf.fit(features, labels)
# about 100MB/process!!!

So, in conclusion:

  • copy-on-write is a lovely theoretical construct that appears to fall on its face in Python in general and in sklearn.ensemble.RandomForestClassifier in particular
  • @glouppe's shared memory implementation is incredibly awesome, valuable (thanks @glouppe!). Incidentally, there was no time penalty in using shm.
  • It should probably be a priority to get some kind of shared memory implementation for fit()... Any sklearn admins care to comment?

@ogrisel
Copy link
Member

ogrisel commented Jul 10, 2012

Thanks for the COW check. It's good to know that it's not working and that it's not fixable. For the shm module, we would rather having to avoid the maintenance burden of an external dependency (furthermore it's probably quite experimental and not guaranteed to work on other platforms).

I would rather like a solution based on numpy.memmap'ed arrays or the multiprocessing.Array that are maintained outside of scikit-learn, in existing dependencies (numpy and the standard python library respectively).

For the pbm with numpy memmap, it's seems to be a known bug (a regression in numpy 1.5+):

http://projects.scipy.org/numpy/ticket/1809

Would be great to find a fix and then backport it into the sklearn.utils.fixes module as a monkey patch.

@ogrisel
Copy link
Member

ogrisel commented Jul 10, 2012

As for the joblib.load with mmap_mode issue, the real underlying issue is that numpy.memmap arrays are curently not correctly picklable (known old issue). As it does not seem that easy to fix, we could improve the Parallel.dispatch and the SafeFunction.__call__ methods of joblib.parallel to detect kwards that are instances of numpy.memmap and recreate new instances on the new process using a simple wrap / unwrap logic.

@ogrisel
Copy link
Member

ogrisel commented Jul 15, 2012

Hi @jni, FYI I have started a new branch in joblib to add support for numpy.memmap arrays to joblib.Parallel here:
joblib/joblib#40

This is not yet used in scikit-learn though: the embedded joblib version in scikit-learn will need to get synchronized with upstream once this PR is merged.

@jni
Copy link
Author

jni commented Jul 15, 2012

Thanks @ogrisel! Is it actually fixed? i.e. do I just need to replace the bundled version with this branch? Or still working on it?

@ogrisel
Copy link
Member

ogrisel commented Jul 15, 2012

It should be fixed in my joblib branch. You can try to do the swap manually but I am not sure if other recent changes in joblib will impact its use in scikit-learn (it probably should not) as I have not tested myself yet.

Then you can try something as:

#!/usr/bin/env python

import numpy as np
from sklearn.datasets.samples_generator import make_classification
from sklearn.externals import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib

print "generating dataset"
X, y = make_classification(n_samples=100000, n_features=500)

filename = '/tmp/dataset.joblib'
print "put data in the right layout and map to " + filename
joblib.dump(np.asarray(X, dtype=np.float32, order='F'), filename)
X = joblib.load(filename, mmap_mode='c')

print "fitting random forest:"
clf = RandomForestClassifier(n_estimators=100, n_jobs=2)
print clf.fit(X, y).score(X, y)

@ogrisel
Copy link
Member

ogrisel commented Jul 15, 2012

I have tried the previous script, I don't get the previous error anymore by the memory usage does not seem to stay constant when I increase n_jobs so there might be another part of the code triggering a memory copy. I will do more tests just with joblib.

@glouppe
Copy link
Contributor

glouppe commented Jul 15, 2012

@ogrisel : This comes from RandomForestClassifier.

  1. Bootstrap samples are made from copies of X.
  2. X_argsorted should be put into a memmap array.

Note that I also fixed some useless data copies in #946 (not yet in master).

@bdholt1
Copy link
Member

bdholt1 commented Jul 25, 2012

@ogrisel I'm not sure if I'm doing something wrong here, but I've just imported your joblib.external.parallel.py fixes and run the above code #936 (comment).

generating dataset
put data in the right layout and map to /tmp/dataset.joblib
fitting random forest:
Traceback (most recent call last):
  File "test_forest_parallel_joblib.py", line 18, in <module>
    print clf.fit(X, y).score(X, y)
  File "/vol/vssp/signsrc/brian/python/scikit-learn/sklearn/ensemble/forest.py", line 288, in fit
    for i in xrange(n_jobs))
  File "/vol/vssp/signsrc/brian/python/scikit-learn/sklearn/externals/joblib/parallel.py", line 565, in __call__
    self.dispatch(function, args, kwargs)
  File "/vol/vssp/signsrc/brian/python/scikit-learn/sklearn/externals/joblib/parallel.py", line 397, in dispatch
    args, kwargs = wrap_mmap_args(args, kwargs)
  File "/vol/vssp/signsrc/brian/python/scikit-learn/sklearn/externals/joblib/parallel.py", line 94, in wrap_mmap_args
    args = [wrap_mmap(a) for a in args]
  File "/vol/vssp/signsrc/brian/python/scikit-learn/sklearn/externals/joblib/parallel.py", line 83, in wrap_mmap
    return WrappedMemmapArray(obj)
  File "/vol/vssp/signsrc/brian/python/scikit-learn/sklearn/externals/joblib/parallel.py", line 53, in __init__
    self.filename = mmap_array.filename
AttributeError: 'memmap' object has no attribute 'filename'

I'd love to get the shared memory working using your method instead of the sharedmem package...

@ogrisel
Copy link
Member

ogrisel commented Jul 26, 2012

Which branch are you using? I have started a new branch with another approach:

joblib/joblib#43

I am still working on it though.

@ogrisel
Copy link
Member

ogrisel commented Sep 2, 2012

FYI: I am working on a new approach to efficiently deal with shared memory with joblib.parallel / multiprocessing and numpy.memmap here: joblib/joblib#44

@ogrisel
Copy link
Member

ogrisel commented Sep 9, 2012

@jni @bdholt1 I think my pull request is in a workable state ready for testing on your use cases:

joblib/joblib#44

To tests you can just replace the joblib embedded within the sklearn source tree with the one that comes from this repo / branch:

git clone https://github.com/ogrisel/joblib.git
(cd joblib && git checkout pickling-pool)
rm -rf scikit-learn/sklearn/externals/joblib
ln -s `pwd`/joblib/joblib scikit-learn/sklearn/externals/joblib

With this drop-in replacement, any numpy array larger than 1MB passed as argument to a joblib.Parallel operation will be dumped to a temporary file for being exposed as shared memory (instance of a copy-on-write numpy.memmap) which should heavily reduce memory usage and also speed up the process by reducing the number of redundant memory allocations especially for large readonly data and large values of n_jobs.

Please feel free to report any issue directly as comments to joblib/joblib#44 .

@glouppe
Copy link
Contributor

glouppe commented Jul 22, 2013

Clone of #2179.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants