Skip to content

Commit

Permalink
Make slice sampler sample from 1D conditionals as it should (#2446)
Browse files Browse the repository at this point in the history
* Make Slice sampler sample from 1D conditionals 

In the previous implementation it would sample jointly from non-scalar variables, and hang for when the size is high (due to low probability to get a joint sample within the slice in high-D).

* slicer.py

Fix broken indentation due to copypaste

* Apply autopep8

* Delete a superfluous commented line

* Update the master sample for Slice in test_step.py
  • Loading branch information
madanh authored and Junpeng Lao committed Jul 28, 2017
1 parent 15b8595 commit b988ba9
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 48 deletions.
61 changes: 33 additions & 28 deletions pymc3/step_methods/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def __init__(self, vars=None, w=1., tune=True, model=None, **kwargs):
self.model = modelcontext(model)
self.w = w
self.tune = tune
self.w_sum = 0
self.n_tunes = 0
self.n_tunes = 0.

if vars is None:
vars = self.model.cont_vars
Expand All @@ -44,33 +43,39 @@ def __init__(self, vars=None, w=1., tune=True, model=None, **kwargs):
super(Slice, self).__init__(vars, [self.model.fastlogp], **kwargs)

def astep(self, q0, logp):
self.w = np.resize(self.w, len(q0))
y = logp(q0) - nr.standard_exponential()

# Stepping out procedure
q_left = q0 - nr.uniform(0, self.w)
q_right = q_left + self.w

while (y < logp(q_left)).all():
q_left -= self.w

while (y < logp(q_right)).all():
q_right += self.w

q = nr.uniform(q_left, q_right, size=q_left.size) # new variable to avoid copies
while logp(q) <= y:
# Sample uniformly from slice
if (q > q0).all():
q_right = q
elif (q < q0).all():
q_left = q
q = nr.uniform(q_left, q_right, size=q_left.size)

self.w = np.resize(self.w, len(q0)) # this is a repmat
q = np.copy(q0) # TODO: find out if we need this
ql = np.copy(q0) # l for left boundary
qr = np.copy(q0) # r for right boudary
for i in range(len(q0)):
# uniformly sample from 0 to p(q), but in log space
y = logp(q) - nr.standard_exponential()
ql[i] = q[i] - nr.uniform(0, self.w[i])
qr[i] = q[i] + self.w[i]
# Stepping out procedure
while(y <= logp(ql)): # changed lt to leq for locally uniform posteriors
ql[i] -= self.w[i]
while(y <= logp(qr)):
qr[i] += self.w[i]

q[i] = nr.uniform(ql[i], qr[i])
while logp(q) < y: # Changed leq to lt, to accomodate for locally flat posteriors
# Sample uniformly from slice
if q[i] > q0[i]:
qr[i] = q[i]
elif q[i] < q0[i]:
ql[i] = q[i]
q[i] = nr.uniform(ql[i], qr[i])

if self.tune: # I was under impression from MacKays lectures that slice width can be tuned without
# breaking markovianness. Can we do it regardless of self.tune?(@madanh)
self.w[i] = self.w[i] * (self.n_tunes / (self.n_tunes + 1)) +\
(qr[i] - ql[i]) / (self.n_tunes + 1) # same as before
# unobvious and important: return qr and ql to the same point
qr[i] = q[i]
ql[i] = q[i]
if self.tune:
# Tune sampler parameters
self.w_sum += np.abs(q0 - q)
self.n_tunes += 1.
self.w = 2. * self.w_sum / self.n_tunes
self.n_tunes += 1
return q

@staticmethod
Expand Down
54 changes: 34 additions & 20 deletions pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,40 @@
class TestStepMethods(object): # yield test doesn't work subclassing object
master_samples = {
Slice: np.array([
-8.13087389e-01, -3.08921856e-01, -6.79377098e-01, 6.50812585e-01, -7.63577596e-01,
-8.13199793e-01, -1.63823548e+00, -7.03863676e-02, 2.05107771e+00, 1.68598170e+00,
6.92463695e-01, -7.75120766e-01, -1.62296463e+00, 3.59722423e-01, -2.31421712e-01,
-7.80686956e-02, -6.05860731e-01, -1.13000202e-01, 1.55675942e-01, -6.78527612e-01,
6.31052333e-01, 6.09012517e-01, -1.56621643e+00, 5.04330883e-01, 3.14824082e-03,
-1.31287073e+00, 4.10706927e-01, 8.93815792e-01, 8.19317020e-01, 3.71900919e-01,
-2.62067312e+00, -3.47616592e+00, 1.50335041e+00, -1.05993351e+00, 2.41571723e-01,
-1.06258156e+00, 5.87999429e-01, -1.78480091e-01, -3.60278680e-01, 1.90615274e-01,
-1.24399204e-01, 4.03845589e-01, -1.47797573e-01, 7.90445804e-01, -1.21043819e+00,
-1.33964776e+00, 1.36366329e+00, -7.50175388e-01, 9.25241839e-01, -4.17493767e-01,
1.85311339e+00, -2.49715343e+00, -3.18571692e-01, -1.49099668e+00, -2.62079621e-01,
-5.82376852e-01, -2.53033395e+00, 2.07580503e+00, -9.82615856e-01, 6.00517782e-01,
-9.83941620e-01, -1.59014118e+00, -1.83931394e-03, -4.71163466e-01, 1.90073737e+00,
-2.08929125e-01, -6.98388847e-01, 1.64502092e+00, -1.19525944e+00, 1.44424109e+00,
1.52974876e+00, -5.70140077e-01, 5.08633322e-01, -1.70862492e-02, -1.69887948e-01,
5.19760297e-01, -4.15149647e-01, 8.63685174e-02, -3.66805233e-01, -9.24988952e-01,
2.33307122e+00, -2.60391496e-01, -5.86271814e-01, -5.01297170e-01, -1.53866195e+00,
5.71285373e-01, -1.30571830e+00, 8.59587795e-01, 6.72170694e-01, 9.12433943e-01,
7.04959179e-01, 8.37863464e-01, -5.24200836e-01, 1.28261340e+00, 9.08774240e-01,
8.80566763e-01, 7.82911967e-01, 8.01843432e-01, 7.09251098e-01, 5.73803618e-01]),
-5.95252353e-01, -1.81894861e-01, -4.98211488e-01,
-1.02262800e-01, -4.26726030e-01, 1.75446860e+00,
-1.30022548e+00, 8.35658004e-01, 8.95879638e-01,
-8.85214481e-01, -6.63530918e-01, -8.39303080e-01,
9.42792225e-01, 9.03554344e-01, 8.45254684e-01,
-1.43299803e+00, 9.04897201e-01, -1.74303131e-01,
-6.38611581e-01, 1.50013968e+00, 1.06864438e+00,
-4.80484421e-01, -7.52199709e-01, 1.95067495e+00,
-3.67960104e+00, 2.49291588e+00, -2.11039152e+00,
1.61674758e-01, -1.59564182e-01, 2.19089873e-01,
1.88643940e+00, 4.04098154e-01, -4.59352326e-01,
-9.06370675e-01, 5.42817654e-01, 6.99040611e-03,
1.66396391e-01, -4.74549281e-01, 8.19064437e-02,
1.69689952e+00, -1.62667304e+00, 1.61295808e+00,
1.30099144e+00, -5.46722750e-01, -7.87745494e-01,
7.91027521e-01, -2.35706976e-02, 1.68824376e+00,
7.10566880e-01, -7.23551374e-01, 8.85613069e-01,
-1.27300146e+00, 1.80274430e+00, 9.34266276e-01,
2.40427061e+00, -1.85132552e-01, 4.47234196e-01,
-9.81894859e-01, -2.83399706e-01, 1.84717533e+00,
-1.58593284e+00, 3.18027270e-02, 1.40566006e+00,
-9.45758714e-01, 1.18813188e-01, -1.19938604e+00,
-8.26038466e-01, 5.03469984e-01, -4.72742758e-01,
2.27820946e-01, -1.02608915e-03, -6.02507158e-01,
7.72739682e-01, 7.16064505e-01, -1.63693490e+00,
-3.97161966e-01, 1.17147944e+00, -2.87796982e+00,
-1.59533297e+00, 6.73096114e-01, -3.34397247e-01,
1.22357427e-01, -4.57299104e-02, 1.32005771e+00,
-1.29910645e+00, 8.16168850e-01, -1.47357594e+00,
1.34688446e+00, 1.06377551e+00, 4.34296696e-02,
8.23143354e-01, 8.40906324e-01, 1.88596864e+00,
5.77120694e-01, 2.71732927e-01, -1.36217979e+00,
2.41488213e+00, 4.68298379e-01, 4.86342250e-01,
-8.43949966e-01]),
HamiltonianMC: np.array([
-0.74925631, -0.2566773 , -2.12480977, 1.64328926, -1.39315913,
2.04200003, 0.00706711, 0.34240498, 0.44276674, -0.21368043,
Expand Down

0 comments on commit b988ba9

Please sign in to comment.