Skip to content

Commit

Permalink
Use set_subtensor instead of inc_subtensor in Ordered transform
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 13, 2023
1 parent 73dee50 commit eb7c3b6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ def __init__(self, ndim_supp=None):

def backward(self, value, *inputs):
x = pt.zeros(value.shape)
x = pt.inc_subtensor(x[..., 0], value[..., 0])
x = pt.inc_subtensor(x[..., 1:], pt.exp(value[..., 1:]))
x = pt.set_subtensor(x[..., 0], value[..., 0])
x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:]))
return pt.cumsum(x, axis=-1)

def forward(self, value, *inputs):
y = pt.zeros(value.shape)
y = pt.inc_subtensor(y[..., 0], value[..., 0])
y = pt.inc_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1]))
y = pt.set_subtensor(y[..., 0], value[..., 0])
y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1]))
return y

def log_jac_det(self, value, *inputs):
Expand Down

0 comments on commit eb7c3b6

Please sign in to comment.