Skip to content

Commit

Permalink
Move output layout to PartitionFunctionBinding
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Jan 30, 2016
1 parent 72a2e2b commit d55f4ad
Show file tree
Hide file tree
Showing 20 changed files with 145 additions and 118 deletions.
Expand Up @@ -145,7 +145,6 @@ private SqlTaskExecution(
LocalExecutionPlan localExecutionPlan = planner.plan( LocalExecutionPlan localExecutionPlan = planner.plan(
taskContext.getSession(), taskContext.getSession(),
fragment.getRoot(), fragment.getRoot(),
fragment.getOutputLayout(),
fragment.getSymbols(), fragment.getSymbols(),
fragment.getPartitionFunction(), fragment.getPartitionFunction(),
sharedBuffer, sharedBuffer,
Expand Down
Expand Up @@ -255,13 +255,13 @@ public LocalExecutionPlanner(
public LocalExecutionPlan plan( public LocalExecutionPlan plan(
Session session, Session session,
PlanNode plan, PlanNode plan,
List<Symbol> outputLayout,
Map<Symbol, Type> types, Map<Symbol, Type> types,
PartitionFunctionBinding functionBinding, PartitionFunctionBinding functionBinding,
SharedBuffer sharedBuffer, SharedBuffer sharedBuffer,
boolean singleNode, boolean singleNode,
boolean allowLocalParallel) boolean allowLocalParallel)
{ {
List<Symbol> outputLayout = functionBinding.getOutputLayout();
if (functionBinding.getPartitioningHandle().equals(FIXED_BROADCAST_DISTRIBUTION)) { if (functionBinding.getPartitioningHandle().equals(FIXED_BROADCAST_DISTRIBUTION)) {
return plan(session, plan, outputLayout, types, new TaskOutputFactory(sharedBuffer), singleNode, allowLocalParallel); return plan(session, plan, outputLayout, types, new TaskOutputFactory(sharedBuffer), singleNode, allowLocalParallel);
} }
Expand Down
Expand Up @@ -16,36 +16,41 @@
import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;


import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;


import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;
import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull; import static java.util.Objects.requireNonNull;


public class PartitionFunctionBinding public class PartitionFunctionBinding
{ {
private final PartitioningHandle partitioningHandle; private final PartitioningHandle partitioningHandle;
private final List<Symbol> outputLayout;
private final List<Symbol> partitioningColumns; private final List<Symbol> partitioningColumns;
private final Optional<Symbol> hashColumn; private final Optional<Symbol> hashColumn;
private final boolean replicateNulls; private final boolean replicateNulls;
private final Optional<int[]> bucketToPartition; private final Optional<int[]> bucketToPartition;


public PartitionFunctionBinding(PartitioningHandle partitioningHandle, List<Symbol> partitioningColumns) public PartitionFunctionBinding(PartitioningHandle partitioningHandle, List<Symbol> outputLayout, List<Symbol> partitioningColumns)
{ {
this(partitioningHandle, this(partitioningHandle,
outputLayout,
partitioningColumns, partitioningColumns,
Optional.empty(), Optional.empty(),
false, false,
Optional.empty()); Optional.empty());
} }


public PartitionFunctionBinding(PartitioningHandle partitioningHandle, List<Symbol> partitioningColumns, Optional<Symbol> hashColumn) public PartitionFunctionBinding(PartitioningHandle partitioningHandle, List<Symbol> outputLayout, List<Symbol> partitioningColumns, Optional<Symbol> hashColumn)
{ {
this( this(
partitioningHandle, partitioningHandle,
outputLayout,
partitioningColumns, partitioningColumns,
hashColumn, hashColumn,
false, false,
Expand All @@ -55,14 +60,23 @@ public PartitionFunctionBinding(PartitioningHandle partitioningHandle, List<Symb
@JsonCreator @JsonCreator
public PartitionFunctionBinding( public PartitionFunctionBinding(
@JsonProperty("partitioningHandle") PartitioningHandle partitioningHandle, @JsonProperty("partitioningHandle") PartitioningHandle partitioningHandle,
@JsonProperty("outputLayout") List<Symbol> outputLayout,
@JsonProperty("partitioningColumns") List<Symbol> partitioningColumns, @JsonProperty("partitioningColumns") List<Symbol> partitioningColumns,
@JsonProperty("hashColumn") Optional<Symbol> hashColumn, @JsonProperty("hashColumn") Optional<Symbol> hashColumn,
@JsonProperty("replicateNulls") boolean replicateNulls, @JsonProperty("replicateNulls") boolean replicateNulls,
@JsonProperty("bucketToPartition") Optional<int[]> bucketToPartition) @JsonProperty("bucketToPartition") Optional<int[]> bucketToPartition)
{ {
this.partitioningHandle = requireNonNull(partitioningHandle, "partitioningHandle is null"); this.partitioningHandle = requireNonNull(partitioningHandle, "partitioningHandle is null");
this.outputLayout = ImmutableList.copyOf(requireNonNull(outputLayout, "outputLayout is null"));

this.partitioningColumns = ImmutableList.copyOf(requireNonNull(partitioningColumns, "partitioningColumns is null")); this.partitioningColumns = ImmutableList.copyOf(requireNonNull(partitioningColumns, "partitioningColumns is null"));
checkArgument(ImmutableSet.copyOf(outputLayout).containsAll(partitioningColumns),
"Output layout (%s) don't include all partition columns (%s)", outputLayout, partitioningColumns);

this.hashColumn = requireNonNull(hashColumn, "hashColumn is null"); this.hashColumn = requireNonNull(hashColumn, "hashColumn is null");
hashColumn.ifPresent(column -> checkArgument(outputLayout.contains(column),
"Output layout (%s) don't include hash column (%s)", outputLayout, column));

checkArgument(!replicateNulls || partitioningColumns.size() == 1, "size of partitioningColumns is not 1 when nullPartition is REPLICATE."); checkArgument(!replicateNulls || partitioningColumns.size() == 1, "size of partitioningColumns is not 1 when nullPartition is REPLICATE.");
this.replicateNulls = replicateNulls; this.replicateNulls = replicateNulls;
this.bucketToPartition = requireNonNull(bucketToPartition, "bucketToPartition is null"); this.bucketToPartition = requireNonNull(bucketToPartition, "bucketToPartition is null");
Expand All @@ -74,6 +88,12 @@ public PartitioningHandle getPartitioningHandle()
return partitioningHandle; return partitioningHandle;
} }


@JsonProperty
public List<Symbol> getOutputLayout()
{
return outputLayout;
}

@JsonProperty @JsonProperty
public List<Symbol> getPartitioningColumns() public List<Symbol> getPartitioningColumns()
{ {
Expand All @@ -100,7 +120,25 @@ public Optional<int[]> getBucketToPartition()


public PartitionFunctionBinding withBucketToPartition(Optional<int[]> bucketToPartition) public PartitionFunctionBinding withBucketToPartition(Optional<int[]> bucketToPartition)
{ {
return new PartitionFunctionBinding(partitioningHandle, partitioningColumns, hashColumn, replicateNulls, bucketToPartition); return new PartitionFunctionBinding(partitioningHandle, outputLayout, partitioningColumns, hashColumn, replicateNulls, bucketToPartition);
}

public PartitionFunctionBinding translateOutputLayout(List<Symbol> newOutputLayout)
{
requireNonNull(newOutputLayout, "newOutputLayout is null");

checkArgument(newOutputLayout.size() == outputLayout.size());

List<Symbol> newPartitioningColumns = partitioningColumns.stream()
.mapToInt(outputLayout::indexOf)
.mapToObj(newOutputLayout::get)
.collect(toImmutableList());

Optional<Symbol> newHashSymbol = hashColumn
.map(outputLayout::indexOf)
.map(newOutputLayout::get);

return new PartitionFunctionBinding(partitioningHandle, newOutputLayout, newPartitioningColumns, newHashSymbol, replicateNulls, bucketToPartition);
} }


@Override @Override
Expand All @@ -114,6 +152,7 @@ public boolean equals(Object o)
} }
PartitionFunctionBinding that = (PartitionFunctionBinding) o; PartitionFunctionBinding that = (PartitionFunctionBinding) o;
return Objects.equals(partitioningHandle, that.partitioningHandle) && return Objects.equals(partitioningHandle, that.partitioningHandle) &&
Objects.equals(outputLayout, that.outputLayout) &&
Objects.equals(partitioningColumns, that.partitioningColumns) && Objects.equals(partitioningColumns, that.partitioningColumns) &&
Objects.equals(hashColumn, that.hashColumn) && Objects.equals(hashColumn, that.hashColumn) &&
replicateNulls == that.replicateNulls && replicateNulls == that.replicateNulls &&
Expand All @@ -123,14 +162,15 @@ public boolean equals(Object o)
@Override @Override
public int hashCode() public int hashCode()
{ {
return Objects.hash(partitioningHandle, partitioningColumns, replicateNulls, bucketToPartition); return Objects.hash(partitioningHandle, outputLayout, partitioningColumns, replicateNulls, bucketToPartition);
} }


@Override @Override
public String toString() public String toString()
{ {
return toStringHelper(this) return toStringHelper(this)
.add("partitioningHandle", partitioningHandle) .add("partitioningHandle", partitioningHandle)
.add("outputLayout", outputLayout)
.add("partitioningChannels", partitioningColumns) .add("partitioningChannels", partitioningColumns)
.add("hashChannel", hashColumn) .add("hashChannel", hashColumn)
.add("replicateNulls", replicateNulls) .add("replicateNulls", replicateNulls)
Expand Down
Expand Up @@ -42,7 +42,6 @@ public class PlanFragment
private final PlanFragmentId id; private final PlanFragmentId id;
private final PlanNode root; private final PlanNode root;
private final Map<Symbol, Type> symbols; private final Map<Symbol, Type> symbols;
private final List<Symbol> outputLayout;
private final PartitioningHandle partitioning; private final PartitioningHandle partitioning;
private final PlanNodeId partitionedSource; private final PlanNodeId partitionedSource;
private final List<Type> types; private final List<Type> types;
Expand All @@ -55,22 +54,20 @@ public PlanFragment(
@JsonProperty("id") PlanFragmentId id, @JsonProperty("id") PlanFragmentId id,
@JsonProperty("root") PlanNode root, @JsonProperty("root") PlanNode root,
@JsonProperty("symbols") Map<Symbol, Type> symbols, @JsonProperty("symbols") Map<Symbol, Type> symbols,
@JsonProperty("outputLayout") List<Symbol> outputLayout,
@JsonProperty("partitioning") PartitioningHandle partitioning, @JsonProperty("partitioning") PartitioningHandle partitioning,
@JsonProperty("partitionedSource") PlanNodeId partitionedSource, @JsonProperty("partitionedSource") PlanNodeId partitionedSource,
@JsonProperty("partitionFunction") PartitionFunctionBinding partitionFunction) @JsonProperty("partitionFunction") PartitionFunctionBinding partitionFunction)
{ {
this.id = requireNonNull(id, "id is null"); this.id = requireNonNull(id, "id is null");
this.root = requireNonNull(root, "root is null"); this.root = requireNonNull(root, "root is null");
this.symbols = requireNonNull(symbols, "symbols is null"); this.symbols = requireNonNull(symbols, "symbols is null");
this.outputLayout = requireNonNull(outputLayout, "outputLayout is null"); this.partitioning = requireNonNull(partitioning, "partitioning is null");
this.partitioning = requireNonNull(partitioning, "distribution is null");
this.partitionedSource = partitionedSource; this.partitionedSource = partitionedSource;


checkArgument(ImmutableSet.copyOf(root.getOutputSymbols()).containsAll(outputLayout), checkArgument(ImmutableSet.copyOf(root.getOutputSymbols()).containsAll(partitionFunction.getOutputLayout()),
"Root node outputs (%s) don't include all fragment outputs (%s)", root.getOutputSymbols(), outputLayout); "Root node outputs (%s) does not include all fragment outputs (%s)", root.getOutputSymbols(), partitionFunction.getOutputLayout());


types = outputLayout.stream() types = partitionFunction.getOutputLayout().stream()
.map(symbols::get) .map(symbols::get)
.collect(toImmutableList()); .collect(toImmutableList());


Expand Down Expand Up @@ -101,12 +98,6 @@ public Map<Symbol, Type> getSymbols()
return symbols; return symbols;
} }


@JsonProperty
public List<Symbol> getOutputLayout()
{
return outputLayout;
}

@JsonProperty @JsonProperty
public PartitioningHandle getPartitioning() public PartitioningHandle getPartitioning()
{ {
Expand Down Expand Up @@ -171,15 +162,15 @@ private static void findRemoteSourceNodes(PlanNode node, Builder<RemoteSourceNod


public PlanFragment withBucketToPartition(Optional<int[]> bucketToPartition) public PlanFragment withBucketToPartition(Optional<int[]> bucketToPartition)
{ {
return new PlanFragment(id, root, symbols, outputLayout, partitioning, partitionedSource, partitionFunction.withBucketToPartition(bucketToPartition)); return new PlanFragment(id, root, symbols, partitioning, partitionedSource, partitionFunction.withBucketToPartition(bucketToPartition));
} }


@Override @Override
public String toString() public String toString()
{ {
return toStringHelper(this) return toStringHelper(this)
.add("id", id) .add("id", id)
.add("distribution", partitioning) .add("partitioning", partitioning)
.add("partitionedSource", partitionedSource) .add("partitionedSource", partitionedSource)
.add("partitionFunction", partitionFunction) .add("partitionFunction", partitionFunction)
.toString(); .toString();
Expand Down
Expand Up @@ -52,7 +52,8 @@ public SubPlan createSubPlans(Plan plan)
{ {
Fragmenter fragmenter = new Fragmenter(plan.getSymbolAllocator().getTypes()); Fragmenter fragmenter = new Fragmenter(plan.getSymbolAllocator().getTypes());


FragmentProperties properties = new FragmentProperties(new PartitionFunctionBinding(SINGLE_DISTRIBUTION, ImmutableList.of())); FragmentProperties properties = new FragmentProperties(new PartitionFunctionBinding(SINGLE_DISTRIBUTION, plan.getRoot().getOutputSymbols(), ImmutableList.of()))
.setSingleNodeDistribution();
PlanNode root = SimplePlanRewriter.rewriteWith(fragmenter, plan.getRoot(), properties); PlanNode root = SimplePlanRewriter.rewriteWith(fragmenter, plan.getRoot(), properties);


SubPlan result = fragmenter.buildRootFragment(root, properties); SubPlan result = fragmenter.buildRootFragment(root, properties);
Expand Down Expand Up @@ -92,7 +93,6 @@ private SubPlan buildFragment(PlanNode root, FragmentProperties properties, Plan
fragmentId, fragmentId,
root, root,
Maps.filterKeys(types, in(dependencies)), Maps.filterKeys(types, in(dependencies)),
properties.getOutputLayout(),
properties.getPartitioningHandle(), properties.getPartitioningHandle(),
properties.getDistributeBy(), properties.getDistributeBy(),
properties.getPartitionFunction()); properties.getPartitionFunction());
Expand All @@ -103,9 +103,7 @@ private SubPlan buildFragment(PlanNode root, FragmentProperties properties, Plan
@Override @Override
public PlanNode visitOutput(OutputNode node, RewriteContext<FragmentProperties> context) public PlanNode visitOutput(OutputNode node, RewriteContext<FragmentProperties> context)
{ {
context.get() context.get().setSingleNodeDistribution(); // TODO: add support for distributed output
.setSingleNodeDistribution() // TODO: add support for distributed output
.setOutputLayout(node.getOutputSymbols());


return context.defaultRewrite(node, context.get()); return context.defaultRewrite(node, context.get());
} }
Expand Down Expand Up @@ -141,30 +139,25 @@ public PlanNode visitValues(ValuesNode node, RewriteContext<FragmentProperties>
@Override @Override
public PlanNode visitExchange(ExchangeNode exchange, RewriteContext<FragmentProperties> context) public PlanNode visitExchange(ExchangeNode exchange, RewriteContext<FragmentProperties> context)
{ {
PartitionFunctionBinding partitionFunction = exchange.getPartitionFunction();

ImmutableList.Builder<SubPlan> builder = ImmutableList.builder(); ImmutableList.Builder<SubPlan> builder = ImmutableList.builder();
if (exchange.getType() == ExchangeNode.Type.GATHER) { if (exchange.getType() == ExchangeNode.Type.GATHER) {
context.get().setSingleNodeDistribution(); context.get().setSingleNodeDistribution();


for (int i = 0; i < exchange.getSources().size(); i++) { for (int i = 0; i < exchange.getSources().size(); i++) {
FragmentProperties childProperties = new FragmentProperties(exchange.getPartitionFunction()) FragmentProperties childProperties = new FragmentProperties(partitionFunction.translateOutputLayout(exchange.getInputs().get(i)));
.setOutputLayout(exchange.getInputs().get(i));

builder.add(buildSubPlan(exchange.getSources().get(i), childProperties, context)); builder.add(buildSubPlan(exchange.getSources().get(i), childProperties, context));
} }
} }
else if (exchange.getType() == ExchangeNode.Type.REPARTITION) { else if (exchange.getType() == ExchangeNode.Type.REPARTITION) {
PartitionFunctionBinding partitionFunction = exchange.getPartitionFunction();
context.get().setDistribution(partitionFunction.getPartitioningHandle()); context.get().setDistribution(partitionFunction.getPartitioningHandle());


FragmentProperties childProperties = new FragmentProperties(partitionFunction) FragmentProperties childProperties = new FragmentProperties(partitionFunction.translateOutputLayout(Iterables.getOnlyElement(exchange.getInputs())));
.setOutputLayout(Iterables.getOnlyElement(exchange.getInputs()));

builder.add(buildSubPlan(Iterables.getOnlyElement(exchange.getSources()), childProperties, context)); builder.add(buildSubPlan(Iterables.getOnlyElement(exchange.getSources()), childProperties, context));
} }
else if (exchange.getType() == ExchangeNode.Type.REPLICATE) { else if (exchange.getType() == ExchangeNode.Type.REPLICATE) {
FragmentProperties childProperties = new FragmentProperties(exchange.getPartitionFunction()) FragmentProperties childProperties = new FragmentProperties(partitionFunction.translateOutputLayout(Iterables.getOnlyElement(exchange.getInputs())));
.setOutputLayout(Iterables.getOnlyElement(exchange.getInputs()));

builder.add(buildSubPlan(Iterables.getOnlyElement(exchange.getSources()), childProperties, context)); builder.add(buildSubPlan(Iterables.getOnlyElement(exchange.getSources()), childProperties, context));
} }


Expand Down Expand Up @@ -192,7 +185,6 @@ private static class FragmentProperties
private final List<SubPlan> children = new ArrayList<>(); private final List<SubPlan> children = new ArrayList<>();


private final PartitionFunctionBinding partitionFunction; private final PartitionFunctionBinding partitionFunction;
private Optional<List<Symbol>> outputLayout = Optional.empty();


private Optional<PartitioningHandle> partitioningHandle = Optional.empty(); private Optional<PartitioningHandle> partitioningHandle = Optional.empty();
private PlanNodeId distributeBy; private PlanNodeId distributeBy;
Expand Down Expand Up @@ -279,29 +271,13 @@ public FragmentProperties setSourceDistribution(PlanNodeId source)
return this; return this;
} }


public FragmentProperties setOutputLayout(List<Symbol> layout)
{
outputLayout.ifPresent(current -> {
throw new IllegalStateException(String.format("Cannot overwrite output layout with %s (currently set to %s)", layout, current));
});

outputLayout = Optional.of(layout);

return this;
}

public FragmentProperties addChildren(List<SubPlan> children) public FragmentProperties addChildren(List<SubPlan> children)
{ {
this.children.addAll(children); this.children.addAll(children);


return this; return this;
} }


public List<Symbol> getOutputLayout()
{
return outputLayout.get();
}

public PartitionFunctionBinding getPartitionFunction() public PartitionFunctionBinding getPartitionFunction()
{ {
return partitionFunction; return partitionFunction;
Expand Down
Expand Up @@ -134,11 +134,11 @@ public static String textDistributedPlan(SubPlan plan, Metadata metadata, Sessio
fragment.getId(), fragment.getId(),
fragment.getPartitioning())); fragment.getPartitioning()));


PartitionFunctionBinding partitionFunction = fragment.getPartitionFunction();
builder.append(indentString(1)) builder.append(indentString(1))
.append(format("Output layout: [%s]\n", .append(format("Output layout: [%s]\n",
Joiner.on(", ").join(fragment.getOutputLayout()))); Joiner.on(", ").join(partitionFunction.getOutputLayout())));


PartitionFunctionBinding partitionFunction = fragment.getPartitionFunction();
boolean replicateNulls = partitionFunction.isReplicateNulls(); boolean replicateNulls = partitionFunction.isReplicateNulls();
List<Symbol> symbols = partitionFunction.getPartitioningColumns(); List<Symbol> symbols = partitionFunction.getPartitioningColumns();
builder.append(indentString(1)); builder.append(indentString(1));
Expand Down Expand Up @@ -166,10 +166,9 @@ public static String graphvizLogicalPlan(PlanNode plan, Map<Symbol, Type> types)
new PlanFragmentId("graphviz_plan"), new PlanFragmentId("graphviz_plan"),
plan, plan,
types, types,
plan.getOutputSymbols(),
SINGLE_DISTRIBUTION, SINGLE_DISTRIBUTION,
plan.getId(), plan.getId(),
new PartitionFunctionBinding(SINGLE_DISTRIBUTION, ImmutableList.of())); new PartitionFunctionBinding(SINGLE_DISTRIBUTION, plan.getOutputSymbols(), ImmutableList.of()));
return GraphvizPrinter.printLogical(ImmutableList.of(fragment)); return GraphvizPrinter.printLogical(ImmutableList.of(fragment));
} }


Expand Down
Expand Up @@ -413,9 +413,12 @@ public Void visitExchange(ExchangeNode node, Void context)
for (int i = 0; i < node.getSources().size(); i++) { for (int i = 0; i < node.getSources().size(); i++) {
PlanNode subplan = node.getSources().get(i); PlanNode subplan = node.getSources().get(i);
checkDependencies(subplan.getOutputSymbols(), node.getInputs().get(i), "EXCHANGE subplan must provide all of the necessary symbols"); checkDependencies(subplan.getOutputSymbols(), node.getInputs().get(i), "EXCHANGE subplan must provide all of the necessary symbols");
checkDependencies(subplan.getOutputSymbols(), node.getInputs().get(i), "EXCHANGE subplan must provide all of the necessary symbols");
subplan.accept(this, context); // visit child subplan.accept(this, context); // visit child
} }


checkDependencies(node.getOutputSymbols(), node.getPartitionFunction().getOutputLayout(), "EXCHANGE must provide all of the necessary symbols for partition function");

verifyUniqueId(node); verifyUniqueId(node);


return null; return null;
Expand Down

0 comments on commit d55f4ad

Please sign in to comment.