Skip to content

Commit

Permalink
refactor the tryTranslateToParametricAggregateFunction function in Ex…
Browse files Browse the repository at this point in the history
…pressionAnalyzer
  • Loading branch information
juntao-lei-timeplus committed May 21, 2024
1 parent 9f843ea commit e0647c4
Showing 1 changed file with 151 additions and 97 deletions.
248 changes: 151 additions & 97 deletions src/Interpreters/ExpressionAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,110 +125,164 @@ void tryTranslateToParametricAggregateFunction(
assert(node->arguments);
const ASTs & arguments = node->arguments->children;
const auto & lower_name = node->name;
if (lower_name == "min_k" || lower_name == "max_k" || lower_name == "__min_k_retract" || lower_name == "__max_k_retract")
{
/// Translate `min_k(key, num[, context...])` to `min_k(num)(key[, context...])`
/// Make the second argument as a const parameter
if (arguments.size() < 2)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires at least two arguments.", node->name);

ASTPtr expression_list = std::make_shared<ASTExpressionList>();
expression_list->children.push_back(arguments[1]);
parameters = getAggregateFunctionParametersArray(expression_list, "", context);

argument_names.erase(argument_names.begin() + 1);
types.erase(types.begin() + 1);
}
else if (lower_name == "top_k" || lower_name == "top_k_exact")
{
/// Translate `top_k(key, num[, with_count, load_factor])` to `top_k(num[, with_count, load_factor])(key)`
/// Translate `top_k_exact(key, num[, with_count, limit_memory_size])` to `top_k_exact(num[, with_count, limit_memory_size])(key)`
auto size = arguments.size();
if (size < 2 || size > 4)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 2 to 4 arguments.", node->name);

ASTPtr expression_list = std::make_shared<ASTExpressionList>();
expression_list->children.assign(arguments.begin() + 1, arguments.end());
parameters = getAggregateFunctionParametersArray(expression_list, "", context);

argument_names = {argument_names[0]};
types = {types[0]};
}
else if (lower_name == "top_k_weighted" || lower_name == "top_k_exact_weighted")
{
/// Translate `top_k_weighted(key, weight, num, [, with_count, load_factor])` to `top_k_weighted(num[, with_count, load_factor])(key, weighted)`
/// Translate `top_k_exact_weighted(key, weight, num, [, with_count, limit_memory_size])` to `top_k_exact_weighted(num[, with_count, limit_memory_size])(key, weighted)`
auto size = arguments.size();
if (size < 3 || size > 5)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 3 to 5 arguments.", node->name);

ASTPtr expression_list = std::make_shared<ASTExpressionList>();
expression_list->children.assign(arguments.begin() + 2, arguments.end());
parameters = getAggregateFunctionParametersArray(expression_list, "", context);

argument_names = {argument_names[0], argument_names[1]};
types = {types[0], types[1]};
}
else if (lower_name == "quantile")
{
///Translate `quantile(key, level)` to `quantile(level)(key)`,and the default level is 0.5, median fucntion is the alias of quantile(key, 0.5)
if (arguments.size() != 2 && arguments.size() != 1)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires one or two arguments", node->name);
if (arguments.size() == 2)
{
ASTPtr expression_list = std::make_shared<ASTExpressionList>();
expression_list->children.push_back(arguments[1]);
parameters = getAggregateFunctionParametersArray(expression_list, "", context);
}

argument_names = {argument_names[0]};
types = {types[0]};
}
else if (lower_name == "stochastic_linear_regression_state" || lower_name == "stochastic_logistic_regression_state")
{
/// stochastic_linear_regression_state function need 4 arguments(learning rate, l2 regularization coefficient, mini-batch size, method for updating weights) and any number of feature columns
/// for example: stochastic_linear_regression_state(0.1, 0.1, 100, 'sgd', feature_col1, feature_col2....)
/// At least one feature column is required
if (arguments.size() < 5)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Aggregate function {} requires four arguments and at least 1 feature column",
node->name);
auto size = arguments.size();

/// put 4 arguments into parameters
ASTPtr expression_list = std::make_shared<ASTExpressionList>();
for (size_t i = 0; i < 4; ++i)
expression_list->children.push_back(arguments[i]);

parameters = getAggregateFunctionParametersArray(expression_list, "", context);

/// put feature columns into argument_names and types
Names feature_names;
DataTypes feature_types;
for (size_t i = 4; i < arguments.size(); ++i)
/// Helper function to create an expression list and assign parameters.
/// When called with one argument, it moves the part of the arguments from 'start' to the end into parameters.
/// When called with two arguments, it moves the part of the arguments from 'start' to 'end' into parameters.
auto createExpressionListAndAssignParameters = [&](size_t start, size_t end = std::string::npos) {
if (end == std::string::npos)
{
feature_names.push_back(argument_names[i]);
feature_types.push_back(types[i]);
end = size; /// If 'end' is not provided, use the size of the arguments to signify the end.
}
argument_names = feature_names;
types = feature_types;
}
else if (lower_name == "group_uniq_array" || lower_name == "group_uniq_array_retract")
{
/// there are two cases for group_uniq_array function
/// 1. changelog stream: after StreamingFunctionData::visit() group_uniq_array(column, max_size) -> group_uniq_array(column, max_size, _tp_delta), we translate to group_uniq_array(max_size)(column)
/// 2. append-only stream: group_uniq_array(column, max_size) -> group_uniq_array(max_size)(column)
if (arguments.size() >= 2 && argument_names[1] != ProtonConsts::RESERVED_DELTA_FLAG)
if (end <= size) /// If size == 1, we don't need to moves arguments to parameters.
{
ASTPtr expression_list = std::make_shared<ASTExpressionList>();
expression_list->children.push_back(arguments[1]);
// Assign a subset of arguments to the expression list based on the start and end indices.
expression_list->children.assign(arguments.begin() + start, arguments.begin() + end);
// Get the parameters from the expression list and update the types and argument names accordingly.
parameters = getAggregateFunctionParametersArray(expression_list, "", context);
// Erase the arguments that have been moved to parameters from argument_names and types.
argument_names.erase(argument_names.begin() + start, argument_names.begin() + end);
types.erase(types.begin() + start, types.begin() + end);
}
argument_names = {argument_names[0]};
types = {types[0]};
};

/// Create a map to associate different aggregate function names with their respective handling logic.
/// remove unessary arguments, such as group_uniq_array(column, max_size, _tp_delta) -> group_uniq_array(column, max_size)
const std::map<std::string, std::function<void()>> name_to_logic = {
{"min_k",
[&]() {
/// Translate `min_k(key, num[, context...])` to `min_k(num)(key[, context...])`
/// Make the second argument as a const parameter
if (size < 2)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires at least two arguments.", node->name);
createExpressionListAndAssignParameters(1, 2);
}},
{"max_k",
[&]() {
/// Translate `max_k(key, num[, context...])` to `max_k(num)(key[, context...])`
/// Make the second argument as a const parameter
if (size < 2)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires at least two arguments.", node->name);
createExpressionListAndAssignParameters(1, 2);
}},
{"__min_k_retract",
[&]() {
/// Translate `__min_k_retract(key, num[, context...])` to `__min_k_retract(num)(key[, context...])`
/// Make the second argument as a const parameter
if (size < 2)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires at least two arguments.", node->name);
createExpressionListAndAssignParameters(1, 2);
}},
{"__max_k_retract",
[&]() {
/// Translate `__max_k_retract(key, num[, context...])` to `__max_k_retract(num)(key[, context...])`
/// Make the second argument as a const parameter
if (size < 2)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires at least two arguments.", node->name);
createExpressionListAndAssignParameters(1, 2);
}},
{"top_k",
[&]() {
/// Translate `top_k(key, num[, with_count, load_factor])` to `top_k(num[, with_count, load_factor])(key)`
if (size < 2 || size > 4)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 2 to 4 arguments.", node->name);
createExpressionListAndAssignParameters(1);
}},
{"top_k_exact",
[&]() {
/// Translate `top_k_exact(key, num[, with_count, limit_memory_size])` to `top_k_exact(num[, with_count, limit_memory_size])(key)`
if (size < 2 || size > 4)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 2 to 4 arguments.", node->name);
createExpressionListAndAssignParameters(1);
}},
{"top_k_weighted",
[&]() {
/// Translate `top_k_weighted(key, weight, num, [, with_count, load_factor])` to `top_k_weighted(num[, with_count, load_factor])(key, weighted)`
if (size < 3 || size > 5)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 3 to 5 arguments.", node->name);
createExpressionListAndAssignParameters(2);
}},
{"top_k_exact_weighted",
[&]() {
/// Translate `top_k_exact_weighted(key, weight, num, [, with_count, limit_memory_size])` to `top_k_exact_weighted(num[, with_count, limit_memory_size])(key, weighted)`
if (size < 3 || size > 5)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires 3 to 5 arguments.", node->name);
createExpressionListAndAssignParameters(2);
}},
{"quantile",
[&]() {
/// Translate `quantile(key, level)` to `quantile(level)(key)`, and the default level is 0.5
/// Median function is the alias of quantile(key, 0.5)
if (arguments.size() != 2 && arguments.size() != 1)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires one or two arguments", node->name);
createExpressionListAndAssignParameters(1, 2);
}},
{"stochastic_linear_regression_state",
[&]() {
/// stochastic_linear_regression_state function need 4 arguments(learning rate, l2 regularization coefficient, mini-batch size, method for updating weights) and any number of feature columns
/// for example: stochastic_linear_regression_state(0.1, 0.1, 100, 'sgd', feature_col1, feature_col2....)
/// At least one feature column is required
if (arguments.size() < 5)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Aggregate function {} requires four arguments and at least 1 feature column",
node->name);
createExpressionListAndAssignParameters(0, 4);
}},
{"stochastic_logistic_regression_state",
[&]() {
/// stochastic_logistic_regression_state function need 4 arguments(learning rate, l2 regularization coefficient, mini-batch size, method for updating weights) and any number of feature columns
/// for example: stochastic_logistic_regression_state(0.1, 0.1, 100, 'sgd', feature_col1, feature_col2....)
/// At least one feature column is required
if (arguments.size() < 5)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Aggregate function {} requires four arguments and at least 1 feature column",
node->name);
createExpressionListAndAssignParameters(0, 4);
}},
{"group_uniq_array",
[&]() {
/// there are two cases for group_uniq_array function
/// 1. changelog stream: after StreamingFunctionData::visit() group_uniq_array(column, max_size) -> group_uniq_array(column, max_size, _tp_delta), we translate to group_uniq_array(max_size)(column)
/// 2. append-only stream: group_uniq_array(column, max_size) -> group_uniq_array(max_size)(column)
if (size >= 2 && argument_names[1] != ProtonConsts::RESERVED_DELTA_FLAG)
if (size == 3)
{
argument_names.erase(argument_names.begin() + 2);
types.erase(types.begin() + 2);
}
createExpressionListAndAssignParameters(1, 2);
}},
{"group_uniq_array_retract", [&]() {
/// there are two cases for group_uniq_array_retract function
/// 1. changelog stream: after StreamingFunctionData::visit() group_uniq_array_retract(column, max_size) -> group_uniq_array_retract(column, max_size, _tp_delta), we translate to group_uniq_array_retract(max_size)(column)
/// 2. append-only stream: group_uniq_array_retract(column, max_size) -> group_uniq_array_retract(max_size)(column)
if (size >= 2 && argument_names[1] != ProtonConsts::RESERVED_DELTA_FLAG)
if (size == 3)
{
argument_names.erase(argument_names.begin() + 2);
types.erase(types.begin() + 2);
}
createExpressionListAndAssignParameters(1, 2);
}}};

// Find and execute the logic associated with the function name if it exists.
auto logic_it = name_to_logic.find(lower_name);
if (logic_it != name_to_logic.end())
{
logic_it->second();
}
};

Expand Down

0 comments on commit e0647c4

Please sign in to comment.