-
Notifications
You must be signed in to change notification settings - Fork 268
[draft] Event dim labelling primitive for arviz dim labelling #2012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
hey @fehiepsi I'm back with another complicated PR! this is more of a proposal than a PR, but let me know if you think its something numpyro is interested in supporting. Technically the code here works, but it probably has some unnecessary code right now, is missing some edge cases and tests, and I have some open questions I need to sort out |
return graph | ||
|
||
|
||
def get_site_dims( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming ideas appreciated, ie maybe it should be get_named_dims
instead?
return _get_dist_name(fn.base_dist) | ||
return type(fn).__name__ | ||
|
||
def get_trace(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the same as the get_trace function in get_model_relations
. I could externalize it and call that within get_model_relations
and here to reduce code (at the cost of making these 2 inspection tools share a dependency)
self, | ||
name: str, | ||
size: int, | ||
subsample_size: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still considering if size
and subsample_size
are needed. right now they dont do anything, but I could potentially use size
for some form of shape validation
yield | ||
|
||
|
||
class label_event_dim(plate): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be good to break the dependence on plate
and inherit Messenger instead - this was a quick hack to get things working
Thanks @kylejcaron! It is not clear to me why named event dimensions are needed. If arviz needs this, could we implement this in arviz instead? Current numpyro functionalities do not need named event dimensions, afaik. |
In addition, you can add more infos to the |
thanks for taking a look @fehiepsi! youre right numpyro itself wouldnt get additional functionality out of this, the idea would be to open up support for arviz to have automatically named dimensions (helpful for working with numpyro model output)
Ah this is a cool trick I didnt realize thank you! I just tried it out but the main problem with this approach is that it wouldnt work in cases where there are deterministic plates, which dont have an infer statement - these wouldnt be able to have named event dimensions Open to your suggestion here! Here are the options I can think of
|
I think we can add an argument in arviz.from_numpyro to specify event dimensions, like |
sounds good to me since its such a small edge case, I'll close this out - thanks for the feedback on this one! |
This PR isnt quite ready yet, but is this something numpyro would be interested in supporting?
Background
This PR would allow for az.from_numpyro to have a way to automatically get named dims based on all of the model plates (this issue arviz-devs/arviz#2022) and a new primitive for event dims.
ArviZ uses dims and coords to automatically created named dims in their xarray-based datasets. Currently, there is no way to automatically extract dims from a numpyro model, while PyMC does have this feature.
Numpyro plates provide everything we need to automatically extract dims from a model (name of the dimension, and the order of dimensions), but this doesnt cover event dimensions. For example with the ZeroSumNormal distribution, it would be ideal to have something like this:
so that Arviz knows that
alpha_country
has a dim named country.I went into more depth in this blog, but thats a bit long to read (just leaving it for additional context)
Current State
This is very much a draft - it gets the job done for making named event dims accessible for arviz, but I haven't put alot of thought into whether having a size argument is necessary (it probably isnt), and there might be some edge cases and unnecessary code in my implementation right now
Open Questions