Skip to content

Commit

Permalink
consistency between backends and multisqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
philastrophist committed Feb 2, 2017
1 parent 52babab commit bbc191c
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 21 deletions.
6 changes: 4 additions & 2 deletions pymc3/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,13 @@
For specific examples, see pymc3.backends.{ndarray,text,sqlite}.py.
"""
from ..backends.ndarray import NDArray
from ..backends.ndarray import NDArray, MultiNDArray
from ..backends.text import Text
from ..backends.sqlite import SQLite
from ..backends.sqlite import SQLite, MultiSQLite

_shortcuts = {'text': {'backend': Text,
'name': 'mcmc'},
'sqlite': {'backend': SQLite,
'name': 'mcmc.sqlite'}}

_particle_shortcuts = {'sqlite': {'backend': MultiSQLite, 'name': 'mcmc.sqlite'}}
10 changes: 10 additions & 0 deletions pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ def point(self, idx):
raise NotImplementedError


class MultiMixin(object):
chain_coverage = []

def iter_fn(self, point, particle):
d = {}
for k, v in point.iteritems():
d[k] = v[particle]
return self.fn(d)


class MultiTrace(object):
"""Main interface for accessing values from MCMC results
Expand Down
40 changes: 22 additions & 18 deletions pymc3/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,12 @@ def point(self, idx):
for varname, values in self.samples.items()}


class MultiNDArray(NDArray):
def __init__(self, nparticles=1, name=None, model=None, vars=None):
class MultiNDArray(NDArray, base.MultiMixin):
def __init__(self, nparticles, name=None, model=None, vars=None):
super(MultiNDArray, self).__init__(name, model, vars)
self.nparticles = nparticles
self.samples = OrderedDict([])


def iter_fn(self, point, particle):
d = {}
for k, v in point.iteritems():
d[k] = v[particle]
return self.fn(d)

