Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inductor: align inductor behavior with eager mode for split_with_sizes #99702

Closed

Conversation

XiaobingSuper
Copy link
Collaborator

@XiaobingSuper XiaobingSuper commented Apr 21, 2023

Stack from ghstack (oldest at bottom):

Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node
    return node.target(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split
    return tensor.split(split_size_or_sections, dim)
  File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
  File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch
    return decomposition_table[func](*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes
    raise ValueError(
ValueError: Split sizes don't add up to the tensor's size in the given dimension

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value
    return wrap_fake_exception(
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
    return fn()
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}):
Split sizes don't add up to the tensor's size in the given dimension
(scroll up for backtrace)

The above exception was the direct cause of the following exception:

cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

…t_with_sizes"



Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR:

```
Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node
    return node.target(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split
    return tensor.split(split_size_or_sections, dim)
  File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
  File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch
    return decomposition_table[func](*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes
    raise ValueError(
ValueError: Split sizes don't add up to the tensor's size in the given dimension

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value
    return wrap_fake_exception(
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
    return fn()
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}):
Split sizes don't add up to the tensor's size in the given dimension
(scroll up for backtrace)

The above exception was the direct cause of the following exception:
```

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Apr 21, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99702

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 8ab1336:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

XiaobingSuper added a commit that referenced this pull request Apr 21, 2023
ghstack-source-id: 47d7d6517a1590fafa45b17f5375a3f51f970c8d
Pull Request resolved: #99702
@XiaobingSuper XiaobingSuper requested review from desertfire and removed request for albanD April 21, 2023 06:50
@XiaobingSuper XiaobingSuper linked an issue Apr 21, 2023 that may be closed by this pull request
XiaobingSuper added a commit that referenced this pull request Apr 21, 2023
ghstack-source-id: 840658dc2b23a87a6162cd09a68c21e03a4babfa
Pull Request resolved: #99702
…t_with_sizes"



Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR:

```
Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node
    return node.target(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split
    return tensor.split(split_size_or_sections, dim)
  File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
  File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch
    return decomposition_table[func](*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes
    raise ValueError(
ValueError: Split sizes don't add up to the tensor's size in the given dimension

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value
    return wrap_fake_exception(
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
    return fn()
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}):
Split sizes don't add up to the tensor's size in the given dimension
(scroll up for backtrace)

The above exception was the direct cause of the following exception:
```

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
@@ -1087,10 +1087,21 @@ def prod(x: List[int]):
return r


def sum(x: List[int]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use the builtin sum?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, changed.

def fn(a):
return torch.split(a, [2, 1, 1], dim=1)

with self.assertRaisesRegex(RuntimeError, ""):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify the regex pattern here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, shouldn't this be a ValueError?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used ValueError, but it seems catch RuntimeError firstly even the error log is:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node
    return node.target(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split
    return tensor.split(split_size_or_sections, dim)
  File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
  File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch
    return decomposition_table[func](*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes
    raise ValueError(
ValueError: Split sizes don't add up to the tensor's size in the given dimension

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value
    return wrap_fake_exception(
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
    return fn()
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}):
Split sizes don't add up to the tensor's size in the given dimension
(scroll up for backtrace)

The above exception was the direct cause of the following exception:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugh, I see

XiaobingSuper added a commit that referenced this pull request Apr 21, 2023
ghstack-source-id: 789217f504e980a424eb6567198081a610228fb7
Pull Request resolved: #99702
…t_with_sizes"



Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR:

```
Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node
    return node.target(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split
    return tensor.split(split_size_or_sections, dim)
  File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
  File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch
    return decomposition_table[func](*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes
    raise ValueError(
ValueError: Split sizes don't add up to the tensor's size in the given dimension

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value
    return wrap_fake_exception(
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
    return fn()
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}):
Split sizes don't add up to the tensor's size in the given dimension
(scroll up for backtrace)

The above exception was the direct cause of the following exception:
```

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
XiaobingSuper added a commit that referenced this pull request Apr 21, 2023
ghstack-source-id: e9597778b26b2a16a62cdf8912bcbd95fbfa899a
Pull Request resolved: #99702
@EikanWang
Copy link
Collaborator

but inductor can run

Do you mean there is no correctness issue?

@EikanWang EikanWang self-requested a review April 21, 2023 10:02
@XiaobingSuper
Copy link
Collaborator Author

but inductor can run

Do you mean there is no correctness issue?

For the definition of split_with_size, we should report an error to the user for the given case.

@XiaobingSuper XiaobingSuper added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 24, 2023
@XiaobingSuper
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/XiaobingSuper/100/head branch June 8, 2023 14:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[torch.compile] WRONG VALUE for split+cat
7 participants