From f03e2bb62a2916715f173d7e29998b1c0cbc8f77 Mon Sep 17 00:00:00 2001 From: Denis Khalikov Date: Fri, 14 Nov 2025 18:56:51 +0300 Subject: [PATCH] [KQP RBO] Add group by on expressions --- ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp | 94 +++++++++++++++----- ydb/core/kqp/ut/rbo/kqp_rbo_ut.cpp | 20 ++++- 2 files changed, 90 insertions(+), 24 deletions(-) diff --git a/ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp b/ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp index 51de4d8b2822..9fa1ea1daccd 100644 --- a/ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp +++ b/ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp @@ -151,8 +151,9 @@ void BuildSimpleMapElementLambda(TExprNode::TPtr resultExpr, const TVector>& aggFieldsExpressionsMap, const TVector>& aggFieldsRenamesMap, - const TVector>& groupByKeysRenamesMap, TExprContext& ctx, - TPositionHandle pos) { + const TVector>& groupByKeysRenamesMap, + const THashMap>& groupByKeysExpressionsMap, + TExprContext& ctx, TPositionHandle pos) { // Add expressions TVector mapElements; for (const auto& [colName, expr] : aggFieldsExpressionsMap) { @@ -171,6 +172,19 @@ TExprNode::TPtr BuildExpressionMap(TExprNode::TPtr resultExpr, BuildSimpleMapElementLambda(resultExpr, aggFieldsRenamesMap, mapElements, ctx, pos); BuildSimpleMapElementLambda(resultExpr, groupByKeysRenamesMap, mapElements, ctx, pos); + // Add expressions for group by keys. + for (const auto& [_, pair] : groupByKeysExpressionsMap) { + // clang-format off + mapElements.push_back(Build(ctx, pos) + .Input(resultExpr) + .Variable() + .Value(pair.first.GetFullName()) + .Build() + .Lambda(pair.second) + .Done().Ptr()); + // clang-format on + } + // clang-format off return Build(ctx, pos) .Input(resultExpr) @@ -184,8 +198,8 @@ TExprNode::TPtr BuildExpressionMap(TExprNode::TPtr resultExpr, // clang-format on } -void BuildMapElementRename(TExprNode::TPtr resultExpr, const TVector>& renamesMap, TVector& mapElements, - TExprContext& ctx, TPositionHandle pos) { +void BuildMapElementRename(TExprNode::TPtr resultExpr, const TVector>& renamesMap, + TVector& mapElements, TExprContext& ctx, TPositionHandle pos) { for (const auto& [colName, newColName] : renamesMap) { // clang-format off mapElements.push_back(Build(ctx, pos) @@ -552,8 +566,9 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx, filterExpr = Build(ctx, node->Pos()).Done().Ptr(); } - // FIXME: Group by key can be an expression, we need to handle this case + THashSet aggregationColumnsRequireCastToPgType; TVector> groupByKeysRenamesMap; + THashMap> groupByKeysExpressionsMap; TVector groupByKeys; auto groupOps = GetSetting(setItem->Tail(), "group_exprs"); if (groupOps) { @@ -561,11 +576,28 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx, for (ui32 i = 0; i < groupByList->ChildrenSize(); ++i) { auto lambda = TCoLambda(ctx.DeepCopyLambda(*(groupByList->ChildPtr(i)->Child(1)))); auto body = lambda.Body().Ptr(); - TVector keys; - GetAllMembers(body, keys); - groupByKeys.insert(groupByKeys.end(), keys.begin(), keys.end()); - for (const auto &infoUnit : keys) { - groupByKeysRenamesMap.push_back({infoUnit, infoUnit}); + auto pgResolvedOp = GetPgCallable(lambda.Body().Ptr(), "PgResolvedOp"); + // Expression for group by keys. + if (pgResolvedOp) { + auto fromPg = ctx.NewCallable(node->Pos(), "FromPg", {pgResolvedOp}); + + // clang-format off + auto groupExprLambda = Build(ctx, node->Pos()) + .Args(lambda.Args()) + .Body(fromPg) + .Done().Ptr(); + // clang-format on + + const auto newColName = TInfoUnit(GenerateUniqueColumnName("_group_expr_")); + groupByKeysExpressionsMap[i] = std::make_pair(newColName, groupExprLambda); + groupByKeys.push_back(newColName); + } else { + TVector keys; + GetAllMembers(body, keys); + Y_ENSURE(keys.size() == 1, "Invalid size of the group keys."); + const auto groupKeyName = keys.front(); + groupByKeys.push_back(groupKeyName); + groupByKeysRenamesMap.push_back({groupKeyName, groupKeyName}); } } } @@ -575,8 +607,6 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx, Y_ENSURE(result); auto finalType = node->GetTypeAnn()->Cast()->GetItemType()->Cast(); - // This is a hack to enable convertion for aggregation columns. - THashSet aggregationColumns; THashSet columnNames; // Collect PgAgg for each result item at first pass. TVector aggTraitsList; @@ -598,14 +628,18 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx, TVector originalColNames; GetAllMembers(pgAgg, originalColNames); auto pgResolvedOp = GetPgCallable(lambda.Body().Ptr(), "PgResolvedOp"); - //Y_ENSURE(originalColNames.size() > 1 && pgResolvedOp, "Invalid column size for aggregation columns."); auto originalColName = originalColNames.front(); auto renamedColName = originalColName; if (pgResolvedOp) { auto fromPg = ctx.NewCallable(node->Pos(), "FromPg", {pgResolvedOp}); - auto exprLambda = Build(ctx, node->Pos()).Args(lambda.Args()).Body(fromPg).Done().Ptr(); + // clang-format off + auto exprLambda = Build(ctx, node->Pos()) + .Args(lambda.Args()) + .Body(fromPg) + .Done().Ptr(); + // clang-format on // Just any unique name for expression result, physical plan should be AsSturct(`unique_name (expression)) originalColName = TInfoUnit(GenerateUniqueColumnName("_expr_")); @@ -623,7 +657,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx, columnNames.insert(renamedColName.GetFullName()); Y_ENSURE(!GetAtom(pgAgg->ChildPtr(1), "distinct"), "Aggregation on distinct is not supported"); - aggregationColumns.insert(resultColName); + aggregationColumnsRequireCastToPgType.insert(resultColName); const TString aggFuncName = TString(pgAgg->ChildPtr(0)->Content()); auto aggregationTraits = BuildAggregationTraits(renamedColName.GetFullName(), resultColName, aggFuncName, aggFuncResultType, ctx, node->Pos()); @@ -637,7 +671,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx, } // This case covers distinct all on just columns without aggregation functions. } else if (!pgAgg && distinctAll) { - aggregationColumns.insert(resultColName); + aggregationColumnsRequireCastToPgType.insert(resultColName); Y_ENSURE(aggFuncResultType, "Cannot find type for aggregation result."); TVector originalColNames; GetAllMembers(resultItem->ChildPtr(2), originalColNames); @@ -658,9 +692,10 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx, if (needRenameMap) { resultExpr = BuildRenameMap(resultExpr, aggFieldsRenamesMap, groupByKeysRenamesMap, ctx, node->Pos()); } - // In case we have an expression for aggregation - f(a + b ..) - if (!aggFieldsExpressionsMap.empty()) { - resultExpr = BuildExpressionMap(resultExpr, aggFieldsExpressionsMap, aggFieldsRenamesMap, groupByKeysRenamesMap, ctx, node->Pos()); + // In case we have an expression for aggregation - f(a + b ..) or group by. + if (!aggFieldsExpressionsMap.empty() || !groupByKeysExpressionsMap.empty()) { + resultExpr = BuildExpressionMap(resultExpr, aggFieldsExpressionsMap, aggFieldsRenamesMap, groupByKeysRenamesMap, + groupByKeysExpressionsMap, ctx, node->Pos()); } resultExpr = BuildAggregate(resultExpr, groupByKeys, aggTraitsList, distinctAll, ctx, node->Pos()); } @@ -671,7 +706,6 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx, } finalColumnOrder.clear(); - THashMap aggProjectionMap; for (auto resultItem : result->Child(1)->Children()) { auto column = resultItem->Child(0); @@ -691,7 +725,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx, bool needPgCast = (expectedType->GetId() != actualPgTypeId); auto lambda = TCoLambda(ctx.DeepCopyLambda(*(resultItem->Child(2)))); - bool needPgCastForAgg = aggregationColumns.count(columnName); + bool needPgCastForAgg = aggregationColumnsRequireCastToPgType.count(columnName); auto pgAgg = GetPgCallable(lambda.Body().Ptr(), "PgAgg"); if (pgAgg) { @@ -715,14 +749,28 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx, // Eliminate `PgGroupRef` from projection lambda. auto pgGroupRef = GetPgCallable(lambda.Body().Ptr(), "PgGroupRef"); if (pgGroupRef) { - Y_ENSURE(pgGroupRef->ChildrenSize() == 4); + TString columnName; + if (pgGroupRef->ChildrenSize() == 4) { + columnName = TString(pgGroupRef->ChildPtr(3)->Content()); + } else if (pgGroupRef->ChildrenSize() == 3) { + // In this case we can get a column name from group expr map + const auto groupByKeyExprId = FromString(TString(pgGroupRef->ChildPtr(2)->Content())); + auto it = groupByKeysExpressionsMap.find(groupByKeyExprId); + Y_ENSURE(it != groupByKeysExpressionsMap.end(), "Group by key expression has invalid content."); + columnName = it->second.first.GetFullName(); + // Always need a pg cast for expressions. + needPgCast = true; + } else { + Y_ENSURE(false, "Invalid children size for `pgGroupRef`"); + } + // clang-format off lambda = Build(ctx, node->Pos()) .Args(lambda.Args()) .Body() .Struct(lambda.Args().Arg(0)) .Name() - .Value(pgGroupRef->ChildPtr(3)->Content()) + .Value(columnName) .Build() .Build() .Done(); diff --git a/ydb/core/kqp/ut/rbo/kqp_rbo_ut.cpp b/ydb/core/kqp/ut/rbo/kqp_rbo_ut.cpp index f7fec50e9582..d87b4de1f851 100644 --- a/ydb/core/kqp/ut/rbo/kqp_rbo_ut.cpp +++ b/ydb/core/kqp/ut/rbo/kqp_rbo_ut.cpp @@ -609,6 +609,21 @@ Y_UNIT_TEST_SUITE(KqpRbo) { SET TablePathPrefix = "/Root/"; select sum(t1.a + 1 + t1.c) as sumExpr0, sum(t1.c + 2) as sumExpr1 from t1 group by t1.b order by sumExpr0; )", + R"( + --!syntax_pg + SET TablePathPrefix = "/Root/"; + select sum(t1.c) as sum0, sum(t1.a + 3) as sum1 from t1 group by t1.b + 1 order by sum0; + )", + R"( + --!syntax_pg + SET TablePathPrefix = "/Root/"; + select sum(t1.c) as sum0, t1.b + 1, t1.c + 2 from t1 group by t1.b + 1, t1.c + 2 order by sum0; + )", + R"( + --!syntax_pg + SET TablePathPrefix = "/Root/"; + select sum(t1.c + 2) as sum0 from t1 group by t1.b + t1.a order by sum0; + )", }; std::vector results = { @@ -625,7 +640,10 @@ Y_UNIT_TEST_SUITE(KqpRbo) { R"([["0";"2"];["1";"1"];["2";"2"];["3";"1"];["4";"2"]])", R"([["4";"4"];["6";"6"]])", R"([["0";"4"];["1";"3"]])", - R"([["10";"8"];["15";"12"]])" + R"([["10";"8"];["15";"12"]])", + R"([["4";"10"];["6";"15"]])", + R"([["4";"2";"4"];["6";"3";"4"]])", + R"([["4"];["8"];["8"]])" }; for (ui32 i = 0; i < queries.size(); ++i) {