def record(self, point):
"""
:param q:
Expand Down Expand Up @@ -172,12 +165,14 @@ def setup(self, draws, chain):
self.samples[varname] = np.zeros((self.nparticles, draws) + shape,
dtype=self.var_dtypes[varname])

def point(self, idx):
def point(self, idx, ipx=None):
"""Return dictionary of point values at `idx` for current chain
with variables names as keys.
"""
if ipx is None:
ipx = slice(None)
idx = int(idx)
return {varname: values[:, idx]
return {varname: values[ipx, idx]
for varname, values in self.samples.items()}


Expand All @@ -197,7 +192,7 @@ def __len__(self):
varname = self.varnames[0]
return self.samples[varname].shape[1]

def get_values(self, varname, burn=0, thin=1, particle_idx=slice(None)):
def get_values(self, varname, burn=0, thin=1, particles=None):
"""Get values from trace.
Parameters
Expand All @@ -211,7 +206,9 @@ def get_values(self, varname, burn=0, thin=1, particle_idx=slice(None)):
-------
A NumPy array
"""
return self.samples[varname][particle_idx, burn::thin]
if particles is None:
particles = slice(None)
return self.samples[varname][particles, burn::thin]

def _slice(self, step_idx, particle_idx=slice(None)):
# Slicing directly instead of using _slice_as_ndarray to
Expand Down Expand Up @@ -239,7 +236,7 @@ def get_flat_trace(self):
for varname, values in self.samples.items()}
return sliced

def _slice_as_ndarray(strace, idx):
def _slice_as_ndarray(strace, idx, ipx=None):
if idx.start is None:
burn = 0
else:
Expand All @@ -249,10 +246,17 @@ def _slice_as_ndarray(strace, idx):
else:
thin = idx.step

sliced = NDArray(model=strace.model, vars=strace.vars)
sliced.chain = strace.chain
sliced.samples = {v: strace.get_values(v, burn=burn, thin=thin)
for v in strace.varnames}
if hasattr(strace, 'nparticles'):
sliced = MultiNDArray(strace.nparticles, model=strace.model, vars=strace.vars)
sliced.chain = strace.chain
sliced.chain_coverage = strace.chain_coverage
sliced.samples = {v: strace.get_values(v, burn=burn, thin=thin, particles=ipx)
for v in strace.varnames}
else:
sliced = NDArray(model=strace.model, vars=strace.vars)
sliced.chain = strace.chain
sliced.samples = {v: strace.get_values(v, burn=burn, thin=thin)
for v in strace.varnames}
return sliced


Expand Down
119 changes: 119 additions & 0 deletions pymc3/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,122 @@ def point(self, idx):
return var_values


class MultiSQLite(SQLite, base.MultiMixin):
def __init__(self, nparticles, name, model = None, vars = None):
super(MultiSQLite, self).__init__(name, model, vars)
self.nparticles = nparticles

def setup(self, draws, chain):
super(MultiSQLite, self).setup(draws, chain)
self.chain_coverage = [i+chain for i in range(self.nparticles)]

def record(self, point):
"""Record results of a sampling iteration.
Parameters
----------
point : dict
Values mapped to variable names
"""
for w in range(self.nparticles):
for varname, value in zip(self.varnames, self.iter_fn(point, w)):
values = (self.draw_idx, self.chain+w) + tuple(np.ravel(value))
self._queue[varname].append(values)

if len(self._queue[self.varnames[0]]) > self._queue_limit:
self._execute_queue()
self.draw_idx += 1

def get_values(self, varname, burn=0, thin=1, particles=None):
"""Get values from trace.
Parameters
----------
varname : str
burn : int
thin : int
Returns
-------
A NumPy array
"""
if burn < 0:
raise ValueError('Negative burn values not supported '
'in SQLite backend.')
if thin < 1:
raise ValueError('Only positive thin values are supported '
'in SQLite backend.')
varname = str(varname)

if isinstance(particles, int):
statement_args = {'chain': self.chain+particles}
if burn == 0 and thin == 1:
action = 'select'
elif thin == 1:
action = 'select_burn'
statement_args['burn'] = burn - 1
elif burn == 0:
action = 'select_thin'
statement_args['thin'] = thin
else:
action = 'select_burn_thin'
statement_args['burn'] = burn - 1
statement_args['thin'] = thin

self.db.connect()
shape = (-1,) + self.var_shapes[varname]
statement = TEMPLATES[action].format(table=varname)
self.db.cursor.execute(statement, statement_args)
values = _rows_to_ndarray(self.db.cursor)
return values.reshape(shape)

elif isinstance(particles, (list, tuple)):
return np.asarray([self.get_values(varname, burn, thin, i) for i in particles])
elif isinstance(particles, slice):
particles = range(self.nparticles)[particles]
return self.get_values(varname, burn, thin, particles)
elif particles is None:
return self.get_values(varname, burn, thin, slice(None))
else:
raise ValueError('{} is not a valid particle slice'.format(particles))

def _slice(self, idx):
if isinstance(idx, tuple):
i = idx[0]
else:
i = idx
if i.stop is not None:
raise ValueError('Stop value in slice not supported.')
return ndarray._slice_as_ndarray(self, idx)

def point(self, idx, ipx=None):
"""Return dictionary of point values at `idx` for current chain
with variables names as keys.
"""
if isinstance(ipx, int):
idx = int(idx)
if idx < 0:
idx = self._get_max_draw(self.chain) + idx + 1
statement = TEMPLATES['select_point']
self.db.connect()
var_values = {}
statement_args = {'chain': self.chain+ipx, 'draw': idx}
for varname in self.varnames:
self.db.cursor.execute(statement.format(table=varname),
statement_args)
values = _rows_to_ndarray(self.db.cursor)
var_values[varname] = values.reshape(self.var_shapes[varname])
return var_values
elif isinstance(ipx, (list, tuple)):
return np.asarray([self.point(idx, i) for i in ipx])
elif isinstance(ipx, slice):
ipx = range(self.nparticles)[ipx]
return self.point(idx, ipx)
elif ipx is None:
return self.point(idx, slice(None))
else:
raise ValueError('{} is not a valid particle slice'.format(ipx))

class _SQLiteDB(object):

def __init__(self, name):
Expand Down Expand Up @@ -337,3 +453,6 @@ def _get_chain_list(cursor, varname):
def _rows_to_ndarray(cursor):
"""Convert SQL row to NDArray."""
return np.squeeze(np.array([row[3:] for row in cursor.fetchall()]))


# TODO: store chain coverage in SQL
3 changes: 2 additions & 1 deletion pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,11 @@ def _choose_backend(trace, chain, shortcuts=None, nparticles=None, **kwds):
if shortcuts is None:
shortcuts = pm.backends._shortcuts

kwds['nparticles'] = nparticles
try:
backend = shortcuts[trace]['backend']
name = shortcuts[trace]['name']
return backend(name, **kwds)
return backend(name=name, **kwds)
except TypeError:
if nparticles is None:
return NDArray(vars=trace, **kwds)
Expand Down

0 comments on commit bbc191c

Please sign in to comment.