Skip to content

Commit

Permalink
Merge branch 'master' into transformed_init_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fonnesbeck committed Apr 20, 2017
2 parents b1866e9 + ec0cd55 commit 792dd0d
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 315 deletions.
121 changes: 33 additions & 88 deletions pymc3/backends/smc_text.py
@@ -1,6 +1,4 @@
"""
Text file trace backend modified from pymc3 to work efficiently with
SMC
"""Text file trace backend modified from pymc3 to work efficiently with SMC
Store sampling values as CSV files.
Expand Down Expand Up @@ -36,9 +34,8 @@
import multiprocessing


def paripool(function, work, **kwargs):
"""
Initialises a pool of workers and executes a function in parallel by
def paripool(function, work, nprocs=None, chunksize=1):
"""Initialise a pool of workers and execute a function in parallel by
forking the process. Does forking once during initialisation.
Parameters
Expand All @@ -53,13 +50,8 @@ def paripool(function, work, **kwargs):
chunksize : int
number of work packages to throw at workers in each instance
"""

nprocs = kwargs.get('nprocs', None)
chunksize = kwargs.get('chunksize', None)
verbose = kwargs.get('verbose', None)

if chunksize is None:
chunksize = 1
if nprocs is None:
nprocs = multiprocessing.cpu_count()

if nprocs == 1:
def pack_one_worker(*work):
Expand All @@ -75,9 +67,6 @@ def pack_one_worker(*work):

return

if nprocs is None:
nprocs = multiprocessing.cpu_count()

try:
pool = multiprocessing.Pool(processes=nprocs)
yield pool.map(function, work, chunksize=chunksize)
Expand All @@ -86,9 +75,7 @@ def pack_one_worker(*work):


class ArrayStepSharedLLK(BlockedStep):
"""
Modified ArrayStepShared To handle returned larger point including the
likelihood values.
"""Modified ArrayStepShared To handle returned larger point including the likelihood values.
Takes additionally a list of output vars including the likelihoods.
Parameters
Expand All @@ -103,7 +90,6 @@ class ArrayStepSharedLLK(BlockedStep):
blocked : boolen
(default True)
"""

def __init__(self, vars, out_vars, shared, blocked=True):
self.vars = vars
self.ordering = ArrayOrdering(vars)
Expand Down Expand Up @@ -155,23 +141,19 @@ def __init__(self, name, model=None, vars=None):
self.vars = vars
self.varnames = [var.name for var in vars]

## Get variable shapes. Most backends will need this
## information.
# Get variable shapes. Most backends will need this information.

self.var_shapes_list = [var.tag.test_value.shape for var in vars]
self.var_dtypes_list = [var.tag.test_value.dtype for var in vars]

self.var_shapes = {var: shape
for var, shape in zip(self.varnames, self.var_shapes_list)}
self.var_dtypes = {var: dtype
for var, dtype in zip(self.varnames, self.var_dtypes_list)}
self.var_shapes = dict(zip(self.varnames, self.var_shapes_list))
self.var_dtypes = dict(zip(self.varnames, self.var_dtypes_list))

self.chain = None

def __getitem__(self, idx):
if isinstance(idx, slice):
return self._slice(idx)

try:
return self.point(int(idx))
except (ValueError, TypeError): # Passed variable or variable name.
Expand All @@ -185,8 +167,7 @@ def __setstate__(self, state):


class Text(BaseSMCTrace):
"""
Text trace object
"""Text trace object
Parameters
----------
Expand All @@ -212,11 +193,8 @@ def __init__(self, name, model=None, vars=None):
self.df = None
self.corrupted_flag = False

## Sampling methods

def setup(self, draws, chain):
"""
Perform chain-specific setup.
"""Perform chain-specific setup.
Parameters
----------
Expand All @@ -225,7 +203,6 @@ def setup(self, draws, chain):
chain : int
Chain number
"""

self.chain = chain
self.filename = os.path.join(self.name, 'chain-{}.csv'.format(chain))

Expand All @@ -238,8 +215,7 @@ def setup(self, draws, chain):
fh.write(','.join(cnames) + '\n')

def record(self, lpoint):
"""
Record results of a sampling iteration.
"""Record results of a sampling iteration.
Parameters
----------
Expand All @@ -255,13 +231,11 @@ def _load_df(self):
try:
self.df = pd.read_csv(self.filename)
except pd.parser.EmptyDataError:
pm._log.warn('Trace %s is empty and needs to be resampled!' % \
self.filename)
pm._log.warn('Trace %s is empty and needs to be resampled!' % self.filename)
os.remove(self.filename)
self.corrupted_flag = True
except pd.io.common.CParserError:
pm._log.warn('Trace %s has wrong size!' % \
self.filename)
pm._log.warn('Trace %s has wrong size!' % self.filename)
self.corrupted_flag = True
os.remove(self.filename)

Expand All @@ -277,8 +251,7 @@ def __len__(self):
return self.df.shape[0]

def get_values(self, varname, burn=0, thin=1):
"""
Get values from trace.
"""Get values from trace.
Parameters
----------
Expand Down Expand Up @@ -308,8 +281,7 @@ def _slice(self, idx):
return ndarray._slice_as_ndarray(self, idx)

def point(self, idx):
"""
Get point of current chain with variables names as keys.
"""Get point of current chain with variables names as keys.
Parameters
----------
Expand All @@ -330,8 +302,7 @@ def point(self, idx):


def get_highest_sampled_stage(homedir, return_final=False):
"""
Return stage number of stage that has been sampled before the final stage.
"""Return stage number of stage that has been sampled before the final stage.
Paramaeters
-----------
Expand All @@ -343,7 +314,6 @@ def get_highest_sampled_stage(homedir, return_final=False):
stage number : int
"""
stages = glob(os.path.join(homedir, 'stage_*'))

