diff --git a/pymc3/backends/ndarray.py b/pymc3/backends/ndarray.py index 4f173acee7..7a523175df 100644 --- a/pymc3/backends/ndarray.py +++ b/pymc3/backends/ndarray.py @@ -6,12 +6,16 @@ import json import os import shutil +from typing import Optional, Dict, Any import numpy as np -from ..backends import base +from pymc3.backends import base +from pymc3.backends.base import MultiTrace +from pymc3.model import Model +from pymc3.exceptions import TraceDirectoryError -def save_trace(trace, directory=None, overwrite=False): +def save_trace(trace: MultiTrace, directory: Optional[str]=None, overwrite=False) -> str: """Save multitrace to file. TODO: Also save warnings. @@ -54,7 +58,7 @@ def save_trace(trace, directory=None, overwrite=False): return directory -def load_trace(directory, model=None): +def load_trace(directory: str, model=None) -> MultiTrace: """Loads a multitrace that has been written to file. A the model used for the trace must be passed in, or the command @@ -72,17 +76,21 @@ def load_trace(directory, model=None): pm.Multitrace that was saved in the directory """ straces = [] - for directory in glob.glob(os.path.join(directory, '*')): - if os.path.isdir(directory): - straces.append(SerializeNDArray(directory).load(model)) + for subdir in glob.glob(os.path.join(directory, '*')): + if os.path.isdir(subdir): + straces.append(SerializeNDArray(subdir).load(model)) + if not straces: + raise TraceDirectoryError("%s is not a PyMC3 saved chain directory." % directory) return base.MultiTrace(straces) class SerializeNDArray: metadata_file = 'metadata.json' samples_file = 'samples.npz' + metadata_path = None # type: str + samples_path = None # type: str - def __init__(self, directory): + def __init__(self, directory: str): """Helper to save and load NDArray objects""" self.directory = directory self.metadata_path = os.path.join(self.directory, self.metadata_file) @@ -126,8 +134,11 @@ def save(self, ndarray): np.savez_compressed(self.samples_path, **ndarray.samples) - def load(self, model): + def load(self, model: Model) -> 'NDArray': """Load the saved ndarray from file""" + if not os.path.exists(self.samples_path) or not os.path.exists(self.metadata_path): + raise TraceDirectoryError("%s is not a trace directory" % self.directory) + new_trace = NDArray(model=model) with open(self.metadata_path, 'r') as buff: metadata = json.load(buff) @@ -165,7 +176,7 @@ def __init__(self, name=None, model=None, vars=None, test_point=None): # Sampling methods - def setup(self, draws, chain, sampler_vars=None): + def setup(self, draws, chain, sampler_vars=None) -> None: """Perform chain-specific setup. Parameters @@ -204,7 +215,7 @@ def setup(self, draws, chain, sampler_vars=None): if self._stats is None: self._stats = [] for sampler in sampler_vars: - data = dict() + data = dict() # type: Dict[str, np.ndarray] self._stats.append(data) for varname, dtype in sampler.items(): data[varname] = np.zeros(draws, dtype=dtype) @@ -218,7 +229,7 @@ def setup(self, draws, chain, sampler_vars=None): new = np.zeros(draws, dtype=dtype) data[varname] = np.concatenate([old, new]) - def record(self, point, sampler_stats=None): + def record(self, point, sampler_stats=None) -> None: """Record results of a sampling iteration. Parameters @@ -261,7 +272,7 @@ def __len__(self): return 0 return self.draw_idx - def get_values(self, varname, burn=0, thin=1): + def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray: """Get values from trace. Parameters @@ -302,7 +313,7 @@ def _slice(self, idx): return sliced - def point(self, idx): + def point(self, idx) -> Dict[str, Any]: """Return dictionary of point values at `idx` for current chain with variable names as keys. """ diff --git a/pymc3/exceptions.py b/pymc3/exceptions.py index 62d275e2b9..b2ff9f0c52 100644 --- a/pymc3/exceptions.py +++ b/pymc3/exceptions.py @@ -1,4 +1,4 @@ -__all__ = ['SamplingError', 'IncorrectArgumentsError'] +__all__ = ['SamplingError', 'IncorrectArgumentsError', 'TraceDirectoryError'] class SamplingError(RuntimeError): @@ -7,3 +7,7 @@ class SamplingError(RuntimeError): class IncorrectArgumentsError(ValueError): pass + +class TraceDirectoryError(ValueError): + '''Error from trying to load a trace from an incorrectly-structured directory,''' + pass diff --git a/pymc3/tests/test_ndarray_backend.py b/pymc3/tests/test_ndarray_backend.py index be438c706b..b5ea2daa80 100644 --- a/pymc3/tests/test_ndarray_backend.py +++ b/pymc3/tests/test_ndarray_backend.py @@ -213,6 +213,11 @@ def test_save_and_load(self, tmpdir_factory): for var in ('x', 'z'): assert (self.trace[var] == trace2[var]).all() + def test_bad_load(self, tmpdir_factory): + directory = str(tmpdir_factory.mktemp('data')) + with pytest.raises(pm.TraceDirectoryError): + pm.load_trace(directory, model=TestSaveLoad.model()) + def test_sample_posterior_predictive(self, tmpdir_factory): directory = str(tmpdir_factory.mktemp('data')) save_dir = pm.save_trace(self.trace, directory, overwrite=True)