Skip to content

Speed of NUTS sampling #2094

@narendramukherjee

Description

@narendramukherjee

I am trying to detect changepoints on several trials of a process with categorical emissions. I find that NUTS sampling in really slow (something like 4s/it) - Metropolis is faster (of course) and samples about 200 samples/s, but I need to run it to about 500k samples to get reasonable Gelman-Rubin convergence values. I just wanted to check in and see if there was something incorrect (or open to improvement) in the way that I have set up my model, and if NUTS can be gotten to work better in this case.

Here's one example of what I am trying to do: 32 trials of a Categorical process with 17 emissions, each trial has two changepoints. There are 7 unique states of emissions in the set of trials - depending upon the type of trial, it changes from a start state (which is the same for all trials) to 1 (out of 2 possible) states at changepoint1, and then changes again to 1 (out of 4 possible) states at changepoint2. Each trial has 150 emissions in total:

with pm.Model() as model:
	# Dirichlet prior on the emission/spiking probabilities - 7 states (1 start state, 2 from changepoint1 to 
        changepoint2, 4 from changepoint2 to end of the trial)
	p = pm.Dirichlet('p', np.ones(num_emissions), shape = (7, num_emissions))

	# Uniform switch times
	# First changepoint
	t1 = pm.Uniform('t1', lower = 20, upper = 60, shape = num_trials)
	# Second changepoint
	t2 = pm.Uniform('t2', lower = t1 + 20, upper = 130, shape = num_trials)

	# Get the actual state numbers based on the switch times.
	states = []
	for i in range(num_trials):
		states1 = tt.switch(t1[i] >= np.arange(150), 0, set1[i])
		states2 = tt.switch(t2[i] >= np.arange(150), states1, set2[i])
		states.append(states2)

	# Define the log-likelihood function
	def logp(value):
		value = tt.cast(value, 'int32')
		loglik = 0
		for i in range(32):
			loglik += tt.sum(tt.log(p[states[i], value[i, :]]))
		return loglik

	# Categorical observations
	obs = pm.DensityDist('obs', logp, observed = {'value': data[0, :, :150]})

	# Inference button :D
	tr = pm.sample(500000, init = None, step = pm.Metropolis(), njobs = 2, start = {'t1': 25.0, 't2': 120.0})

Initializing with find_MAP() gives errors at times with this model, so I stuck to the start values I specified above as they seem to work well generally with Metropolis, although I have to sample in the range of 500-600k samples to get reasonable convergence (as I said above).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions