Skip to content

Commit

Permalink
Refactor move sort expression to JoinNode
Browse files Browse the repository at this point in the history
This requires changing SortExpressionExtractor to work on SymbolReferences instead of FieldReferences
  • Loading branch information
pnowojski authored and martint committed May 14, 2017
1 parent a267d97 commit aeea5e2
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 47 deletions.
Expand Up @@ -1563,6 +1563,7 @@ private LookupSourceFactory createLookupSourceFactory(
Optional<JoinFilterFunctionFactory> filterFunctionFactory = node.getFilter()
.map(filterExpression -> compileJoinFilterFunction(
filterExpression,
node.getSortExpression(),
probeLayout,
buildSource.getLayout(),
context.getTypes(),
Expand Down Expand Up @@ -1596,6 +1597,7 @@ private LookupSourceFactory createLookupSourceFactory(

private JoinFilterFunctionFactory compileJoinFilterFunction(
Expression filterExpression,
Optional<Expression> sortExpression,
Map<Symbol, Integer> probeLayout,
Map<Symbol, Integer> buildLayout,
Map<Symbol, Type> types,
Expand All @@ -1607,8 +1609,10 @@ private JoinFilterFunctionFactory compileJoinFilterFunction(
.collect(toImmutableMap(Map.Entry::getValue, entry -> types.get(entry.getKey())));

Expression rewrittenFilter = new SymbolToInputRewriter(joinSourcesLayout).rewrite(filterExpression);
Optional<Expression> rewrittenSortExpression = sortExpression.map(
expression -> new SymbolToInputRewriter(buildLayout).rewrite(expression));

Optional<SortExpression> sortChannel = SortExpressionExtractor.extractSortExpression(buildLayout, rewrittenFilter);
Optional<SortExpression> sortChannel = rewrittenSortExpression.map(SortExpression::fromExpression);

IdentityLinkedHashMap<Expression, Type> expressionTypes = getExpressionTypesFromInput(
session,
Expand Down
Expand Up @@ -18,14 +18,15 @@
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FieldReference;
import com.facebook.presto.sql.tree.Node;
import com.google.common.collect.ImmutableSet;
import com.facebook.presto.sql.tree.SymbolReference;

import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;

/**
Expand All @@ -49,24 +50,23 @@ public final class SortExpressionExtractor
{
private SortExpressionExtractor() {}

public static Optional<SortExpression> extractSortExpression(Map<Symbol, Integer> buildLayout, Expression filter)
public static Optional<Expression> extractSortExpression(Set<Symbol> buildSymbols, Expression filter)
{
Set<Integer> buildFields = ImmutableSet.copyOf(buildLayout.values());
if (filter instanceof ComparisonExpression) {
ComparisonExpression comparison = (ComparisonExpression) filter;
switch (comparison.getType()) {
case GREATER_THAN:
case GREATER_THAN_OR_EQUAL:
case LESS_THAN:
case LESS_THAN_OR_EQUAL:
Optional<Integer> sortChannel = asBuildFieldReference(buildFields, comparison.getRight());
boolean hasBuildReferencesOnOtherSide = hasBuildFieldReference(buildFields, comparison.getLeft());
Optional<SymbolReference> sortChannel = asBuildSymbolReference(buildSymbols, comparison.getRight());
boolean hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.getLeft());
if (!sortChannel.isPresent()) {
sortChannel = asBuildFieldReference(buildFields, comparison.getLeft());
hasBuildReferencesOnOtherSide = hasBuildFieldReference(buildFields, comparison.getRight());
sortChannel = asBuildSymbolReference(buildSymbols, comparison.getLeft());
hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.getRight());
}
if (sortChannel.isPresent() && !hasBuildReferencesOnOtherSide) {
return Optional.of(new SortExpression(sortChannel.get()));
return sortChannel.map(symbolReference -> (Expression) symbolReference);
}
return Optional.empty();
default:
Expand All @@ -77,30 +77,32 @@ public static Optional<SortExpression> extractSortExpression(Map<Symbol, Integer
return Optional.empty();
}

private static Optional<Integer> asBuildFieldReference(Set<Integer> buildLayout, Expression expression)
private static Optional<SymbolReference> asBuildSymbolReference(Set<Symbol> buildLayout, Expression expression)
{
if (expression instanceof FieldReference) {
FieldReference field = (FieldReference) expression;
if (buildLayout.contains(field.getFieldIndex())) {
return Optional.of(field.getFieldIndex());
if (expression instanceof SymbolReference) {
SymbolReference symbolReference = (SymbolReference) expression;
if (buildLayout.contains(new Symbol(symbolReference.getName()))) {
return Optional.of(symbolReference);
}
}
return Optional.empty();
}

private static boolean hasBuildFieldReference(Set<Integer> buildLayout, Expression expression)
private static boolean hasBuildSymbolReference(Set<Symbol> buildSymbols, Expression expression)
{
return new BuildFieldReferenceFinder(buildLayout).process(expression);
return new BuildSymbolReferenceFinder(buildSymbols).process(expression);
}

private static class BuildFieldReferenceFinder
private static class BuildSymbolReferenceFinder
extends AstVisitor<Boolean, Void>
{
private final Set<Integer> buildLayout;
private final Set<String> buildSymbols;

public BuildFieldReferenceFinder(Set<Integer> buildLayout)
public BuildSymbolReferenceFinder(Set<Symbol> buildSymbols)
{
this.buildLayout = ImmutableSet.copyOf(requireNonNull(buildLayout, "buildLayout is null"));
this.buildSymbols = requireNonNull(buildSymbols, "buildSymbols is null").stream()
.map(Symbol::getName)
.collect(toImmutableSet());
}

@Override
Expand All @@ -115,9 +117,9 @@ protected Boolean visitNode(Node node, Void context)
}

@Override
protected Boolean visitFieldReference(FieldReference fieldReference, Void context)
protected Boolean visitSymbolReference(SymbolReference symbolReference, Void context)
{
return buildLayout.contains(fieldReference.getFieldIndex());
return buildSymbols.contains(symbolReference.getName());
}
}

Expand Down Expand Up @@ -160,5 +162,11 @@ public String toString()
.add("channel", channel)
.toString();
}

public static SortExpression fromExpression(Expression expression)
{
checkState(expression instanceof FieldReference, "Unsupported expression type [%s]", expression);
return new SortExpression(((FieldReference) expression).getFieldIndex());
}
}
}
Expand Up @@ -19,6 +19,7 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import javax.annotation.concurrent.Immutable;

Expand All @@ -27,6 +28,7 @@
import java.util.Optional;
import java.util.stream.Stream;

import static com.facebook.presto.sql.planner.SortExpressionExtractor.extractSortExpression;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -168,6 +170,11 @@ public Optional<Expression> getFilter()
return filter;
}

public Optional<Expression> getSortExpression()
{
return filter.map(filter -> extractSortExpression(ImmutableSet.copyOf(right.getOutputSymbols()), filter).orElse(null));
}

@JsonProperty("leftHashSymbol")
public Optional<Symbol> getLeftHashSymbol()
{
Expand Down
Expand Up @@ -13,76 +13,73 @@
*/
package com.facebook.presto.sql.planner;

import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FieldReference;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;

import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static org.testng.AssertJUnit.assertEquals;

public class TestSortExpressionExtractor
{
private static final Map<Symbol, Integer> BUILD_LAYOUT = ImmutableMap.of(
new Symbol("b1"), 1,
new Symbol("b2"), 2);
private static final Set<Symbol> BUILD_SYMBOLS = ImmutableSet.of(new Symbol("b1"), new Symbol("b2"));

@Test
public void testGetSortExpression()
{
assertGetSortExpression(
new ComparisonExpression(
ComparisonExpressionType.GREATER_THAN,
new FieldReference(11),
new FieldReference(1)),
1);
new SymbolReference("p1"),
new SymbolReference("b1")),
"b1");

assertGetSortExpression(
new ComparisonExpression(
ComparisonExpressionType.LESS_THAN_OR_EQUAL,
new FieldReference(2),
new FieldReference(11)),
2);
new SymbolReference("b2"),
new SymbolReference("p1")),
"b2");

assertGetSortExpression(
new ComparisonExpression(
ComparisonExpressionType.GREATER_THAN,
new FieldReference(2),
new FieldReference(11)),
2);
new SymbolReference("b2"),
new SymbolReference("p1")),
"b2");

assertGetSortExpression(
new ComparisonExpression(
ComparisonExpressionType.GREATER_THAN,
new FieldReference(1),
new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.ADD, new FieldReference(2), new FieldReference(11))));
new SymbolReference("b1"),
new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.ADD, new SymbolReference("b2"), new SymbolReference("p1"))));

