diff --git a/presto-pinot-toolkit/src/main/java/com/facebook/presto/pinot/query/PinotAggregationProjectConverter.java b/presto-pinot-toolkit/src/main/java/com/facebook/presto/pinot/query/PinotAggregationProjectConverter.java index fd1988905413..a632b18d361c 100644 --- a/presto-pinot-toolkit/src/main/java/com/facebook/presto/pinot/query/PinotAggregationProjectConverter.java +++ b/presto-pinot-toolkit/src/main/java/com/facebook/presto/pinot/query/PinotAggregationProjectConverter.java @@ -51,14 +51,28 @@ public class PinotAggregationProjectConverter "/", "DIV"); private static final String FROM_UNIXTIME = "from_unixtime"; + private static final Map PRESTO_TO_PINOT_ARRAY_AGGREGATIONS = ImmutableMap.builder() + .put("array_min", "arrayMin") + .put("array_max", "arrayMax") + .put("array_average", "arrayAverage") + .put("array_sum", "arraySum") + .build(); + private final FunctionMetadataManager functionMetadataManager; private final ConnectorSession session; + private final VariableReferenceExpression arrayVariableHint; public PinotAggregationProjectConverter(TypeManager typeManager, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution standardFunctionResolution, ConnectorSession session) + { + this(typeManager, functionMetadataManager, standardFunctionResolution, session, null); + } + + public PinotAggregationProjectConverter(TypeManager typeManager, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution standardFunctionResolution, ConnectorSession session, VariableReferenceExpression arrayVariableHint) { super(typeManager, standardFunctionResolution); this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); this.session = requireNonNull(session, "session is null"); + this.arrayVariableHint = arrayVariableHint; } @Override @@ -217,17 +231,48 @@ private PinotExpression handleFunction( CallExpression function, Map context) { - switch (function.getDisplayName().toLowerCase(ENGLISH)) { + String functionName = function.getDisplayName().toLowerCase(ENGLISH); + switch (functionName) { case "date_trunc": boolean useDateTruncation = PinotSessionProperties.isUseDateTruncation(session); return useDateTruncation ? handleDateTruncationViaDateTruncation(function, context) : handleDateTruncationViaDateTimeConvert(function, context); + case "array_max": + case "array_min": + String pinotArrayFunctionName = PRESTO_TO_PINOT_ARRAY_AGGREGATIONS.get(functionName); + requireNonNull(pinotArrayFunctionName, "Converted Pinot array function is null for - " + functionName); + return derived(String.format( + "%s(%s)", + pinotArrayFunctionName, + function.getArguments().get(0).accept(this, context).getDefinition())); + // array_sum and array_reduce are translated to a reduce function with lambda functions, so we pass in + // this arrayVariableHint to help determine which array function it is. + case "reduce": + if (arrayVariableHint != null) { + String arrayFunctionName = getArrayFunctionName(arrayVariableHint); + if (arrayFunctionName != null) { + String inputColumn = function.getArguments().get(0).accept(this, context).getDefinition(); + return derived(String.format("%s(%s)", arrayFunctionName, inputColumn)); + } + } default: throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(), format("function %s not supported yet", function.getDisplayName())); } } + // The array function variable names are in the format of `array_sum`, `array_average_0`, `array_sum_1`. + // So we can parse the array function name based on variable name. + private String getArrayFunctionName(VariableReferenceExpression variable) + { + String[] variableNameSplits = variable.getName().split("_"); + if (variableNameSplits.length < 2 || variableNameSplits.length > 3) { + return null; + } + String arrayFunctionName = String.format("%s_%s", variableNameSplits[0], variableNameSplits[1]); + return PRESTO_TO_PINOT_ARRAY_AGGREGATIONS.get(arrayFunctionName); + } + private static String getStringFromConstant(RowExpression expression) { if (expression instanceof ConstantExpression) { diff --git a/presto-pinot-toolkit/src/main/java/com/facebook/presto/pinot/query/PinotFilterExpressionConverter.java b/presto-pinot-toolkit/src/main/java/com/facebook/presto/pinot/query/PinotFilterExpressionConverter.java index 6173c471a63c..1dcd2f43d2c7 100644 --- a/presto-pinot-toolkit/src/main/java/com/facebook/presto/pinot/query/PinotFilterExpressionConverter.java +++ b/presto-pinot-toolkit/src/main/java/com/facebook/presto/pinot/query/PinotFilterExpressionConverter.java @@ -201,6 +201,24 @@ else if (timeFieldExpression instanceof ConstantExpression) { return Optional.of(timeValueString); } + private PinotExpression handleContains( + CallExpression contains, + Function context) + { + if (contains.getArguments().size() != 2) { + throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(), format("Contains operator not supported: %s", contains)); + } + RowExpression left = contains.getArguments().get(0); + RowExpression right = contains.getArguments().get(1); + if (!(right instanceof ConstantExpression)) { + throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(), format("Contains operator can not push down non-literal value: %s", right)); + } + return derived(format( + "(%s = %s)", + left.accept(this, context).getDefinition(), + right.accept(this, context).getDefinition())); + } + private PinotExpression handleBetween( CallExpression between, Function context) @@ -294,6 +312,9 @@ public PinotExpression visitCall(CallExpression call, Function s + x, s -> s)", DERIVED)) // derived column + .put(new VariableReferenceExpression("array_average_0", DOUBLE), new PinotQueryGeneratorContext.Selection("reduce(scores, CAST(ROW(0.0, 0) AS ROW(sum DOUBLE, count INTEGER)), (s,x) -> CAST(ROW(x + s.sum, s.count + 1) AS ROW(sum DOUBLE, count INTEGER)), s -> IF(s.count = 0, NULL, s.sum / s.count))", DERIVED)) // derived column .put(new VariableReferenceExpression("secondssinceepoch", BIGINT), new PinotQueryGeneratorContext.Selection("secondsSinceEpoch", TABLE_COLUMN)) // column for datetime functions .put(new VariableReferenceExpression("dayssinceepoch", DATE), new PinotQueryGeneratorContext.Selection("daysSinceEpoch", TABLE_COLUMN)) // column for date functions .put(new VariableReferenceExpression("millissinceepoch", TIMESTAMP), new PinotQueryGeneratorContext.Selection("millisSinceEpoch", TABLE_COLUMN)) // column for timestamp functions diff --git a/presto-pinot-toolkit/src/test/java/com/facebook/presto/pinot/query/TestPinotQueryGenerator.java b/presto-pinot-toolkit/src/test/java/com/facebook/presto/pinot/query/TestPinotQueryGenerator.java index 3c4701f97584..d9ad048eff89 100644 --- a/presto-pinot-toolkit/src/test/java/com/facebook/presto/pinot/query/TestPinotQueryGenerator.java +++ b/presto-pinot-toolkit/src/test/java/com/facebook/presto/pinot/query/TestPinotQueryGenerator.java @@ -44,6 +44,7 @@ import java.util.stream.Collectors; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static java.lang.String.format; @@ -146,6 +147,7 @@ private void testUnaryAggregationHelper(BiConsumer tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare)); PlanNode filter = buildPlan(planBuilder -> filter(planBuilder, tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare), getRowExpression("fare > 3", defaultSessionHolder))); PlanNode anotherFilter = buildPlan(planBuilder -> filter(planBuilder, tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare), getRowExpression("secondssinceepoch between 200 and 300 and regionid >= 40", defaultSessionHolder))); + PlanNode filterWithMultiValue = buildPlan(planBuilder -> filter(planBuilder, tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare, scores), getRowExpression("contains(scores, 100) OR contains(scores, 200)", defaultSessionHolder))); testPinotQuery( planBuilder -> planBuilder.aggregation(aggBuilder -> aggregationFunctionBuilder.accept(planBuilder, aggBuilder.source(justScan).globalGrouping())), format("SELECT %s FROM realtimeOnly", getExpectedAggOutput(expectedAggOutput, ""))); @@ -161,6 +163,9 @@ private void testUnaryAggregationHelper(BiConsumer planBuilder.aggregation(aggBuilder -> aggregationFunctionBuilder.accept(planBuilder, aggBuilder.source(anotherFilter).singleGroupingSet(variable("regionid"), variable("city")))), format("SELECT %s FROM realtimeOnly WHERE ((secondsSinceEpoch BETWEEN 200 AND 300) AND (regionId >= 40)) GROUP BY regionId, city %s 10000", getExpectedAggOutput(expectedAggOutput, "regionId, city"), getGroupByLimitKey())); + testPinotQuery( + planBuilder -> planBuilder.aggregation(aggBuilder -> aggregationFunctionBuilder.accept(planBuilder, aggBuilder.source(filterWithMultiValue).singleGroupingSet(variable("regionid"), variable("city")))), + format("SELECT %s FROM realtimeOnly WHERE ((scores = 100) OR (scores = 200)) GROUP BY regionId, city %s 10000", getExpectedAggOutput(expectedAggOutput, "regionId, city"), getGroupByLimitKey())); } protected String getGroupByLimitKey() @@ -287,6 +292,42 @@ public void testAggWithUDFInGroupBy() String.format("SELECT %s FROM realtimeOnly GROUP BY dateTimeConvert(SUB(secondsSinceEpoch, 50), '1:SECONDS:EPOCH', '1:MILLISECONDS:EPOCH', '1:DAYS'), city %s 10000", getExpectedAggOutput("count(*)", "dateTimeConvert(SUB(secondsSinceEpoch, 50), '1:SECONDS:EPOCH', '1:MILLISECONDS:EPOCH', '1:DAYS'), city"), getGroupByLimitKey())); } + @Test + public void testAggWithArrayFunctionsInGroupBy() + { + LinkedHashMap aggProjection = new LinkedHashMap<>(); + aggProjection.put("array_max_0", "array_max(scores)"); + PlanNode justMaxScores = buildPlan(planBuilder -> project(planBuilder, tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare, scores), aggProjection, defaultSessionHolder)); + testPinotQuery( + planBuilder -> planBuilder.aggregation(aggBuilder -> aggBuilder.source(justMaxScores).singleGroupingSet(new VariableReferenceExpression("array_max_0", DOUBLE)).addAggregation(planBuilder.variable("agg"), getRowExpression("count(*)", defaultSessionHolder))), + String.format("SELECT %s FROM realtimeOnly GROUP BY arrayMax(scores) %s 10000", getExpectedAggOutput("count(*)", "arrayMax(scores)"), getGroupByLimitKey())); + aggProjection.put("city", "city"); + PlanNode newScanWithCity = buildPlan(planBuilder -> project(planBuilder, tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare, scores), aggProjection, defaultSessionHolder)); + testPinotQuery( + planBuilder -> planBuilder.aggregation(aggBuilder -> aggBuilder.source(newScanWithCity).singleGroupingSet(new VariableReferenceExpression("array_max_0", DOUBLE), variable("city")).addAggregation(planBuilder.variable("agg"), getRowExpression("count(*)", defaultSessionHolder))), + String.format("SELECT %s FROM realtimeOnly GROUP BY arrayMax(scores), city %s 10000", getExpectedAggOutput("count(*)", "arrayMax(scores), city"), getGroupByLimitKey())); + } + + private void testAggWithArrayFunction(String functionVariable, String prestoFunctionExpression, String pinotFunctionExpression) + { + LinkedHashMap aggProjection = new LinkedHashMap<>(); + aggProjection.put("city", "city"); + aggProjection.put(functionVariable, prestoFunctionExpression); + PlanNode aggregationPlanNode = buildPlan(planBuilder -> project(planBuilder, tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare, scores), aggProjection, defaultSessionHolder)); + testPinotQuery( + planBuilder -> planBuilder.aggregation(aggBuilder -> aggBuilder.source(aggregationPlanNode).singleGroupingSet(variable("city")).addAggregation(planBuilder.variable("agg"), getRowExpression(String.format("sum(%s)", functionVariable), defaultSessionHolder))), + String.format("SELECT %s FROM realtimeOnly GROUP BY city %s 10000", getExpectedAggOutput(String.format("sum(%s)", pinotFunctionExpression), "city"), getGroupByLimitKey())); + } + + @Test + public void testAggWithArrayFunctions() + { + testAggWithArrayFunction("array_min_0", "array_min(scores)", "arrayMin(scores)"); + testAggWithArrayFunction("array_max_0", "array_max(scores)", "arrayMax(scores)"); + testAggWithArrayFunction("array_sum_0", "reduce(scores, cast(0 as double), (s, x) -> s + x, s -> s)", "arraySum(scores)"); + testAggWithArrayFunction("array_average_0", "reduce(scores, CAST(ROW(0.0, 0) AS ROW(sum DOUBLE, count INTEGER)), (s,x) -> CAST(ROW(x + s.sum, s.count + 1) AS ROW(sum DOUBLE, count INTEGER)), s -> IF(s.count = 0, NULL, s.sum / s.count))", "arrayAverage(scores)"); + } + @Test public void testMultipleAggregatesWithOutGroupBy() {