Skip to content

Commit

Permalink
Adding array functions pushdown to pinot
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangfu0 committed Oct 4, 2020
1 parent 30e4606 commit 52f19fb
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,28 @@ public class PinotAggregationProjectConverter
"/", "DIV");
private static final String FROM_UNIXTIME = "from_unixtime";

private static final Map<String, String> PRESTO_TO_PINOT_ARRAY_AGGREGATIONS = ImmutableMap.<String, String>builder()
.put("array_min", "arrayMin")
.put("array_max", "arrayMax")
.put("array_average", "arrayAverage")
.put("array_sum", "arraySum")
.build();

private final FunctionMetadataManager functionMetadataManager;
private final ConnectorSession session;
private final VariableReferenceExpression arrayVariableHint;

public PinotAggregationProjectConverter(TypeManager typeManager, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution standardFunctionResolution, ConnectorSession session)
{
this(typeManager, functionMetadataManager, standardFunctionResolution, session, null);
}

public PinotAggregationProjectConverter(TypeManager typeManager, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution standardFunctionResolution, ConnectorSession session, VariableReferenceExpression arrayVariableHint)
{
super(typeManager, standardFunctionResolution);
this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null");
this.session = requireNonNull(session, "session is null");
this.arrayVariableHint = arrayVariableHint;
}

@Override
Expand Down Expand Up @@ -217,17 +231,48 @@ private PinotExpression handleFunction(
CallExpression function,
Map<VariableReferenceExpression, PinotQueryGeneratorContext.Selection> context)
{
switch (function.getDisplayName().toLowerCase(ENGLISH)) {
String functionName = function.getDisplayName().toLowerCase(ENGLISH);
switch (functionName) {
case "date_trunc":
boolean useDateTruncation = PinotSessionProperties.isUseDateTruncation(session);
return useDateTruncation ?
handleDateTruncationViaDateTruncation(function, context) :
handleDateTruncationViaDateTimeConvert(function, context);
case "array_max":
case "array_min":
String pinotArrayFunctionName = PRESTO_TO_PINOT_ARRAY_AGGREGATIONS.get(functionName);
requireNonNull(pinotArrayFunctionName, "Converted Pinot array function is null for - " + functionName);
return derived(String.format(
"%s(%s)",
pinotArrayFunctionName,
function.getArguments().get(0).accept(this, context).getDefinition()));
// array_sum and array_reduce are translated to a reduce function with lambda functions, so we pass in
// this arrayVariableHint to help determine which array function it is.
case "reduce":
if (arrayVariableHint != null) {
String arrayFunctionName = getArrayFunctionName(arrayVariableHint);
if (arrayFunctionName != null) {
String inputColumn = function.getArguments().get(0).accept(this, context).getDefinition();
return derived(String.format("%s(%s)", arrayFunctionName, inputColumn));
}
}
default:
throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(), format("function %s not supported yet", function.getDisplayName()));
}
}

// The array function variable names are in the format of `array_sum`, `array_average_0`, `array_sum_1`.
// So we can parse the array function name based on variable name.
private String getArrayFunctionName(VariableReferenceExpression variable)
{
String[] variableNameSplits = variable.getName().split("_");
if (variableNameSplits.length < 2 || variableNameSplits.length > 3) {
return null;
}
String arrayFunctionName = String.format("%s_%s", variableNameSplits[0], variableNameSplits[1]);
return PRESTO_TO_PINOT_ARRAY_AGGREGATIONS.get(arrayFunctionName);
}

private static String getStringFromConstant(RowExpression expression)
{
if (expression instanceof ConstantExpression) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,24 @@ else if (timeFieldExpression instanceof ConstantExpression) {
return Optional.of(timeValueString);
}

private PinotExpression handleContains(
CallExpression contains,
Function<VariableReferenceExpression, Selection> context)
{
if (contains.getArguments().size() != 2) {
throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(), format("Contains operator not supported: %s", contains));
}
RowExpression left = contains.getArguments().get(0);
RowExpression right = contains.getArguments().get(1);
if (!(right instanceof ConstantExpression)) {
throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(), format("Contains operator can not push down non-literal value: %s", right));
}
return derived(format(
"(%s = %s)",
left.accept(this, context).getDefinition(),
right.accept(this, context).getDefinition()));
}

