Skip to content

Commit

Permalink
Merge bbc191c into c014311
Browse files Browse the repository at this point in the history
  • Loading branch information
philastrophist committed Feb 2, 2017
2 parents c014311 + bbc191c commit 087b31f
Show file tree
Hide file tree
Showing 13 changed files with 601 additions and 69 deletions.
6 changes: 4 additions & 2 deletions pymc3/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,13 @@
For specific examples, see pymc3.backends.{ndarray,text,sqlite}.py.
"""
from ..backends.ndarray import NDArray
from ..backends.ndarray import NDArray, MultiNDArray
from ..backends.text import Text
from ..backends.sqlite import SQLite
from ..backends.sqlite import SQLite, MultiSQLite

_shortcuts = {'text': {'backend': Text,
'name': 'mcmc'},
'sqlite': {'backend': SQLite,
'name': 'mcmc.sqlite'}}

_particle_shortcuts = {'sqlite': {'backend': MultiSQLite, 'name': 'mcmc.sqlite'}}
20 changes: 20 additions & 0 deletions pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ def point(self, idx):
raise NotImplementedError


class MultiMixin(object):
chain_coverage = []

def iter_fn(self, point, particle):
d = {}
for k, v in point.iteritems():
d[k] = v[particle]
return self.fn(d)


class MultiTrace(object):
"""Main interface for accessing values from MCMC results
Expand Down Expand Up @@ -152,6 +162,16 @@ class MultiTrace(object):

def __init__(self, straces):
self._straces = {}
particle_traces = []
for i, strace in enumerate(straces):
try:
for j in range(strace.nparticles):
particle_traces.append(strace.get_particle_trace(j))
del straces[i]
except AttributeError:
pass

