Skip to content

Commit

Permalink
Propagate partition properties for full outer join
Browse files Browse the repository at this point in the history
For full outer join, the outputs are effectively partitioned on a parwise combination
of the partitioning columns involved in the equi-join criteria. The combination function
is equivalente to sql's coalesce: the first non-null value between each pair of columns.

This is due to the following observations:

* If both input columns are null, they'll be reshuffled to the same node, so coalesce(null, null) => null
* Otherwise, for any combination of values of left & right, if a row is produce by the join, one of these
  must be true:
    * left is not null, right is not null and left = right, so both values are reshuffled to the same node
          => coalesce(left, right) = left = right
    * left is not null and there is no matching row => coalesce(left, null) = left, which can only be emitted
      by the node left was reshuffled to.
    * right is not null and there is no matching row => coalesce(null, right) = right, which can only be emitted
      by the node right was reshuffled to.

Extracted-from: https://github.com/prestosql/presto
  • Loading branch information
martint committed Feb 5, 2019
1 parent 890a6a0 commit c9f8a23
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 45 deletions.
105 changes: 77 additions & 28 deletions presto-main/src/main/java/io/prestosql/sql/planner/Partitioning.java
Expand Up @@ -20,6 +20,8 @@
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.predicate.NullableValue;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.SymbolReference;

import javax.annotation.concurrent.Immutable;

Expand All @@ -32,6 +34,7 @@

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;
Expand All @@ -51,7 +54,15 @@ private Partitioning(PartitioningHandle handle, List<ArgumentBinding> arguments)
public static Partitioning create(PartitioningHandle handle, List<Symbol> columns)
{
return new Partitioning(handle, columns.stream()
.map(ArgumentBinding::columnBinding)
.map(Symbol::toSymbolReference)
.map(ArgumentBinding::expressionBinding)
.collect(toImmutableList()));
}

public static Partitioning createWithExpressions(PartitioningHandle handle, List<Expression> expressions)
{
return new Partitioning(handle, expressions.stream()
.map(ArgumentBinding::expressionBinding)
.collect(toImmutableList()));
}

Expand Down Expand Up @@ -160,13 +171,20 @@ private static boolean isPartitionedWith(

public boolean isPartitionedOn(Collection<Symbol> columns, Set<Symbol> knownConstants)
{
// partitioned on (k_1, k_2, ..., k_n) => partitioned on (k_1, k_2, ..., k_n, k_n+1, ...)
// can safely ignore all constant columns when comparing partition properties
return arguments.stream()
.filter(ArgumentBinding::isVariable)
.map(ArgumentBinding::getColumn)
.filter(symbol -> !knownConstants.contains(symbol))
.allMatch(columns::contains);
for (ArgumentBinding argument : arguments) {
// partitioned on (k_1, k_2, ..., k_n) => partitioned on (k_1, k_2, ..., k_n, k_n+1, ...)
// can safely ignore all constant columns when comparing partition properties
if (argument.isConstant()) {
continue;
}
if (!argument.isVariable()) {
return false;
}
if (!knownConstants.contains(argument.getColumn()) && !columns.contains(argument.getColumn())) {
return false;
}
}
return true;
}

public boolean isEffectivelySinglePartition(Set<Symbol> knownConstants)
Expand Down Expand Up @@ -194,11 +212,11 @@ public Partitioning translate(Function<Symbol, Symbol> translator)
.collect(toImmutableList()));
}

