Skip to content

Commit

Permalink
Enable DomainTranslator to fold constants when extracting TupleDomains
Browse files Browse the repository at this point in the history
  • Loading branch information
erichwang committed Nov 14, 2015
1 parent 9a4fe1e commit 4f44780
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 94 deletions.
Expand Up @@ -15,28 +15,26 @@

import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.predicate.DiscreteValues;
import com.facebook.presto.spi.predicate.Domain;
import com.facebook.presto.spi.predicate.Marker;
import com.facebook.presto.spi.predicate.NullableValue;
import com.facebook.presto.spi.predicate.Range;
import com.facebook.presto.spi.predicate.Ranges;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.predicate.ValueSet;
import com.facebook.presto.spi.type.DoubleType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.BetweenPredicate;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.DoubleLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.InListExpression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NotExpression;
Expand All @@ -49,13 +47,13 @@
import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.metadata.FunctionRegistry.getMagicLiteralFunctionSignature;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.DateType.DATE;
import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
import static com.facebook.presto.sql.ExpressionUtils.and;
import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts;
import static com.facebook.presto.sql.ExpressionUtils.combineDisjunctsWithDefault;
Expand All @@ -70,22 +68,16 @@
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.LESS_THAN;
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.LESS_THAN_OR_EQUAL;
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.NOT_EQUAL;
import static com.facebook.presto.type.ColorType.COLOR;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.primitives.Primitives.wrap;
import static java.math.RoundingMode.CEILING;
import static java.math.RoundingMode.FLOOR;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

