Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate partitioning properties for full outer join #154

Merged
merged 2 commits into from Feb 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -136,7 +136,6 @@
import io.prestosql.sql.gen.OrderingCompiler;
import io.prestosql.sql.gen.PageFunctionCompiler;
import io.prestosql.sql.parser.SqlParser;
import io.prestosql.sql.planner.Partitioning.ArgumentBinding;
import io.prestosql.sql.planner.optimizations.IndexJoinOptimizer;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.AggregationNode.Aggregation;
Expand Down Expand Up @@ -393,8 +392,12 @@ public LocalExecutionPlan plan(
}
else {
partitionChannels = partitioningScheme.getPartitioning().getArguments().stream()
.map(ArgumentBinding::getColumn)
.map(outputLayout::indexOf)
.map(argument -> {
if (argument.isConstant()) {
return -1;
}
return outputLayout.indexOf(argument.getColumn());
})
.collect(toImmutableList());
partitionConstants = partitioningScheme.getPartitioning().getArguments().stream()
.map(argument -> {
Expand Down
105 changes: 77 additions & 28 deletions presto-main/src/main/java/io/prestosql/sql/planner/Partitioning.java
Expand Up @@ -20,6 +20,8 @@
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.predicate.NullableValue;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.SymbolReference;

import javax.annotation.concurrent.Immutable;

Expand All @@ -32,6 +34,7 @@

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;
Expand All @@ -51,7 +54,15 @@ private Partitioning(PartitioningHandle handle, List<ArgumentBinding> arguments)
public static Partitioning create(PartitioningHandle handle, List<Symbol> columns)
{
return new Partitioning(handle, columns.stream()
.map(ArgumentBinding::columnBinding)
.map(Symbol::toSymbolReference)
.map(ArgumentBinding::expressionBinding)
.collect(toImmutableList()));
}

public static Partitioning createWithExpressions(PartitioningHandle handle, List<Expression> expressions)
{
return new Partitioning(handle, expressions.stream()
.map(ArgumentBinding::expressionBinding)
.collect(toImmutableList()));
}

Expand Down Expand Up @@ -160,13 +171,20 @@ private static boolean isPartitionedWith(

public boolean isPartitionedOn(Collection<Symbol> columns, Set<Symbol> knownConstants)
{
// partitioned on (k_1, k_2, ..., k_n) => partitioned on (k_1, k_2, ..., k_n, k_n+1, ...)
// can safely ignore all constant columns when comparing partition properties
return arguments.stream()
.filter(ArgumentBinding::isVariable)
.map(ArgumentBinding::getColumn)
.filter(symbol -> !knownConstants.contains(symbol))
.allMatch(columns::contains);
for (ArgumentBinding argument : arguments) {
// partitioned on (k_1, k_2, ..., k_n) => partitioned on (k_1, k_2, ..., k_n, k_n+1, ...)
// can safely ignore all constant columns when comparing partition properties
if (argument.isConstant()) {
continue;
}
if (!argument.isVariable()) {
return false;
}
if (!knownConstants.contains(argument.getColumn()) && !columns.contains(argument.getColumn())) {
return false;
}
}
return true;
}

public boolean isEffectivelySinglePartition(Set<Symbol> knownConstants)
Expand Down Expand Up @@ -194,11 +212,11 @@ public Partitioning translate(Function<Symbol, Symbol> translator)
.collect(toImmutableList()));
}

public Optional<Partitioning> translate(Function<Symbol, Optional<Symbol>> translator, Function<Symbol, Optional<NullableValue>> constants)
public Optional<Partitioning> translate(Translator translator)
{
ImmutableList.Builder<ArgumentBinding> newArguments = ImmutableList.builder();
for (ArgumentBinding argument : arguments) {
Optional<ArgumentBinding> newArgument = argument.translate(translator, constants);
Optional<ArgumentBinding> newArgument = argument.translate(translator);
if (!newArgument.isPresent()) {
return Optional.empty();
}
Expand Down Expand Up @@ -242,25 +260,43 @@ public String toString()
.toString();
}

@Immutable
public static final class Translator
{
private final Function<Symbol, Optional<Symbol>> columnTranslator;
private final Function<Symbol, Optional<NullableValue>> constantTranslator;
private final Function<Expression, Optional<Symbol>> expressionTranslator;

public Translator(
Function<Symbol, Optional<Symbol>> columnTranslator,
Function<Symbol, Optional<NullableValue>> constantTranslator,
Function<Expression, Optional<Symbol>> expressionTranslator)
{
this.columnTranslator = requireNonNull(columnTranslator, "columnTranslator is null");
this.constantTranslator = requireNonNull(constantTranslator, "constantTranslator is null");
this.expressionTranslator = requireNonNull(expressionTranslator, "expressionTranslator is null");
}
}

@Immutable
public static final class ArgumentBinding
{
private final Symbol column;
private final Expression expression;
private final NullableValue constant;

@JsonCreator
public ArgumentBinding(
@JsonProperty("column") Symbol column,
@JsonProperty("expression") Expression expression,
@JsonProperty("constant") NullableValue constant)
{
this.column = column;
this.expression = expression;
this.constant = constant;
checkArgument((column == null) != (constant == null), "Either column or constant must be set");
checkArgument((expression == null) != (constant == null), "Either expression or constant must be set");
}

public static ArgumentBinding columnBinding(Symbol column)
public static ArgumentBinding expressionBinding(Expression expression)
{
return new ArgumentBinding(requireNonNull(column, "column is null"), null);
return new ArgumentBinding(requireNonNull(expression, "expression is null"), null);
}

public static ArgumentBinding constantBinding(NullableValue constant)
Expand All @@ -275,13 +311,19 @@ public boolean isConstant()

public boolean isVariable()
{
return column != null;
return expression instanceof SymbolReference;
}

@JsonProperty
public Symbol getColumn()
{
return column;
verify(expression instanceof SymbolReference, "Expect the expression to be a SymbolReference");
return Symbol.from(expression);
}

@JsonProperty
public Expression getExpression()
{
return expression;
}

@JsonProperty
Expand All @@ -295,25 +337,31 @@ public ArgumentBinding translate(Function<Symbol, Symbol> translator)
if (isConstant()) {
return this;
}
return columnBinding(translator.apply(column));
return expressionBinding(translator.apply(Symbol.from(expression)).toSymbolReference());
}

public Optional<ArgumentBinding> translate(Function<Symbol, Optional<Symbol>> translator, Function<Symbol, Optional<NullableValue>> constants)
public Optional<ArgumentBinding> translate(Translator translator)
{
if (isConstant()) {
return Optional.of(this);
}

Optional<ArgumentBinding> newColumn = translator.apply(column)
.map(ArgumentBinding::columnBinding);
if (!isVariable()) {
return translator.expressionTranslator.apply(expression)
.map(Symbol::toSymbolReference)
.map(ArgumentBinding::expressionBinding);
}

Optional<ArgumentBinding> newColumn = translator.columnTranslator.apply(Symbol.from(expression))
.map(Symbol::toSymbolReference)
.map(ArgumentBinding::expressionBinding);
if (newColumn.isPresent()) {
return newColumn;
}

// As a last resort, check for a constant mapping for the symbol
// Note: this MUST be last because we want to favor the symbol representation
// as it makes further optimizations possible.
return constants.apply(column)
return translator.constantTranslator.apply(Symbol.from(expression))
.map(ArgumentBinding::constantBinding);
}

Expand All @@ -323,7 +371,8 @@ public String toString()
if (constant != null) {
return constant.toString();
}
return "\"" + column + "\"";

return expression.toString();
}

@Override
Expand All @@ -336,14 +385,14 @@ public boolean equals(Object o)
return false;
}
ArgumentBinding that = (ArgumentBinding) o;
return Objects.equals(column, that.column) &&
return Objects.equals(expression, that.expression) &&
Objects.equals(constant, that.constant);
}

