Skip to content

Conversation

ColinPeppler
Copy link
Contributor

@ColinPeppler ColinPeppler commented Apr 2, 2024

Context

Suppose we have two symbols: u0 and s0 where we know that u0 = s0. Now, let's say we tried to look up the size hint for u0 + 1.

  • Before this PR, we would use a fallback hint if one was provided.

    def symbolic_hint(self, expr: Expr) -> Expr:
    # Substitute all hints into expr, but leave unbacked symints alone

  • With this PR, we would try to replace u0 with s0 via simplify() before using a fallback hint.

    def simplify(self, expr: Expr):
    return sympy.expand(expr).xreplace(self.replacements)

Concrete Example

A scenario where this is useful is when we're running autotuning benchmarking on bmm with two input nodes: one who has s0 as the batch size and one who has u0 as the batch size. During benchmarking, we'll create two example input tensors where the input with u0 has to use a fallback hint for batch size. This will lead to a mismatch.

example_inputs_extern = [
torch.as_strided(
unique_example_inputs[input_node.get_name()],
V.graph.sizevars.size_hints(
input_node.get_size(),
fallback=config.unbacked_symint_fallback,
),

Using the fallback hint (i.e. 8192) leads to a batch size mismatch.

# Note: s0 = 7 and u0 = 7 and fallback hint is 8192.
LoweringException: ErrorFromChoice: Expected size for first two dimensions of batch2 tensor to be: [7, 30] but got: [8192, 30].
From choice ExternKernelCaller(extern_kernels.bmm)

Differential Revision: D55619331

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @amjames @desertfire @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123140

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 35fcbbb with merge base f15fd65 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55619331

ColinPeppler added a commit to ColinPeppler/pytorch that referenced this pull request Apr 2, 2024
Summary: Pull Request resolved: pytorch#123140

Test Plan: tbd

Differential Revision: D55619331
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55619331

@ColinPeppler ColinPeppler changed the title [inductor] simplify expr before looking up hint [inductor] simplify expr when looking up size hint Apr 2, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55619331

@ColinPeppler
Copy link
Contributor Author

Hi @peterbell10, should this change be okay? or does this look like a fix that should be happening in Dynamo?

Copy link
Contributor

@aakhundov aakhundov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the bmm's shape is wrong in the comment? Should be [s0, 16, 32]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

woops forgot to update lol

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment above this line mentioning that here s0 and u0 are unified?

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 3, 2024
@aakhundov
Copy link
Contributor

or does this look like a fix that should be happening in Dynamo?

I think, it's fine to do the replacements inside size_hint? cc @ezyang.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55619331

Summary:
## Context

Suppose we have two symbols: `u0` and `s0` where we know that `u0 = s0`. Now, let's say we tried to look up the size hint for `u0 + 1`. 
* Before this PR, we would use a fallback hint if one was provided.
https://github.com/pytorch/pytorch/blob/3f6acf65fd9b6094513cf28898a42b90dd1169a0/torch/_inductor/sizevars.py#L406-L407

* With this PR, we would try to replace `u0` with `s0` via `simplify()` before using a fallback hint. https://github.com/pytorch/pytorch/blob/3f6acf65fd9b6094513cf28898a42b90dd1169a0/torch/_inductor/sizevars.py#L46-L47

## Concrete Example
A scenario where this is useful is when we're running autotuning benchmarking on bmm with two input nodes: one who has `s0` as the batch size and one who has `u0` as the batch size. During benchmarking, we'll create two example input tensors where the input with `u0` has to use a fallback hint for batch size. This will lead to a mismatch.

https://github.com/pytorch/pytorch/blob/e3d80f2fa98d7ab02f88023d381b2e5981dd99ff/torch/_inductor/select_algorithm.py#L991-L997

Using the fallback hint (i.e. 8192) leads to a batch size mismatch.


```python
# Note: s0 = 7 and u0 = 7 and fallback hint is 8192.
LoweringException: ErrorFromChoice: Expected size for first two dimensions of batch2 tensor to be: [7, 30] but got: [8192, 30].
From choice ExternKernelCaller(extern_kernels.bmm)
```




Test Plan:
CI

```
$ CUDA_VISIBLE_DEVICES=0 python test/inductor/test_unbacked_symints.py -k test_equivalent_backed_unbacked_cuda

### Before ###
  File "torch/_inductor/select_algorithm.py", line 964, in __call__
    timings = do_autotuning(precompile_fn)
  File "torch/_inductor/select_algorithm.py", line 911, in do_autotuning
    timings = self.lookup(
  File "torch/_inductor/codecache.py", line 306, in lookup
    raise e
  File "torch/_inductor/codecache.py", line 297, in lookup
    timings = benchmark(choices)
  File "torch/_inductor/select_algorithm.py", line 897, in autotune
    return make_benchmark_fn()(choices)
  File "torch/_inductor/select_algorithm.py", line 1068, in benchmark_in_current_process
    raise ErrorFromChoice(msg, choice, debug_str())  # noqa: TRY200

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: ErrorFromChoice: Expected size for first two dimensions of batch2 tensor to be: [7, 30] but got: [8192, 30].

From choice ExternKernelCaller(extern_kernels.bmm)
inputs = [
    torch.empty_strided((7, 30, 16), (480, 16, 1), dtype=torch.float32, device='cuda'),
    torch.empty_strided((30, 32), (32, 1), dtype=torch.float32, device='cuda'),
]


### After ###
----------------------------------------------------------------------
Ran 1 test in 4.627s

OK
```

Reviewed By: tissue3, aakhundov

Differential Revision: D55619331
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D55619331

@ColinPeppler
Copy link
Contributor Author

CI passes in OSS and internally. Also, we were worried that self.simplify(expr) will fail if expr is an int but sympy.expand(expr) can handle if expr is an int.

@ColinPeppler
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

return sympy_subs(expr, self.var_to_val)

def size_hint(self, expr: Expr, *, fallback: Optional[int] = None) -> int:
expr = self.simplify(expr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in symbolic_hint but otherwise this LGTM

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ColinPeppler this wasn't fixed

@clee2000
Copy link
Contributor

clee2000 commented Apr 4, 2024

@pytorchbot merge -f "merged internally"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
## Context

Suppose we have two symbols: `u0` and `s0` where we know that `u0 = s0`. Now, let's say we tried to look up the size hint for `u0 + 1`.
* Before this PR, we would use a fallback hint if one was provided.
https://github.com/pytorch/pytorch/blob/3f6acf65fd9b6094513cf28898a42b90dd1169a0/torch/_inductor/sizevars.py#L406-L407

* With this PR, we would try to replace `u0` with `s0` via `simplify()` before using a fallback hint. https://github.com/pytorch/pytorch/blob/3f6acf65fd9b6094513cf28898a42b90dd1169a0/torch/_inductor/sizevars.py#L46-L47

## Concrete Example
A scenario where this is useful is when we're running autotuning benchmarking on bmm with two input nodes: one who has `s0` as the batch size and one who has `u0` as the batch size. During benchmarking, we'll create two example input tensors where the input with `u0` has to use a fallback hint for batch size. This will lead to a mismatch.

https://github.com/pytorch/pytorch/blob/e3d80f2fa98d7ab02f88023d381b2e5981dd99ff/torch/_inductor/select_algorithm.py#L991-L997

Using the fallback hint (i.e. 8192) leads to a batch size mismatch.

```
# Note: s0 = 7 and u0 = 7 and fallback hint is 8192.
LoweringException: ErrorFromChoice: Expected size for first two dimensions of batch2 tensor to be: [7, 30] but got: [8192, 30].
From choice ExternKernelCaller(extern_kernels.bmm)
```

Differential Revision: D55619331

Pull Request resolved: pytorch#123140
Approved by: https://github.com/aakhundov
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants