Skip to content

Commit

Permalink
Check load_trace argument (#3534)
Browse files Browse the repository at this point in the history
Check for errors from loading bad trace directory.
Add new exception class.
While doing this, added type annotations to ndarray.py
  • Loading branch information
rpgoldman committed Jun 28, 2019
1 parent 5ce6403 commit ff7d3cd
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
37 changes: 24 additions & 13 deletions pymc3/backends/ndarray.py
Expand Up @@ -6,12 +6,16 @@
import json
import os
import shutil
from typing import Optional, Dict, Any

import numpy as np
from ..backends import base
from pymc3.backends import base
from pymc3.backends.base import MultiTrace
from pymc3.model import Model
from pymc3.exceptions import TraceDirectoryError


def save_trace(trace, directory=None, overwrite=False):
def save_trace(trace: MultiTrace, directory: Optional[str]=None, overwrite=False) -> str:
"""Save multitrace to file.
TODO: Also save warnings.
Expand Down Expand Up @@ -54,7 +58,7 @@ def save_trace(trace, directory=None, overwrite=False):
return directory


def load_trace(directory, model=None):
def load_trace(directory: str, model=None) -> MultiTrace:
"""Loads a multitrace that has been written to file.
A the model used for the trace must be passed in, or the command
Expand All @@ -72,17 +76,21 @@ def load_trace(directory, model=None):
pm.Multitrace that was saved in the directory
"""
straces = []
for directory in glob.glob(os.path.join(directory, '*')):
if os.path.isdir(directory):
straces.append(SerializeNDArray(directory).load(model))
for subdir in glob.glob(os.path.join(directory, '*')):
if os.path.isdir(subdir):
straces.append(SerializeNDArray(subdir).load(model))
if not straces:
raise TraceDirectoryError("%s is not a PyMC3 saved chain directory." % directory)
return base.MultiTrace(straces)


class SerializeNDArray:
metadata_file = 'metadata.json'
samples_file = 'samples.npz'
metadata_path = None # type: str
samples_path = None # type: str

def __init__(self, directory):
def __init__(self, directory: str):
"""Helper to save and load NDArray objects"""
self.directory = directory
self.metadata_path = os.path.join(self.directory, self.metadata_file)
Expand Down Expand Up @@ -126,8 +134,11 @@ def save(self, ndarray):

np.savez_compressed(self.samples_path, **ndarray.samples)

def load(self, model):
def load(self, model: Model) -> 'NDArray':
"""Load the saved ndarray from file"""
if not os.path.exists(self.samples_path) or not os.path.exists(self.metadata_path):
raise TraceDirectoryError("%s is not a trace directory" % self.directory)

new_trace = NDArray(model=model)
with open(self.metadata_path, 'r') as buff:
metadata = json.load(buff)
Expand Down Expand Up @@ -165,7 +176,7 @@ def __init__(self, name=None, model=None, vars=None, test_point=None):

# Sampling methods

def setup(self, draws, chain, sampler_vars=None):
def setup(self, draws, chain, sampler_vars=None) -> None:
"""Perform chain-specific setup.
Parameters
Expand Down Expand Up @@ -204,7 +215,7 @@ def setup(self, draws, chain, sampler_vars=None):
if self._stats is None:
self._stats = []
for sampler in sampler_vars:
data = dict()
data = dict() # type: Dict[str, np.ndarray]
self._stats.append(data)
for varname, dtype in sampler.items():
data[varname] = np.zeros(draws, dtype=dtype)
Expand All @@ -218,7 +229,7 @@ def setup(self, draws, chain, sampler_vars=None):
new = np.zeros(draws, dtype=dtype)
data[varname] = np.concatenate([old, new])

def record(self, point, sampler_stats=None):
def record(self, point, sampler_stats=None) -> None:
"""Record results of a sampling iteration.
Parameters
Expand Down Expand Up @@ -261,7 +272,7 @@ def __len__(self):
return 0
return self.draw_idx

def get_values(self, varname, burn=0, thin=1):
def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
"""Get values from trace.
Parameters
Expand Down Expand Up @@ -302,7 +313,7 @@ def _slice(self, idx):

return sliced

def point(self, idx):
def point(self, idx) -> Dict[str, Any]:
"""Return dictionary of point values at `idx` for current chain
with variable names as keys.
"""
Expand Down
6 changes: 5 additions & 1 deletion pymc3/exceptions.py
@@ -1,4 +1,4 @@
__all__ = ['SamplingError', 'IncorrectArgumentsError']
__all__ = ['SamplingError', 'IncorrectArgumentsError', 'TraceDirectoryError']


class SamplingError(RuntimeError):
Expand All @@ -7,3 +7,7 @@ class SamplingError(RuntimeError):

class IncorrectArgumentsError(ValueError):
pass

class TraceDirectoryError(ValueError):
'''Error from trying to load a trace from an incorrectly-structured directory,'''
pass
5 changes: 5 additions & 0 deletions pymc3/tests/test_ndarray_backend.py
Expand Up @@ -213,6 +213,11 @@ def test_save_and_load(self, tmpdir_factory):
for var in ('x', 'z'):
assert (self.trace[var] == trace2[var]).all()

def test_bad_load(self, tmpdir_factory):
directory = str(tmpdir_factory.mktemp('data'))
with pytest.raises(pm.TraceDirectoryError):
pm.load_trace(directory, model=TestSaveLoad.model())

def test_sample_posterior_predictive(self, tmpdir_factory):
directory = str(tmpdir_factory.mktemp('data'))
save_dir = pm.save_trace(self.trace, directory, overwrite=True)
Expand Down

0 comments on commit ff7d3cd

Please sign in to comment.