Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for named dims (torchdim) #3347

Open
wants to merge 14 commits into
base: dev
Choose a base branch
from
Open

Add support for named dims (torchdim) #3347

wants to merge 14 commits into from

Conversation

ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented Mar 26, 2024

Few notes about this PR

  • Importing functorch.dim messes up torch.Tensor a bit (e.g., the torch.Tensor.split method starts to fail). Therefore, I siloed the functorch.dim import from the main Pyro
    • That's why I also removed pyro/contrib/named from doctest in order to avoid the import of functorch.dim
    • And use hasattr(self.dim, "is_bound") instead of isinstance(self.dim, Dim)
  • Trace_ELBO implementation is actually similar to TraceGraph_ELBO with dependency tracking with the goal of having a single Trace_ELBO implementation that will generalize to TraceEnum_ELBO and others.
  • pyro.enable_validation needs to be set to False to avoid validation errors caused by if value.all() method
  • Distribution arguments all need to be bound by named dim, otherwise broadcasting attempt of parameters by a distribution will lead to segmentation fault

@ordabayevy ordabayevy added the WIP label Mar 27, 2024
]
if bind_named_dims:
result = result[bind_named_dims]
return result
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests for distribution shapes for log_prob, mean, sample, rsample, entropy (fail when named and positional dims are mixed in the batch/event/sample shape; conflicting named dims)

Generalize named dim binding implementation.

Test transforms and support.

Shape inference.

@@ -215,3 +215,7 @@ def __init__(self, tensor):

def __getitem__(self, args):
return vindex(self._tensor, args)


def index_select(input, dim, index):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add type annotation. Move to contrib/named in the follow up PR.

@eb8680 eb8680 self-requested a review April 9, 2024 18:49
@fritzo fritzo removed their request for review April 16, 2024 16:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants