Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 71 additions & 23 deletions ydb/core/kqp/opt/rbo/kqp_rbo_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,9 @@ void BuildSimpleMapElementLambda(TExprNode::TPtr resultExpr, const TVector<std::
TExprNode::TPtr BuildExpressionMap(TExprNode::TPtr resultExpr,
const TVector<std::pair<TInfoUnit, TExprNode::TPtr>>& aggFieldsExpressionsMap,
const TVector<std::pair<TInfoUnit, TInfoUnit>>& aggFieldsRenamesMap,
const TVector<std::pair<TInfoUnit, TInfoUnit>>& groupByKeysRenamesMap, TExprContext& ctx,
TPositionHandle pos) {
const TVector<std::pair<TInfoUnit, TInfoUnit>>& groupByKeysRenamesMap,
const THashMap<uint32_t, std::pair<TInfoUnit, TExprNode::TPtr>>& groupByKeysExpressionsMap,
TExprContext& ctx, TPositionHandle pos) {
// Add expressions
TVector<TExprNode::TPtr> mapElements;
for (const auto& [colName, expr] : aggFieldsExpressionsMap) {
Expand All @@ -171,6 +172,19 @@ TExprNode::TPtr BuildExpressionMap(TExprNode::TPtr resultExpr,
BuildSimpleMapElementLambda(resultExpr, aggFieldsRenamesMap, mapElements, ctx, pos);
BuildSimpleMapElementLambda(resultExpr, groupByKeysRenamesMap, mapElements, ctx, pos);

// Add expressions for group by keys.
for (const auto& [_, pair] : groupByKeysExpressionsMap) {
// clang-format off
mapElements.push_back(Build<TKqpOpMapElementLambda>(ctx, pos)
.Input(resultExpr)
.Variable()
.Value(pair.first.GetFullName())
.Build()
.Lambda(pair.second)
.Done().Ptr());
// clang-format on
}

// clang-format off
return Build<TKqpOpMap>(ctx, pos)
.Input(resultExpr)
Expand All @@ -184,8 +198,8 @@ TExprNode::TPtr BuildExpressionMap(TExprNode::TPtr resultExpr,
// clang-format on
}

void BuildMapElementRename(TExprNode::TPtr resultExpr, const TVector<std::pair<TInfoUnit, TInfoUnit>>& renamesMap, TVector<TExprNode::TPtr>& mapElements,
TExprContext& ctx, TPositionHandle pos) {
void BuildMapElementRename(TExprNode::TPtr resultExpr, const TVector<std::pair<TInfoUnit, TInfoUnit>>& renamesMap,
TVector<TExprNode::TPtr>& mapElements, TExprContext& ctx, TPositionHandle pos) {
for (const auto& [colName, newColName] : renamesMap) {
// clang-format off
mapElements.push_back(Build<TKqpOpMapElementRename>(ctx, pos)
Expand Down Expand Up @@ -552,20 +566,38 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
filterExpr = Build<TKqpOpEmptySource>(ctx, node->Pos()).Done().Ptr();
}

// FIXME: Group by key can be an expression, we need to handle this case
THashSet<TString> aggregationColumnsRequireCastToPgType;
TVector<std::pair<TInfoUnit, TInfoUnit>> groupByKeysRenamesMap;
THashMap<uint32_t, std::pair<TInfoUnit, TExprNode::TPtr>> groupByKeysExpressionsMap;
TVector<TInfoUnit> groupByKeys;
auto groupOps = GetSetting(setItem->Tail(), "group_exprs");
if (groupOps) {
const auto groupByList = groupOps->TailPtr();
for (ui32 i = 0; i < groupByList->ChildrenSize(); ++i) {
auto lambda = TCoLambda(ctx.DeepCopyLambda(*(groupByList->ChildPtr(i)->Child(1))));
auto body = lambda.Body().Ptr();
TVector<TInfoUnit> keys;
GetAllMembers(body, keys);
groupByKeys.insert(groupByKeys.end(), keys.begin(), keys.end());
for (const auto &infoUnit : keys) {
groupByKeysRenamesMap.push_back({infoUnit, infoUnit});
auto pgResolvedOp = GetPgCallable(lambda.Body().Ptr(), "PgResolvedOp");
// Expression for group by keys.
if (pgResolvedOp) {
auto fromPg = ctx.NewCallable(node->Pos(), "FromPg", {pgResolvedOp});

// clang-format off
auto groupExprLambda = Build<TCoLambda>(ctx, node->Pos())
.Args(lambda.Args())
.Body(fromPg)
.Done().Ptr();
// clang-format on

const auto newColName = TInfoUnit(GenerateUniqueColumnName("_group_expr_"));
groupByKeysExpressionsMap[i] = std::make_pair(newColName, groupExprLambda);
groupByKeys.push_back(newColName);
} else {
TVector<TInfoUnit> keys;
GetAllMembers(body, keys);
Y_ENSURE(keys.size() == 1, "Invalid size of the group keys.");
const auto groupKeyName = keys.front();
groupByKeys.push_back(groupKeyName);
groupByKeysRenamesMap.push_back({groupKeyName, groupKeyName});
}
}
}
Expand All @@ -575,8 +607,6 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
Y_ENSURE(result);
auto finalType = node->GetTypeAnn()->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>();

// This is a hack to enable convertion for aggregation columns.
THashSet<TString> aggregationColumns;
THashSet<TString> columnNames;
// Collect PgAgg for each result item at first pass.
TVector<TExprNode::TPtr> aggTraitsList;
Expand All @@ -598,14 +628,18 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
TVector<TInfoUnit> originalColNames;
GetAllMembers(pgAgg, originalColNames);
auto pgResolvedOp = GetPgCallable(lambda.Body().Ptr(), "PgResolvedOp");
//Y_ENSURE(originalColNames.size() > 1 && pgResolvedOp, "Invalid column size for aggregation columns.");

auto originalColName = originalColNames.front();
auto renamedColName = originalColName;

if (pgResolvedOp) {
auto fromPg = ctx.NewCallable(node->Pos(), "FromPg", {pgResolvedOp});
auto exprLambda = Build<TCoLambda>(ctx, node->Pos()).Args(lambda.Args()).Body(fromPg).Done().Ptr();
// clang-format off
auto exprLambda = Build<TCoLambda>(ctx, node->Pos())
.Args(lambda.Args())
.Body(fromPg)
.Done().Ptr();
// clang-format on

// Just any unique name for expression result, physical plan should be AsSturct(`unique_name (expression))
originalColName = TInfoUnit(GenerateUniqueColumnName("_expr_"));
Expand All @@ -623,7 +657,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
columnNames.insert(renamedColName.GetFullName());
Y_ENSURE(!GetAtom(pgAgg->ChildPtr(1), "distinct"), "Aggregation on distinct is not supported");

aggregationColumns.insert(resultColName);
aggregationColumnsRequireCastToPgType.insert(resultColName);
const TString aggFuncName = TString(pgAgg->ChildPtr(0)->Content());
auto aggregationTraits = BuildAggregationTraits(renamedColName.GetFullName(), resultColName, aggFuncName,
aggFuncResultType, ctx, node->Pos());
Expand All @@ -637,7 +671,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
}
// This case covers distinct all on just columns without aggregation functions.
} else if (!pgAgg && distinctAll) {
aggregationColumns.insert(resultColName);
aggregationColumnsRequireCastToPgType.insert(resultColName);
Y_ENSURE(aggFuncResultType, "Cannot find type for aggregation result.");
TVector<TInfoUnit> originalColNames;
GetAllMembers(resultItem->ChildPtr(2), originalColNames);
Expand All @@ -658,9 +692,10 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
if (needRenameMap) {
resultExpr = BuildRenameMap(resultExpr, aggFieldsRenamesMap, groupByKeysRenamesMap, ctx, node->Pos());
}
// In case we have an expression for aggregation - f(a + b ..)
if (!aggFieldsExpressionsMap.empty()) {
resultExpr = BuildExpressionMap(resultExpr, aggFieldsExpressionsMap, aggFieldsRenamesMap, groupByKeysRenamesMap, ctx, node->Pos());
// In case we have an expression for aggregation - f(a + b ..) or group by.
if (!aggFieldsExpressionsMap.empty() || !groupByKeysExpressionsMap.empty()) {
resultExpr = BuildExpressionMap(resultExpr, aggFieldsExpressionsMap, aggFieldsRenamesMap, groupByKeysRenamesMap,
groupByKeysExpressionsMap, ctx, node->Pos());
}
resultExpr = BuildAggregate(resultExpr, groupByKeys, aggTraitsList, distinctAll, ctx, node->Pos());
}
Expand All @@ -671,7 +706,6 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
}

finalColumnOrder.clear();
THashMap<TString, TExprNode::TPtr> aggProjectionMap;

for (auto resultItem : result->Child(1)->Children()) {
auto column = resultItem->Child(0);
Expand All @@ -691,7 +725,7 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,

bool needPgCast = (expectedType->GetId() != actualPgTypeId);
auto lambda = TCoLambda(ctx.DeepCopyLambda(*(resultItem->Child(2))));
bool needPgCastForAgg = aggregationColumns.count(columnName);
bool needPgCastForAgg = aggregationColumnsRequireCastToPgType.count(columnName);

auto pgAgg = GetPgCallable(lambda.Body().Ptr(), "PgAgg");
if (pgAgg) {
Expand All @@ -715,14 +749,28 @@ TExprNode::TPtr RewritePgSelect(const TExprNode::TPtr &node, TExprContext &ctx,
// Eliminate `PgGroupRef` from projection lambda.
auto pgGroupRef = GetPgCallable(lambda.Body().Ptr(), "PgGroupRef");
if (pgGroupRef) {
Y_ENSURE(pgGroupRef->ChildrenSize() == 4);
TString columnName;
if (pgGroupRef->ChildrenSize() == 4) {
columnName = TString(pgGroupRef->ChildPtr(3)->Content());
} else if (pgGroupRef->ChildrenSize() == 3) {
// In this case we can get a column name from group expr map
const auto groupByKeyExprId = FromString<uint32_t>(TString(pgGroupRef->ChildPtr(2)->Content()));
auto it = groupByKeysExpressionsMap.find(groupByKeyExprId);
Y_ENSURE(it != groupByKeysExpressionsMap.end(), "Group by key expression has invalid content.");
columnName = it->second.first.GetFullName();
// Always need a pg cast for expressions.
needPgCast = true;
} else {
Y_ENSURE(false, "Invalid children size for `pgGroupRef`");
}

// clang-format off
lambda = Build<TCoLambda>(ctx, node->Pos())
.Args(lambda.Args())
.Body<TCoMember>()
.Struct(lambda.Args().Arg(0))
.Name<TCoAtom>()
.Value(pgGroupRef->ChildPtr(3)->Content())
.Value(columnName)
.Build()
.Build()
.Done();
Expand Down
20 changes: 19 additions & 1 deletion ydb/core/kqp/ut/rbo/kqp_rbo_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,21 @@ Y_UNIT_TEST_SUITE(KqpRbo) {
SET TablePathPrefix = "/Root/";
select sum(t1.a + 1 + t1.c) as sumExpr0, sum(t1.c + 2) as sumExpr1 from t1 group by t1.b order by sumExpr0;
)",
R"(
--!syntax_pg
SET TablePathPrefix = "/Root/";
select sum(t1.c) as sum0, sum(t1.a + 3) as sum1 from t1 group by t1.b + 1 order by sum0;
)",
R"(
--!syntax_pg
SET TablePathPrefix = "/Root/";
select sum(t1.c) as sum0, t1.b + 1, t1.c + 2 from t1 group by t1.b + 1, t1.c + 2 order by sum0;
)",
R"(
--!syntax_pg
SET TablePathPrefix = "/Root/";
select sum(t1.c + 2) as sum0 from t1 group by t1.b + t1.a order by sum0;
)",
};

std::vector<std::string> results = {
Expand All @@ -625,7 +640,10 @@ Y_UNIT_TEST_SUITE(KqpRbo) {
R"([["0";"2"];["1";"1"];["2";"2"];["3";"1"];["4";"2"]])",
R"([["4";"4"];["6";"6"]])",
R"([["0";"4"];["1";"3"]])",
R"([["10";"8"];["15";"12"]])"
R"([["10";"8"];["15";"12"]])",
R"([["4";"10"];["6";"15"]])",
R"([["4";"2";"4"];["6";"3";"4"]])",
R"([["4"];["8"];["8"]])"
};

for (ui32 i = 0; i < queries.size(); ++i) {
Expand Down
Loading