Skip to content

Commit

Permalink
Take FunctionRegistry instead of entire Metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejgrzybek authored and martint committed May 17, 2017
1 parent cddd1c8 commit d37e3a6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea
new RemoveUnreferencedScalarInputApplyNodes(), new RemoveUnreferencedScalarInputApplyNodes(),
new TransformUncorrelatedInPredicateSubqueryToSemiJoin(), new TransformUncorrelatedInPredicateSubqueryToSemiJoin(),
new TransformUncorrelatedScalarToJoin(), new TransformUncorrelatedScalarToJoin(),
new TransformCorrelatedScalarAggregationToJoin(metadata), new TransformCorrelatedScalarAggregationToJoin(metadata.getFunctionRegistry()),
new PredicatePushDown(metadata, sqlParser), new PredicatePushDown(metadata, sqlParser),
new PruneUnreferencedOutputs(), new PruneUnreferencedOutputs(),
new IterativeOptimizer( new IterativeOptimizer(
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package com.facebook.presto.sql.planner.optimizations; package com.facebook.presto.sql.planner.optimizations;


import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.BooleanType;
Expand Down Expand Up @@ -61,13 +60,15 @@


public class ScalarAggregationToJoinRewriter public class ScalarAggregationToJoinRewriter
{ {
private final Metadata metadata; private static final QualifiedName COUNT = QualifiedName.of("count");

private final FunctionRegistry functionRegistry;
private final SymbolAllocator symbolAllocator; private final SymbolAllocator symbolAllocator;
private final PlanNodeIdAllocator idAllocator; private final PlanNodeIdAllocator idAllocator;


public ScalarAggregationToJoinRewriter(Metadata metadata, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) public ScalarAggregationToJoinRewriter(FunctionRegistry functionRegistry, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator)
{ {
this.metadata = requireNonNull(metadata, "metadata is null"); this.functionRegistry = requireNonNull(functionRegistry, "metadata is null");
this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null");
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
} }
Expand Down Expand Up @@ -162,19 +163,17 @@ private Optional<AggregationNode> createAggregationNode(
{ {
ImmutableMap.Builder<Symbol, FunctionCall> aggregations = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, FunctionCall> aggregations = ImmutableMap.builder();
ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder();
FunctionRegistry functionRegistry = metadata.getFunctionRegistry();
for (Map.Entry<Symbol, FunctionCall> entry : scalarAggregation.getAggregations().entrySet()) { for (Map.Entry<Symbol, FunctionCall> entry : scalarAggregation.getAggregations().entrySet()) {
FunctionCall call = entry.getValue(); FunctionCall call = entry.getValue();
QualifiedName count = QualifiedName.of("count");
Symbol symbol = entry.getKey(); Symbol symbol = entry.getKey();
if (call.getName().equals(count)) { if (call.getName().equals(COUNT)) {
aggregations.put(symbol, new FunctionCall( aggregations.put(symbol, new FunctionCall(
count, COUNT,
ImmutableList.of(nonNullableAggregationSourceSymbol.toSymbolReference()))); ImmutableList.of(nonNullableAggregationSourceSymbol.toSymbolReference())));
List<TypeSignature> scalarAggregationSourceTypeSignatures = ImmutableList.of( List<TypeSignature> scalarAggregationSourceTypeSignatures = ImmutableList.of(
symbolAllocator.getTypes().get(nonNullableAggregationSourceSymbol).getTypeSignature()); symbolAllocator.getTypes().get(nonNullableAggregationSourceSymbol).getTypeSignature());
functions.put(symbol, functionRegistry.resolveFunction( functions.put(symbol, functionRegistry.resolveFunction(
count, COUNT,
fromTypeSignatures(scalarAggregationSourceTypeSignatures))); fromTypeSignatures(scalarAggregationSourceTypeSignatures)));
} }
else { else {
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
package com.facebook.presto.sql.planner.optimizations; package com.facebook.presto.sql.planner.optimizations;


import com.facebook.presto.Session; import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.Symbol;
Expand Down Expand Up @@ -65,11 +65,11 @@
public class TransformCorrelatedScalarAggregationToJoin public class TransformCorrelatedScalarAggregationToJoin
implements PlanOptimizer implements PlanOptimizer
{ {
private final Metadata metadata; private final FunctionRegistry functionRegistry;


public TransformCorrelatedScalarAggregationToJoin(Metadata metadata) public TransformCorrelatedScalarAggregationToJoin(FunctionRegistry functionRegistry)
{ {
this.metadata = requireNonNull(metadata, "metadata is null"); this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry is null");
} }


@Override @Override
Expand All @@ -80,21 +80,21 @@ public PlanNode optimize(
SymbolAllocator symbolAllocator, SymbolAllocator symbolAllocator,
PlanNodeIdAllocator idAllocator) PlanNodeIdAllocator idAllocator)
{ {
return rewriteWith(new Rewriter(idAllocator, symbolAllocator, metadata), plan, null); return rewriteWith(new Rewriter(idAllocator, symbolAllocator, functionRegistry), plan, null);
} }


private static class Rewriter private static class Rewriter
extends SimplePlanRewriter<PlanNode> extends SimplePlanRewriter<PlanNode>
{ {
private final PlanNodeIdAllocator idAllocator; private final PlanNodeIdAllocator idAllocator;
private final SymbolAllocator symbolAllocator; private final SymbolAllocator symbolAllocator;
private final Metadata metadata; private final FunctionRegistry functionRegistry;


public Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Metadata metadata) public Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, FunctionRegistry functionRegistry)
{ {
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null");
this.metadata = requireNonNull(metadata, "metadata is null"); this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry is null");
} }


@Override @Override
Expand All @@ -107,7 +107,7 @@ public PlanNode visitApply(ApplyNode node, RewriteContext<PlanNode> context)
.skipOnlyWhen(isInstanceOfAny(ProjectNode.class, EnforceSingleRowNode.class)) .skipOnlyWhen(isInstanceOfAny(ProjectNode.class, EnforceSingleRowNode.class))
.findFirst(); .findFirst();
if (aggregation.isPresent() && aggregation.get().getGroupingKeys().isEmpty()) { if (aggregation.isPresent() && aggregation.get().getGroupingKeys().isEmpty()) {
ScalarAggregationToJoinRewriter scalarAggregationToJoinRewriter = new ScalarAggregationToJoinRewriter(metadata, symbolAllocator, idAllocator); ScalarAggregationToJoinRewriter scalarAggregationToJoinRewriter = new ScalarAggregationToJoinRewriter(functionRegistry, symbolAllocator, idAllocator);
return scalarAggregationToJoinRewriter.rewriteScalarAggregation(rewrittenNode, aggregation.get()); return scalarAggregationToJoinRewriter.rewriteScalarAggregation(rewrittenNode, aggregation.get());
} }
} }
Expand Down

0 comments on commit d37e3a6

Please sign in to comment.