Skip to content

Commit

Permalink
#8364: skip tests in getitem
Browse files Browse the repository at this point in the history
  • Loading branch information
ayerofieiev-tt committed May 14, 2024
1 parent 5dc7cde commit f4e1c13
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tests/ttnn/unit_tests/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
@pytest.mark.parametrize("input_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("on_device", [True, False])
def test_getitem(device, batch_sizes, height, width, input_layout, on_device):
if not on_device:
pytest.skip("Tensor.__getitem__ only supports Tensors stored on Device")

torch_input_tensor = torch.rand((*batch_sizes, height, width), dtype=torch.bfloat16)

if batch_sizes:
Expand Down Expand Up @@ -52,6 +55,9 @@ def test_getitem(device, batch_sizes, height, width, input_layout, on_device):
@pytest.mark.parametrize("input_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("on_device", [True, False])
def test_getitem_2d(device, height, width, input_layout, on_device):
if not on_device:
pytest.skip("Tensor.__getitem__ only supports Tensors stored on Device")

torch_input_tensor = torch.rand((height, width), dtype=torch.bfloat16)

torch_output_tensor = torch_input_tensor[:32]
Expand All @@ -77,6 +83,7 @@ def test_getitem_2d(device, height, width, input_layout, on_device):


def test_getitem_scalar_output():
pytest.skip("Tensor.__getitem__ only supports Tensors stored on Device")
torch_input_tensor = torch.rand((16, 32), dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor)
Expand Down

0 comments on commit f4e1c13

Please sign in to comment.