New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add broadcast_shapes() function and use it in MultivariateNormal #43935
Conversation
@albanD should I move this implementation into C++? Can you or anyone provide guidance on ensuring jit compatibility? |
💊 CI failures summary and remediationsAs of commit 3047d7c (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
🚧 1 ongoing upstream failure:These were probably caused by upstream breakages that are not fixed yet:
🚧 1 fixed upstream failure:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
Such code will be broken:
So C++ implementation is probably needed. |
Just checked that previous version also bombs with such user code, so it was a bad example 🙂
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new function looks quite good to me.
What is the code sample printing? I am not sure to understand what the issue is here? Tracing is not working properly?
@albanD, I'm going to assume you've got this, shout if you need some more eyes. |
@albanD Yes,tracing gets broken with this implementation. The snippet will not print anything but crash with bad expand. |
@suo any idea who would be the best person to debug this tracing issue? |
torch/functional.py
Outdated
result = [1] * max(map(len, shapes)) | ||
for shape in shapes: | ||
for i in range(-1, -1 - len(shape), -1): | ||
if shape[i] == 1 or shape[i] == result[i]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure moving to c++ would help.
Also this is just working with sizes right? Not Tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we're working with sizes, but torch.Size.__getitem__
returns torch.Tensor with zero dimension IIRC.
The point here is that we can pass tensor with batch shape 1 as tracer argument, which will select continue
control flow, and during trace replay with tensor whose batch shape is greater than one, this expansion will not trigger, leading to wrong resulting shape.
Same applies to exception raising few lines below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that adding this function to TorchScript might work... However I am not sure if it will work with shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho i see, because the function does not use Tensors at all, it does nothing from the point of view of tracing and the size that is computed is just used as a constant.
While the old code that was using only Tensor based functions did not had this issue.
I could see two options here:
- special case the MultivariateNormal code to use the more efficient code only if you're not tracing
- make this work with tracing, but I don't think that would be possible...
@suo might have a better idea?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not exactly constant, seems like tracing accounts for sizes as well:
def foo(a):
return a.shape[0]
traced = torch.jit.trace(foo, (torch.ones(3,4,5),))
traced(torch.ones(17))
>>>: tensor(17)
traced(torch.zeros(7,6,5))
>>>: tensor(7)
That is why I thought that it might be possible to make such function in C++. However, there's clearly a limitation on a traced function inputs. Hopefully @suo can help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the code sample, I see what you mean now by "torch.Size.getitem returns torch.Tensor with zero dimension IIRC". That is a dirty trick...
So I think that the regular limitations of tracing apply here: the if statements and for loop won't be properly considered :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds to me like the easiest solution is to move this function into C++.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought, the most expedient solution is to settle on an interface and use a workaround that creates a single-element temporary tensor. We can later move this to C++ if needed.
Can someone point me to the C++ files I would need to update to move this to C++ and thereby support usage in jit tracing and scripting? I believe I will need to:
|
You can add a python binding here to make it show up in torch._C
Not sure what registration you are referring to, could you clarify? |
@gmagogsfm This function needs to be not only added as general binding (in that case it is basically the same as putting it in torch.functional, but in C++), but also registered as TorchScript function so that tracing works. I looked through the codebase and it seems that the latter includes adding function to |
Hi @fatvlady , Sorry this somehow got off my radar. It is still not clear to me what you are looking for, do you want to chat over VC or slack to clarify? |
Hi @fritzo! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but we do not have a signature on file. In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks! |
@gmagogsfm Sorry that my response took that long. We can discuss that on a Slack but I am not invited to PyTorch Slack. IIRC we need to add this function to C++ registered operations so it is not inlined along execution path when being traced with |
torch/functional.py
Outdated
RuntimeError: If shapes are incompatible. | ||
""" | ||
# TODO Consider moving this to C++ once the jit has better support for torch.Size. | ||
scalar = torch.zeros((), device="cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure but maybe it is better to add here with torch.no_grad()
so that no extra autograd tracing is performed?
Also, looking at original broadcast_shapes
, there's additional handle_torch_function
call when scripting is disabled. I am not sure if there's some kind of function handler for torch.Size
so probably the only check that is needed is that arguments are of type torch.Size
. On the other hand, tuples of ints are ok arguments too, so I'm hesitating here. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I've added torch.no_grad()
.
I don't understand the purpose of the handle_torch_function
. However jit usage should be exercised in the jit tests of MultivariateNormal
, so I guess we can see if all tests pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think distribution tests use jit.trace
so jit.script
is uncovered. However it means that distributions might be not working with jit.script
aside from this PR :)
@fritzo Thanks for reviving this PR! I added few minor comments to your changes, otherwise looks good to me 🙂 |
@fatvlady thanks for reviewing! I believe remaining test failures are unrelated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@neerajprad merged this pull request in 313e77f. |
…orch#43935) Summary: Fixes pytorch#43837 This adds a `torch.broadcast_shapes()` function similar to Pyro's [broadcast_shape()](https://github.com/pyro-ppl/pyro/blob/7c2c22c10dffda8a33ffbd593cc8d58819959e40/pyro/distributions/util.py#L151) and JAX's [lax.broadcast_shapes()](https://jax.readthedocs.io/en/test-docs/_modules/jax/lax/lax.html). This helper is useful e.g. in multivariate distributions that are parameterized by multiple tensors and we want to `torch.broadcast_tensors()` but the parameter tensors have different "event shape" (e.g. mean vectors and covariance matrices). This helper is already heavily used in Pyro's distribution codebase, and we would like to start using it in `torch.distributions`. - [x] refactor `MultivariateNormal`'s expansion logic to use `torch.broadcast_shapes()` - [x] add unit tests for `torch.broadcast_shapes()` - [x] add docs cc neerajprad Pull Request resolved: pytorch#43935 Reviewed By: bdhirsh Differential Revision: D25275213 Pulled By: neerajprad fbshipit-source-id: 1011fdd597d0a7a4ef744ebc359bbb3c3be2aadc
Fixes #43837
This adds a
torch.broadcast_shapes()
function similar to Pyro's broadcast_shape() and JAX's lax.broadcast_shapes(). This helper is useful e.g. in multivariate distributions that are parameterized by multiple tensors and we want totorch.broadcast_tensors()
but the parameter tensors have different "event shape" (e.g. mean vectors and covariance matrices). This helper is already heavily used in Pyro's distribution codebase, and we would like to start using it intorch.distributions
.Tasks
torch.broadcast_shapes()
function similar to Pyro's broadcast_shape().MultivariateNormal
's expansion logic to usetorch.broadcast_shapes()
torch.broadcast_shapes()
cc @neerajprad