Skip to content

Commit

Permalink
inductor: align inductor behavior with eager mode for split_with_sizes
Browse files Browse the repository at this point in the history
ghstack-source-id: e9597778b26b2a16a62cdf8912bcbd95fbfa899a
Pull Request resolved: #99702
  • Loading branch information
XiaobingSuper committed Apr 21, 2023
1 parent cddf768 commit fdb6e2f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
8 changes: 8 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,14 @@ def fn(a, sizes):
self.common(fn, (torch.randn(2, 2, 10), [4, 3, 3]))
self.common(fn, (torch.randn(2, 2, 10), [1, 2, 3, 4]))

def test_split_with_sizes_failed(self):
@torch._dynamo.optimize("inductor")
def fn(a):
return torch.split(a, [2, 1, 1], dim=1)

with self.assertRaisesRegex(RuntimeError, ""):
fn(torch.randn(1, 5))

def test_split(self):
def fn(a):
t = torch.split(a, 3, -1)
Expand Down
4 changes: 4 additions & 0 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,10 @@ def prod(x: List[int]):
def split_with_sizes(
self: Tensor, split_sizes: List[int], dim: int = 0
) -> List[Tensor]:
if sum(split_sizes) != self.shape[dim]:
raise ValueError(
"Split sizes don't add up to the tensor's size in the given dimension"
)
num_splits = len(split_sizes)
splits = []
start_idx = 0
Expand Down

0 comments on commit fdb6e2f

Please sign in to comment.