Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check load_trace argument #3534

Merged
merged 3 commits into from
Jun 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 24 additions & 13 deletions pymc3/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
import json
import os
import shutil
from typing import Optional, Dict, Any

import numpy as np
from ..backends import base
from pymc3.backends import base
from pymc3.backends.base import MultiTrace
from pymc3.model import Model
from pymc3.exceptions import TraceDirectoryError


def save_trace(trace, directory=None, overwrite=False):
def save_trace(trace: MultiTrace, directory: Optional[str]=None, overwrite=False) -> str:
"""Save multitrace to file.

TODO: Also save warnings.
Expand Down Expand Up @@ -54,7 +58,7 @@ def save_trace(trace, directory=None, overwrite=False):
return directory


def load_trace(directory, model=None):
def load_trace(directory: str, model=None) -> MultiTrace:
"""Loads a multitrace that has been written to file.

A the model used for the trace must be passed in, or the command
Expand All @@ -72,17 +76,21 @@ def load_trace(directory, model=None):
pm.Multitrace that was saved in the directory
"""
straces = []
for directory in glob.glob(os.path.join(directory, '*')):
if os.path.isdir(directory):
straces.append(SerializeNDArray(directory).load(model))
for subdir in glob.glob(os.path.join(directory, '*')):
if os.path.isdir(subdir):
straces.append(SerializeNDArray(subdir).load(model))
if not straces:
raise TraceDirectoryError("%s is not a PyMC3 saved chain directory." % directory)
return base.MultiTrace(straces)


class SerializeNDArray:
metadata_file = 'metadata.json'
samples_file = 'samples.npz'
metadata_path = None # type: str
samples_path = None # type: str

def __init__(self, directory):
def __init__(self, directory: str):
"""Helper to save and load NDArray objects"""
self.directory = directory
self.metadata_path = os.path.join(self.directory, self.metadata_file)
Expand Down Expand Up @@ -126,8 +134,11 @@ def save(self, ndarray):

np.savez_compressed(self.samples_path, **ndarray.samples)

def load(self, model):
def load(self, model: Model) -> 'NDArray':
"""Load the saved ndarray from file"""
if not os.path.exists(self.samples_path) or not os.path.exists(self.metadata_path):
raise TraceDirectoryError("%s is not a trace directory" % self.directory)

new_trace = NDArray(model=model)
with open(self.metadata_path, 'r') as buff:
metadata = json.load(buff)
Expand Down Expand Up @@ -165,7 +176,7 @@ def __init__(self, name=None, model=None, vars=None, test_point=None):

# Sampling methods

def setup(self, draws, chain, sampler_vars=None):
def setup(self, draws, chain, sampler_vars=None) -> None:
"""Perform chain-specific setup.

Parameters
Expand Down Expand Up @@ -204,7 +215,7 @@ def setup(self, draws, chain, sampler_vars=None):
if self._stats is None:
self._stats = []
for sampler in sampler_vars:
data = dict()
data = dict() # type: Dict[str, np.ndarray]
self._stats.append(data)
for varname, dtype in sampler.items():
data[varname] = np.zeros(draws, dtype=dtype)
Expand All @@ -218,7 +229,7 @@ def setup(self, draws, chain, sampler_vars=None):
new = np.zeros(draws, dtype=dtype)
data[varname] = np.concatenate([old, new])

def record(self, point, sampler_stats=None):
def record(self, point, sampler_stats=None) -> None:
"""Record results of a sampling iteration.

Parameters
Expand Down Expand Up @@ -261,7 +272,7 @@ def __len__(self):
return 0
return self.draw_idx

def get_values(self, varname, burn=0, thin=1):
def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
"""Get values from trace.

Parameters
Expand Down Expand Up @@ -302,7 +313,7 @@ def _slice(self, idx):

return sliced

def point(self, idx):
def point(self, idx) -> Dict[str, Any]:
"""Return dictionary of point values at `idx` for current chain
with variable names as keys.
"""
Expand Down
6 changes: 5 additions & 1 deletion pymc3/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ['SamplingError', 'IncorrectArgumentsError']
__all__ = ['SamplingError', 'IncorrectArgumentsError', 'TraceDirectoryError']


class SamplingError(RuntimeError):
Expand All @@ -7,3 +7,7 @@ class SamplingError(RuntimeError):

class IncorrectArgumentsError(ValueError):
pass

class TraceDirectoryError(ValueError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice. i did not know this file existed!

'''Error from trying to load a trace from an incorrectly-structured directory,'''
pass
5 changes: 5 additions & 0 deletions pymc3/tests/test_ndarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ def test_save_and_load(self, tmpdir_factory):
for var in ('x', 'z'):
assert (self.trace[var] == trace2[var]).all()

def test_bad_load(self, tmpdir_factory):
directory = str(tmpdir_factory.mktemp('data'))
with pytest.raises(pm.TraceDirectoryError):
pm.load_trace(directory, model=TestSaveLoad.model())

def test_sample_posterior_predictive(self, tmpdir_factory):
directory = str(tmpdir_factory.mktemp('data'))
save_dir = pm.save_trace(self.trace, directory, overwrite=True)
Expand Down