Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 48 additions & 10 deletions pints/_mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ def __init__(self, log_pdf, chains, x0, sigma0=None, method=None):
self._log_csv = False
self.set_log_interval()

# Storing chains and evaluations in memory
self._chains_in_memory = True

# Writing chains and evaluations to disk
self._chain_files = None
self._evaluation_files = None
Expand Down Expand Up @@ -524,9 +527,14 @@ def run(self):
sampler._log_init(logger)
logger.add_time('Time m:s')

# Create chains (pre-allocate)
samples = np.zeros(
(self._chains, self._max_iterations, self._n_parameters))
# Pre-allocate array for chains
if self._chains_in_memory:
# Store full chains
samples = np.zeros(
(self._chains, self._max_iterations, self._n_parameters))
else:
# Store only the current iteration
samples = np.zeros((self._chains, self._n_parameters))

# Some samplers need intermediate steps, where None is returned instead
# of a sample. Samplers can run asynchronously, so that one returns
Expand Down Expand Up @@ -565,19 +573,25 @@ def run(self):

# Update chains
if self._single_chain:
# Single chain

# Check and update the individual chains
xs_iterator = iter(xs)
fxs_iterator = iter(fxs)
for i in list(active): # (active may be modified)
for i in list(active): # new list: active may be modified
x = next(xs_iterator)
fx = next(fxs_iterator)
y = self._samplers[i].tell(fx)

if y is not None:
samples[i][n_samples[i]] = y
n_samples[i] += 1
# Store sample in memory
if self._chains_in_memory:
samples[i][n_samples[i]] = y
else:
samples[i] = y

# Stop adding samples if maximum number reached
n_samples[i] += 1
if n_samples[i] == self._max_iterations:
active.remove(i)

Expand All @@ -603,12 +617,18 @@ def run(self):
intermediate_step = min(n_samples) <= iteration

else:
# Multi-chain methods

# Get all chains samples at once
ys = self._samplers[0].tell(fxs)
intermediate_step = ys is None

if not intermediate_step:
samples[:, iteration] = ys
# Store samples in memory
if self._chains_in_memory:
samples[:, iteration] = ys
else:
samples = ys

# Write evaluations to disk
if self._evaluation_files:
Expand Down Expand Up @@ -636,8 +656,12 @@ def run(self):
continue

# Write samples to disk
for i, chain_logger in enumerate(chain_loggers):
chain_logger.log(*samples[i][iteration])
if self._chains_in_memory:
for i, chain_logger in enumerate(chain_loggers):
chain_logger.log(*samples[i][iteration])
else:
for i, chain_logger in enumerate(chain_loggers):
chain_logger.log(*samples[i])

# Show progress
if logging and iteration >= next_message:
Expand Down Expand Up @@ -674,7 +698,10 @@ def run(self):
print(halt_message)

# Return generated chains
return samples
if self._chains_in_memory:
return samples
else:
return None

def sampler(self):
"""
Expand Down Expand Up @@ -723,6 +750,17 @@ def set_chain_filename(self, chain_file):
b, e = os.path.splitext(str(chain_file))
self._chain_files = [b + '_' + str(i) + e for i in range(d)]

def set_chain_storage(self, store_in_memory=True):
"""
Store chains in memory as they are generated.

By default, all generated chains are stored in memory as they are
generated, and returned by :meth:`run()`. This method allows this
behaviour to be disabled, which can be useful for very large chains
which are already stored to disk (see :meth:`set_chain_filename()`).
"""
self._chains_in_memory = bool(store_in_memory)

def set_log_pdf_filename(self, log_pdf_file):
"""
Write :class:`LogPDF` evaluations to disk as they are generated.
Expand Down
126 changes: 124 additions & 2 deletions pints/tests/test_mcmc_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def setUpClass(cls):
cls.nchains = len(cls.xs)

def test_writing_chains_only(self):
""" Test writing chains - not evals to disk (using LogPosterior). """
""" Test writing chains - but not evals - to disk. """

mcmc = pints.MCMCController(self.log_posterior, self.nchains, self.xs)
mcmc.set_initial_phase_iterations(5)
Expand Down Expand Up @@ -534,6 +534,128 @@ def test_writing_chains_only(self):
self.assertNotIn('Writing evaluations to', text)
self.assertNotIn('evals_0.csv', text)

def test_writing_chains_only_no_memory_single(self):
"""
Test writing chains - but not evals - to disk, without storing chains
in memory, using a single-chain method.
"""

mcmc = pints.MCMCController(self.log_posterior, self.nchains, self.xs)
mcmc.set_initial_phase_iterations(5)
mcmc.set_max_iterations(20)
mcmc.set_log_to_screen(True)
mcmc.set_log_to_file(False)
mcmc.set_chain_storage(False)

