Skip to content

Commit

Permalink
[MRG] FIX Avoid accumulating forest predictions in non-threadsafe man…
Browse files Browse the repository at this point in the history
…ner (#9830)
  • Loading branch information
jnothman committed Oct 17, 2017
1 parent 83411db commit 95fcde8
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions sklearn/ensemble/forest.py
Expand Up @@ -43,6 +43,7 @@ class calls the ``fit`` method of each sub-estimator on random samples

import warnings
from warnings import warn
import threading

from abc import ABCMeta, abstractmethod
import numpy as np
Expand Down Expand Up @@ -378,13 +379,14 @@ def feature_importances_(self):
# ForestClassifier or ForestRegressor, because joblib complains that it cannot
# pickle it when placed there.

def accumulate_prediction(predict, X, out):
def accumulate_prediction(predict, X, out, lock):
prediction = predict(X, check_input=False)
if len(out) == 1:
out[0] += prediction
else:
for i in range(len(out)):
out[i] += prediction[i]
with lock:
if len(out) == 1:
out[0] += prediction
else:
for i in range(len(out)):
out[i] += prediction[i]


class ForestClassifier(six.with_metaclass(ABCMeta, BaseForest,
Expand Down Expand Up @@ -581,8 +583,9 @@ class in a leaf.
# avoid storing the output of every estimator by summing them here
all_proba = [np.zeros((X.shape[0], j), dtype=np.float64)
for j in np.atleast_1d(self.n_classes_)]
lock = threading.Lock()
Parallel(n_jobs=n_jobs, verbose=self.verbose, backend="threading")(
delayed(accumulate_prediction)(e.predict_proba, X, all_proba)
delayed(accumulate_prediction)(e.predict_proba, X, all_proba, lock)
for e in self.estimators_)

for proba in all_proba:
Expand Down Expand Up @@ -687,8 +690,9 @@ def predict(self, X):
y_hat = np.zeros((X.shape[0]), dtype=np.float64)

# Parallel loop
lock = threading.Lock()
Parallel(n_jobs=n_jobs, verbose=self.verbose, backend="threading")(
delayed(accumulate_prediction)(e.predict, X, [y_hat])
delayed(accumulate_prediction)(e.predict, X, [y_hat], lock)
for e in self.estimators_)

y_hat /= len(self.estimators_)
Expand Down

0 comments on commit 95fcde8

Please sign in to comment.