Skip to content

Commit

Permalink
Update on "inductor: align inductor behavior with eager mode for spli…
Browse files Browse the repository at this point in the history
…t_with_sizes"



Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR:

```
Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node
    return node.target(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split
    return tensor.split(split_size_or_sections, dim)
  File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
  File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch
    return decomposition_table[func](*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes
    raise ValueError(
ValueError: Split sizes don't add up to the tensor's size in the given dimension

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value
    return wrap_fake_exception(
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
    return fn()
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}):
Split sizes don't add up to the tensor's size in the given dimension
(scroll up for backtrace)

The above exception was the direct cause of the following exception:
```

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
  • Loading branch information
XiaobingSuper committed Apr 21, 2023
2 parents 72dbccd + d23d25b commit 8ab1336
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def is_valid_splitwithsizes_cat(match):
cat_items_args_order = [
get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0)
]
if cat_items_args_order != list(get_item_args):
if cat_items_args_order != list(range(len(split_sizes))):
return False

return True
Expand Down

0 comments on commit 8ab1336

Please sign in to comment.