Skip to content

Conversation

kylejcaron
Copy link
Contributor

@kylejcaron kylejcaron commented Mar 29, 2025

This PR isnt quite ready yet, but is this something numpyro would be interested in supporting?

Background

This PR would allow for az.from_numpyro to have a way to automatically get named dims based on all of the model plates (this issue arviz-devs/arviz#2022) and a new primitive for event dims.

ArviZ uses dims and coords to automatically created named dims in their xarray-based datasets. Currently, there is no way to automatically extract dims from a numpyro model, while PyMC does have this feature.

Numpyro plates provide everything we need to automatically extract dims from a model (name of the dimension, and the order of dimensions), but this doesnt cover event dimensions. For example with the ZeroSumNormal distribution, it would be ideal to have something like this:

with numpyro.event_dim_label("country", n_countries):
   alpha_country = dist.ZeroSumNormal(1, event_shape=(n_countries,))

so that Arviz knows that alpha_country has a dim named country.

I went into more depth in this blog, but thats a bit long to read (just leaving it for additional context)

Current State

This is very much a draft - it gets the job done for making named event dims accessible for arviz, but I haven't put alot of thought into whether having a size argument is necessary (it probably isnt), and there might be some edge cases and unnecessary code in my implementation right now

Open Questions

  • Is there any justification for having size or subsample arguments?
  • Is it worth having a function on numpyro's end for extracting site_dims, or should it all exist on ArviZ's side? Ie inspect.get_named_dims(…)
  • Is numpyro even open to having a primitive like this?

@kylejcaron
Copy link
Contributor Author

kylejcaron commented Mar 29, 2025

hey @fehiepsi I'm back with another complicated PR!

this is more of a proposal than a PR, but let me know if you think its something numpyro is interested in supporting.

Technically the code here works, but it probably has some unnecessary code right now, is missing some edge cases and tests, and I have some open questions I need to sort out

return graph


def get_site_dims(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming ideas appreciated, ie maybe it should be get_named_dims instead?

return _get_dist_name(fn.base_dist)
return type(fn).__name__

def get_trace():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the same as the get_trace function in get_model_relations. I could externalize it and call that within get_model_relations and here to reduce code (at the cost of making these 2 inspection tools share a dependency)

self,
name: str,
size: int,
subsample_size: Optional[int] = None,
Copy link
Contributor Author

@kylejcaron kylejcaron Mar 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still considering if size and subsample_size are needed. right now they dont do anything, but I could potentially use size for some form of shape validation

yield


class label_event_dim(plate):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be good to break the dependence on plate and inherit Messenger instead - this was a quick hack to get things working

@fehiepsi
Copy link
Member

Thanks @kylejcaron! It is not clear to me why named event dimensions are needed. If arviz needs this, could we implement this in arviz instead? Current numpyro functionalities do not need named event dimensions, afaik.

@fehiepsi
Copy link
Member

In addition, you can add more infos to the infer field of numpyro.sample statements and inspect it via the trace handler.

@kylejcaron
Copy link
Contributor Author

Thanks @kylejcaron! It is not clear to me why named event dimensions are needed. If arviz needs this, could we implement this in arviz instead? Current numpyro functionalities do not need named event dimensions, afaik.

thanks for taking a look @fehiepsi! youre right numpyro itself wouldnt get additional functionality out of this, the idea would be to open up support for arviz to have automatically named dimensions (helpful for working with numpyro model output)

In addition, you can add more infos to the infer field of numpyro.sample statements and inspect it via the trace handler.

Ah this is a cool trick I didnt realize thank you! I just tried it out but the main problem with this approach is that it wouldnt work in cases where there are deterministic plates, which dont have an infer statement - these wouldnt be able to have named event dimensions

Open to your suggestion here! Here are the options I can think of

  1. Add a primitive/handler for labelling event dims in numpyro like in the PR (adds overhead to numpyro)
    1. Make an external primitive/handler for labelling numpyro event dims that lives in ArviZ (might be a little weird to call an external primitive/handler in a numpyro model? unsure)
  2. Add an infer argument to numpyro.determinstic (might be confusing for people using infer discrete)
  3. Use infer for event dims and miss the edge cases where theres a deterministic site with event dims (which is probably a rare occurence anyway)

@fehiepsi
Copy link
Member

fehiepsi commented Mar 31, 2025

I think we can add an argument in arviz.from_numpyro to specify event dimensions, like event_names={"x": ["group", "cat"]}. That would work for the deterministic case I think.

@kylejcaron
Copy link
Contributor Author

I think we can add an argument in arviz.from_numpyro to specify event dimensions, like event_names={"x": ["group", "cat"]}. That would work for the deterministic case I think.

sounds good to me since its such a small edge case, I'll close this out - thanks for the feedback on this one!

@kylejcaron kylejcaron closed this Mar 31, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants