Skip to content

Add experimental dims module with objects that follow dim-based semantics (like xarray) #7820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jun 17, 2025

This builds on top of PyTensor xtensor module, to introduce distributions and model objects that follow xarray-like semantics. Example model:

import numpy as np
import pymc as pm
import pymc.dims as pmd

# Very realistic looking data!
observed_response_np = np.ones((5, 20), dtype=int)
coords = coords = {
    "participant": range(5),
    "trial": range(20),
    "item": range(3),
}
with pm.Model(coords=coords) as dmodel:
    observed_response = pmd.Data(
        "observed_response", observed_response_np, dims=("participant", "trial")
    )

    # Participant constant preferences for each item
    participant_preference = pmd.ZeroSumNormal(
        "participant_preference", 
        core_dims="item", 
        dims=("participant", "item"),
    )

    # Shared time effects across all participants
    time_effects = pmd.Normal("time_effects", dims=("item", "trial"))

    trial_preference = pmd.Deterministic(
        "trial_pereference",
        participant_preference + time_effects,
        dims=(...,),  # No need to specify, PyMC knows them
    )

    response = pmd.Categorical(
        "response",
        p=pmd.math.softmax(trial_preference, dim="item"),
        core_dims="item",
        observed=observed_response,
        dims=(...,), # No need to specify, PyMC knows them
    )

Equivalently, with the traditional API:

with pm.Model(coords=coords) as model:
    observed_response = pm.Data(
        "observed_response", observed_response_np, dims=("participant", "trial")
    )

    # Participant constant preferences for each item
    participant_preference = pm.ZeroSumNormal(
        "participant_preference", 
        n_zerosum_axes=1,
       dims=("participant", "item"),
    )

    # Shared time effects across all participants
    time_effects = pm.Normal("time_effects", dims=("trial", "item"))

    trial_preference = pm.Deterministic(
        "trial_preference",
        participant_preference[:, None, :] + time_effects[None, :, :],
        dims=("participant", "trial", "item"),
    )

    response = pm.Categorical(
        "response",
        p=pm.math.softmax(trial_preference, axis=-1),
        observed=observed_response,
        dims=("participant", "trial"),
    )

More details in the new core notebook

day_of_conception = datetime.date(2025, 6, 17)
day_of_last_bug = datetime.date(2025, 6, 17)
today = datetime.date.today()
days_with_bugs = (day_of_last_bug - day_of_conception).days
Copy link
Member

@twiecki twiecki Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wtf 😆

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has two purposes: distract reviewers so they don't focus on the critical changes, and prove that OSS libraries can't be fun.

@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 9 times, most recently from 1cfde5b to f571e5d Compare June 21, 2025 15:54
@twiecki
Copy link
Member

twiecki commented Jun 21, 2025

Can this index using labels? x["a"]

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jun 21, 2025

Can this index using labels? x["a"]

I don't know what x["a"] means :).

Is "a" a coordinate? x.loc["a"] would be the xarray syntax? You can't do that.

Like in xarray, you can do x.isel(dim=idxs) or x[{dim: idxs}].

You cannot do x.sel(dim=coords) or x.loc[coords]

The new PyTensor objects have dims but not coords. It's not trivial to encode coord based operations in our backends.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 3 times, most recently from 3eb6738 to 67a0eda Compare June 30, 2025 07:20
@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 2 times, most recently from 3234468 to d4017fb Compare June 30, 2025 12:54
@twiecki
Copy link
Member

twiecki commented Jun 30, 2025

We should make this the 6.0 release.

@ricardoV94
Copy link
Member Author

We should make this the 6.0 release.

I agree, but would perhaps wait until we beta-tested it to the point it no longer feels too experimental

@ricardoV94 ricardoV94 changed the title Model with dims Add experimental dims module with objects that follow dim-based semantics (like xarray) Jun 30, 2025
@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 4 times, most recently from f1f7478 to 56c05f1 Compare June 30, 2025 16:17
Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding one extra comment about myst syntax but outside reviewnb because it always messes up rendering of backticks.

I will try to play around and build/port some models one of these days, and look at the code itself while I do that

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inside the notebook you have to use myst syntax, so you'll have to replace the :reftype:`reftarget` with {reftype}`reftarget`

@@ -0,0 +1,1666 @@
{
Copy link
Member

@OriolAbril OriolAbril Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is too high up in the notebook as I envision this as something beginner users should not do, but I think it would be nice to show how to get array values for an XTensorVariable into their corresponding DataArray, the xr.DataArray(pm.draw(...), dims=outer_addition.type.dims) . Maybe this could go into a dims FAQ? It could be started at the bottom of this page and if we see it grows extend it to their own page.

Or start a global FAQ page with whichever questions/answers in the discord post are still relevant and add a pymc.dims section. Feel free to make an issue for this to keep the PR scope clear.


Reply via ReviewNB

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I think we can sneak it in at the end where I mention coordinates. Add a "why are the variables not producing DataArrays directly" with this recipe

@@ -0,0 +1,1666 @@
{
Copy link
Member

@OriolAbril OriolAbril Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nits, it currently says isel won't be supported instead of sel. for the link to the dist method to work once we add api docs for the module it needs to be :meth:


Reply via ReviewNB

@@ -0,0 +1,1666 @@
{
Copy link
Member

@OriolAbril OriolAbril Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a tilde before pytensor.xtensor... (still inside the tildes) so it is rendered as values only instead of the full name. We should also double check the attribute is documented explicitly in the pytensor docs, otherwise this won't be rendered as a cross-reference


Reply via ReviewNB

@@ -0,0 +1,1666 @@
{
Copy link
Member

@OriolAbril OriolAbril Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given we have just looked at the number of nodes as a measure of model complexity/efficiency. I think we should show that to for the dims_splines_model to show it is roughly the same as the vectorized model, but much easier to read, and (arguably?) easier to code too


Reply via ReviewNB

@@ -0,0 +1,1666 @@
{
Copy link
Member

@OriolAbril OriolAbril Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am skipping this last part for now to see if we converge on the custom object dim approach on pytensor which should make this a non-issue


Reply via ReviewNB

@ricardoV94 ricardoV94 force-pushed the model_with_dims branch 3 times, most recently from 116f803 to a4273c0 Compare July 3, 2025 14:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants