Skip to content

Commit

Permalink
Fix Count(*) on empty relation returns NULL when optimize_mixed_disti…
Browse files Browse the repository at this point in the history
…nct_aggregation is turned on.
  • Loading branch information
kaka11chen authored and sopel39 committed Jan 7, 2019
1 parent fe6635b commit b1bb391
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
Expand Up @@ -33,6 +33,7 @@
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
Expand Down Expand Up @@ -151,6 +152,9 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext<Optional<A
// Change aggregate node to do second aggregation, handles this part of optimized plan mentioned above:
// SELECT a1, a2,..., an, arbitrary(if(group = 0, f1)),...., arbitrary(if(group = 0, fm)), F(if(group = 1, c))
ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder();
// Add coalesce projection node to handle count(), count_if(), approx_distinct() functions return a
// non-null result without any input
ImmutableMap.Builder<Symbol, Symbol> coalesceSymbolsBuilder = ImmutableMap.builder();
for (Map.Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) {
FunctionCall functionCall = entry.getValue().getCall();
if (entry.getValue().getMask().isPresent()) {
Expand All @@ -167,14 +171,25 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext<Optional<A
// Aggregations on non-distinct are already done by new node, just extract the non-null value
Symbol argument = aggregateInfo.getNewNonDistinctAggregateSymbols().get(entry.getKey());
QualifiedName functionName = QualifiedName.of("arbitrary");
aggregations.put(entry.getKey(), new Aggregation(
String signatureName = entry.getValue().getSignature().getName();
Aggregation aggregation = new Aggregation(
new FunctionCall(functionName, functionCall.getWindow(), false, ImmutableList.of(argument.toSymbolReference())),
getFunctionSignature(functionName, argument),
Optional.empty()));
Optional.empty());
if (signatureName.equals("count")
|| signatureName.equals("count_if") || signatureName.equals("approx_distinct")) {
Symbol newSymbol = symbolAllocator.newSymbol("expr", symbolAllocator.getTypes().get(entry.getKey()));
aggregations.put(newSymbol, aggregation);
coalesceSymbolsBuilder.put(newSymbol, entry.getKey());
}
else {
aggregations.put(entry.getKey(), aggregation);
}
}
}
Map<Symbol, Symbol> coalesceSymbols = coalesceSymbolsBuilder.build();

return new AggregationNode(
AggregationNode aggregationNode = new AggregationNode(
idAllocator.getNextId(),
source,
aggregations.build(),
Expand All @@ -183,6 +198,23 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext<Optional<A
node.getStep(),
Optional.empty(),
node.getGroupIdSymbol());

if (coalesceSymbols.isEmpty()) {
return aggregationNode;
}

Assignments.Builder outputSymbols = Assignments.builder();
for (Symbol symbol : aggregationNode.getOutputSymbols()) {
if (coalesceSymbols.containsKey(symbol)) {
Expression expression = new CoalesceExpression(symbol.toSymbolReference(), new Cast(new LongLiteral("0"), "bigint"));
outputSymbols.put(coalesceSymbols.get(symbol), expression);
}
else {
outputSymbols.putIdentity(symbol);
}
}

return new ProjectNode(idAllocator.getNextId(), aggregationNode, outputSymbols.build());
}

@Override
Expand Down
Expand Up @@ -28,10 +28,8 @@ public TestOptimizeMixedDistinctAggregations()
@Override
public void testCountDistinct()
{
// TODO https://github.com/prestodb/presto/issues/8894 . Once fixed, remove test override.

assertQuery("SELECT COUNT(DISTINCT custkey + 1) FROM orders", "SELECT COUNT(*) FROM (SELECT DISTINCT custkey + 1 FROM orders) t");
// assertQuery("SELECT COUNT(DISTINCT linenumber), COUNT(*) from lineitem where linenumber < 0");
assertQuery("SELECT COUNT(DISTINCT linenumber), COUNT(*) from lineitem where linenumber < 0");
}

// TODO add dedicated test cases and remove `extends AbstractTestAggregation`
Expand Down

0 comments on commit b1bb391

Please sign in to comment.