Skip to content

Commit

Permalink
#8364: Disable implicit fallback for Tensor._getitem_
Browse files Browse the repository at this point in the history
  • Loading branch information
ayerofieiev-tt committed May 23, 2024
1 parent 9e5a8b6 commit 98de183
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions ttnn/ttnn/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _golden_function(input_tensor: ttnn.Tensor, slices):
validate_input_tensors=_getitem_validate_input_tensors,
is_method=True,
golden_function=_golden_function,
allow_to_fallback_to_golden_function_on_failure=True,
allow_to_fallback_to_golden_function_on_failure=False,
)
def __getitem__(input_tensor: ttnn.Tensor, slices) -> ttnn.Tensor:
input_rank = len(input_tensor.shape)
Expand Down Expand Up @@ -69,7 +69,7 @@ def __getitem__(input_tensor: ttnn.Tensor, slices) -> ttnn.Tensor:
if len(slices) > input_rank:
raise RuntimeError(f"Too many slices for tensor of rank {input_rank}")

if ttnn.is_tensor_storage_on_device(input_tensor) and input_rank <= 4:
if input_rank <= 4:
input_tensor = ttnn.unsqueeze_to_4D(input_tensor)

while len(slices) != 4:
Expand All @@ -89,7 +89,15 @@ def __getitem__(input_tensor: ttnn.Tensor, slices) -> ttnn.Tensor:
output = input_tensor
else:
padded_slice_end_minus_1 = [x - 1 for x in padded_slice_end]
output = ttl.tensor.unpad(input_tensor, slice_start, padded_slice_end_minus_1)
if any([x < 0 for x in padded_slice_end_minus_1]):
raise RuntimeError("ttnn.Tensor.__getitem__: cannot return a scalar!")

if ttnn.is_tensor_storage_on_device(input_tensor):
output = ttl.tensor.unpad(input_tensor, slice_start, padded_slice_end_minus_1)
else:
input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT)
output = input_tensor.unpad(slice_start, padded_slice_end_minus_1)
output = ttnn.to_layout(output, input_layout)

output_shape = [end - start for (start, end) in zip(slice_start, slice_end)][-input_rank:]
padded_output_shape = list(output.shape.with_tile_padding())[-input_rank:]
Expand Down

0 comments on commit 98de183

Please sign in to comment.