From 299c006330c28f0b6e21a959354e45f604a2a686 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 17 Oct 2017 14:56:47 +0200 Subject: [PATCH] Sample sequentially when pickle fails --- pymc3/sampling.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 74f741c3a50..c2d647d48bc 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -1,4 +1,5 @@ from collections import defaultdict +import pickle from joblib import Parallel, delayed import numpy as np @@ -318,19 +319,23 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, 'random_seed': random_seed, 'live_plot': live_plot, 'live_plot_kwargs': live_plot_kwargs, + 'njobs': njobs, } sample_args.update(kwargs) - if njobs > 1 and chains > 1: - sample_func = _mp_sample - sample_args['njobs'] = njobs - else: - sample_func = _sample_many + parallel = njobs > 1 and chains > 1 + if parallel: + try: + trace = _mp_sample(**sample_args) + except pickle.PickleError: + pm._log.warn("Could not pickle model, sampling sequentially.") + parallel = False + if not parallel: + trace = _sample_many(**sample_args) discard = tune if discard_tuned_samples else 0 - - return sample_func(**sample_args)[discard:] + return trace[discard:] def _check_start_shape(model, start):