diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 0749b348f36..46d7f589c4a 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -39,7 +39,7 @@ def _golden_function(input_tensor: ttnn.Tensor, slices): validate_input_tensors=_getitem_validate_input_tensors, is_method=True, golden_function=_golden_function, - allow_to_fallback_to_golden_function_on_failure=True, + allow_to_fallback_to_golden_function_on_failure=False, ) def __getitem__(input_tensor: ttnn.Tensor, slices) -> ttnn.Tensor: input_rank = len(input_tensor.shape) @@ -69,7 +69,7 @@ def __getitem__(input_tensor: ttnn.Tensor, slices) -> ttnn.Tensor: if len(slices) > input_rank: raise RuntimeError(f"Too many slices for tensor of rank {input_rank}") - if ttnn.is_tensor_storage_on_device(input_tensor) and input_rank <= 4: + if input_rank <= 4: input_tensor = ttnn.unsqueeze_to_4D(input_tensor) while len(slices) != 4: @@ -89,7 +89,15 @@ def __getitem__(input_tensor: ttnn.Tensor, slices) -> ttnn.Tensor: output = input_tensor else: padded_slice_end_minus_1 = [x - 1 for x in padded_slice_end] - output = ttl.tensor.unpad(input_tensor, slice_start, padded_slice_end_minus_1) + if any([x < 0 for x in padded_slice_end_minus_1]): + raise RuntimeError("ttnn.Tensor.__getitem__: cannot return a scalar!") + + if ttnn.is_tensor_storage_on_device(input_tensor): + output = ttl.tensor.unpad(input_tensor, slice_start, padded_slice_end_minus_1) + else: + input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT) + output = input_tensor.unpad(slice_start, padded_slice_end_minus_1) + output = ttnn.to_layout(output, input_layout) output_shape = [end - start for (start, end) in zip(slice_start, slice_end)][-input_rank:] padded_output_shape = list(output.shape.with_tile_padding())[-input_rank:]