with StreamCapture() as c:
with TemporaryDirectory() as d:
cpath = d.path('chain.csv')
p0 = d.path('chain_0.csv')
p1 = d.path('chain_1.csv')
p2 = d.path('chain_2.csv')
epath = d.path('evals.csv')
p3 = d.path('evals_0.csv')
p4 = d.path('evals_1.csv')
p5 = d.path('evals_2.csv')

# Test files aren't created before mcmc runs
mcmc.set_chain_filename(cpath)
mcmc.set_log_pdf_filename(None)
self.assertFalse(os.path.exists(cpath))
self.assertFalse(os.path.exists(epath))
self.assertFalse(os.path.exists(p0))
self.assertFalse(os.path.exists(p1))
self.assertFalse(os.path.exists(p2))
self.assertFalse(os.path.exists(p3))
self.assertFalse(os.path.exists(p4))
self.assertFalse(os.path.exists(p5))

# Test files are created afterwards
chains1 = mcmc.run()
self.assertFalse(os.path.exists(cpath))
self.assertFalse(os.path.exists(epath))
self.assertTrue(os.path.exists(p0))
self.assertTrue(os.path.exists(p1))
self.assertTrue(os.path.exists(p2))
self.assertFalse(os.path.exists(p3))
self.assertFalse(os.path.exists(p4))
self.assertFalse(os.path.exists(p5))

# Test chains weren't returned in memory
self.assertIsNone(chains1)

# Test disk contains chains
import pints.io as io
chains2 = np.array(io.load_samples(cpath, self.nchains))
self.assertEqual(
chains2.shape, (self.nchains, 20, len(self.xs)))

text = c.text()
self.assertIn('Writing chains to', text)
self.assertIn('chain_0.csv', text)
self.assertNotIn('Writing evaluations to', text)
self.assertNotIn('evals_0.csv', text)

def test_writing_chains_only_no_memory_multi(self):
"""
Test writing chains - but not evals - to disk, without storing chains
in memory, using a multi-chain method.
"""

mcmc = pints.MCMCController(
self.log_posterior, self.nchains, self.xs,
method=pints.DifferentialEvolutionMCMC)
mcmc.set_max_iterations(20)
mcmc.set_log_to_screen(True)
mcmc.set_log_to_file(False)
mcmc.set_chain_storage(False)

with StreamCapture() as c:
with TemporaryDirectory() as d:
cpath = d.path('chain.csv')
p0 = d.path('chain_0.csv')
p1 = d.path('chain_1.csv')
p2 = d.path('chain_2.csv')
epath = d.path('evals.csv')
p3 = d.path('evals_0.csv')
p4 = d.path('evals_1.csv')
p5 = d.path('evals_2.csv')

# Test files aren't created before mcmc runs
mcmc.set_chain_filename(cpath)
self.assertFalse(os.path.exists(cpath))
self.assertFalse(os.path.exists(epath))
self.assertFalse(os.path.exists(p0))
self.assertFalse(os.path.exists(p1))
self.assertFalse(os.path.exists(p2))
self.assertFalse(os.path.exists(p3))
self.assertFalse(os.path.exists(p4))
self.assertFalse(os.path.exists(p5))

# Test files are created afterwards
chains1 = mcmc.run()
self.assertFalse(os.path.exists(cpath))
self.assertFalse(os.path.exists(epath))
self.assertTrue(os.path.exists(p0))
self.assertTrue(os.path.exists(p1))
self.assertTrue(os.path.exists(p2))
self.assertFalse(os.path.exists(p3))
self.assertFalse(os.path.exists(p4))
self.assertFalse(os.path.exists(p5))

# Test chains weren't returned in memory
self.assertIsNone(chains1)

# Test disk contains chains
import pints.io as io
chains2 = np.array(io.load_samples(cpath, self.nchains))
self.assertEqual(
chains2.shape, (self.nchains, 20, len(self.xs)))

text = c.text()
self.assertIn('Writing chains to', text)
self.assertIn('chain_0.csv', text)

def test_writing_priors_and_likelihoods(self):
""" Test writing priors and loglikelihoods - not chains - to disk. """

Expand Down Expand Up @@ -867,7 +989,7 @@ def test_writing_chains_likelihoods_and_priors_one_chain(self):
self.assertIn('Writing evaluations to', text)
self.assertIn('evals_0.csv', text)

def test_disabling_storage(self):
def test_disabling_disk_storage(self):
""" Test if storage can be enabled and then disabled again. """
mcmc = pints.MCMCController(self.log_posterior, self.nchains, self.xs)
mcmc.set_initial_phase_iterations(5)
Expand Down