@@ -1018,12 +1018,32 @@ def test_qnn_backend_max_pool2d(self):
1018
1018
sample_input = (torch .randn (4 , 3 , 24 , 24 ),)
1019
1019
self .lower_module_and_test_output (module , sample_input )
1020
1020
1021
- def test_qnn_backend_mean_dim (self ):
1022
- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
1023
- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
1024
- for i , module in enumerate (modules ):
1021
+ def test_qnn_backend_mean (self ):
1022
+ test_comb = [
1023
+ {
1024
+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ), # keepdim=True
1025
+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1026
+ },
1027
+ {
1028
+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ), # keepdim=False
1029
+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
1030
+ },
1031
+ {
1032
+ QCOM_MODULE : Mean (), # default: reduce all dims
1033
+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
1034
+ },
1035
+ {
1036
+ QCOM_MODULE : Mean (), # scalar case
1037
+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
1038
+ },
1039
+ ]
1040
+
1041
+ for i , test in enumerate (test_comb ):
1025
1042
with self .subTest (i = i ):
1026
- self .lower_module_and_test_output (module , sample_input )
1043
+ module = self .get_qdq_module (
1044
+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
1045
+ )
1046
+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
1027
1047
1028
1048
@unittest .skip ("failed to lower in QNN 2.26" )
1029
1049
def test_qnn_backend_mha (self ):
@@ -2666,13 +2686,34 @@ def test_qnn_backend_max_pool2d(self):
2666
2686
module = self .get_qdq_module (module , sample_input )
2667
2687
self .lower_module_and_test_output (module , sample_input )
2668
2688
2669
- def test_qnn_backend_mean_dim (self ):
2670
- modules = [MeanWKeppDim (), MeanWOKeppDim ()] # noqa: F405
2671
- sample_input = (torch .randn ([2 , 5 , 1 , 3 ]),)
2672
- for i , module in enumerate (modules ):
2689
+ def test_qnn_backend_mean (self ):
2690
+ test_comb = [
2691
+ {
2692
+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = True ), # keepdim=True
2693
+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2694
+ },
2695
+ {
2696
+ QCOM_MODULE : Mean (dim = (- 1 , - 2 ), keepdim = False ), # keepdim=False
2697
+ QCOM_SAMPLE_INPUTS : (torch .randn ([2 , 5 , 1 , 3 ]),),
2698
+ },
2699
+ {
2700
+ QCOM_MODULE : Mean (), # default: reduce all dims
2701
+ QCOM_SAMPLE_INPUTS : (torch .randn (10 , 10 ),),
2702
+ },
2703
+ {
2704
+ QCOM_MODULE : Mean (), # scalar case
2705
+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
2706
+ },
2707
+ ]
2708
+
2709
+ for i , test in enumerate (test_comb ):
2673
2710
with self .subTest (i = i ):
2674
- module = self .get_qdq_module (module , sample_input )
2675
- self .lower_module_and_test_output (module , sample_input )
2711
+ module = self .get_qdq_module (
2712
+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
2713
+ )
2714
+ module = self .get_qdq_module (module , test [QCOM_SAMPLE_INPUTS ])
2715
+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
2716
+
2676
2717
2677
2718
def test_qnn_backend_mha (self ):
2678
2719
module = MultiheadAttention () # noqa: F405
0 commit comments