public Optional<Partitioning> translate(Function<Symbol, Optional<Symbol>> translator, Function<Symbol, Optional<NullableValue>> constants)
public Optional<Partitioning> translate(Translator translator)
{
ImmutableList.Builder<ArgumentBinding> newArguments = ImmutableList.builder();
for (ArgumentBinding argument : arguments) {
Optional<ArgumentBinding> newArgument = argument.translate(translator, constants);
Optional<ArgumentBinding> newArgument = argument.translate(translator);
if (!newArgument.isPresent()) {
return Optional.empty();
}
Expand Down Expand Up @@ -242,25 +260,43 @@ public String toString()
.toString();
}

@Immutable
public static final class Translator
{
private final Function<Symbol, Optional<Symbol>> columnTranslator;
private final Function<Symbol, Optional<NullableValue>> constantTranslator;
private final Function<Expression, Optional<Symbol>> expressionTranslator;

public Translator(
Function<Symbol, Optional<Symbol>> columnTranslator,
Function<Symbol, Optional<NullableValue>> constantTranslator,
Function<Expression, Optional<Symbol>> expressionTranslator)
{
this.columnTranslator = requireNonNull(columnTranslator, "columnTranslator is null");
this.constantTranslator = requireNonNull(constantTranslator, "constantTranslator is null");
this.expressionTranslator = requireNonNull(expressionTranslator, "expressionTranslator is null");
}
}

@Immutable
public static final class ArgumentBinding
{
private final Symbol column;
private final Expression expression;
private final NullableValue constant;

@JsonCreator
public ArgumentBinding(
@JsonProperty("column") Symbol column,
@JsonProperty("expression") Expression expression,
@JsonProperty("constant") NullableValue constant)
{
this.column = column;
this.expression = expression;
this.constant = constant;
checkArgument((column == null) != (constant == null), "Either column or constant must be set");
checkArgument((expression == null) != (constant == null), "Either expression or constant must be set");
}

public static ArgumentBinding columnBinding(Symbol column)
public static ArgumentBinding expressionBinding(Expression expression)
{
return new ArgumentBinding(requireNonNull(column, "column is null"), null);
return new ArgumentBinding(requireNonNull(expression, "expression is null"), null);
}

public static ArgumentBinding constantBinding(NullableValue constant)
Expand All @@ -275,13 +311,19 @@ public boolean isConstant()

public boolean isVariable()
{
return column != null;
return expression instanceof SymbolReference;
}

@JsonProperty
public Symbol getColumn()
{
return column;
verify(expression instanceof SymbolReference, "Expect the expression to be a SymbolReference");
return Symbol.from(expression);
}

@JsonProperty
public Expression getExpression()
{
return expression;
}

@JsonProperty
Expand All @@ -295,25 +337,31 @@ public ArgumentBinding translate(Function<Symbol, Symbol> translator)
if (isConstant()) {
return this;
}
return columnBinding(translator.apply(column));
return expressionBinding(translator.apply(Symbol.from(expression)).toSymbolReference());
}

public Optional<ArgumentBinding> translate(Function<Symbol, Optional<Symbol>> translator, Function<Symbol, Optional<NullableValue>> constants)
public Optional<ArgumentBinding> translate(Translator translator)
{
if (isConstant()) {
return Optional.of(this);
}

Optional<ArgumentBinding> newColumn = translator.apply(column)
.map(ArgumentBinding::columnBinding);
if (!isVariable()) {
return translator.expressionTranslator.apply(expression)
.map(Symbol::toSymbolReference)
.map(ArgumentBinding::expressionBinding);
}

Optional<ArgumentBinding> newColumn = translator.columnTranslator.apply(Symbol.from(expression))
.map(Symbol::toSymbolReference)
.map(ArgumentBinding::expressionBinding);
if (newColumn.isPresent()) {
return newColumn;
}

// As a last resort, check for a constant mapping for the symbol
// Note: this MUST be last because we want to favor the symbol representation
// as it makes further optimizations possible.
return constants.apply(column)
return translator.constantTranslator.apply(Symbol.from(expression))
.map(ArgumentBinding::constantBinding);
}

Expand All @@ -323,7 +371,8 @@ public String toString()
if (constant != null) {
return constant.toString();
}
return "\"" + column + "\"";

return expression.toString();
}

@Override
Expand All @@ -336,14 +385,14 @@ public boolean equals(Object o)
return false;
}
ArgumentBinding that = (ArgumentBinding) o;
return Objects.equals(column, that.column) &&
return Objects.equals(expression, that.expression) &&
Objects.equals(constant, that.constant);
}

@Override
public int hashCode()
{
return Objects.hash(column, constant);
return Objects.hash(expression, constant);
}
}
}
Expand Up @@ -24,6 +24,7 @@
import io.prestosql.sql.planner.Partitioning;
import io.prestosql.sql.planner.PartitioningHandle;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.tree.Expression;