@Override
public int hashCode()
{
return Objects.hash(column, constant);
return Objects.hash(expression, constant);
}
}
}
Expand Up @@ -24,6 +24,7 @@
import io.prestosql.sql.planner.Partitioning;
import io.prestosql.sql.planner.PartitioningHandle;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.tree.Expression;

import javax.annotation.concurrent.Immutable;

Expand Down Expand Up @@ -153,17 +154,21 @@ public boolean isStreamRepartitionEffective(Collection<Symbol> keys)

public ActualProperties translate(Function<Symbol, Optional<Symbol>> translator)
{
Map<Symbol, NullableValue> translatedConstants = new HashMap<>();
for (Map.Entry<Symbol, NullableValue> entry : constants.entrySet()) {
Optional<Symbol> translatedKey = translator.apply(entry.getKey());
if (translatedKey.isPresent()) {
translatedConstants.put(translatedKey.get(), entry.getValue());
}
}
return builder()
.global(global.translate(translator, symbol -> Optional.ofNullable(constants.get(symbol))))
.global(global.translate(new Partitioning.Translator(translator, symbol -> Optional.ofNullable(constants.get(symbol)), expression -> Optional.empty())))
.local(LocalProperties.translate(localProperties, translator))
.constants(translatedConstants)
.constants(translateConstants(translator))
.build();
}

