Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix contrib.funsor.Trace_EnumELBO model enumeration #3063

Merged
merged 1 commit into from
Apr 9, 2022

Conversation

ordabayevy
Copy link
Member

Pair coded with @fritzo

Resolves #3046

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

@ordabayevy
Copy link
Member Author

I've found out why the tests used to pass before with an older version of partial_sum_product. The innermost plates would still be eliminated even though they are not in the eliminate. It happens on the second to last line below:

# old partial_sum_product
    while ordinal_to_factors:
        leaf = max(ordinal_to_factors, key=len)
        ...
                if new_plates == leaf:
                    raise ValueError("intractable!")
                f = f.reduce(prod_op, leaf - new_plates)  # innermost "data" plate in test_elbo_enumerate_plate_1 eliminated here
                ordinal_to_factors[new_plates].append(f)

which is probably less intuitive then the current version of partial_sum_product where all eliminated plates have to be passed explicitly in eliminate.

@fritzo fritzo merged commit c531396 into dev Apr 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Some pyro.contrib.funsor tests are failing in torch==1.11.0
2 participants