public final class DomainTranslator
{
private static final String DATE_LITERAL = getMagicLiteralFunctionSignature(DATE).getName();
private static final String TIMESTAMP_LITERAL = getMagicLiteralFunctionSignature(TIMESTAMP).getName();
private static final String COLOR_LITERAL = getMagicLiteralFunctionSignature(COLOR).getName();

private DomainTranslator()
{
}
Expand Down Expand Up @@ -239,13 +231,13 @@ private static class Visitor
extends AstVisitor<ExtractionResult, Boolean>
{
private final Metadata metadata;
private final ConnectorSession session;
private final Session session;
private final Map<Symbol, Type> types;

private Visitor(Metadata metadata, Session session, Map<Symbol, Type> types)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.session = requireNonNull(session, "session is null").toConnectorSession();
this.session = requireNonNull(session, "session is null");
this.types = ImmutableMap.copyOf(requireNonNull(types, "types is null"));
}

Expand Down Expand Up @@ -350,31 +342,26 @@ protected ExtractionResult visitNotExpression(NotExpression node, Boolean comple
@Override
protected ExtractionResult visitComparisonExpression(ComparisonExpression node, Boolean complement)
{
if (isSimpleMagicLiteralComparison(node)) {
node = normalizeSimpleComparison(node);
node = convertMagicLiteralComparison(node);
}
else if (isSimpleComparison(node)) {
node = normalizeSimpleComparison(node);
}
else {
Optional<NormalizedSimpleComparison> optionalNormalized = toNormalizedSimpleComparison(session, metadata, types, node);
if (!optionalNormalized.isPresent()) {
return super.visitComparisonExpression(node, complement);
}
NormalizedSimpleComparison normalized = optionalNormalized.get();

Symbol symbol = Symbol.fromQualifiedName(((QualifiedNameReference) node.getLeft()).getName());
Symbol symbol = Symbol.fromQualifiedName(normalized.getNameReference().getName());
Type type = checkedTypeLookup(symbol);
Object value = LiteralInterpreter.evaluate(metadata, session, node.getRight());
NullableValue value = normalized.getValue();

// Handle the cases where implicit coercions can happen in comparisons
// TODO: how to abstract this out
if (value instanceof Double && type.equals(BIGINT)) {
return process(coerceDoubleToLongComparison(node), complement);
if (value.getType().equals(DOUBLE) && type.equals(BIGINT)) {
return process(coerceDoubleToLongComparison(normalized), complement);
}
if (value instanceof Long && type.equals(DoubleType.DOUBLE)) {
value = ((Long) value).doubleValue();
if (value.getType().equals(BIGINT) && type.equals(DOUBLE)) {
value = NullableValue.of(DOUBLE, ((Long) value.getValue()).doubleValue());
}
verifyType(type, value);
return createComparisonExtractionResult(node.getType(), symbol, type, value, complement);
checkState(value.isNull() || value.getType().equals(type), "INVARIANT: comparison should be working on the same types");
return createComparisonExtractionResult(normalized.getComparisonType(), symbol, type, value.getValue(), complement);
}

private ExtractionResult createComparisonExtractionResult(ComparisonExpression.Type comparisonType, Symbol column, Type type, @Nullable Object value, boolean complement)
Expand Down Expand Up @@ -456,11 +443,6 @@ private static Domain extractEquatableDomain(ComparisonExpression.Type compariso
}
}

private static void verifyType(Type type, Object value)
{
checkState(value == null || wrap(type.getJavaType()).isInstance(value), "Value %s is not of expected type %s", value, type);
}

@Override
protected ExtractionResult visitInPredicate(InPredicate node, Boolean complement)
{
Expand Down Expand Up @@ -532,44 +514,68 @@ protected ExtractionResult visitNullLiteral(NullLiteral node, Boolean complement
}
}

private static boolean isSimpleComparison(ComparisonExpression comparison)
{
return (comparison.getLeft() instanceof QualifiedNameReference && comparison.getRight() instanceof Literal) ||
(comparison.getLeft() instanceof Literal && comparison.getRight() instanceof QualifiedNameReference);
}

/**
* Normalize a simple comparison between a QualifiedNameReference and a Literal such that the QualifiedNameReference will always be on the left and the Literal on the right.
* Extract a normalized simple comparison between a QualifiedNameReference and a native value if possible.
*/
private static ComparisonExpression normalizeSimpleComparison(ComparisonExpression comparison)
private static Optional<NormalizedSimpleComparison> toNormalizedSimpleComparison(Session session, Metadata metadata, Map<Symbol, Type> types, ComparisonExpression comparison)
{
if (comparison.getLeft() instanceof QualifiedNameReference) {
return comparison;
IdentityHashMap<Expression, Type> expressionTypes = ExpressionAnalyzer.getExpressionTypes(session, metadata, new SqlParser(), types, comparison);
Object left = ExpressionInterpreter.expressionOptimizer(comparison.getLeft(), metadata, session, expressionTypes).optimize(NoOpSymbolResolver.INSTANCE);
Object right = ExpressionInterpreter.expressionOptimizer(comparison.getRight(), metadata, session, expressionTypes).optimize(NoOpSymbolResolver.INSTANCE);

if (left instanceof QualifiedNameReference && !(right instanceof Expression)) {
return Optional.of(new NormalizedSimpleComparison((QualifiedNameReference) left, comparison.getType(), new NullableValue(expressionTypes.get(comparison.getRight()), right)));
}
if (comparison.getRight() instanceof QualifiedNameReference) {
return new ComparisonExpression(flipComparison(comparison.getType()), comparison.getRight(), comparison.getLeft());
if (right instanceof QualifiedNameReference && !(left instanceof Expression)) {
return Optional.of(new NormalizedSimpleComparison((QualifiedNameReference) right, flipComparison(comparison.getType()), new NullableValue(expressionTypes.get(comparison.getLeft()), left)));
}
throw new IllegalArgumentException("ComparisonExpression not a simple literal comparison: " + comparison);
return Optional.empty();
}

private static Expression coerceDoubleToLongComparison(ComparisonExpression comparison)
private static class NormalizedSimpleComparison
{
comparison = normalizeSimpleComparison(comparison);
private final QualifiedNameReference nameReference;
private final ComparisonExpression.Type comparisonType;
private final NullableValue value;

checkArgument(comparison.getLeft() instanceof QualifiedNameReference, "Left must be a QualifiedNameReference");
checkArgument(comparison.getRight() instanceof DoubleLiteral, "Right must be a DoubleLiteral");
public NormalizedSimpleComparison(QualifiedNameReference nameReference, ComparisonExpression.Type comparisonType, NullableValue value)
{
this.nameReference = requireNonNull(nameReference, "nameReference is null");
this.comparisonType = requireNonNull(comparisonType, "comparisonType is null");
this.value = requireNonNull(value, "value is null");
}

QualifiedNameReference reference = (QualifiedNameReference) comparison.getLeft();
Double value = ((DoubleLiteral) comparison.getRight()).getValue();
public QualifiedNameReference getNameReference()
{
return nameReference;
}

switch (comparison.getType()) {
public ComparisonExpression.Type getComparisonType()
{
return comparisonType;
}

public NullableValue getValue()
{
return value;
}
}

private static Expression coerceDoubleToLongComparison(NormalizedSimpleComparison normalized)
{
checkArgument(normalized.getValue().getType().equals(DOUBLE), "Value should be of DOUBLE type");
checkArgument(!normalized.getValue().isNull(), "Value should not be null");
QualifiedNameReference reference = normalized.getNameReference();
Double value = (Double) normalized.getValue().getValue();

switch (normalized.getComparisonType()) {
case GREATER_THAN_OR_EQUAL:
case LESS_THAN:
return new ComparisonExpression(comparison.getType(), reference, toExpression(DoubleMath.roundToLong(value, CEILING), BIGINT));
return new ComparisonExpression(normalized.getComparisonType(), reference, toExpression(DoubleMath.roundToLong(value, CEILING), BIGINT));

case GREATER_THAN:
case LESS_THAN_OR_EQUAL:
return new ComparisonExpression(comparison.getType(), reference, toExpression(DoubleMath.roundToLong(value, FLOOR), BIGINT));
return new ComparisonExpression(normalized.getComparisonType(), reference, toExpression(DoubleMath.roundToLong(value, FLOOR), BIGINT));

case EQUAL:
Long equalValue = DoubleMath.roundToLong(value, FLOOR);
Expand All @@ -578,7 +584,7 @@ private static Expression coerceDoubleToLongComparison(ComparisonExpression comp
return and(new ComparisonExpression(EQUAL, reference, new LongLiteral("0")),
new ComparisonExpression(NOT_EQUAL, reference, new LongLiteral("0")));
}
return new ComparisonExpression(comparison.getType(), reference, toExpression(equalValue, BIGINT));
return new ComparisonExpression(normalized.getComparisonType(), reference, toExpression(equalValue, BIGINT));

case NOT_EQUAL:
Long notEqualValue = DoubleMath.roundToLong(value, FLOOR);
Expand All @@ -587,17 +593,17 @@ private static Expression coerceDoubleToLongComparison(ComparisonExpression comp
return or(new ComparisonExpression(EQUAL, reference, new LongLiteral("0")),
new ComparisonExpression(NOT_EQUAL, reference, new LongLiteral("0")));
}
return new ComparisonExpression(comparison.getType(), reference, toExpression(notEqualValue, BIGINT));
return new ComparisonExpression(normalized.getComparisonType(), reference, toExpression(notEqualValue, BIGINT));

case IS_DISTINCT_FROM:
Long distinctValue = DoubleMath.roundToLong(value, FLOOR);
if (distinctValue.doubleValue() != value) {
return TRUE_LITERAL;
}
return new ComparisonExpression(comparison.getType(), reference, toExpression(distinctValue, BIGINT));
return new ComparisonExpression(normalized.getComparisonType(), reference, toExpression(distinctValue, BIGINT));

default:
throw new AssertionError("Unhandled type: " + comparison.getType());
throw new AssertionError("Unhandled type: " + normalized.getComparisonType());
}
}

Expand All @@ -622,35 +628,4 @@ public Expression getRemainingExpression()
return remainingExpression;
}
}

// TODO: remove this horrible hack
private static boolean isSimpleMagicLiteralComparison(ComparisonExpression node)
{
FunctionCall call;
if ((node.getLeft() instanceof QualifiedNameReference) && (node.getRight() instanceof FunctionCall)) {
call = (FunctionCall) node.getRight();
}
else if ((node.getLeft() instanceof FunctionCall) && (node.getRight() instanceof QualifiedNameReference)) {
call = (FunctionCall) node.getLeft();
}
else {
return false;
}

if (call.getName().getPrefix().isPresent()) {
return false;
}
String name = call.getName().getSuffix();

return name.equals(DATE_LITERAL) || name.equals(TIMESTAMP_LITERAL) || name.equals(COLOR_LITERAL);
}

private static ComparisonExpression convertMagicLiteralComparison(ComparisonExpression node)
{
// "magic literal" functions use the stack type value for the argument
checkArgument(isSimpleMagicLiteralComparison(node), "not a simple magic literal comparison");
FunctionCall call = (FunctionCall) node.getRight();
Expression value = call.getArguments().get(0);
return new ComparisonExpression(node.getType(), node.getLeft(), value);
}
}
Expand Up @@ -37,6 +37,9 @@
import com.facebook.presto.sql.tree.StringLiteral;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.BaseEncoding;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.testng.annotations.Test;
Expand All @@ -53,6 +56,7 @@
import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
import static com.facebook.presto.spi.type.HyperLogLogType.HYPER_LOG_LOG;
import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY;
import static com.facebook.presto.spi.type.VarcharType.VARCHAR;
import static com.facebook.presto.sql.ExpressionUtils.and;
import static com.facebook.presto.sql.ExpressionUtils.or;
Expand Down Expand Up @@ -86,6 +90,7 @@ public class TestDomainTranslator
private static final Symbol I = new Symbol("i");
private static final Symbol J = new Symbol("j");
private static final Symbol K = new Symbol("k");
private static final Symbol L = new Symbol("l");

private static final Map<Symbol, Type> TYPES = ImmutableMap.<Symbol, Type>builder()
.put(A, BIGINT)
Expand All @@ -99,6 +104,7 @@ public class TestDomainTranslator
.put(I, DATE)
.put(J, COLOR) // Equatable, but not orderable
.put(K, HYPER_LOG_LOG) // Not Equatable or orderable
.put(L, VARBINARY)
.build();

private static final long TIMESTAMP_VALUE = new DateTime(2013, 3, 30, 1, 5, 0, 0, DateTimeZone.UTC).getMillis();
Expand Down Expand Up @@ -1048,6 +1054,20 @@ public void testFromNullLiteralPredicate()
assertTrue(result.getTupleDomain().isNone());
}

@Test
public void testExpressionConstantFolding()
throws Exception
{
Expression originalExpression = comparison(GREATER_THAN, reference(L), function("from_hex", stringLiteral("123456")));
ExtractionResult result = fromPredicate(originalExpression);
assertEquals(result.getRemainingExpression(), TRUE_LITERAL);
Slice value = Slices.wrappedBuffer(BaseEncoding.base16().decode("123456"));
assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(L, Domain.create(ValueSet.ofRanges(Range.greaterThan(VARBINARY, value)), false))));

Expression expression = toPredicate(result.getTupleDomain());
assertEquals(expression, comparison(GREATER_THAN, reference(L), varbinaryLiteral(value)));
}

private static ExtractionResult fromPredicate(Expression originalPredicate)
{
return DomainTranslator.fromPredicate(METADATA, TEST_SESSION, originalPredicate, TYPES);
Expand Down Expand Up @@ -1169,4 +1189,14 @@ private static FunctionCall colorLiteral(long value)
{
return new FunctionCall(QualifiedName.of(getMagicLiteralFunctionSignature(COLOR).getName()), ImmutableList.<Expression>of(longLiteral(value)));
}

private static Expression varbinaryLiteral(Slice value)
{
return LiteralInterpreter.toExpression(value, VARBINARY);
}

private static FunctionCall function(String functionName, Expression... args)
{
return new FunctionCall(QualifiedName.of(functionName), ImmutableList.copyOf(args));
}
}

0 comments on commit 4f44780

Please sign in to comment.