|
14 | 14 | OpNotSupportedPipeline, |
15 | 15 | TosaPipelineFP, |
16 | 16 | TosaPipelineINT, |
| 17 | + VgfPipeline, |
17 | 18 | ) |
18 | 19 |
|
19 | 20 |
|
@@ -138,3 +139,57 @@ def test_max_dim_tosa_FP_not_delegated(): |
138 | 139 | data, dim = Max.test_data["rank_4_dim_3"]() |
139 | 140 | pipeline = OpNotSupportedPipeline[Max.input_t](MaxWithIndex(dim), data, {}) |
140 | 141 | pipeline.run() |
| 142 | + |
| 143 | + |
| 144 | +@common.parametrize("test_data", Amax.test_data) |
| 145 | +@common.SkipIfNoModelConverter |
| 146 | +def test_amax_vgf_FP(test_data: Amax.input_t): |
| 147 | + data, dim, keep_dims = test_data() |
| 148 | + module = Amax(dim, keep_dims) |
| 149 | + pipeline = VgfPipeline[Amax.input_t]( |
| 150 | + module, |
| 151 | + data, |
| 152 | + Amax.aten_op, |
| 153 | + tosa_version="TOSA-1.0+FP", |
| 154 | + ) |
| 155 | + pipeline.run() |
| 156 | + |
| 157 | + |
| 158 | +@common.parametrize("test_data", Amax.test_data) |
| 159 | +@common.SkipIfNoModelConverter |
| 160 | +def test_amax_vgf_INT(test_data: Amax.input_t): |
| 161 | + data, dim, keep_dims = test_data() |
| 162 | + module = Amax(dim, keep_dims) |
| 163 | + pipeline = VgfPipeline[Amax.input_t]( |
| 164 | + module, |
| 165 | + data, |
| 166 | + Amax.aten_op, |
| 167 | + tosa_version="TOSA-1.0+INT", |
| 168 | + ) |
| 169 | + pipeline.run() |
| 170 | + |
| 171 | + |
| 172 | +@common.parametrize("test_data", Max.test_data) |
| 173 | +@common.SkipIfNoModelConverter |
| 174 | +def test_max_dim_vgf_FP_to_amax(test_data: Max.input_t): |
| 175 | + data, dim = test_data() |
| 176 | + pipeline = VgfPipeline[Max.input_t]( |
| 177 | + Max(dim), |
| 178 | + data, |
| 179 | + "torch.ops.aten.max", |
| 180 | + tosa_version="TOSA-1.0+FP", |
| 181 | + ) |
| 182 | + pipeline.run() |
| 183 | + |
| 184 | + |
| 185 | +@common.parametrize("test_data", Max.test_data) |
| 186 | +@common.SkipIfNoModelConverter |
| 187 | +def test_max_dim_vgf_INT_to_amax(test_data: Max.input_t): |
| 188 | + data, dim = test_data() |
| 189 | + pipeline = VgfPipeline[Max.input_t]( |
| 190 | + Max(dim), |
| 191 | + data, |
| 192 | + "torch.ops.aten.amax", |
| 193 | + tosa_version="TOSA-1.0+INT", |
| 194 | + ) |
| 195 | + pipeline.run() |
0 commit comments