From e81103311b38316166e45bff7df156d01a4baa61 Mon Sep 17 00:00:00 2001 From: keerthana-r-mcw Date: Tue, 21 May 2024 12:56:47 +0530 Subject: [PATCH] #8332: Add unit tests for Bevdepth conv op --- .../ttnn/unit_tests/operations/test_conv2d.py | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/tests/ttnn/unit_tests/operations/test_conv2d.py b/tests/ttnn/unit_tests/operations/test_conv2d.py index 1c7b4db5c7e..9ae50f0cb73 100644 --- a/tests/ttnn/unit_tests/operations/test_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_conv2d.py @@ -1267,3 +1267,110 @@ def test_conv_core_nondivis( use_1d_systolic_array, config_override, ) +@pytest.mark.parametrize("device_l1_small_size", [16384], indirect=True) +@pytest.mark.parametrize( + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array", + ( + (6, 64, 3, 256, 704, 7, 7, 2, 2, 3, 3, True), + (6, 64, 64, 64, 176, 1, 1, 1, 1, 0, 0, True), + (6, 256, 64, 64, 176, 1, 1, 1, 1, 0, 0, True), + (6, 64, 64, 64, 176, 3, 3, 1, 1, 1, 1, True), + (6, 64, 256, 64, 176, 1, 1, 1, 1, 0, 0, True), + (6, 512, 512, 1, 1, 1, 1, 1, 1, 0, 0, True), + (6, 128, 256, 64, 176, 1, 1, 1, 1, 0, 0, True), + (6, 512, 256, 64, 176, 1, 1, 2, 2, 0, 0, True), + (6, 128, 256, 64, 176, 4, 4, 4, 4, 0, 0, True), + (6, 128, 128, 64, 176, 3, 3, 2, 2, 1, 1, True), + (6, 512, 128, 32, 88, 1, 1, 1, 1, 0, 0, True), + (6, 128, 512, 32, 88, 1, 1, 1, 1, 0, 0, True), + (6, 128, 128, 32, 88, 3, 3, 1, 1, 1, 1, True), + (6, 256, 512, 32, 88, 1, 1, 1, 1, 0, 0, True), + (6, 1024, 512, 32, 88, 1, 1, 2, 2, 0, 0, True), + (6, 128, 512, 32, 88, 2, 2, 2, 2, 0, 0, True), + (6, 256, 256, 32, 88, 3, 3, 2, 2, 1, 1, True), + (6, 1024, 256, 16, 44, 1, 1, 1, 1, 0, 0, True), + (6, 256, 1024, 16, 44, 1, 1, 1, 1, 0, 0, True), + (6, 256, 256, 16, 44, 3, 3, 1, 1, 1, 1, True), + (6, 512, 1024, 16, 44, 1, 1, 1, 1, 0, 0, True), + (6, 2048, 1024, 16, 44, 1, 1, 2, 2, 0, 0, True), + (6, 128, 1024, 16, 44, 1, 1, 1, 1, 0, 0, True), + (6, 512, 512, 16, 44, 3, 3, 2, 2, 1, 1, True), + (6, 2048, 512, 8, 22, 1, 1, 1, 1, 0, 0, True), + (6, 512, 2048, 8, 22, 1, 1, 1, 1, 0, 0, True), + (6, 512, 512, 8, 22, 3, 3, 1, 1, 1, 1, True), + (6, 128, 2048, 8, 22, 2, 2, 2, 2, 0, 0, True), + (6, 512, 512, 16, 44, 3, 3, 1, 1, 1, 1, True), + (6, 80, 512, 16, 44, 1, 1, 1, 1, 0, 0, True), + (6, 512, 512, 16, 44, 1, 1, 1, 1, 0, 0, True), + (6, 512, 512, 16, 44, 3, 3, 1, 1, 6, 6, True), + (6, 512, 512, 16, 44, 3, 3, 1, 1, 12, 12, True), + (6, 512, 512, 16, 44, 3, 3, 1, 1, 18, 18, True), + (6, 512, 2560, 16, 44, 1, 1, 1, 1, 0, 0, True), + (6, 18, 512, 16, 44, 3, 3, 1, 1, 1, 1, True), + ), +) +@pytest.mark.parametrize( + "weights_dtype", + [ttnn.bfloat16, ttnn.bfloat8_b], +) +@pytest.mark.parametrize( + "activations_dtype", + [ttnn.bfloat16, ttnn.bfloat8_b], +) +@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) +def test_conv_bevdepth( + device, + use_program_cache, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, +): + if batch_size > 8 and (activations_dtype != ttnn.bfloat8_b or weights_dtype != ttnn.bfloat8_b): + pytest.skip("Batch > 8 must be run fully bfp8") + + if ( + activations_dtype == ttnn.bfloat16 + and batch_size == 20 + and ( + output_channels == 64 + or ( + stride_h == 2 + and (output_channels == 256 or (output_channels == 128 and weights_dtype == ttnn.bfloat16)) + ) + ) + ): + pytest.skip("Skipping test because it won't fit in L1!") + + run_conv( + device, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + config_override={"act_block_h": 32}, + use_shallow_conv_variant=input_channels == 16, + padded_input_channels=16 if input_channels == 16 else None, + ) \ No newline at end of file