private PinotExpression handleBetween(
CallExpression between,
Function<VariableReferenceExpression, Selection> context)
Expand Down Expand Up @@ -294,6 +312,9 @@ public PinotExpression visitCall(CallExpression call, Function<VariableReference
return handleLogicalBinary(operatorType.getOperator(), call, context);
}
}
if ("contains".equals(functionMetadata.getName().getFunctionName())) {
return handleContains(call, context);
}
// Handle queries like `eventTimestamp < 1391126400000`.
// Otherwise TypeManager.canCoerce(...) will return false and directly fail this query.
if (functionMetadata.getName().getFunctionName().equalsIgnoreCase("$literal$timestamp") ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ public PinotQueryGeneratorContext visitProject(ProjectNode node, PinotQueryGener
RowExpression expression = node.getAssignments().get(variable);
PinotExpression pinotExpression = expression.accept(
contextIn.getVariablesInAggregation().contains(variable) ?
new PinotAggregationProjectConverter(typeManager, functionMetadataManager, standardFunctionResolution, session) :
new PinotAggregationProjectConverter(typeManager, functionMetadataManager, standardFunctionResolution, session, variable) :
pinotProjectExpressionConverter,
context.getSelections());
newSelections.put(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public class TestPinotQueryBase
protected static PinotColumnHandle regionId = new PinotColumnHandle("regionId", BIGINT, REGULAR);
protected static PinotColumnHandle city = new PinotColumnHandle("city", VARCHAR, REGULAR);
protected static final PinotColumnHandle fare = new PinotColumnHandle("fare", DOUBLE, REGULAR);
protected static final PinotColumnHandle scores = array(DOUBLE, "scores");
protected static final PinotColumnHandle secondsSinceEpoch = new PinotColumnHandle("secondsSinceEpoch", BIGINT, REGULAR);
protected static final PinotColumnHandle daysSinceEpoch = new PinotColumnHandle("daysSinceEpoch", DATE, REGULAR);
protected static final PinotColumnHandle millisSinceEpoch = new PinotColumnHandle("millisSinceEpoch", TIMESTAMP, REGULAR);
Expand All @@ -111,10 +112,15 @@ public class TestPinotQueryBase
.put(new VariableReferenceExpression("regionid", BIGINT), new PinotQueryGeneratorContext.Selection("regionId", TABLE_COLUMN)) // direct column reference
.put(new VariableReferenceExpression("regionid$distinct", BIGINT), new PinotQueryGeneratorContext.Selection("regionId", TABLE_COLUMN)) // distinct column reference
.put(new VariableReferenceExpression("city", VARCHAR), new PinotQueryGeneratorContext.Selection("city", TABLE_COLUMN)) // direct column reference
.put(new VariableReferenceExpression("scores", new ArrayType(DOUBLE)), new PinotQueryGeneratorContext.Selection("scores", TABLE_COLUMN)) // direct column reference
.put(new VariableReferenceExpression("fare", DOUBLE), new PinotQueryGeneratorContext.Selection("fare", TABLE_COLUMN)) // direct column reference
.put(new VariableReferenceExpression("totalfare", DOUBLE), new PinotQueryGeneratorContext.Selection("(fare + trip)", DERIVED)) // derived column
.put(new VariableReferenceExpression("count_regionid", BIGINT), new PinotQueryGeneratorContext.Selection("count(regionid)", DERIVED))// derived column
.put(new VariableReferenceExpression("sum_fare", BIGINT), new PinotQueryGeneratorContext.Selection("sum(fare)", DERIVED))// derived column
.put(new VariableReferenceExpression("array_min_0", DOUBLE), new PinotQueryGeneratorContext.Selection("array_min(scores)", DERIVED)) // derived column
.put(new VariableReferenceExpression("array_max_0", DOUBLE), new PinotQueryGeneratorContext.Selection("array_max(scores)", DERIVED)) // derived column
.put(new VariableReferenceExpression("array_sum_0", DOUBLE), new PinotQueryGeneratorContext.Selection("reduce(scores, cast(0 as double), (s, x) -> s + x, s -> s)", DERIVED)) // derived column
.put(new VariableReferenceExpression("array_average_0", DOUBLE), new PinotQueryGeneratorContext.Selection("reduce(scores, CAST(ROW(0.0, 0) AS ROW(sum DOUBLE, count INTEGER)), (s,x) -> CAST(ROW(x + s.sum, s.count + 1) AS ROW(sum DOUBLE, count INTEGER)), s -> IF(s.count = 0, NULL, s.sum / s.count))", DERIVED)) // derived column
.put(new VariableReferenceExpression("secondssinceepoch", BIGINT), new PinotQueryGeneratorContext.Selection("secondsSinceEpoch", TABLE_COLUMN)) // column for datetime functions
.put(new VariableReferenceExpression("dayssinceepoch", DATE), new PinotQueryGeneratorContext.Selection("daysSinceEpoch", TABLE_COLUMN)) // column for date functions
.put(new VariableReferenceExpression("millissinceepoch", TIMESTAMP), new PinotQueryGeneratorContext.Selection("millisSinceEpoch", TABLE_COLUMN)) // column for timestamp functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.util.stream.Collectors;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static java.lang.String.format;
Expand Down Expand Up @@ -146,6 +147,7 @@ private void testUnaryAggregationHelper(BiConsumer<PlanBuilder, PlanBuilder.Aggr
PlanNode justScan = buildPlan(planBuilder -> tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare));
PlanNode filter = buildPlan(planBuilder -> filter(planBuilder, tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare), getRowExpression("fare > 3", defaultSessionHolder)));
PlanNode anotherFilter = buildPlan(planBuilder -> filter(planBuilder, tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare), getRowExpression("secondssinceepoch between 200 and 300 and regionid >= 40", defaultSessionHolder)));
PlanNode filterWithMultiValue = buildPlan(planBuilder -> filter(planBuilder, tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare, scores), getRowExpression("contains(scores, 100) OR contains(scores, 200)", defaultSessionHolder)));
testPinotQuery(
planBuilder -> planBuilder.aggregation(aggBuilder -> aggregationFunctionBuilder.accept(planBuilder, aggBuilder.source(justScan).globalGrouping())),
format("SELECT %s FROM realtimeOnly", getExpectedAggOutput(expectedAggOutput, "")));
Expand All @@ -161,6 +163,9 @@ private void testUnaryAggregationHelper(BiConsumer<PlanBuilder, PlanBuilder.Aggr
testPinotQuery(
planBuilder -> planBuilder.aggregation(aggBuilder -> aggregationFunctionBuilder.accept(planBuilder, aggBuilder.source(anotherFilter).singleGroupingSet(variable("regionid"), variable("city")))),
format("SELECT %s FROM realtimeOnly WHERE ((secondsSinceEpoch BETWEEN 200 AND 300) AND (regionId >= 40)) GROUP BY regionId, city %s 10000", getExpectedAggOutput(expectedAggOutput, "regionId, city"), getGroupByLimitKey()));
testPinotQuery(
planBuilder -> planBuilder.aggregation(aggBuilder -> aggregationFunctionBuilder.accept(planBuilder, aggBuilder.source(filterWithMultiValue).singleGroupingSet(variable("regionid"), variable("city")))),
format("SELECT %s FROM realtimeOnly WHERE ((scores = 100) OR (scores = 200)) GROUP BY regionId, city %s 10000", getExpectedAggOutput(expectedAggOutput, "regionId, city"), getGroupByLimitKey()));
}

