From 47634f3f1933402c45ed64af5f502c40611f385d Mon Sep 17 00:00:00 2001 From: Artem Yerofieiev Date: Wed, 15 May 2024 18:41:42 +0000 Subject: [PATCH] #8364: Revert "skip concat tests which fallback to torch" This reverts commit bf163e22d6693b20a26967c63dd003d029aef47d. --- tests/ttnn/unit_tests/operations/test_concat.py | 5 ----- ttnn/ttnn/operations/data_movement.py | 4 +--- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_concat.py b/tests/ttnn/unit_tests/operations/test_concat.py index 16663b660d8..642935cb4f8 100644 --- a/tests/ttnn/unit_tests/operations/test_concat.py +++ b/tests/ttnn/unit_tests/operations/test_concat.py @@ -15,9 +15,6 @@ @pytest.mark.parametrize("width", [4, 32]) @pytest.mark.parametrize("dim", [0, 1]) def test_concat(device, height, width, dim): - if height % ttnn.TILE_SIZE != 0 or width % ttnn.TILE_SIZE != 0: - pytest.skip("ttnn.concat only supports tensors with Layout.TILE_LAYOUT without a padding") - torch_input_tensor_a = torch.rand((height, width), dtype=torch.bfloat16) torch_input_tensor_b = torch.rand((height, width), dtype=torch.bfloat16) torch_output_tensor = torch.concat([torch_input_tensor_a, torch_input_tensor_b], dim=dim) @@ -81,8 +78,6 @@ def test_concat(device, height, width, dim): def test_sharded_concat( device, input_shape_a, shard_shape_a, input_shape_b, shard_shape_b, output_shard_shape, shard_grid ): - pytest.skip("ttnn.concat only supports tensors with Layout.TILE_LAYOUT without a padding") - input_a_sharded_memory_config = ttnn.create_sharded_memory_config( shard_shape_a, core_grid=shard_grid, diff --git a/ttnn/ttnn/operations/data_movement.py b/ttnn/ttnn/operations/data_movement.py index fac60f17cd7..47fbf911561 100644 --- a/ttnn/ttnn/operations/data_movement.py +++ b/ttnn/ttnn/operations/data_movement.py @@ -328,9 +328,7 @@ def concat( output_tensor = ttnn.squeeze(output_tensor, dim=0) return output_tensor else: - raise NotImplementedError( - "ttnn.concat only supports tensors with Layout.TILE_LAYOUT without a padding, with rank <= 4" - ) + raise NotImplementedError def _golden_function(input_tensor, split_size, dim):