Skip to content

Commit

Permalink
Add VariableReferenceExpression to TopNRowNumberNode
Browse files Browse the repository at this point in the history
  • Loading branch information
rongrong committed Jun 10, 2019
1 parent 88b8805 commit a8bb2c5
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 16 deletions.
Expand Up @@ -418,7 +418,7 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, PreferredPr
idAllocator.getNextId(), idAllocator.getNextId(),
child.getNode(), child.getNode(),
node.getSpecification(), node.getSpecification(),
node.getRowNumberSymbol(), node.getRowNumberVariable(),
node.getMaxRowCountPerPartition(), node.getMaxRowCountPerPartition(),
true, true,
node.getHashSymbol()), node.getHashSymbol()),
Expand Down
Expand Up @@ -290,7 +290,7 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, HashComputa
node.getId(), node.getId(),
child.getNode(), child.getNode(),
node.getSpecification(), node.getSpecification(),
node.getRowNumberSymbol(), node.getRowNumberVariable(),
node.getMaxRowCountPerPartition(), node.getMaxRowCountPerPartition(),
node.isPartial(), node.isPartial(),
Optional.of(hashSymbol)), Optional.of(hashSymbol)),
Expand Down
Expand Up @@ -623,7 +623,7 @@ public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext<Set<Sy
return new TopNRowNumberNode(node.getId(), return new TopNRowNumberNode(node.getId(),
source, source,
node.getSpecification(), node.getSpecification(),
node.getRowNumberSymbol(), node.getRowNumberVariable(),
node.getMaxRowCountPerPartition(), node.getMaxRowCountPerPartition(),
node.isPartial(), node.isPartial(),
node.getHashSymbol()); node.getHashSymbol());
Expand Down
Expand Up @@ -422,7 +422,7 @@ public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext<Void>
node.getId(), node.getId(),
context.rewrite(node.getSource()), context.rewrite(node.getSource()),
canonicalizeAndDistinct(node.getSpecification()), canonicalizeAndDistinct(node.getSpecification()),
canonicalize(node.getRowNumberSymbol()), canonicalize(node.getRowNumberVariable()),
node.getMaxRowCountPerPartition(), node.getMaxRowCountPerPartition(),
node.isPartial(), node.isPartial(),
canonicalize(node.getHashSymbol())); canonicalize(node.getHashSymbol()));
Expand Down
Expand Up @@ -265,7 +265,7 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi
return new TopNRowNumberNode(idAllocator.getNextId(), return new TopNRowNumberNode(idAllocator.getNextId(),
windowNode.getSource(), windowNode.getSource(),
windowNode.getSpecification(), windowNode.getSpecification(),
getOnlyElement(windowNode.getCreatedSymbols()), getOnlyElement(windowNode.getCreatedVariable()),
limit, limit,
false, false,
Optional.empty()); Optional.empty());
Expand Down
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.sql.planner.plan; package com.facebook.presto.sql.planner.plan;


import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.OrderingScheme;
import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.WindowNode.Specification; import com.facebook.presto.sql.planner.plan.WindowNode.Specification;
Expand All @@ -28,7 +29,6 @@
import java.util.Optional; import java.util.Optional;


import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.Iterables.concat;
import static java.util.Objects.requireNonNull; import static java.util.Objects.requireNonNull;