import javax.annotation.concurrent.Immutable;

Expand Down Expand Up @@ -153,17 +154,21 @@ public boolean isStreamRepartitionEffective(Collection<Symbol> keys)

public ActualProperties translate(Function<Symbol, Optional<Symbol>> translator)
{
Map<Symbol, NullableValue> translatedConstants = new HashMap<>();
for (Map.Entry<Symbol, NullableValue> entry : constants.entrySet()) {
Optional<Symbol> translatedKey = translator.apply(entry.getKey());
if (translatedKey.isPresent()) {
translatedConstants.put(translatedKey.get(), entry.getValue());
}
}
return builder()
.global(global.translate(translator, symbol -> Optional.ofNullable(constants.get(symbol))))
.global(global.translate(new Partitioning.Translator(translator, symbol -> Optional.ofNullable(constants.get(symbol)), expression -> Optional.empty())))
.local(LocalProperties.translate(localProperties, translator))
.constants(translatedConstants)
.constants(translateConstants(translator))
.build();
}

public ActualProperties translate(
Function<Symbol, Optional<Symbol>> translator,
Function<Expression, Optional<Symbol>> expressionTranslator)
{
return builder()
.global(global.translate(new Partitioning.Translator(translator, symbol -> Optional.ofNullable(constants.get(symbol)), expressionTranslator)))
.local(LocalProperties.translate(localProperties, translator))
.constants(translateConstants(translator))
.build();
}

Expand Down Expand Up @@ -199,6 +204,18 @@ public static Builder builderFrom(ActualProperties properties)
return new Builder(properties.global, properties.localProperties, properties.constants);
}

private Map<Symbol, NullableValue> translateConstants(Function<Symbol, Optional<Symbol>> translator)
{
Map<Symbol, NullableValue> translatedConstants = new HashMap<>();
for (Map.Entry<Symbol, NullableValue> entry : constants.entrySet()) {
Optional<Symbol> translatedKey = translator.apply(entry.getKey());
if (translatedKey.isPresent()) {
translatedConstants.put(translatedKey.get(), entry.getValue());
}
}
return translatedConstants;
}

