Skip to content

Commit

Permalink
Add option to pass existing logp_dlogp_function to sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Sep 17, 2020
1 parent 5d2f697 commit e4012b7
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions pymc3/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,20 +236,25 @@ def link_population(self, population, chain_index):

class GradientSharedStep(BlockedStep):
def __init__(self, vars, model=None, blocked=True,
dtype=None, **theano_kwargs):
dtype=None, logp_dlogp_func=None, **theano_kwargs):
model = modelcontext(model)
self.vars = vars
self.blocked = blocked

func = model.logp_dlogp_function(
vars, dtype=dtype, **theano_kwargs)
if logp_dlogp_func is None:
func = model.logp_dlogp_function(
vars, dtype=dtype, **theano_kwargs)
else:
func = logp_dlogp_func

# handle edge case discovered in #2948
try:
func.set_extra_values(model.test_point)
q = func.dict_to_array(model.test_point)
logp, dlogp = func(q)
except ValueError:
if logp_dlogp_func is not None:
raise
theano_kwargs.update(mode='FAST_COMPILE')
func = model.logp_dlogp_function(
vars, dtype=dtype, **theano_kwargs)
Expand Down

0 comments on commit e4012b7

Please sign in to comment.