https://emcee.readthedocs.io/en/stable/tutorials/parallel/

In [1]:
import os
os.environ["OMP_NUM_THREADS"] = "1"

In [2]:
import time
import numpy as np


def log_prob(theta):
    t = time.time() + np.random.uniform(0.005, 0.008)
    while True:
        if time.time() >= t:
            break
    return -0.5 * np.sum(theta**2)

In [3]:
import emcee

np.random.seed(42)
initial = np.random.randn(32, 5)
nwalkers, ndim = initial.shape
nsteps = 100

sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)
start = time.time()
sampler.run_mcmc(initial, nsteps, progress=True)
end = time.time()
serial_time = end - start
print("Serial took {0:.1f} seconds".format(serial_time))

You must install the tqdm library to use progress indicators with emcee


Serial took 21.7 seconds


In [4]:
from multiprocessing import Pool

with Pool() as pool:
    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, pool=pool)
    start = time.time()
    sampler.run_mcmc(initial, nsteps, progress=True)
    end = time.time()
    multi_time = end - start
    print("Multiprocessing took {0:.1f} seconds".format(multi_time))
    print("{0:.1f} times faster than serial".format(serial_time / multi_time))

You must install the tqdm library to use progress indicators with emcee


Multiprocessing took 20.4 seconds
1.1 times faster than serial


In [None]:
from executorlib import Executor

with Executor() as exe:
    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, pool=exe)
    start = time.time()
    sampler.run_mcmc(initial, nsteps, progress=True)
    end = time.time()
    multi_time = end - start
    print("Multiprocessing took {0:.1f} seconds".format(multi_time))
    print("{0:.1f} times faster than serial".format(serial_time / multi_time))

You must install the tqdm library to use progress indicators with emcee


In [None]:
from executorlib import Executor

with Executor(max_workers=8, backend="local", block_allocation=True) as exe:
    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, pool=exe)
    start = time.time()
    sampler.run_mcmc(initial, nsteps, progress=True)
    end = time.time()
    multi_time = end - start
    print("Multiprocessing took {0:.1f} seconds".format(multi_time))
    print("{0:.1f} times faster than serial".format(serial_time / multi_time))