protected String getGroupByLimitKey()
Expand Down Expand Up @@ -287,6 +292,42 @@ public void testAggWithUDFInGroupBy()
String.format("SELECT %s FROM realtimeOnly GROUP BY dateTimeConvert(SUB(secondsSinceEpoch, 50), '1:SECONDS:EPOCH', '1:MILLISECONDS:EPOCH', '1:DAYS'), city %s 10000", getExpectedAggOutput("count(*)", "dateTimeConvert(SUB(secondsSinceEpoch, 50), '1:SECONDS:EPOCH', '1:MILLISECONDS:EPOCH', '1:DAYS'), city"), getGroupByLimitKey()));
}

@Test
public void testAggWithArrayFunctionsInGroupBy()
{
LinkedHashMap<String, String> aggProjection = new LinkedHashMap<>();
aggProjection.put("array_max_0", "array_max(scores)");
PlanNode justMaxScores = buildPlan(planBuilder -> project(planBuilder, tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare, scores), aggProjection, defaultSessionHolder));
testPinotQuery(
planBuilder -> planBuilder.aggregation(aggBuilder -> aggBuilder.source(justMaxScores).singleGroupingSet(new VariableReferenceExpression("array_max_0", DOUBLE)).addAggregation(planBuilder.variable("agg"), getRowExpression("count(*)", defaultSessionHolder))),
String.format("SELECT %s FROM realtimeOnly GROUP BY arrayMax(scores) %s 10000", getExpectedAggOutput("count(*)", "arrayMax(scores)"), getGroupByLimitKey()));
aggProjection.put("city", "city");
PlanNode newScanWithCity = buildPlan(planBuilder -> project(planBuilder, tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare, scores), aggProjection, defaultSessionHolder));
testPinotQuery(
planBuilder -> planBuilder.aggregation(aggBuilder -> aggBuilder.source(newScanWithCity).singleGroupingSet(new VariableReferenceExpression("array_max_0", DOUBLE), variable("city")).addAggregation(planBuilder.variable("agg"), getRowExpression("count(*)", defaultSessionHolder))),
String.format("SELECT %s FROM realtimeOnly GROUP BY arrayMax(scores), city %s 10000", getExpectedAggOutput("count(*)", "arrayMax(scores), city"), getGroupByLimitKey()));
}