public static class Builder
{
private Global global;
Expand Down Expand Up @@ -448,11 +465,11 @@ private boolean isStreamRepartitionEffective(Collection<Symbol> keys, Set<Symbol
return (!streamPartitioning.isPresent() || streamPartitioning.get().isRepartitionEffective(keys, constants)) && !nullsAndAnyReplicated;
}

private Global translate(Function<Symbol, Optional<Symbol>> translator, Function<Symbol, Optional<NullableValue>> constants)
private Global translate(Partitioning.Translator translator)
{
return new Global(
nodePartitioning.flatMap(partitioning -> partitioning.translate(translator, constants)),
streamPartitioning.flatMap(partitioning -> partitioning.translate(translator, constants)),
nodePartitioning.flatMap(partitioning -> partitioning.translate(translator)),
streamPartitioning.flatMap(partitioning -> partitioning.translate(translator)),
nullsAndAnyReplicated);
}

Expand Down
Expand Up @@ -391,7 +391,7 @@ public Optional<PartitioningProperties> translate(Function<Symbol, Optional<Symb
return Optional.of(new PartitioningProperties(newPartitioningColumns, Optional.empty(), nullsAndAnyReplicated));
}

Optional<Partitioning> newPartitioning = partitioning.get().translate(translator, symbol -> Optional.empty());
Optional<Partitioning> newPartitioning = partitioning.get().translate(new Partitioning.Translator(translator, symbol -> Optional.empty(), coalesceSymbols -> Optional.empty()));
if (!newPartitioning.isPresent()) {
return Optional.empty();
}
Expand Down
Expand Up @@ -37,6 +37,7 @@
import io.prestosql.sql.planner.ExpressionInterpreter;
import io.prestosql.sql.planner.NoOpSymbolResolver;
import io.prestosql.sql.planner.OrderingScheme;
import io.prestosql.sql.planner.Partitioning;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.optimizations.ActualProperties.Global;
Expand Down Expand Up @@ -74,6 +75,7 @@
import io.prestosql.sql.planner.plan.UnnestNode;
import io.prestosql.sql.planner.plan.ValuesNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.tree.CoalesceExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.SymbolReference;
Expand Down Expand Up @@ -431,8 +433,22 @@ public ActualProperties visitJoin(JoinNode node, List<ActualProperties> inputPro
.unordered(unordered)
.build();
case FULL:
// We can't say anything about the partitioning scheme because any partition of
// a hash-partitioned join can produce nulls in case of a lack of matches
if (probeProperties.getNodePartitioning().isPresent()) {
Partitioning nodePartitioning = probeProperties.getNodePartitioning().get();
ImmutableList.Builder<Expression> coalesceExpressions = ImmutableList.builder();
for (Symbol column : nodePartitioning.getColumns()) {
for (JoinNode.EquiJoinClause equality : node.getCriteria()) {
if (equality.getLeft().equals(column) || equality.getRight().equals(column)) {
coalesceExpressions.add(new CoalesceExpression(ImmutableList.of(equality.getLeft().toSymbolReference(), equality.getRight().toSymbolReference())));
}
}
}

return ActualProperties.builder()
.global(partitionedOn(Partitioning.createWithExpressions(nodePartitioning.getHandle(), coalesceExpressions.build()), Optional.empty()))
.unordered(unordered)
.build();
}
return ActualProperties.builder()
.global(probeProperties.isSingleNode() ? singleStreamPartition() : arbitraryPartition())
.unordered(unordered)
Expand Down Expand Up @@ -615,7 +631,7 @@ public ActualProperties visitProject(ProjectNode node, List<ActualProperties> in

Map<Symbol, Symbol> identities = computeIdentityTranslations(node.getAssignments().getMap());

ActualProperties translatedProperties = properties.translate(column -> Optional.ofNullable(identities.get(column)));
ActualProperties translatedProperties = properties.translate(column -> Optional.ofNullable(identities.get(column)), expression -> rewriteExpression(node.getAssignments().getMap(), expression));

// Extract additional constants
Map<Symbol, NullableValue> constants = new HashMap<>();
Expand Down Expand Up @@ -833,4 +849,28 @@ else if (equality.getRight().equals(column) && columns.contains(equality.getLeft

return Optional.empty();
}

public static Optional<Symbol> rewriteExpression(Map<Symbol, Expression> assignments, Expression expression)
{
checkArgument(expression instanceof CoalesceExpression, "The rewrite can only handle CoalesceExpression");
// We are using the property that the result of coalesce from full outer join keys would not be null despite of the order
// of the arguments. Thus we extract and compare the symbols of the CoalesceExpression as a set rather than compare the
// CoalesceExpression directly.
for (Map.Entry<Symbol, Expression> entry : assignments.entrySet()) {
if (entry.getValue() instanceof CoalesceExpression) {
Set<Symbol> symbolsInAssignment = ((CoalesceExpression) entry.getValue()).getOperands().stream()
.filter(SymbolReference.class::isInstance)
.map(Symbol::from)
.collect(toImmutableSet());
Set<Symbol> symbolInExpression = ((CoalesceExpression) expression).getOperands().stream()
.filter(SymbolReference.class::isInstance)
.map(Symbol::from)
.collect(toImmutableSet());
if (symbolsInAssignment.containsAll(symbolInExpression)) {
return Optional.of(entry.getKey());
}
}
}
return Optional.empty();
}
}
Expand Up @@ -319,7 +319,8 @@ public Void visitRemoteSource(RemoteSourceNode node, Void context)
public Void visitExchange(ExchangeNode node, Void context)
{
List<ArgumentBinding> symbols = node.getOutputSymbols().stream()
.map(ArgumentBinding::columnBinding)
.map(Symbol::toSymbolReference)
.map(ArgumentBinding::expressionBinding)
.collect(toImmutableList());
if (node.getType() == REPARTITION) {
symbols = node.getPartitioningScheme().getPartitioning().getArguments();
Expand Down

0 comments on commit c9f8a23

Please sign in to comment.