stagenumbers = []
for s in stages:
stage_ending = os.path.splitext(s)[0].rsplit('_', 1)[1]
Expand All @@ -354,14 +324,11 @@ def get_highest_sampled_stage(homedir, return_final=False):
if return_final:
return stage_ending

stagenumbers.sort()

return stagenumbers[-1]
return max(stagenumbers)


def check_multitrace(mtrace, draws, n_chains):
"""
Check multitrace for incomplete sampling and return indexes from chains
"""Check multitrace for incomplete sampling and return indexes from chains
that need to be resampled.
Parameters
Expand All @@ -377,29 +344,22 @@ def check_multitrace(mtrace, draws, n_chains):
-------
list of indexes for chains that need to be resampled
"""

not_sampled_idx = []

for chain in range(n_chains):
if chain in mtrace.chains:
if len(mtrace._straces[chain]) != draws:
pm._log.warn('Trace number %i incomplete' % chain)
mtrace._straces[chain].corrupted_flag = True

else:
not_sampled_idx.append(chain)

flag_bool = [
mtrace._straces[chain].corrupted_flag for chain in mtrace.chains]

flag_bool = [mtrace._straces[chain].corrupted_flag for chain in mtrace.chains]
corrupted_idx = [i for i, x in enumerate(flag_bool) if x]

return corrupted_idx + not_sampled_idx


def load(name, model=None):
"""
Load Text database.
"""Load Text database.
Parameters
----------
Expand All @@ -426,8 +386,7 @@ def load(name, model=None):


def dump(name, trace, chains=None):
"""
Store values from NDArray trace as CSV files.
"""Store values from NDArray trace as CSV files.
Parameters
----------
Expand All @@ -438,26 +397,22 @@ def dump(name, trace, chains=None):
chains : list
Chains to dump. If None, all chains are dumped.
"""

if not os.path.exists(name):
os.mkdir(name)
if chains is None:
chains = trace.chains

var_shapes = trace._straces[chains[0]].var_shapes
flat_names = {v: ttab.create_flat_names(v, shape)
for v, shape in var_shapes.items()}
flat_names = {v: ttab.create_flat_names(v, shape) for v, shape in var_shapes.items()}

for chain in chains:
filename = os.path.join(name, 'chain-{}.csv'.format(chain))
df = ttab.trace_to_dataframe(
trace, chains=chain, flat_names=flat_names)
df = ttab.trace_to_dataframe(trace, chains=chain, flat_names=flat_names)
df.to_csv(filename, index=False)


def dump_objects(outpath, outlist):
"""
Dump objects in outlist into pickle file.
"""Dump objects in outlist into pickle file.
Parameters
----------
Expand All @@ -466,14 +421,12 @@ def dump_objects(outpath, outlist):
outlist : list
of objects to save pickle
"""

with open(outpath, 'wb') as f:
pickle.dump(outlist, f, protocol=pickle.HIGHEST_PROTOCOL)


def load_objects(loadpath):
"""
Load (unpickle) saved (pickled) objects from specified loadpath.
"""Load (unpickle) saved (pickled) objects from specified loadpath.
Parameters
----------
Expand All @@ -484,18 +437,15 @@ def load_objects(loadpath):
objects : list
of saved objects
"""

try:
objects = pickle.load(open(loadpath, 'rb'))
with open(loadpath, 'rb') as buff:
return pickle.load(buff)
except IOError:
raise Exception(
'File %s does not exist! Data already imported?' % loadpath)
return objects
raise Exception('File %s does not exist! Data already imported?' % loadpath)


def load_atmip_params(project_dir, stage_number, mode):
"""
Load saved parameters from given ATMIP stage.
"""Load saved parameters from given ATMIP stage.
Parameters
----------
Expand All @@ -506,17 +456,13 @@ def load_atmip_params(project_dir, stage_number, mode):
mode : str
problem mode that has been solved ('geometry', 'static', 'kinematic')
"""

stage_path = os.path.join(project_dir, mode, 'stage_%s' % stage_number,
'atmip.params')
stage_path = os.path.join(project_dir, mode, 'stage_%s' % stage_number, 'atmip.params')
step = load_objects(stage_path)
return step


def split_off_list(l, off_length):
"""
Split a list with length 'off_length' from the beginning of an input
list l.
"""Split a list with length 'off_length' from the beginning of an input list l.
Modifies input list!
Parameters
Expand All @@ -530,5 +476,4 @@ def split_off_list(l, off_length):
-------
list
"""

return [l.pop(0) for i in range(off_length)]
return [l.pop(0) for _ in range(off_length)]

0 comments on commit 792dd0d

Please sign in to comment.