In [1]:
import os
import time
import numpy as np
import emcee
from multiprocessing import Pool

In [2]:
os.environ["OMP_NUM_THREADS"] = "1"

In [3]:
def log_prob(theta):
    t = time.time() + np.random.uniform(0.005, 0.008)
    while True:
        if time.time() >= t:
            break
    return -0.5*np.sum(theta**2)

In [4]:
np.random.seed(42)
initial = np.random.randn(32,5)
nwalkers, ndim = initial.shape
nsteps = 100

sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)
start = time.time()
sampler.run_mcmc(initial, nsteps, progress=True)
end = time.time()
serial_time = end - start
print("Serial took {0:.1f} seconds:".format(serial_time))

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:22<00:00,  4.46it/s]

Serial took 22.7 seconds:





In [None]:
with Pool() as pool:
    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, pool=pool)
    start = time.time()
    sampler.run_mcmc(initial, nsteps, progress=True)
    end = time.time()
    multi_time = end = start
    print("Multiprocessing took {0:.1f} seconds.".format(multi_time))
    print("{0:.1f} times faster than serial".format(serial_time / multi_time))