-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make ProvenanceTensor behave more like a Tensor (closes #3218) #3220
Conversation
the data is now stored in the Tensor itself instead of an attribute. This fixes torch.to_tensor returning empty tensors when called with a ProvenanceTensor and and a device as arguments
this is important when using Tensors as keys in a dict, e.g. the Pyro param store
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this subtle fix! This looks good to me (after one minor comment), but I'm unsure how this will interact with other subclasses of tensor.
@ordabayevy could you also take a look as you've thought about this before?
pyro/ops/provenance.py
Outdated
@@ -46,15 +46,21 @@ def __new__(cls, data: torch.Tensor, provenance=frozenset(), **kwargs): | |||
assert not isinstance(data, ProvenanceTensor) | |||
if not provenance: | |||
return data | |||
return super().__new__(cls) | |||
ret = data.view(data.shape) | |||
ret._t = data.view(data.shape) # this makes sure that detach_provenance always |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this line be simplified to
ret._t = data
or would that break something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, thanks. Took me about four tries to get all the tests to pass, this was still a remnant of an earlier attempt.
Would you be able to add a regression test and decorate it with @requires_cuda? It won't run on CI, but it might help future maintainers of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for adding a test.
I'll leave this up a couple days before merging in case @ordabayevy has any comments.
Thanks for holding it up. I'll have a look at this later tonight. |
@ilia-kats thanks for fixing this! What about trying to use class ProvenanceTensor(torch.Tensor):
assert not isinstance(data, ProvenanceTensor)
if not provenance:
return data
- return super().__new__(cls)
+ return torch.Tensor._make_subclass(cls, data) And I believe we can remove instance check from the def __init__(self, data, provenance=frozenset()):
assert isinstance(provenance, frozenset)
- if isinstance(data, ProvenanceTensor):
- provenance |= data._provenance
- data = data._t
self._t = data
self._provenance = provenance |
also remove unnecessary check in __init__
@ordabayevy Thanks for the comment. I actually |
@ordabayevy ready to merge? I'll release today or tomorrow and will include this PR in the release |
Yeah, lgtm. |
The data is now stored in the Tensor itself instead of an attribute. This fixes torch.to_tensor returning empty tensors when called with a ProvenanceTensor and and a device as arguments.
This is super hacky, but I couldn't come up with a cleaner way. Note that this is the only way to use
pyro.infer.inspect.get_dependencies
when training on GPUs (I'm using it in a custom Messenger guide), since thelog_prob
function of some distributions (for example Gamma) callstorch.to_tensor
.