Skip to content

Commit

Permalink
Merge pull request #973 from pymc-devs/joblib
Browse files Browse the repository at this point in the history
Parallel processing using Joblib
  • Loading branch information
Chris Fonnesbeck committed Feb 23, 2016
2 parents 9be4642 + c8afbdd commit 24fb35c
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 72 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Expand Up @@ -13,7 +13,7 @@ before_install:
install:
- conda create -n testenv --yes pip python=$TRAVIS_PYTHON_VERSION
- source activate testenv
- conda install --yes jupyter pyzmq numpy scipy nose matplotlib pandas Cython patsy statsmodels
- conda install --yes jupyter pyzmq numpy scipy nose matplotlib pandas Cython patsy statsmodels joblib

- if [ ${TRAVIS_PYTHON_VERSION:0:1} == "2" ]; then conda install --yes mock enum34; fi
- pip install --no-deps numdifftools
Expand Down
106 changes: 74 additions & 32 deletions pymc3/examples/rugby_analytics.ipynb

Large diffs are not rendered by default.

66 changes: 28 additions & 38 deletions pymc3/sampling.py
@@ -1,14 +1,17 @@
from . import backends
from .backends.base import merge_traces, BaseTrace, MultiTrace
from .backends.ndarray import NDArray
import multiprocessing as mp
from joblib import Parallel, delayed
from time import time
from .core import *
from .step_methods import *
from .progressbar import progress_bar
from numpy.random import randint, seed
from collections import defaultdict

import sys
sys.setrecursionlimit(10000)

__all__ = ['sample', 'iter_sample', 'sample_ppc']

def assign_step_methods(model, step=None,
Expand Down Expand Up @@ -124,37 +127,26 @@ def sample(draws, step=None, start=None, trace=None, chain=0, njobs=1, tune=None
step = assign_step_methods(model, step)

if njobs is None:
import multiprocessing
njobs = max(mp.cpu_count() - 2, 1)
if njobs > 1:
try:
if not len(random_seed) == njobs:
random_seeds = [random_seed] * njobs
else:
random_seeds = random_seed
except TypeError: # None, int
random_seeds = [random_seed] * njobs

chains = list(range(chain, chain + njobs))

pbars = [progressbar] + [False] * (njobs - 1)

argset = zip([draws] * njobs,
[step] * njobs,
[start] * njobs,
[trace] * njobs,
chains,
[tune] * njobs,
pbars,
[model] * njobs,
random_seeds)
argset = list(argset)

sample_args = {'draws':draws,
'step':step,
'start':start,
'trace':trace,
'chain':chain,
'tune':tune,
'progressbar':progressbar,
'model':model,
'random_seed':random_seed}

if njobs>1:
sample_func = _mp_sample
sample_args = [njobs, argset]
sample_args['njobs'] = njobs
else:
sample_func = _sample
sample_args = [draws, step, start, trace, chain,
tune, progressbar, model, random_seed]
return sample_func(*sample_args)

return sample_func(**sample_args)


def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
Expand Down Expand Up @@ -273,10 +265,14 @@ def _choose_backend(trace, chain, shortcuts=None, **kwds):
raise ValueError('Argument `trace` is invalid.')


def _mp_sample(njobs, args):
p = mp.Pool(njobs)
traces = p.map(argsample, args)
p.close()
def _mp_sample(**kwargs):
njobs = kwargs.pop('njobs')
chain = kwargs.pop('chain')
chains = list(range(chain, chain + njobs))
pbars = [kwargs.pop('progressbar')] + [False] * (njobs - 1)
traces = Parallel(n_jobs=njobs)(delayed(_sample)(chain=chains[i],
progressbar=pbars[i],
**kwargs) for i in range(njobs))
return merge_traces(traces)


Expand All @@ -291,12 +287,6 @@ def stop_tuning(step):

return step


def argsample(args):
""" defined at top level so it can be pickled"""
return _sample(*args)


def _soft_update(a, b):
"""As opposed to dict.update, don't overwrite keys if present.
"""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -27,7 +27,7 @@
'Operating System :: OS Independent']

install_reqs = ['numpy>=1.7.1', 'scipy>=0.12.0', 'matplotlib>=1.2.1',
'Theano<=0.7.1dev', 'pandas>=0.15.0', 'patsy>=0.4.0']
'Theano<=0.7.1dev', 'pandas>=0.15.0', 'patsy>=0.4.0', 'joblib>=0.9']
if sys.version_info < (3, 4):
install_reqs.append('enum34')

Expand Down

0 comments on commit 24fb35c

Please sign in to comment.