-
-
Notifications
You must be signed in to change notification settings - Fork 987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for named dims (torchdim
)
#3347
base: dev
Are you sure you want to change the base?
Conversation
] | ||
if bind_named_dims: | ||
result = result[bind_named_dims] | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unit tests for distribution shapes for log_prob, mean, sample, rsample, entropy (fail when named and positional dims are mixed in the batch/event/sample shape; conflicting named dims)
Generalize named dim binding implementation.
Test transforms and support.
Shape inference.
@@ -215,3 +215,7 @@ def __init__(self, tensor): | |||
|
|||
def __getitem__(self, args): | |||
return vindex(self._tensor, args) | |||
|
|||
|
|||
def index_select(input, dim, index): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add type annotation. Move to contrib/named in the follow up PR.
Few notes about this PR
functorch.dim
messes uptorch.Tensor
a bit (e.g., thetorch.Tensor.split
method starts to fail). Therefore, I siloed thefunctorch.dim
import from the main Pyropyro/contrib/named
from doctest in order to avoid the import offunctorch.dim
hasattr(self.dim, "is_bound")
instead ofisinstance(self.dim, Dim)
Trace_ELBO
implementation is actually similar toTraceGraph_ELBO
with dependency tracking with the goal of having a singleTrace_ELBO
implementation that will generalize toTraceEnum_ELBO
and others.pyro.enable_validation
needs to be set toFalse
to avoid validation errors caused byif value.all()
methodDistribution
arguments all need to be bound by named dim, otherwise broadcasting attempt of parameters by a distribution will lead to segmentation fault