@Immutable @Immutable
Expand All @@ -37,7 +37,7 @@ public final class TopNRowNumberNode
{ {
private final PlanNode source; private final PlanNode source;
private final Specification specification; private final Specification specification;
private final Symbol rowNumberSymbol; private final VariableReferenceExpression rowNumberVariable;
private final int maxRowCountPerPartition; private final int maxRowCountPerPartition;
private final boolean partial; private final boolean partial;
private final Optional<Symbol> hashSymbol; private final Optional<Symbol> hashSymbol;
Expand All @@ -47,7 +47,7 @@ public TopNRowNumberNode(
@JsonProperty("id") PlanNodeId id, @JsonProperty("id") PlanNodeId id,
@JsonProperty("source") PlanNode source, @JsonProperty("source") PlanNode source,
@JsonProperty("specification") Specification specification, @JsonProperty("specification") Specification specification,
@JsonProperty("rowNumberSymbol") Symbol rowNumberSymbol, @JsonProperty("rowNumberVariable") VariableReferenceExpression rowNumberVariable,
@JsonProperty("maxRowCountPerPartition") int maxRowCountPerPartition, @JsonProperty("maxRowCountPerPartition") int maxRowCountPerPartition,
@JsonProperty("partial") boolean partial, @JsonProperty("partial") boolean partial,
@JsonProperty("hashSymbol") Optional<Symbol> hashSymbol) @JsonProperty("hashSymbol") Optional<Symbol> hashSymbol)
Expand All @@ -57,13 +57,13 @@ public TopNRowNumberNode(
requireNonNull(source, "source is null"); requireNonNull(source, "source is null");
requireNonNull(specification, "specification is null"); requireNonNull(specification, "specification is null");
checkArgument(specification.getOrderingScheme().isPresent(), "specification orderingScheme is absent"); checkArgument(specification.getOrderingScheme().isPresent(), "specification orderingScheme is absent");
requireNonNull(rowNumberSymbol, "rowNumberSymbol is null"); requireNonNull(rowNumberVariable, "rowNumberVariable is null");
checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0"); checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0");
requireNonNull(hashSymbol, "hashSymbol is null"); requireNonNull(hashSymbol, "hashSymbol is null");


this.source = source; this.source = source;
this.specification = specification; this.specification = specification;
this.rowNumberSymbol = rowNumberSymbol; this.rowNumberVariable = rowNumberVariable;
this.maxRowCountPerPartition = maxRowCountPerPartition; this.maxRowCountPerPartition = maxRowCountPerPartition;
this.partial = partial; this.partial = partial;
this.hashSymbol = hashSymbol; this.hashSymbol = hashSymbol;
Expand All @@ -78,10 +78,13 @@ public List<PlanNode> getSources()
@Override @Override
public List<Symbol> getOutputSymbols() public List<Symbol> getOutputSymbols()
{ {
ImmutableList.Builder<Symbol> builder = ImmutableList.<Symbol>builder().addAll(source.getOutputSymbols());

if (!partial) { if (!partial) {
return ImmutableList.copyOf(concat(source.getOutputSymbols(), ImmutableList.of(rowNumberSymbol))); builder.add(new Symbol(rowNumberVariable.getName()));
} }
return ImmutableList.copyOf(source.getOutputSymbols());
return builder.build();
} }


@JsonProperty @JsonProperty
Expand All @@ -106,10 +109,15 @@ public OrderingScheme getOrderingScheme()
return specification.getOrderingScheme().get(); return specification.getOrderingScheme().get();
} }


@JsonProperty
public Symbol getRowNumberSymbol() public Symbol getRowNumberSymbol()
{ {
return rowNumberSymbol; return new Symbol(rowNumberVariable.getName());
}

@JsonProperty
public VariableReferenceExpression getRowNumberVariable()
{
return rowNumberVariable;
} }


@JsonProperty @JsonProperty
Expand Down Expand Up @@ -139,6 +147,6 @@ public <R, C> R accept(InternalPlanVisitor<R, C> visitor, C context)
@Override @Override
public PlanNode replaceChildren(List<PlanNode> newChildren) public PlanNode replaceChildren(List<PlanNode> newChildren)
{ {
return new TopNRowNumberNode(getId(), Iterables.getOnlyElement(newChildren), specification, rowNumberSymbol, maxRowCountPerPartition, partial, hashSymbol); return new TopNRowNumberNode(getId(), Iterables.getOnlyElement(newChildren), specification, rowNumberVariable, maxRowCountPerPartition, partial, hashSymbol);
} }
} }
Expand Up @@ -108,6 +108,11 @@ public Set<Symbol> getCreatedSymbols()
return windowFunctions.keySet().stream().map(variable -> new Symbol(variable.getName())).collect(toImmutableSet()); return windowFunctions.keySet().stream().map(variable -> new Symbol(variable.getName())).collect(toImmutableSet());
} }


public Set<VariableReferenceExpression> getCreatedVariable()
{
return windowFunctions.keySet();
}

@JsonProperty @JsonProperty
public PlanNode getSource() public PlanNode getSource()
{ {
Expand Down
Expand Up @@ -625,7 +625,7 @@ public Void visitTopNRowNumber(TopNRowNumberNode node, Void context)
"TopNRowNumber", "TopNRowNumber",
format("[%s limit %s]%s", Joiner.on(", ").join(args), node.getMaxRowCountPerPartition(), formatHash(node.getHashSymbol()))); format("[%s limit %s]%s", Joiner.on(", ").join(args), node.getMaxRowCountPerPartition(), formatHash(node.getHashSymbol())));


nodeOutput.appendDetailsLine("%s := %s", node.getRowNumberSymbol(), "row_number()"); nodeOutput.appendDetailsLine("%s := %s", node.getRowNumberVariable(), "row_number()");


return processChildren(node, context); return processChildren(node, context);
} }
Expand Down

0 comments on commit a8bb2c5

Please sign in to comment.