Skip to content

Commit

Permalink
Merge b57c144 into 3cb4570
Browse files Browse the repository at this point in the history
  • Loading branch information
springcoil committed Oct 29, 2018
2 parents 3cb4570 + b57c144 commit 5451a31
Show file tree
Hide file tree
Showing 8 changed files with 986 additions and 333 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Expand Up @@ -9,6 +9,7 @@
- Add Incomplete Beta function `incomplete_beta(a, b, value)`
- Add log CDF functions to continuous distributions: `Beta`, `Cauchy`, `ExGaussian`, `Exponential`, `Flat`, `Gumbel`, `HalfCauchy`, `HalfFlat`, `HalfNormal`, `Laplace`, `Logistic`, `Lognormal`, `Normal`, `Pareto`, `StudentT`, `Triangular`, `Uniform`, `Wald`, `Weibull`.
- Behavior of `sample_posterior_predictive` is now to produce posterior predictive samples, in order, from all values of the `trace`. Previously, by default it would produce 1 chain worth of samples, using a random selection from the `trace` (#3212)
- Show diagnostics for initial energy errors in HMC and NUTS.

### Maintenance

Expand Down
1 change: 1 addition & 0 deletions pymc3/backends/report.py
Expand Up @@ -19,6 +19,7 @@ class WarningType(enum.Enum):
# Indications that chains did not converge, eg Rhat
CONVERGENCE = 6
BAD_ACCEPTANCE = 7
BAD_ENERGY = 8


SamplerWarning = namedtuple(
Expand Down
131 changes: 85 additions & 46 deletions pymc3/parallel_sampling.py
Expand Up @@ -5,13 +5,23 @@
import logging
from collections import namedtuple
import traceback
from pymc3.exceptions import SamplingError

import six
import numpy as np

from . import theanof

logger = logging.getLogger('pymc3')
logger = logging.getLogger("pymc3")


class ParallelSamplingError(Exception):
def __init__(self, message, chain, warnings=None):
super(ParallelSamplingError, self).__init__(message)
if warnings is None:
warnings = []
self._chain = chain
self._warnings = warnings


# Taken from https://hg.python.org/cpython/rev/c4f92b597074
Expand All @@ -26,7 +36,7 @@ def __str__(self):
class ExceptionWithTraceback:
def __init__(self, exc, tb):
tb = traceback.format_exception(type(exc), exc, tb)
tb = ''.join(tb)
tb = "".join(tb)
self.exc = exc
self.tb = '\n"""\n%s"""' % tb

Expand All @@ -40,8 +50,8 @@ def rebuild_exc(exc, tb):


# Messages
# ('writing_done', is_last, sample_idx, tuning, stats)
# ('error', *exception_info)
# ('writing_done', is_last, sample_idx, tuning, stats, warns)
# ('error', warnings, *exception_info)

# ('abort', reason)
# ('write_next',)
Expand All @@ -50,12 +60,11 @@ def rebuild_exc(exc, tb):

class _Process(multiprocessing.Process):
"""Seperate process for each chain.
We communicate with the main process using a pipe,
and send finished samples using shared memory.
"""
def __init__(self, name, msg_pipe, step_method, shared_point,
draws, tune, seed):

def __init__(self, name, msg_pipe, step_method, shared_point, draws, tune, seed):
super(_Process, self).__init__(daemon=True, name=name)
self._msg_pipe = msg_pipe
self._step_method = step_method
Expand All @@ -75,7 +84,7 @@ def run(self):
pass
except BaseException as e:
e = ExceptionWithTraceback(e, e.__traceback__)
self._msg_pipe.send(('error', e))
self._msg_pipe.send(("error", None, e))
finally:
self._msg_pipe.close()

Expand Down Expand Up @@ -103,14 +112,19 @@ def _start_loop(self):
tuning = True

msg = self._recv_msg()
if msg[0] == 'abort':
if msg[0] == "abort":
raise KeyboardInterrupt()
if msg[0] != 'start':
raise ValueError('Unexpected msg ' + msg[0])
if msg[0] != "start":
raise ValueError("Unexpected msg " + msg[0])

while True:
if draw < self._draws + self._tune:
point, stats = self._compute_point()
try:
point, stats = self._compute_point()
except SamplingError as e:
warns = self._collect_warnings()
e = ExceptionWithTraceback(e, e.__traceback__)
self._msg_pipe.send(("error", warns, e))
else:
return

Expand All @@ -119,20 +133,21 @@ def _start_loop(self):
tuning = False

msg = self._recv_msg()
if msg[0] == 'abort':
if msg[0] == "abort":
raise KeyboardInterrupt()
elif msg[0] == 'write_next':
elif msg[0] == "write_next":
self._write_point(point)
is_last = draw + 1 == self._draws + self._tune
if is_last:
warns = self._collect_warnings()
else:
warns = None
self._msg_pipe.send(
('writing_done', is_last, draw, tuning, stats, warns))
("writing_done", is_last, draw, tuning, stats, warns)
)
draw += 1
else:
raise ValueError('Unknown message ' + msg[0])
raise ValueError("Unknown message " + msg[0])

def _compute_point(self):
if self._step_method.generates_stats:
Expand All @@ -143,14 +158,15 @@ def _compute_point(self):
return point, stats

def _collect_warnings(self):
if hasattr(self._step_method, 'warnings'):
if hasattr(self._step_method, "warnings"):
return self._step_method.warnings()
else:
return []


class ProcessAdapter(object):
"""Control a Chain process from the main thread."""

def __init__(self, draws, tune, step_method, chain, seed, start):
self.chain = chain
process_name = "worker_chain_%s" % chain
Expand All @@ -164,9 +180,9 @@ def __init__(self, draws, tune, step_method, chain, seed, start):
size *= int(dim)
size *= dtype.itemsize
if size != ctypes.c_size_t(size).value:
raise ValueError('Variable %s is too large' % name)
raise ValueError("Variable %s is too large" % name)

array = multiprocessing.sharedctypes.RawArray('c', size)
array = multiprocessing.sharedctypes.RawArray("c", size)
self._shared_point[name] = array
array_np = np.frombuffer(array, dtype).reshape(shape)
array_np[...] = start[name]
Expand All @@ -176,8 +192,14 @@ def __init__(self, draws, tune, step_method, chain, seed, start):
self._num_samples = 0

self._process = _Process(
process_name, remote_conn, step_method, self._shared_point,
draws, tune, seed)
process_name,
remote_conn,
step_method,
self._shared_point,
draws,
tune,
seed,
)
# We fork right away, so that the main process can start tqdm threads
self._process.start()

