-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: sampling error when indexing over mutliple dimensions #6380
Comments
Thanks for the clean reproducible bug report! Here is a more direct way to trigger the error model.compile_logp()(model.initial_point()) # AssertionError: SpecifyShape: Got shape (1, 3, 4), expected (3, 3, 4). Tracking issue: pymc-devs/pytensor#98 |
For now I think you can just remove the rightmost semicolons in import pymc as pm
import numpy as np
rng = np.random.default_rng(5648)
# prior dimensions
N = 1
S = 2
T = 3
# indexer
a = np.repeat(np.arange(S), [1,2])
b = np.repeat(np.arange(T), [1,2,1])
K = len(a)
L = len(b)
y = rng.standard_normal((N,K,L))
with pm.Model() as model:
mu = pm.Normal("mu", size=(N, S, T))
mua = mu[:, a]
mub = mua[:, :, b]
y = pm.Normal("y", mub, sigma=1, observed=y)
model.compile_logp()(model.initial_point()) |
This solved problem, also in a more complex application! Lovely :) |
Thanks for the bug report. Closing this as it should be solved by pymc-devs/pytensor#101 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Describe the issue:
When indexing/broadcasting a random variable over multiple dimensions, during sampling, a dimension mismatch is called. However, when evaluating the final array, it matches the dimensionality of the input array. The error does not occur, when indexing is only done over one dimension. I'm not sure where the problem lies. Or even preferably, whether there is a better solution for the task.
Reproduceable code example:
Error message:
PyMC version information:
4.4
Context for the issue:
No response
The text was updated successfully, but these errors were encountered: