Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add broadcast_shapes() function and use it in MultivariateNormal #43935

Closed
wants to merge 13 commits into from

Conversation

fritzo
Copy link
Collaborator

@fritzo fritzo commented Sep 1, 2020

Fixes #43837

This adds a torch.broadcast_shapes() function similar to Pyro's broadcast_shape() and JAX's lax.broadcast_shapes(). This helper is useful e.g. in multivariate distributions that are parameterized by multiple tensors and we want to torch.broadcast_tensors() but the parameter tensors have different "event shape" (e.g. mean vectors and covariance matrices). This helper is already heavily used in Pyro's distribution codebase, and we would like to start using it in torch.distributions.

Tasks

  • add a torch.broadcast_shapes() function similar to Pyro's broadcast_shape().
  • refactor MultivariateNormal's expansion logic to use torch.broadcast_shapes()
  • add unit tests for torch.broadcast_shapes()
  • add docs

cc @neerajprad

@fritzo
Copy link
Collaborator Author

fritzo commented Sep 1, 2020

@albanD should I move this implementation into C++? Can you or anyone provide guidance on ensuring jit compatibility?

@dr-ci
Copy link

dr-ci bot commented Sep 1, 2020

💊 CI failures summary and remediations

As of commit 3047d7c (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



🚧 1 ongoing upstream failure:

These were probably caused by upstream breakages that are not fixed yet:


🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is newer than viable/strict, you can try basing on an older, stable commit:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase --onto FETCH_HEAD $(git merge-base origin/master HEAD)

If your commit is older than viable/strict:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Check out the recency history of this "viable master" tracking branch.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 53 times.

@fatvlady
Copy link

fatvlady commented Sep 1, 2020

Such code will be broken:

    def foo(loc, cov, target):
        dist = MultivariateNormal(loc=loc, covariance_matrix=cov)
        return dist.log_prob(target)

    loc = torch.randn(1000, 800)
    target = torch.randn(1000, 800)
    cov = target.transpose(-1, -2) @ target / 1000
    loc2 = loc[0, :]
    target2 = target[0, :]
    foo_traced = torch.jit.trace(foo, (loc2, cov, target2))
    print(foo_traced(loc, cov, target).shape)

So C++ implementation is probably needed.

@fatvlady
Copy link

fatvlady commented Sep 1, 2020

Just checked that previous version also bombs with such user code, so it was a bad example 🙂
It seems that it is current limitation of distributions that we cannot introduce new dimensions?
Anyway, here's corrected snippet which shows potential regression:

    def foo(loc, cov, target):
        dist = MultivariateNormal(loc=loc, covariance_matrix=cov)
        return dist.log_prob(target), dist.loc

    loc = torch.randn(1000, 800)
    target = torch.randn(1000, 800)
    cov = target.transpose(-1, -2) @ target / 1000
    loc2 = loc[:1, :]
    target2 = target[:1, :]
    foo_traced = torch.jit.trace(foo, (loc2, cov, target2))
    print(foo_traced(loc, cov, target)[0].shape)
    print(foo_traced(loc, cov, target)[1].shape)

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new function looks quite good to me.

What is the code sample printing? I am not sure to understand what the issue is here? Tracing is not working properly?

torch/functional.py Outdated Show resolved Hide resolved
torch/functional.py Outdated Show resolved Hide resolved
@ezyang
Copy link
Contributor

ezyang commented Sep 1, 2020

@albanD, I'm going to assume you've got this, shout if you need some more eyes.

@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 1, 2020
@fatvlady
Copy link

fatvlady commented Sep 1, 2020

@albanD Yes,tracing gets broken with this implementation. The snippet will not print anything but crash with bad expand.

@albanD
Copy link
Collaborator

albanD commented Sep 1, 2020

@suo any idea who would be the best person to debug this tracing issue?

result = [1] * max(map(len, shapes))
for shape in shapes:
for i in range(-1, -1 - len(shape), -1):
if shape[i] == 1 or shape[i] == result[i]:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD @fritzo I am pretty sure tracing issue comes from Tensor-to-bool conversion here. I am not sure if it is possible to fix that in python, I am not very familiar with the codebase.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure moving to c++ would help.
Also this is just working with sizes right? Not Tensor?

Copy link

@fatvlady fatvlady Sep 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we're working with sizes, but torch.Size.__getitem__ returns torch.Tensor with zero dimension IIRC.
The point here is that we can pass tensor with batch shape 1 as tracer argument, which will select continue control flow, and during trace replay with tensor whose batch shape is greater than one, this expansion will not trigger, leading to wrong resulting shape.
Same applies to exception raising few lines below.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that adding this function to TorchScript might work... However I am not sure if it will work with shapes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho i see, because the function does not use Tensors at all, it does nothing from the point of view of tracing and the size that is computed is just used as a constant.
While the old code that was using only Tensor based functions did not had this issue.

I could see two options here:

  • special case the MultivariateNormal code to use the more efficient code only if you're not tracing
  • make this work with tracing, but I don't think that would be possible...

@suo might have a better idea?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not exactly constant, seems like tracing accounts for sizes as well:

def foo(a):
    return a.shape[0]
traced = torch.jit.trace(foo, (torch.ones(3,4,5),))
traced(torch.ones(17))
>>>: tensor(17)
traced(torch.zeros(7,6,5))
>>>: tensor(7)

That is why I thought that it might be possible to make such function in C++. However, there's clearly a limitation on a traced function inputs. Hopefully @suo can help.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the code sample, I see what you mean now by "torch.Size.getitem returns torch.Tensor with zero dimension IIRC". That is a dirty trick...
So I think that the regular limitations of tracing apply here: the if statements and for loop won't be properly considered :/

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds to me like the easiest solution is to move this function into C++.

Copy link
Collaborator Author

@fritzo fritzo Nov 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, the most expedient solution is to settle on an interface and use a workaround that creates a single-element temporary tensor. We can later move this to C++ if needed.

@fritzo
Copy link
Collaborator Author

fritzo commented Sep 3, 2020

Can someone point me to the C++ files I would need to update to move this to C++ and thereby support usage in jit tracing and scripting? I believe I will need to:

  • add an implementation somewhere
  • add a function registration so this shows up in torch._C
  • add some jit registrations

@gmagogsfm
Copy link
Contributor

gmagogsfm commented Sep 3, 2020

Can someone point me to the C++ files I would need to update to move this to C++ and thereby support usage in jit tracing and scripting? I believe I will need to:

  • add an implementation somewhere
  • add a function registration so this shows up in torch._C

You can add a python binding here to make it show up in torch._C

  • add some jit registrations

Not sure what registration you are referring to, could you clarify?

@fatvlady
Copy link

fatvlady commented Sep 4, 2020

@gmagogsfm This function needs to be not only added as general binding (in that case it is basically the same as putting it in torch.functional, but in C++), but also registered as TorchScript function so that tracing works. I looked through the codebase and it seems that the latter includes adding function to native_functions.yaml and also registering dispatch. I am not sure if there's a guide on native extensions or example to follow.

@gmagogsfm
Copy link
Contributor

@gmagogsfm This function needs to be not only added as general binding (in that case it is basically the same as putting it in torch.functional, but in C++), but also registered as TorchScript function so that tracing works. I looked through the codebase and it seems that the latter includes adding function to native_functions.yaml and also registering dispatch. I am not sure if there's a guide on native extensions or example to follow.

Hi @fatvlady ,

Sorry this somehow got off my radar. It is still not clear to me what you are looking for, do you want to chat over VC or slack to clarify?

@facebook-github-bot
Copy link
Contributor

Hi @fritzo!

Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but we do not have a signature on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@fatvlady
Copy link

@gmagogsfm Sorry that my response took that long. We can discuss that on a Slack but I am not invited to PyTorch Slack. IIRC we need to add this function to C++ registered operations so it is not inlined along execution path when being traced with torch.jit.trace. When this function is in python, it will select dimensions with size > 1, forcing traced function to look only at these dimensions. If there is a dimension of size 1 in trace inputs it will be "ignored" in jitted code, leading to crash of jitted module when inputs have shape > 1 over this dimension.

@fritzo fritzo changed the title Add broadcast_shape() function and use it in MultivariateNormal Add broadcast_shapes() function and use it in MultivariateNormal Nov 30, 2020
@fritzo
Copy link
Collaborator Author

fritzo commented Nov 30, 2020

@fatvlady I've renamed and updated to the simple version we discussed.

torch/distributions/multivariate_normal.py Outdated Show resolved Hide resolved
RuntimeError: If shapes are incompatible.
"""
# TODO Consider moving this to C++ once the jit has better support for torch.Size.
scalar = torch.zeros((), device="cpu")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure but maybe it is better to add here with torch.no_grad() so that no extra autograd tracing is performed?
Also, looking at original broadcast_shapes, there's additional handle_torch_function call when scripting is disabled. I am not sure if there's some kind of function handler for torch.Size so probably the only check that is needed is that arguments are of type torch.Size. On the other hand, tuples of ints are ok arguments too, so I'm hesitating here. What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I've added torch.no_grad().

I don't understand the purpose of the handle_torch_function. However jit usage should be exercised in the jit tests of MultivariateNormal, so I guess we can see if all tests pass?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think distribution tests use jit.trace so jit.script is uncovered. However it means that distributions might be not working with jit.script aside from this PR :)

@fatvlady
Copy link

@fritzo Thanks for reviving this PR! I added few minor comments to your changes, otherwise looks good to me 🙂

@fritzo
Copy link
Collaborator Author

fritzo commented Dec 1, 2020

@fatvlady thanks for reviewing! I believe remaining test failures are unrelated.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@neerajprad merged this pull request in 313e77f.

shaibagon pushed a commit to shaibagon/pytorch that referenced this pull request Dec 3, 2020
…orch#43935)

Summary:
Fixes pytorch#43837

This adds a `torch.broadcast_shapes()` function similar to Pyro's [broadcast_shape()](https://github.com/pyro-ppl/pyro/blob/7c2c22c10dffda8a33ffbd593cc8d58819959e40/pyro/distributions/util.py#L151) and JAX's [lax.broadcast_shapes()](https://jax.readthedocs.io/en/test-docs/_modules/jax/lax/lax.html). This helper is useful e.g. in multivariate distributions that are parameterized by multiple tensors and we want to `torch.broadcast_tensors()` but the parameter tensors have different "event shape" (e.g. mean vectors and covariance matrices). This helper is already heavily used in Pyro's distribution codebase, and we would like to start using it in `torch.distributions`.

- [x] refactor `MultivariateNormal`'s expansion logic to use `torch.broadcast_shapes()`
- [x] add unit tests for `torch.broadcast_shapes()`
- [x] add docs

cc neerajprad

Pull Request resolved: pytorch#43935

Reviewed By: bdhirsh

Differential Revision: D25275213

Pulled By: neerajprad

fbshipit-source-id: 1011fdd597d0a7a4ef744ebc359bbb3c3be2aadc
@facebook-github-bot facebook-github-bot deleted the broadcast-shape branch January 27, 2021 18:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: distributions Related to torch.distributions open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MultivariateNormal backprop performance issue related to broadcasting
8 participants