straces += particle_traces
for strace in straces:
if strace.chain in self._straces:
raise ValueError("Chains are not unique.")
Expand Down
138 changes: 133 additions & 5 deletions pymc3/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Store sampling values in memory as a NumPy array.
"""
from collections import OrderedDict

import numpy as np
from ..backends import base

Expand Down Expand Up @@ -118,7 +120,123 @@ def point(self, idx):
for varname, values in self.samples.items()}


def _slice_as_ndarray(strace, idx):
class MultiNDArray(NDArray, base.MultiMixin):
def __init__(self, nparticles, name=None, model=None, vars=None):
super(MultiNDArray, self).__init__(name, model, vars)
self.nparticles = nparticles
self.samples = OrderedDict([])

def record(self, point):
"""
:param q:
:return:
"""
for w in range(self.nparticles):
for varname, value in zip(self.varnames, self.iter_fn(point, w)):
self.samples[varname][w, self.draw_idx] = value
self.draw_idx += 1

def setup(self, draws, chain):
"""Perform chain-specific setup.
Parameters
----------
draws : int
Expected number of draws
chain : int
Chain number
"""
self.chain = chain
self.chain_coverage = range(chain, chain+self.nparticles)
if self.samples: # Concatenate new array if chain is already present.
old_draws = len(self)
self.draws = old_draws + draws
self.draws_idx = old_draws
for varname, shape in self.var_shapes.items():
old_var_samples = self.samples[varname]
new_var_samples = np.zeros((self.nparticles, draws) + shape,
self.var_dtypes[varname])
self.samples[varname] = np.concatenate((old_var_samples,
new_var_samples),
axis=0)
else: # Otherwise, make array of zeros for each variable.
self.draws = draws
for varname, shape in self.var_shapes.items():
self.samples[varname] = np.zeros((self.nparticles, draws) + shape,
dtype=self.var_dtypes[varname])

def point(self, idx, ipx=None):
"""Return dictionary of point values at `idx` for current chain
with variables names as keys.
"""
if ipx is None:
ipx = slice(None)
idx = int(idx)
return {varname: values[ipx, idx]
for varname, values in self.samples.items()}


def close(self):
if self.draw_idx == self.draws:
return
# Remove trailing zeros if interrupted before completed all
# draws.
self.samples = {var: vtrace[:, :self.draw_idx]
for var, vtrace in self.samples.items()}

# Selection methods

def __len__(self):
if not self.samples: # `setup` has not been called.
return 0
varname = self.varnames[0]
return self.samples[varname].shape[1]

def get_values(self, varname, burn=0, thin=1, particles=None):
"""Get values from trace.
Parameters
----------
varname : str
burn : int
thin : int
particle_idx: int
Returns
-------
A NumPy array
"""
if particles is None:
particles = slice(None)
return self.samples[varname][particles, burn::thin]

def _slice(self, step_idx, particle_idx=slice(None)):
# Slicing directly instead of using _slice_as_ndarray to
# support stop value in slice (which is needed by
# iter_sample).

sliced = MultiNDArray(model=self.model, vars=self.vars)
sliced.chain = self.chain
sliced.samples = {varname: values[particle_idx, step_idx]
for varname, values in self.samples.items()}
return sliced

def get_particle_trace(self, particle_index):
assert isinstance(particle_index, int)
sliced = NDArray(model=self.model, vars=self.vars)
sliced.chain = self.chain_coverage[particle_index]
sliced.samples = {varname: values[particle_index]
for varname, values in self.samples.items()}
return sliced

def get_flat_trace(self):
sliced = NDArray(model=self.model, vars=self.vars)
sliced.chain = self.chain_coverage[0]
sliced.samples = {varname: values.ravel()
for varname, values in self.samples.items()}
return sliced

def _slice_as_ndarray(strace, idx, ipx=None):
if idx.start is None:
burn = 0
else:
Expand All @@ -128,8 +246,18 @@ def _slice_as_ndarray(strace, idx):
else:
thin = idx.step

sliced = NDArray(model=strace.model, vars=strace.vars)
sliced.chain = strace.chain
sliced.samples = {v: strace.get_values(v, burn=burn, thin=thin)
for v in strace.varnames}
if hasattr(strace, 'nparticles'):
sliced = MultiNDArray(strace.nparticles, model=strace.model, vars=strace.vars)
sliced.chain = strace.chain
sliced.chain_coverage = strace.chain_coverage
sliced.samples = {v: strace.get_values(v, burn=burn, thin=thin, particles=ipx)
for v in strace.varnames}
else:
sliced = NDArray(model=strace.model, vars=strace.vars)
sliced.chain = strace.chain
sliced.samples = {v: strace.get_values(v, burn=burn, thin=thin)
for v in strace.varnames}
return sliced


# TODO: Make MultiText, Multisqlite
119 changes: 119 additions & 0 deletions pymc3/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,122 @@ def point(self, idx):
return var_values


class MultiSQLite(SQLite, base.MultiMixin):
def __init__(self, nparticles, name, model = None, vars = None):
super(MultiSQLite, self).__init__(name, model, vars)
self.nparticles = nparticles

def setup(self, draws, chain):
super(MultiSQLite, self).setup(draws, chain)
self.chain_coverage = [i+chain for i in range(self.nparticles)]

def record(self, point):
"""Record results of a sampling iteration.
Parameters
----------
point : dict
Values mapped to variable names
"""
for w in range(self.nparticles):
for varname, value in zip(self.varnames, self.iter_fn(point, w)):
values = (self.draw_idx, self.chain+w) + tuple(np.ravel(value))
self._queue[varname].append(values)

if len(self._queue[self.varnames[0]]) > self._queue_limit:
self._execute_queue()
self.draw_idx += 1

def get_values(self, varname, burn=0, thin=1, particles=None):
"""Get values from trace.
Parameters
----------
varname : str
burn : int
thin : int
Returns
-------
A NumPy array
"""
if burn < 0:
raise ValueError('Negative burn values not supported '
'in SQLite backend.')
if thin < 1:
raise ValueError('Only positive thin values are supported '
'in SQLite backend.')
varname = str(varname)

if isinstance(particles, int):
statement_args = {'chain': self.chain+particles}
if burn == 0 and thin == 1:
action = 'select'
elif thin == 1:
action = 'select_burn'
statement_args['burn'] = burn - 1
elif burn == 0:
action = 'select_thin'
statement_args['thin'] = thin
else:
action = 'select_burn_thin'
statement_args['burn'] = burn - 1
statement_args['thin'] = thin

self.db.connect()
shape = (-1,) + self.var_shapes[varname]
statement = TEMPLATES[action].format(table=varname)
self.db.cursor.execute(statement, statement_args)
values = _rows_to_ndarray(self.db.cursor)
return values.reshape(shape)

elif isinstance(particles, (list, tuple)):
return np.asarray([self.get_values(varname, burn, thin, i) for i in particles])
elif isinstance(particles, slice):
particles = range(self.nparticles)[particles]
return self.get_values(varname, burn, thin, particles)
elif particles is None:
return self.get_values(varname, burn, thin, slice(None))
else:
raise ValueError('{} is not a valid particle slice'.format(particles))

def _slice(self, idx):
if isinstance(idx, tuple):
i = idx[0]
else:
i = idx
if i.stop is not None:
raise ValueError('Stop value in slice not supported.')
return ndarray._slice_as_ndarray(self, idx)

def point(self, idx, ipx=None):
"""Return dictionary of point values at `idx` for current chain
with variables names as keys.
"""
if isinstance(ipx, int):
idx = int(idx)
if idx < 0:
idx = self._get_max_draw(self.chain) + idx + 1
statement = TEMPLATES['select_point']
self.db.connect()
var_values = {}
statement_args = {'chain': self.chain+ipx, 'draw': idx}
for varname in self.varnames:
self.db.cursor.execute(statement.format(table=varname),
statement_args)
values = _rows_to_ndarray(self.db.cursor)
var_values[varname] = values.reshape(self.var_shapes[varname])
return var_values
elif isinstance(ipx, (list, tuple)):
return np.asarray([self.point(idx, i) for i in ipx])
elif isinstance(ipx, slice):
ipx = range(self.nparticles)[ipx]
return self.point(idx, ipx)
elif ipx is None:
return self.point(idx, slice(None))
else:
raise ValueError('{} is not a valid particle slice'.format(ipx))

class _SQLiteDB(object):

def __init__(self, name):
Expand Down Expand Up @@ -337,3 +453,6 @@ def _get_chain_list(cursor, varname):
def _rows_to_ndarray(cursor):
"""Convert SQL row to NDArray."""
return np.squeeze(np.array([row[3:] for row in cursor.fetchall()]))


# TODO: store chain coverage in SQL
25 changes: 18 additions & 7 deletions pymc3/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@ class ArrayOrdering(object):
An ordering for an array space
"""