public ActualProperties translate(
Function<Symbol, Optional<Symbol>> translator,
Function<Expression, Optional<Symbol>> expressionTranslator)
{
return builder()
.global(global.translate(new Partitioning.Translator(translator, symbol -> Optional.ofNullable(constants.get(symbol)), expressionTranslator)))
.local(LocalProperties.translate(localProperties, translator))
.constants(translateConstants(translator))
.build();
}

Expand Down Expand Up @@ -199,6 +204,18 @@ public static Builder builderFrom(ActualProperties properties)
return new Builder(properties.global, properties.localProperties, properties.constants);
}

private Map<Symbol, NullableValue> translateConstants(Function<Symbol, Optional<Symbol>> translator)
{
Map<Symbol, NullableValue> translatedConstants = new HashMap<>();
for (Map.Entry<Symbol, NullableValue> entry : constants.entrySet()) {
Optional<Symbol> translatedKey = translator.apply(entry.getKey());
if (translatedKey.isPresent()) {
translatedConstants.put(translatedKey.get(), entry.getValue());
}
}
return translatedConstants;
}

public static class Builder
{
private Global global;
Expand Down Expand Up @@ -448,11 +465,11 @@ private boolean isStreamRepartitionEffective(Collection<Symbol> keys, Set<Symbol
return (!streamPartitioning.isPresent() || streamPartitioning.get().isRepartitionEffective(keys, constants)) && !nullsAndAnyReplicated;
}

private Global translate(Function<Symbol, Optional<Symbol>> translator, Function<Symbol, Optional<NullableValue>> constants)
private Global translate(Partitioning.Translator translator)
{
return new Global(
nodePartitioning.flatMap(partitioning -> partitioning.translate(translator, constants)),
streamPartitioning.flatMap(partitioning -> partitioning.translate(translator, constants)),
nodePartitioning.flatMap(partitioning -> partitioning.translate(translator)),
streamPartitioning.flatMap(partitioning -> partitioning.translate(translator)),
nullsAndAnyReplicated);
}

Expand Down
Expand Up @@ -391,7 +391,7 @@ public Optional<PartitioningProperties> translate(Function<Symbol, Optional<Symb
return Optional.of(new PartitioningProperties(newPartitioningColumns, Optional.empty(), nullsAndAnyReplicated));
}

Optional<Partitioning> newPartitioning = partitioning.get().translate(translator, symbol -> Optional.empty());
Optional<Partitioning> newPartitioning = partitioning.get().translate(new Partitioning.Translator(translator, symbol -> Optional.empty(), coalesceSymbols -> Optional.empty()));
if (!newPartitioning.isPresent()) {
return Optional.empty();
}
Expand Down
Expand Up @@ -37,6 +37,7 @@
import io.prestosql.sql.planner.ExpressionInterpreter;
import io.prestosql.sql.planner.NoOpSymbolResolver;
import io.prestosql.sql.planner.OrderingScheme;
import io.prestosql.sql.planner.Partitioning;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.optimizations.ActualProperties.Global;
Expand Down Expand Up @@ -74,6 +75,7 @@
import io.prestosql.sql.planner.plan.UnnestNode;
import io.prestosql.sql.planner.plan.ValuesNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.tree.CoalesceExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.SymbolReference;
Expand Down Expand Up @@ -431,8 +433,22 @@ public ActualProperties visitJoin(JoinNode node, List<ActualProperties> inputPro
.unordered(unordered)
.build();
case FULL:
// We can't say anything about the partitioning scheme because any partition of
// a hash-partitioned join can produce nulls in case of a lack of matches
if (probeProperties.getNodePartitioning().isPresent()) {
Partitioning nodePartitioning = probeProperties.getNodePartitioning().get();
ImmutableList.Builder<Expression> coalesceExpressions = ImmutableList.builder();
for (Symbol column : nodePartitioning.getColumns()) {
for (JoinNode.EquiJoinClause equality : node.getCriteria()) {
if (equality.getLeft().equals(column) || equality.getRight().equals(column)) {
coalesceExpressions.add(new CoalesceExpression(ImmutableList.of(equality.getLeft().toSymbolReference(), equality.getRight().toSymbolReference())));
}
}
}

return ActualProperties.builder()
.global(partitionedOn(Partitioning.createWithExpressions(nodePartitioning.getHandle(), coalesceExpressions.build()), Optional.empty()))
.unordered(unordered)
.build();
}
return ActualProperties.builder()
.global(probeProperties.isSingleNode() ? singleStreamPartition() : arbitraryPartition())
.unordered(unordered)
Expand Down Expand Up @@ -615,7 +631,7 @@ public ActualProperties visitProject(ProjectNode node, List<ActualProperties> in

Map<Symbol, Symbol> identities = computeIdentityTranslations(node.getAssignments().getMap());

ActualProperties translatedProperties = properties.translate(column -> Optional.ofNullable(identities.get(column)));
ActualProperties translatedProperties = properties.translate(column -> Optional.ofNullable(identities.get(column)), expression -> rewriteExpression(node.getAssignments().getMap(), expression));

// Extract additional constants
Map<Symbol, NullableValue> constants = new HashMap<>();
Expand Down Expand Up @@ -833,4 +849,28 @@ else if (equality.getRight().equals(column) && columns.contains(equality.getLeft

return Optional.empty();
}

public static Optional<Symbol> rewriteExpression(Map<Symbol, Expression> assignments, Expression expression)
{
checkArgument(expression instanceof CoalesceExpression, "The rewrite can only handle CoalesceExpression");
// We are using the property that the result of coalesce from full outer join keys would not be null despite of the order
// of the arguments. Thus we extract and compare the symbols of the CoalesceExpression as a set rather than compare the
// CoalesceExpression directly.
for (Map.Entry<Symbol, Expression> entry : assignments.entrySet()) {
if (entry.getValue() instanceof CoalesceExpression) {
Set<Symbol> symbolsInAssignment = ((CoalesceExpression) entry.getValue()).getOperands().stream()
.filter(SymbolReference.class::isInstance)
.map(Symbol::from)
.collect(toImmutableSet());
Set<Symbol> symbolInExpression = ((CoalesceExpression) expression).getOperands().stream()
.filter(SymbolReference.class::isInstance)
.map(Symbol::from)
.collect(toImmutableSet());
if (symbolsInAssignment.containsAll(symbolInExpression)) {
return Optional.of(entry.getKey());
}
}
}
return Optional.empty();
}
}