Expand All @@ -191,14 +213,14 @@ def shared_point_view(self):
return self._point

def start(self):
self._msg_pipe.send(('start',))
self._msg_pipe.send(("start",))

def write_next(self):
self._readable = False
self._msg_pipe.send(('write_next',))
self._msg_pipe.send(("write_next",))

def abort(self):
self._msg_pipe.send(('abort',))
self._msg_pipe.send(("abort",))

def join(self, timeout=None):
self._process.join(timeout)
Expand All @@ -209,24 +231,28 @@ def terminate(self):
@staticmethod
def recv_draw(processes, timeout=3600):
if not processes:
raise ValueError('No processes.')
raise ValueError("No processes.")
pipes = [proc._msg_pipe for proc in processes]
ready = multiprocessing.connection.wait(pipes)
if not ready:
raise multiprocessing.TimeoutError('No message from samplers.')
raise multiprocessing.TimeoutError("No message from samplers.")
idxs = {id(proc._msg_pipe): proc for proc in processes}
proc = idxs[id(ready[0])]
msg = ready[0].recv()

if msg[0] == 'error':
old = msg[1]
six.raise_from(RuntimeError('Chain %s failed.' % proc.chain), old)
elif msg[0] == 'writing_done':
if msg[0] == "error":
warns, old_error = msg[1:]
if warns is not None:
error = ParallelSamplingError(str(old_error), proc.chain, warns)
else:
error = RuntimeError("Chain %s failed." % proc.chain)
six.raise_from(error, old_error)
elif msg[0] == "writing_done":
proc._readable = True
proc._num_samples += 1
return (proc,) + msg[1:]
else:
raise ValueError('Sampler sent bad message.')
raise ValueError("Sampler sent bad message.")