def __init__(self, vars):
def __init__(self, vars, nparticles=None):
self.vmap = []
dim = 0
self.nparticles = nparticles

for var in vars:
slc = slice(dim, dim + var.dsize)
self.vmap.append(VarMap(str(var), slc, var.dshape, var.dtype))
dim += var.dsize

self.dimensions = dim
if self.nparticles is not None:
self.dimensions = (self.nparticles, dim)
else:
self.dimensions = dim


class DictToArrayBijection(object):
Expand All @@ -38,7 +41,7 @@ class DictToArrayBijection(object):
def __init__(self, ordering, dpoint):
self.ordering = ordering
self.dpt = dpoint

# determine smallest float dtype that will fit all data
if all([x.dtyp == 'float16' for x in ordering.vmap]):
self.array_dtype = 'float16'
Expand All @@ -56,8 +59,12 @@ def map(self, dpt):
dpt : dict
"""
apt = np.empty(self.ordering.dimensions, dtype=self.array_dtype)
for var, slc, _, _ in self.ordering.vmap:
apt[slc] = dpt[var].ravel()
for var, slc, shp, _ in self.ordering.vmap:
if self.ordering.nparticles is not None:
for d in range(self.ordering.nparticles):
apt[d, slc] = dpt[var][d].ravel()
else:
apt[slc] = dpt[var].ravel()
return apt

def rmap(self, apt):
Expand All @@ -71,7 +78,11 @@ def rmap(self, apt):
dpt = self.dpt.copy()

for var, slc, shp, dtyp in self.ordering.vmap:
dpt[var] = np.atleast_1d(apt)[slc].reshape(shp).astype(dtyp)
if self.ordering.nparticles is not None:
for d in range(self.ordering.nparticles):
dpt[var][d] = np.atleast_1d(apt)[d, slc].reshape(shp).astype(dtyp)
else:
dpt[var] = np.atleast_1d(apt)[slc].reshape(shp).astype(dtyp)

return dpt

Expand Down
Loading

0 comments on commit 087b31f

Please sign in to comment.