Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inductor: align inductor behavior with eager mode for split_with_sizes #99702

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,14 @@ def fn(a, sizes):
self.common(fn, (torch.randn(2, 2, 10), [4, 3, 3]))
self.common(fn, (torch.randn(2, 2, 10), [1, 2, 3, 4]))

def test_split_with_sizes_failed(self):
@torch._dynamo.optimize("inductor")
def fn(a):
return torch.split(a, [2, 1, 1], dim=1)

with self.assertRaisesRegex(RuntimeError, ""):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify the regex pattern here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, shouldn't this be a ValueError?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used ValueError, but it seems catch RuntimeError firstly even the error log is:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node
    return node.target(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split
    return tensor.split(split_size_or_sections, dim)
  File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
  File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch
    return decomposition_table[func](*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes
    raise ValueError(
ValueError: Split sizes don't add up to the tensor's size in the given dimension

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value
    return wrap_fake_exception(
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
    return fn()
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}):
Split sizes don't add up to the tensor's size in the given dimension
(scroll up for backtrace)

The above exception was the direct cause of the following exception:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugh, I see

fn(torch.randn(1, 5))

def test_split(self):
def fn(a):
t = torch.split(a, 3, -1)
Expand Down
11 changes: 11 additions & 0 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,10 +1087,21 @@ def prod(x: List[int]):
return r


def sum(x: List[int]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use the builtin sum?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, changed.

r = 0
for i in x:
r += i
return r


@register_decomposition([aten.split_with_sizes, aten.unsafe_split_with_sizes])
def split_with_sizes(
self: Tensor, split_sizes: List[int], dim: int = 0
) -> List[Tensor]:
if sum(split_sizes) != self.shape[dim]:
raise ValueError(
"Split sizes don't add up to the tensor's size in the given dimension"
)
num_splits = len(split_sizes)
splits = []
start_idx = 0
Expand Down