private void testAggWithArrayFunction(String functionVariable, String prestoFunctionExpression, String pinotFunctionExpression)
{
LinkedHashMap<String, String> aggProjection = new LinkedHashMap<>();
aggProjection.put("city", "city");
aggProjection.put(functionVariable, prestoFunctionExpression);
PlanNode aggregationPlanNode = buildPlan(planBuilder -> project(planBuilder, tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare, scores), aggProjection, defaultSessionHolder));
testPinotQuery(
planBuilder -> planBuilder.aggregation(aggBuilder -> aggBuilder.source(aggregationPlanNode).singleGroupingSet(variable("city")).addAggregation(planBuilder.variable("agg"), getRowExpression(String.format("sum(%s)", functionVariable), defaultSessionHolder))),
String.format("SELECT %s FROM realtimeOnly GROUP BY city %s 10000", getExpectedAggOutput(String.format("sum(%s)", pinotFunctionExpression), "city"), getGroupByLimitKey()));
}

@Test
public void testAggWithArrayFunctions()
{
testAggWithArrayFunction("array_min_0", "array_min(scores)", "arrayMin(scores)");
testAggWithArrayFunction("array_max_0", "array_max(scores)", "arrayMax(scores)");
testAggWithArrayFunction("array_sum_0", "reduce(scores, cast(0 as double), (s, x) -> s + x, s -> s)", "arraySum(scores)");
testAggWithArrayFunction("array_average_0", "reduce(scores, CAST(ROW(0.0, 0) AS ROW(sum DOUBLE, count INTEGER)), (s,x) -> CAST(ROW(x + s.sum, s.count + 1) AS ROW(sum DOUBLE, count INTEGER)), s -> IF(s.count = 0, NULL, s.sum / s.count))", "arrayAverage(scores)");
}

@Test
public void testMultipleAggregatesWithOutGroupBy()
{
Expand Down

0 comments on commit 52f19fb

Please sign in to comment.