Skip to content

Commit

Permalink
Add VariableReferenceExpression to RowNumberNode
Browse files Browse the repository at this point in the history
  • Loading branch information
rongrong committed Jun 10, 2019
1 parent 887094f commit 88b8805
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 17 deletions.
Expand Up @@ -264,7 +264,7 @@ public PlanWithProperties visitRowNumber(RowNumberNode node, HashComputationSet
node.getId(), node.getId(),
child.getNode(), child.getNode(),
node.getPartitionBy(), node.getPartitionBy(),
node.getRowNumberSymbol(), node.getRowNumberVariable(),
node.getMaxRowCountPerPartition(), node.getMaxRowCountPerPartition(),
Optional.of(hashSymbol)), Optional.of(hashSymbol)),
child.getHashSymbols()); child.getHashSymbols());
Expand Down
Expand Up @@ -604,7 +604,7 @@ public PlanNode visitRowNumber(RowNumberNode node, RewriteContext<Set<Symbol>> c
} }
PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());


return new RowNumberNode(node.getId(), source, node.getPartitionBy(), node.getRowNumberSymbol(), node.getMaxRowCountPerPartition(), node.getHashSymbol()); return new RowNumberNode(node.getId(), source, node.getPartitionBy(), node.getRowNumberVariable(), node.getMaxRowCountPerPartition(), node.getHashSymbol());
} }


@Override @Override
Expand Down
Expand Up @@ -412,7 +412,7 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext<Void> cont
@Override @Override
public PlanNode visitRowNumber(RowNumberNode node, RewriteContext<Void> context) public PlanNode visitRowNumber(RowNumberNode node, RewriteContext<Void> context)
{ {
return new RowNumberNode(node.getId(), context.rewrite(node.getSource()), canonicalizeAndDistinct(node.getPartitionBy()), canonicalize(node.getRowNumberSymbol()), node.getMaxRowCountPerPartition(), canonicalize(node.getHashSymbol())); return new RowNumberNode(node.getId(), context.rewrite(node.getSource()), canonicalizeAndDistinct(node.getPartitionBy()), canonicalize(node.getRowNumberVariable()), node.getMaxRowCountPerPartition(), canonicalize(node.getHashSymbol()));
} }


@Override @Override
Expand Down
Expand Up @@ -118,7 +118,7 @@ public PlanNode visitWindow(WindowNode node, RewriteContext<Void> context)
return new RowNumberNode(idAllocator.getNextId(), return new RowNumberNode(idAllocator.getNextId(),
rewrittenSource, rewrittenSource,
node.getPartitionBy(), node.getPartitionBy(),
getOnlyElement(node.getCreatedSymbols()), getOnlyElement(node.getWindowFunctions().keySet()),
Optional.empty(), Optional.empty(),
Optional.empty()); Optional.empty());
} }
Expand Down Expand Up @@ -257,7 +257,7 @@ private static RowNumberNode mergeLimit(RowNumberNode node, int newRowCountPerPa
if (node.getMaxRowCountPerPartition().isPresent()) { if (node.getMaxRowCountPerPartition().isPresent()) {
newRowCountPerPartition = Math.min(node.getMaxRowCountPerPartition().get(), newRowCountPerPartition); newRowCountPerPartition = Math.min(node.getMaxRowCountPerPartition().get(), newRowCountPerPartition);
} }
return new RowNumberNode(node.getId(), node.getSource(), node.getPartitionBy(), node.getRowNumberSymbol(), Optional.of(newRowCountPerPartition), node.getHashSymbol()); return new RowNumberNode(node.getId(), node.getSource(), node.getPartitionBy(), node.getRowNumberVariable(), Optional.of(newRowCountPerPartition), node.getHashSymbol());
} }


private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limit) private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limit)
Expand Down
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.sql.planner.plan; package com.facebook.presto.sql.planner.plan;


import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.Symbol;
import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
Expand All @@ -25,7 +26,6 @@
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;


import static com.google.common.collect.Iterables.concat;
import static java.util.Objects.requireNonNull; import static java.util.Objects.requireNonNull;


@Immutable @Immutable
Expand All @@ -35,29 +35,29 @@ public final class RowNumberNode
private final PlanNode source; private final PlanNode source;
private final List<Symbol> partitionBy; private final List<Symbol> partitionBy;
private final Optional<Integer> maxRowCountPerPartition; private final Optional<Integer> maxRowCountPerPartition;
private final Symbol rowNumberSymbol; private final VariableReferenceExpression rowNumberVariable;
private final Optional<Symbol> hashSymbol; private final Optional<Symbol> hashSymbol;