assertGetSortExpression(
new ComparisonExpression(
ComparisonExpressionType.GREATER_THAN,
new FunctionCall(QualifiedName.of("sin"), ImmutableList.of(new FieldReference(1))),
new FieldReference(11)));
new FunctionCall(QualifiedName.of("sin"), ImmutableList.of(new SymbolReference("b1"))),
new SymbolReference("p1")));
}

private static void assertGetSortExpression(Expression expression)
{
Optional<SortExpression> actual = SortExpressionExtractor.extractSortExpression(BUILD_LAYOUT, expression);
Optional<Expression> actual = SortExpressionExtractor.extractSortExpression(BUILD_SYMBOLS, expression);
assertEquals(Optional.empty(), actual);
}

private static void assertGetSortExpression(Expression expression, int expectedChannel)
private static void assertGetSortExpression(Expression expression, String expectedSymbol)
{
Optional<SortExpression> expected = Optional.of(new SortExpression(expectedChannel));
Optional<SortExpression> actual = SortExpressionExtractor.extractSortExpression(BUILD_LAYOUT, expression);
Optional<Expression> expected = Optional.of(new SymbolReference(expectedSymbol));
Optional<Expression> actual = SortExpressionExtractor.extractSortExpression(BUILD_SYMBOLS, expression);
assertEquals(expected, actual);
}
}

0 comments on commit aeea5e2

Please sign in to comment.