File tree Expand file tree Collapse file tree 2 files changed +6
-10
lines changed
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape Expand file tree Collapse file tree 2 files changed +6
-10
lines changed Original file line number Diff line number Diff line change @@ -1361,16 +1361,14 @@ bool MatrixRankTolOpInferSymbolicShape(
1361
1361
common::errors::InvalidArgument (
1362
1362
" The dims of input must be greater than 2" ));
1363
1363
bool hermitian = GetBoolAttr (op, " hermitian" );
1364
- const auto &GetProduct = [&](const auto &dim_exprs, const auto &Filter ) {
1364
+ const auto &GetProduct = [&](const auto &dim_exprs) {
1365
1365
symbol::DimExpr product{1 };
1366
1366
for (const auto &dim_expr : dim_exprs) {
1367
- if (Filter (dim_expr)) {
1368
- product = product * dim_expr;
1369
- }
1367
+ product = product * dim_expr;
1370
1368
}
1371
1369
return product;
1372
1370
};
1373
- const auto &x_numel = GetProduct (x_shape, []( const auto &) { return true ; } );
1371
+ const auto &x_numel = GetProduct (x_shape);
1374
1372
1375
1373
if (hermitian && x_numel != 0 ) {
1376
1374
infer_context->AddEqualCstr (x_shape[x_rank - 2 ], x_shape[x_rank - 1 ]);
Original file line number Diff line number Diff line change @@ -2168,16 +2168,14 @@ bool MatrixRankOpInferSymbolicShape(
2168
2168
infer_context->GetShapeOrDataForValue (op->operand_source (0 ));
2169
2169
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape ();
2170
2170
2171
- const auto &GetProduct = [&](const auto &dim_exprs, const auto &Filter ) {
2171
+ const auto &GetProduct = [&](const auto &dim_exprs) {
2172
2172
symbol::DimExpr product{1 };
2173
2173
for (const auto &dim_expr : dim_exprs) {
2174
- if (Filter (dim_expr)) {
2175
- product = product * dim_expr;
2176
- }
2174
+ product = product * dim_expr;
2177
2175
}
2178
2176
return product;
2179
2177
};
2180
- const auto &x_numel = GetProduct (x_shape, []( const auto &) { return true ; } );
2178
+ const auto &x_numel = GetProduct (x_shape);
2181
2179
2182
2180
// 确保输入x的维度大于等于2
2183
2181
PADDLE_ENFORCE_GE (x_shape.size (),
You can’t perform that action at this time.
0 commit comments