-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
I am trying to detect changepoints on several trials of a process with categorical emissions. I find that NUTS sampling in really slow (something like 4s/it) - Metropolis is faster (of course) and samples about 200 samples/s, but I need to run it to about 500k samples to get reasonable Gelman-Rubin convergence values. I just wanted to check in and see if there was something incorrect (or open to improvement) in the way that I have set up my model, and if NUTS can be gotten to work better in this case.
Here's one example of what I am trying to do: 32 trials of a Categorical process with 17 emissions, each trial has two changepoints. There are 7 unique states of emissions in the set of trials - depending upon the type of trial, it changes from a start state (which is the same for all trials) to 1 (out of 2 possible) states at changepoint1, and then changes again to 1 (out of 4 possible) states at changepoint2. Each trial has 150 emissions in total:
with pm.Model() as model:
# Dirichlet prior on the emission/spiking probabilities - 7 states (1 start state, 2 from changepoint1 to
changepoint2, 4 from changepoint2 to end of the trial)
p = pm.Dirichlet('p', np.ones(num_emissions), shape = (7, num_emissions))
# Uniform switch times
# First changepoint
t1 = pm.Uniform('t1', lower = 20, upper = 60, shape = num_trials)
# Second changepoint
t2 = pm.Uniform('t2', lower = t1 + 20, upper = 130, shape = num_trials)
# Get the actual state numbers based on the switch times.
states = []
for i in range(num_trials):
states1 = tt.switch(t1[i] >= np.arange(150), 0, set1[i])
states2 = tt.switch(t2[i] >= np.arange(150), states1, set2[i])
states.append(states2)
# Define the log-likelihood function
def logp(value):
value = tt.cast(value, 'int32')
loglik = 0
for i in range(32):
loglik += tt.sum(tt.log(p[states[i], value[i, :]]))
return loglik
# Categorical observations
obs = pm.DensityDist('obs', logp, observed = {'value': data[0, :, :150]})
# Inference button :D
tr = pm.sample(500000, init = None, step = pm.Metropolis(), njobs = 2, start = {'t1': 25.0, 't2': 120.0})
Initializing with find_MAP() gives errors at times with this model, so I stuck to the start values I specified above as they seem to work well generally with Metropolis, although I have to sample in the range of 500-600k samples to get reasonable convergence (as I said above).