@staticmethod
def terminate_all(processes, patience=2):
Expand All @@ -244,34 +270,46 @@ def terminate_all(processes, patience=2):
raise multiprocessing.TimeoutError()
process.join(timeout)
except multiprocessing.TimeoutError:
logger.warn('Chain processes did not terminate as expected. '
'Terminating forcefully...')
logger.warn(
"Chain processes did not terminate as expected. "
"Terminating forcefully..."
)
for process in processes:
process.terminate()
for process in processes:
process.join()


Draw = namedtuple(
'Draw',
['chain', 'is_last', 'draw_idx', 'tuning', 'stats', 'point', 'warnings']
"Draw", ["chain", "is_last", "draw_idx", "tuning", "stats", "point", "warnings"]
)


class ParallelSampler(object):
def __init__(self, draws, tune, chains, cores, seeds, start_points,
step_method, start_chain_num=0, progressbar=True):
def __init__(
self,
draws,
tune,
chains,
cores,
seeds,
start_points,
step_method,
start_chain_num=0,
progressbar=True,
):
if progressbar:
import tqdm

tqdm_ = tqdm.tqdm

if any(len(arg) != chains for arg in [seeds, start_points]):
raise ValueError(
'Number of seeds and start_points must be %s.' % chains)
raise ValueError("Number of seeds and start_points must be %s." % chains)

self._samplers = [
ProcessAdapter(draws, tune, step_method,
chain + start_chain_num, seed, start)
ProcessAdapter(
draws, tune, step_method, chain + start_chain_num, seed, start
)
for chain, seed, start in zip(range(chains), seeds, start_points)
]

Expand All @@ -286,8 +324,10 @@ def __init__(self, draws, tune, chains, cores, seeds, start_points,
self._progress = None
if progressbar:
self._progress = tqdm_(
total=chains * (draws + tune), unit='draws',
desc='Sampling %s chains' % chains)
total=chains * (draws + tune),
unit="draws",
desc="Sampling %s chains" % chains,
)

def _make_active(self):
while self._inactive and len(self._active) < self._max_active:
Expand All @@ -298,7 +338,7 @@ def _make_active(self):

def __iter__(self):
if not self._in_context:
raise ValueError('Use ParallelSampler as context manager.')
raise ValueError("Use ParallelSampler as context manager.")
self._make_active()

while self._active:
Expand All @@ -317,8 +357,7 @@ def __iter__(self):
# and only call proc.write_next() after the yield returns.
# This seems to be faster overally though, as the worker
# loses less time waiting.
point = {name: val.copy()
for name, val in proc.shared_point_view.items()}
point = {name: val.copy() for name, val in proc.shared_point_view.items()}

# Already called for new proc in _make_active
if not is_last:
Expand Down
35 changes: 23 additions & 12 deletions pymc3/sampling.py
Expand Up @@ -986,17 +986,28 @@ def _mp_sample(draws, tune, step, chains, cores, chain, random_seed,
draws, tune, chains, cores, random_seed, start, step,
chain, progressbar)
try:
with sampler:
for draw in sampler:
trace = traces[draw.chain - chain]
if trace.supports_sampler_stats and draw.stats is not None:
trace.record(draw.point, draw.stats)
else:
trace.record(draw.point)
if draw.is_last:
trace.close()
if draw.warnings is not None:
trace._add_warnings(draw.warnings)
try:
with sampler:
for draw in sampler:
trace = traces[draw.chain - chain]
if (trace.supports_sampler_stats
and draw.stats is not None):
trace.record(draw.point, draw.stats)
else:
trace.record(draw.point)
if draw.is_last:
trace.close()
if draw.warnings is not None:
trace._add_warnings(draw.warnings)
except ps.ParallelSamplingError as error:
trace = traces[error._chain - chain]
trace._add_warnings(error._warnings)
for trace in traces:
trace.close()

multitrace = MultiTrace(traces)
multitrace._report._log_summary()
raise
return MultiTrace(traces)
except KeyboardInterrupt:
traces, length = _choose_chains(traces, tune)
Expand Down Expand Up @@ -1512,4 +1523,4 @@ def init_nuts(init='auto', chains=1, n_init=500000, model=None,

step = pm.NUTS(potential=potential, model=model, **kwargs)

return start, step
return start, step

0 comments on commit 5451a31

Please sign in to comment.