@JsonCreator @JsonCreator
public RowNumberNode( public RowNumberNode(
@JsonProperty("id") PlanNodeId id, @JsonProperty("id") PlanNodeId id,
@JsonProperty("source") PlanNode source, @JsonProperty("source") PlanNode source,
@JsonProperty("partitionBy") List<Symbol> partitionBy, @JsonProperty("partitionBy") List<Symbol> partitionBy,
@JsonProperty("rowNumberSymbol") Symbol rowNumberSymbol, @JsonProperty("rowNumberVariable") VariableReferenceExpression rowNumberVariable,
@JsonProperty("maxRowCountPerPartition") Optional<Integer> maxRowCountPerPartition, @JsonProperty("maxRowCountPerPartition") Optional<Integer> maxRowCountPerPartition,
@JsonProperty("hashSymbol") Optional<Symbol> hashSymbol) @JsonProperty("hashSymbol") Optional<Symbol> hashSymbol)
{ {
super(id); super(id);


requireNonNull(source, "source is null"); requireNonNull(source, "source is null");
requireNonNull(partitionBy, "partitionBy is null"); requireNonNull(partitionBy, "partitionBy is null");
requireNonNull(rowNumberSymbol, "rowNumberSymbol is null"); requireNonNull(rowNumberVariable, "rowNumberVariable is null");
requireNonNull(maxRowCountPerPartition, "maxRowCountPerPartition is null"); requireNonNull(maxRowCountPerPartition, "maxRowCountPerPartition is null");
requireNonNull(hashSymbol, "hashSymbol is null"); requireNonNull(hashSymbol, "hashSymbol is null");


this.source = source; this.source = source;
this.partitionBy = ImmutableList.copyOf(partitionBy); this.partitionBy = ImmutableList.copyOf(partitionBy);
this.rowNumberSymbol = rowNumberSymbol; this.rowNumberVariable = rowNumberVariable;
this.maxRowCountPerPartition = maxRowCountPerPartition; this.maxRowCountPerPartition = maxRowCountPerPartition;
this.hashSymbol = hashSymbol; this.hashSymbol = hashSymbol;
} }
Expand All @@ -71,7 +71,19 @@ public List<PlanNode> getSources()
@Override @Override
public List<Symbol> getOutputSymbols() public List<Symbol> getOutputSymbols()
{ {
return ImmutableList.copyOf(concat(source.getOutputSymbols(), ImmutableList.of(rowNumberSymbol))); return ImmutableList.<Symbol>builder()
.addAll(source.getOutputSymbols())
.add(new Symbol(rowNumberVariable.getName()))
.build();
}

@Override
public List<VariableReferenceExpression> getOutputVariables()
{
return ImmutableList.<VariableReferenceExpression>builder()
.addAll(source.getOutputVariables())
.add(rowNumberVariable)
.build();
} }


@JsonProperty @JsonProperty
Expand All @@ -86,10 +98,15 @@ public List<Symbol> getPartitionBy()
return partitionBy; return partitionBy;
} }


@JsonProperty
public Symbol getRowNumberSymbol() public Symbol getRowNumberSymbol()
{ {
return rowNumberSymbol; return new Symbol(rowNumberVariable.getName());
}

@JsonProperty
public VariableReferenceExpression getRowNumberVariable()
{
return rowNumberVariable;
} }


@JsonProperty @JsonProperty
Expand All @@ -113,6 +130,6 @@ public <R, C> R accept(InternalPlanVisitor<R, C> visitor, C context)
@Override @Override
public PlanNode replaceChildren(List<PlanNode> newChildren) public PlanNode replaceChildren(List<PlanNode> newChildren)
{ {
return new RowNumberNode(getId(), Iterables.getOnlyElement(newChildren), partitionBy, rowNumberSymbol, maxRowCountPerPartition, hashSymbol); return new RowNumberNode(getId(), Iterables.getOnlyElement(newChildren), partitionBy, rowNumberVariable, maxRowCountPerPartition, hashSymbol);
} }
} }
Expand Up @@ -646,7 +646,7 @@ public Void visitRowNumber(RowNumberNode node, Void context)
NodeRepresentation nodeOutput = addNode(node, NodeRepresentation nodeOutput = addNode(node,
"RowNumber", "RowNumber",
format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashSymbol()))); format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashSymbol())));
nodeOutput.appendDetailsLine("%s := %s", node.getRowNumberSymbol(), "row_number()"); nodeOutput.appendDetailsLine("%s := %s", node.getRowNumberVariable(), "row_number()");


