Skip to content

Commit

Permalink
inductor: align inductor behavior with eager mode for split_with_sizes
Browse files Browse the repository at this point in the history
ghstack-source-id: 47d7d6517a1590fafa45b17f5375a3f51f970c8d
Pull Request resolved: #99702
  • Loading branch information
XiaobingSuper committed Apr 21, 2023
1 parent 8c6e430 commit 20dd589
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
9 changes: 9 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,15 @@ def fn(a, sizes):
self.common(fn, (torch.randn(2, 2, 10), [4, 3, 3]))
self.common(fn, (torch.randn(2, 2, 10), [1, 2, 3, 4]))


def test_split_with_sizes_failed(self):
@torch._dynamo.optimize("inductor")
def fn(a):
return torch.split(a, [2, 1, 1], dim=1)

with self.assertRaisesRegex(RuntimeError, ""):
fn(torch.randn(1, 5))

def test_split(self):
def fn(a):
t = torch.split(a, 3, -1)
Expand Down
11 changes: 11 additions & 0 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,10 +1087,21 @@ def prod(x: List[int]):
return r


def sum(x: List[int]):
r = 0
for i in x:
r += i
return r


@register_decomposition([aten.split_with_sizes, aten.unsafe_split_with_sizes])
def split_with_sizes(
self: Tensor, split_sizes: List[int], dim: int = 0
) -> List[Tensor]:
if sum(split_sizes) != self.shape[dim]:
raise ValueError(
"Split sizes don't add up to the tensor's size in the given dimension"
)
num_splits = len(split_sizes)
splits = []
start_idx = 0
Expand Down

0 comments on commit 20dd589

Please sign in to comment.