return processChildren(node, context); return processChildren(node, context);
} }
Expand Down
Expand Up @@ -42,6 +42,7 @@ public void testSingleGroupingKey()
ImmutableList.of(pb.symbol("x", BIGINT)), ImmutableList.of(pb.symbol("x", BIGINT)),
Optional.empty(), Optional.empty(),
pb.symbol("z", BIGINT), pb.symbol("z", BIGINT),
pb.variable(pb.symbol("z", BIGINT)),
pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT))))
.withSourceStats(0, PlanNodeStatsEstimate.builder() .withSourceStats(0, PlanNodeStatsEstimate.builder()
.setOutputRowCount(10) .setOutputRowCount(10)
Expand All @@ -64,6 +65,7 @@ public void testSingleGroupingKey()
ImmutableList.of(pb.symbol("x", BIGINT)), ImmutableList.of(pb.symbol("x", BIGINT)),
Optional.of(1), Optional.of(1),
pb.symbol("z", BIGINT), pb.symbol("z", BIGINT),
pb.variable(pb.symbol("z", BIGINT)),
pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT))))
.withSourceStats(0, PlanNodeStatsEstimate.builder() .withSourceStats(0, PlanNodeStatsEstimate.builder()
.setOutputRowCount(10) .setOutputRowCount(10)
Expand All @@ -84,6 +86,7 @@ public void testSingleGroupingKey()
ImmutableList.of(pb.symbol("y", BIGINT)), ImmutableList.of(pb.symbol("y", BIGINT)),
Optional.empty(), Optional.empty(),
pb.symbol("z", BIGINT), pb.symbol("z", BIGINT),
pb.variable(pb.symbol("z", BIGINT)),
pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT))))
.withSourceStats(0, PlanNodeStatsEstimate.builder() .withSourceStats(0, PlanNodeStatsEstimate.builder()
.setOutputRowCount(60) .setOutputRowCount(60)
Expand All @@ -104,6 +107,7 @@ public void testSingleGroupingKey()
ImmutableList.of(pb.symbol("x", BIGINT)), ImmutableList.of(pb.symbol("x", BIGINT)),
Optional.of(1), Optional.of(1),
pb.symbol("z", BIGINT), pb.symbol("z", BIGINT),
pb.variable(pb.symbol("z", BIGINT)),
pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT))))
.withSourceStats(0, PlanNodeStatsEstimate.builder() .withSourceStats(0, PlanNodeStatsEstimate.builder()
.addSymbolStatistics(new Symbol("x"), xStats) .addSymbolStatistics(new Symbol("x"), xStats)
Expand All @@ -121,6 +125,7 @@ public void testMultipleGroupingKeys()
ImmutableList.of(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)), ImmutableList.of(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)),
Optional.empty(), Optional.empty(),
pb.symbol("z", BIGINT), pb.symbol("z", BIGINT),
pb.variable(pb.symbol("z", BIGINT)),
pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT))))
.withSourceStats(0, PlanNodeStatsEstimate.builder() .withSourceStats(0, PlanNodeStatsEstimate.builder()
.setOutputRowCount(60) .setOutputRowCount(60)
Expand All @@ -141,6 +146,7 @@ public void testMultipleGroupingKeys()
ImmutableList.of(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)), ImmutableList.of(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)),
Optional.empty(), Optional.empty(),
pb.symbol("z", BIGINT), pb.symbol("z", BIGINT),
pb.variable(pb.symbol("z", BIGINT)),
pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT))))
.withSourceStats(0, PlanNodeStatsEstimate.builder() .withSourceStats(0, PlanNodeStatsEstimate.builder()
.setOutputRowCount(20) .setOutputRowCount(20)
Expand All @@ -161,6 +167,7 @@ public void testMultipleGroupingKeys()
ImmutableList.of(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)), ImmutableList.of(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)),
Optional.empty(), Optional.empty(),
pb.symbol("z", BIGINT), pb.symbol("z", BIGINT),
pb.variable(pb.symbol("z", BIGINT)),
pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT)))) pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT))))
.withSourceStats(0, PlanNodeStatsEstimate.builder() .withSourceStats(0, PlanNodeStatsEstimate.builder()
.setOutputRowCount(20) .setOutputRowCount(20)
Expand Down
Expand Up @@ -746,13 +746,13 @@ public WindowNode window(WindowNode.Specification specification, Map<VariableRef
0); 0);
} }


public RowNumberNode rowNumber(List<Symbol> partitionBy, Optional<Integer> maxRowCountPerPartition, Symbol rowNumberSymbol, PlanNode source) public RowNumberNode rowNumber(List<Symbol> partitionBy, Optional<Integer> maxRowCountPerPartition, Symbol rowNumberSymbol, VariableReferenceExpression rownNumberVariable, PlanNode source)
{ {
return new RowNumberNode( return new RowNumberNode(
idAllocator.getNextId(), idAllocator.getNextId(),
source, source,
partitionBy, partitionBy,
rowNumberSymbol, rownNumberVariable,
maxRowCountPerPartition, maxRowCountPerPartition,
Optional.empty()); Optional.empty());
} }
Expand Down

0 comments on commit 88b8805

Please sign in to comment.