diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java index 0bc434f5c703..5a1b391372d6 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java @@ -30,6 +30,7 @@ import com.facebook.presto.type.MapType; import com.google.common.collect.ComparisonChain; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Ordering; import io.airlift.json.JsonCodec; import io.airlift.json.JsonCodecFactory; @@ -189,7 +190,7 @@ public T decodeProperty(String name, @Nullable String value, Class type) @NotNull public static Object evaluatePropertyValue(Expression expression, Type expectedType, Session session, Metadata metadata) { - Object value = evaluateConstantExpression(expression, expectedType, metadata, session); + Object value = evaluateConstantExpression(expression, expectedType, metadata, session, ImmutableSet.of()); // convert to object value type of SQL type BlockBuilder blockBuilder = expectedType.createBlockBuilder(new BlockBuilderStatus(), 1); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/TestingRowConstructor.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/TestingRowConstructor.java index f92d97a53cdc..50993862164a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/TestingRowConstructor.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/TestingRowConstructor.java @@ -151,6 +151,20 @@ DOUBLE, new ArrayType(BIGINT), return toStackRepresentation(parameterTypes, arg1, arg2, arg3); } + @ScalarFunction("test_row") + @SqlType("row('col0','col1')>,row('col0','col1')>('col0','col1','col2')") + public static Block testNestedRowWithArray( + @Nullable @SqlType(StandardTypes.DOUBLE) Double arg1, + @Nullable @SqlType("array('col0','col1')>") Block arg2, + @Nullable @SqlType("row('col0','col1')") Block arg3) + { + List parameterTypes = ImmutableList.of( + DOUBLE, + new ArrayType(new RowType(ImmutableList.of(BIGINT, DOUBLE), Optional.of(ImmutableList.of("col0", "col1")))), + new RowType(ImmutableList.of(BIGINT, DOUBLE), Optional.of(ImmutableList.of("col0", "col1")))); + return toStackRepresentation(parameterTypes, arg1, arg2, arg3); + } + @ScalarFunction("test_row") @SqlType("row('col0')") public static Block testRowBigintBigint(@Nullable @SqlType(StandardTypes.TIMESTAMP) Long arg1) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/AggregationAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/AggregationAnalyzer.java index e53cfd201e6d..276012af325b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/AggregationAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/AggregationAnalyzer.java @@ -23,6 +23,7 @@ import com.facebook.presto.sql.tree.CoalesceExpression; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.CurrentTime; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Extract; import com.facebook.presto.sql.tree.FunctionCall; @@ -48,22 +49,24 @@ import com.facebook.presto.sql.tree.WhenClause; import com.facebook.presto.sql.tree.Window; import com.facebook.presto.sql.tree.WindowFrame; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import javax.annotation.Nullable; import java.util.List; import java.util.Optional; +import java.util.Set; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MUST_BE_AGGREGATE_OR_GROUP_BY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NESTED_AGGREGATION; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NESTED_WINDOW; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.util.ImmutableCollectors.toImmutableList; +import static com.facebook.presto.util.Types.checkType; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Predicates.equalTo; -import static com.google.common.base.Predicates.instanceOf; import static java.util.Objects.requireNonNull; /** @@ -76,18 +79,20 @@ public class AggregationAnalyzer private final List expressions; private final Metadata metadata; + private final Set columnReferences; private final TupleDescriptor tupleDescriptor; - public AggregationAnalyzer(List groupByExpressions, Metadata metadata, TupleDescriptor tupleDescriptor) + public AggregationAnalyzer(List groupByExpressions, Metadata metadata, TupleDescriptor tupleDescriptor, Set columnReferences) { requireNonNull(groupByExpressions, "groupByExpressions is null"); requireNonNull(metadata, "metadata is null"); requireNonNull(tupleDescriptor, "tupleDescriptor is null"); + requireNonNull(columnReferences, "columnReferences is null"); this.tupleDescriptor = tupleDescriptor; this.metadata = metadata; - + this.columnReferences = ImmutableSet.copyOf(columnReferences); this.expressions = groupByExpressions.stream() .filter(FieldOrExpression::isExpression) .map(FieldOrExpression::getExpression) @@ -103,18 +108,23 @@ public AggregationAnalyzer(List groupByExpressions, Metadata // For a query like "SELECT * FROM T GROUP BY a", groupByExpressions will contain "a", // and the '*' will be expanded to Field references. Therefore we translate all simple name expressions // in the group by clause to fields they reference so that the expansion from '*' can be matched against them - for (Expression expression : Iterables.filter(expressions, instanceOf(QualifiedNameReference.class))) { - QualifiedName name = ((QualifiedNameReference) expression).getName(); + for (Expression expression : Iterables.filter(expressions, columnReferences::contains)) { + QualifiedName name; + if (expression instanceof QualifiedNameReference) { + name = ((QualifiedNameReference) expression).getName(); + } + else { + name = DereferenceExpression.getQualifiedName(checkType(expression, DereferenceExpression.class, "expression")); + } List fields = tupleDescriptor.resolveFields(name); - Preconditions.checkState(fields.size() <= 1, "Found more than one field for name '%s': %s", name, fields); + checkState(fields.size() <= 1, "Found more than one field for name '%s': %s", name, fields); if (fields.size() == 1) { Field field = Iterables.getOnlyElement(fields); fieldIndexes.add(tupleDescriptor.indexOf(field)); } } - this.fieldIndexes = fieldIndexes.build(); } @@ -334,11 +344,25 @@ public Boolean visitWindowFrame(WindowFrame node, Void context) @Override protected Boolean visitQualifiedNameReference(QualifiedNameReference node, Void context) { - QualifiedName name = node.getName(); + return isField(node.getName()); + } - List fields = tupleDescriptor.resolveFields(name); - Preconditions.checkState(!fields.isEmpty(), "No fields for name '%s'", name); - Preconditions.checkState(fields.size() <= 1, "Found more than one field for name '%s': %s", name, fields); + @Override + protected Boolean visitDereferenceExpression(DereferenceExpression node, Void context) + { + if (columnReferences.contains(node)) { + return isField(DereferenceExpression.getQualifiedName(node)); + } + + // Allow SELECT col1.f1 FROM table1 GROUP BY col1 + return process(node.getBase(), context); + } + + private Boolean isField(QualifiedName qualifiedName) + { + List fields = tupleDescriptor.resolveFields(qualifiedName); + checkState(!fields.isEmpty(), "No fields for name '%s'", qualifiedName); + checkState(fields.size() <= 1, "Found more than one field for name '%s': %s", qualifiedName, fields); Field field = Iterables.getOnlyElement(fields); return fieldIndexes.contains(tupleDescriptor.indexOf(field)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java index 287e8d26c835..a6fd033800ca 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java @@ -24,8 +24,6 @@ import com.facebook.presto.sql.tree.InPredicate; import com.facebook.presto.sql.tree.Join; import com.facebook.presto.sql.tree.Node; -import com.facebook.presto.sql.tree.QualifiedName; -import com.facebook.presto.sql.tree.QualifiedNameReference; import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.QuerySpecification; import com.facebook.presto.sql.tree.Relation; @@ -44,6 +42,7 @@ import java.util.Optional; import java.util.Set; +import static com.google.common.collect.Sets.newIdentityHashSet; import static java.util.Objects.requireNonNull; public class Analysis @@ -55,7 +54,7 @@ public class Analysis private TupleDescriptor outputDescriptor; private final IdentityHashMap outputDescriptors = new IdentityHashMap<>(); - private final IdentityHashMap> resolvedNames = new IdentityHashMap<>(); + private final IdentityHashMap> resolvedNames = new IdentityHashMap<>(); private final IdentityHashMap> aggregates = new IdentityHashMap<>(); private final IdentityHashMap> groupByExpressions = new IdentityHashMap<>(); @@ -71,11 +70,11 @@ public class Analysis private final IdentityHashMap tables = new IdentityHashMap<>(); - private final IdentityHashMap rowFieldReferences = new IdentityHashMap<>(); private final IdentityHashMap types = new IdentityHashMap<>(); private final IdentityHashMap coercions = new IdentityHashMap<>(); private final IdentityHashMap relationCoercions = new IdentityHashMap<>(); private final IdentityHashMap functionSignature = new IdentityHashMap<>(); + private final Set columnReferences = newIdentityHashSet(); private final IdentityHashMap columns = new IdentityHashMap<>(); @@ -122,12 +121,12 @@ public void setCreateTableAsSelectWithData(boolean createTableAsSelectWithData) this.createTableAsSelectWithData = createTableAsSelectWithData; } - public void addResolvedNames(Expression expression, Map mappings) + public void addResolvedNames(Expression expression, Map mappings) { resolvedNames.put(expression, mappings); } - public Map getResolvedNames(Expression expression) + public Map getResolvedNames(Expression expression) { return resolvedNames.get(expression); } @@ -147,11 +146,6 @@ public IdentityHashMap getTypes() return new IdentityHashMap<>(types); } - public boolean isRowFieldReference(QualifiedNameReference qualifiedNameReference) - { - return rowFieldReferences.containsKey(qualifiedNameReference); - } - public Type getType(Expression expression) { Preconditions.checkArgument(types.containsKey(expression), "Expression not analyzed: %s", expression); @@ -309,14 +303,19 @@ public void addFunctionSignatures(IdentityHashMap infos functionSignature.putAll(infos); } - public void addTypes(IdentityHashMap types) + public Set getColumnReferences() { - this.types.putAll(types); + return ImmutableSet.copyOf(columnReferences); } - public void addRowFieldReferences(IdentityHashMap rowFieldReferences) + public void addColumnReferences(Set references) { - this.rowFieldReferences.putAll(rowFieldReferences); + columnReferences.addAll(references); + } + + public void addTypes(IdentityHashMap types) + { + this.types.putAll(types); } public void addCoercion(Expression expression, Type type) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalysis.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalysis.java index e32c4bf3f4c0..6a0de761c086 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalysis.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalysis.java @@ -16,6 +16,7 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.InPredicate; +import com.google.common.collect.ImmutableSet; import java.util.IdentityHashMap; import java.util.Set; @@ -27,15 +28,18 @@ public class ExpressionAnalysis private final IdentityHashMap expressionTypes; private final IdentityHashMap expressionCoercions; private final Set subqueryInPredicates; + private final Set columnReferences; public ExpressionAnalysis( IdentityHashMap expressionTypes, IdentityHashMap expressionCoercions, - Set subqueryInPredicates) + Set subqueryInPredicates, + Set columnReferences) { this.expressionTypes = requireNonNull(expressionTypes, "expressionTypes is null"); this.expressionCoercions = requireNonNull(expressionCoercions, "expressionCoercions is null"); this.subqueryInPredicates = requireNonNull(subqueryInPredicates, "subqueryInPredicates is null"); + this.columnReferences = ImmutableSet.copyOf(requireNonNull(columnReferences, "columnReferences is null")); } public Type getType(Expression expression) @@ -57,4 +61,9 @@ public Set getSubqueryInPredicates() { return subqueryInPredicates; } + + public Set getColumnReferences() + { + return columnReferences; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index 3745337f9e9c..239ff2e52c31 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -38,6 +38,7 @@ import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.CurrentTime; import com.facebook.presto.sql.tree.DefaultTraversalVisitor; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.DoubleLiteral; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Extract; @@ -72,6 +73,7 @@ import com.facebook.presto.sql.tree.WindowFrame; import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import javax.annotation.Nullable; @@ -118,7 +120,6 @@ import static com.facebook.presto.util.DateTimeUtils.timeHasTimeZone; import static com.facebook.presto.util.DateTimeUtils.timestampHasTimeZone; import static com.facebook.presto.util.ImmutableCollectors.toImmutableList; -import static com.facebook.presto.util.Types.checkType; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Sets.newIdentityHashSet; import static java.lang.String.format; @@ -129,11 +130,11 @@ public class ExpressionAnalyzer private final FunctionRegistry functionRegistry; private final TypeManager typeManager; private final Function statementAnalyzerFactory; - private final Map resolvedNames = new HashMap<>(); private final IdentityHashMap resolvedFunctions = new IdentityHashMap<>(); + private final Map resolvedNames = new HashMap<>(); private final IdentityHashMap expressionTypes = new IdentityHashMap<>(); + private final Set columnReferences = newIdentityHashSet(); private final IdentityHashMap expressionCoercions = new IdentityHashMap<>(); - private final IdentityHashMap rowFieldReferences = new IdentityHashMap<>(); private final Set subqueryInPredicates = newIdentityHashSet(); private final Session session; @@ -145,7 +146,7 @@ public ExpressionAnalyzer(FunctionRegistry functionRegistry, TypeManager typeMan this.session = requireNonNull(session, "session is null"); } - public Map getResolvedNames() + public Map getResolvedNames() { return resolvedNames; } @@ -165,14 +166,14 @@ public IdentityHashMap getExpressionCoercions() return expressionCoercions; } - public IdentityHashMap getRowFieldReferences() + public Set getSubqueryInPredicates() { - return rowFieldReferences; + return subqueryInPredicates; } - public Set getSubqueryInPredicates() + public Set getColumnReferences() { - return subqueryInPredicates; + return ImmutableSet.copyOf(columnReferences); } /** @@ -284,61 +285,68 @@ protected Type visitQualifiedNameReference(QualifiedNameReference node, Analysis { List matches = tupleDescriptor.resolveFields(node.getName()); if (matches.isEmpty()) { - // TODO This is kind of hacky, instead we should change the way QualifiedNameReferences are parsed - return tryVisitRowFieldAccessor(node); + throw createMissingAttributeException(node); } + if (matches.size() > 1) { throw new SemanticException(AMBIGUOUS_ATTRIBUTE, node, "Column '%s' is ambiguous", node.getName()); } Field field = Iterables.getOnlyElement(matches); int fieldIndex = tupleDescriptor.indexOf(field); - resolvedNames.put(node.getName(), fieldIndex); + resolvedNames.put(node, fieldIndex); expressionTypes.put(node, field.getType()); - + columnReferences.add(node); return field.getType(); } - private Type tryVisitRowFieldAccessor(QualifiedNameReference node) + @Override + protected Type visitDereferenceExpression(DereferenceExpression node, AnalysisContext context) { - if (node.getName().getParts().size() < 2) { - throw createMissingAttributeException(node); - } - QualifiedName base = new QualifiedName(node.getName().getParts().subList(0, node.getName().getParts().size() - 1)); - List matches = tupleDescriptor.resolveFields(base); - if (matches.isEmpty()) { - throw createMissingAttributeException(node); - } - if (matches.size() > 1) { - throw new SemanticException(AMBIGUOUS_ATTRIBUTE, node, "Column '%s' is ambiguous", node.getName()); - } + QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(node); - Field field = Iterables.getOnlyElement(matches); - if (field.getType() instanceof RowType) { - RowType rowType = checkType(field.getType(), RowType.class, "field.getType()"); - Type rowFieldType = null; - for (RowField rowField : rowType.getFields()) { - if (rowField.getName().equals(Optional.of(node.getName().getSuffix()))) { - rowFieldType = rowField.getType(); - break; - } + // If this Dereference looks like column reference, try match it to column first. + if (qualifiedName != null) { + List matches = tupleDescriptor.resolveFields(qualifiedName); + if (matches.size() > 1) { + throw new SemanticException(AMBIGUOUS_ATTRIBUTE, node, "Column '%s' is ambiguous", node); } - if (rowFieldType == null) { - throw createMissingAttributeException(node); + + if (matches.size() == 1) { + Field field = Iterables.getOnlyElement(matches); + int fieldIndex = tupleDescriptor.indexOf(field); + resolvedNames.put(node, fieldIndex); + expressionTypes.put(node, field.getType()); + columnReferences.add(node); + return field.getType(); } - int fieldIndex = tupleDescriptor.indexOf(field); - resolvedNames.put(node.getName(), fieldIndex); - expressionTypes.put(node, rowFieldType); - rowFieldReferences.put(node, true); + } + + Type baseType = process(node.getBase(), context); + if (!(baseType instanceof RowType)) { + throw new SemanticException(SemanticErrorCode.TYPE_MISMATCH, node.getBase(), "Expression %s is not of type ROW", node.getBase()); + } - return rowFieldType; + RowType rowType = (RowType) baseType; + + Type rowFieldType = null; + for (RowField rowField : rowType.getFields()) { + if (rowField.getName().equals(Optional.of(node.getFieldName()))) { + rowFieldType = rowField.getType(); + break; + } } - throw createMissingAttributeException(node); + if (rowFieldType == null) { + throw createMissingAttributeException(node); + } + + expressionTypes.put(node, rowFieldType); + return rowFieldType; } - private SemanticException createMissingAttributeException(QualifiedNameReference node) + private SemanticException createMissingAttributeException(Expression node) { - return new SemanticException(MISSING_ATTRIBUTE, node, "Column '%s' cannot be resolved", node.getName()); + return new SemanticException(MISSING_ATTRIBUTE, node, "Column '%s' cannot be resolved", node); } @Override @@ -1021,7 +1029,8 @@ private static ExpressionAnalysis analyzeExpressions( return new ExpressionAnalysis( analyzer.getExpressionTypes(), analyzer.getExpressionCoercions(), - analyzer.getSubqueryInPredicates()); + analyzer.getSubqueryInPredicates(), + analyzer.getColumnReferences()); } public static ExpressionAnalysis analyzeExpression( @@ -1045,7 +1054,7 @@ public static ExpressionAnalysis analyzeExpression( analysis.addTypes(expressionTypes); analysis.addCoercions(expressionCoercions); analysis.addFunctionSignatures(resolvedFunctions); - analysis.addRowFieldReferences(analyzer.getRowFieldReferences()); + analysis.addColumnReferences(analyzer.getColumnReferences()); for (Expression subExpression : expressionTypes.keySet()) { analysis.addResolvedNames(subExpression, analyzer.getResolvedNames()); @@ -1053,7 +1062,7 @@ public static ExpressionAnalysis analyzeExpression( Set subqueryInPredicates = analyzer.getSubqueryInPredicates(); - return new ExpressionAnalysis(expressionTypes, expressionCoercions, subqueryInPredicates); + return new ExpressionAnalysis(expressionTypes, expressionCoercions, subqueryInPredicates, analyzer.getColumnReferences()); } public static ExpressionAnalyzer create( @@ -1070,6 +1079,7 @@ public static ExpressionAnalyzer create( node -> new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, session, experimentalSyntaxEnabled, Optional.empty()), session); } + public static ExpressionAnalyzer createConstantAnalyzer(Metadata metadata, Session session) { return createWithoutSubqueries( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/TupleAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/TupleAnalyzer.java index 182ab03e2171..5141beaa4c62 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/TupleAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/TupleAnalyzer.java @@ -297,7 +297,7 @@ protected TupleDescriptor visitSampledRelation(final SampledRelation relation, A throw new SemanticException(NOT_SUPPORTED, relation, "STRATIFY ON is not yet implemented"); } - if (!DependencyExtractor.extractNames(relation.getSamplePercentage()).isEmpty()) { + if (!DependencyExtractor.extractNames(relation.getSamplePercentage(), analysis.getColumnReferences()).isEmpty()) { throw new SemanticException(NON_NUMERIC_SAMPLE_PERCENTAGE, relation.getSamplePercentage(), "Sample percentage cannot contain column references"); } @@ -366,7 +366,7 @@ protected TupleDescriptor visitQuerySpecification(QuerySpecification node, Analy List orderByExpressions = analyzeOrderBy(node, tupleDescriptor, context, outputExpressions); analyzeHaving(node, tupleDescriptor, context); - analyzeAggregations(node, tupleDescriptor, groupByExpressions, outputExpressions, orderByExpressions, context); + analyzeAggregations(node, tupleDescriptor, groupByExpressions, outputExpressions, orderByExpressions, context, analysis.getColumnReferences()); analyzeWindowFunctions(node, outputExpressions, orderByExpressions); TupleDescriptor descriptor = computeOutputDescriptor(node, tupleDescriptor); @@ -536,8 +536,8 @@ else if (criteria instanceof JoinOn) { } ComparisonExpression comparison = (ComparisonExpression) conjunct; - Set firstDependencies = DependencyExtractor.extractNames(comparison.getLeft()); - Set secondDependencies = DependencyExtractor.extractNames(comparison.getRight()); + Set firstDependencies = DependencyExtractor.extractNames(comparison.getLeft(), analyzer.getColumnReferences()); + Set secondDependencies = DependencyExtractor.extractNames(comparison.getRight(), analyzer.getColumnReferences()); Expression leftExpression; Expression rightExpression; @@ -1004,7 +1004,8 @@ private void analyzeAggregations(QuerySpecification node, List groupByExpressions, List outputExpressions, List orderByExpressions, - AnalysisContext context) + AnalysisContext context, + Set columnReferences) { List aggregates = extractAggregates(node); @@ -1022,11 +1023,11 @@ private void analyzeAggregations(QuerySpecification node, // SELECT f(a + 1) GROUP BY a + 1 // SELECT a + sum(b) GROUP BY a for (FieldOrExpression fieldOrExpression : Iterables.concat(outputExpressions, orderByExpressions)) { - verifyAggregations(node, groupByExpressions, tupleDescriptor, fieldOrExpression); + verifyAggregations(node, groupByExpressions, tupleDescriptor, fieldOrExpression, columnReferences); } if (node.getHaving().isPresent()) { - verifyAggregations(node, groupByExpressions, tupleDescriptor, new FieldOrExpression(node.getHaving().get())); + verifyAggregations(node, groupByExpressions, tupleDescriptor, new FieldOrExpression(node.getHaving().get()), columnReferences); } } } @@ -1054,9 +1055,14 @@ private List extractAggregates(QuerySpecification node) return aggregates; } - private void verifyAggregations(QuerySpecification node, List groupByExpressions, TupleDescriptor tupleDescriptor, FieldOrExpression fieldOrExpression) + private void verifyAggregations( + QuerySpecification node, + List groupByExpressions, + TupleDescriptor tupleDescriptor, + FieldOrExpression fieldOrExpression, + Set columnReferences) { - AggregationAnalyzer analyzer = new AggregationAnalyzer(groupByExpressions, metadata, tupleDescriptor); + AggregationAnalyzer analyzer = new AggregationAnalyzer(groupByExpressions, metadata, tupleDescriptor, columnReferences); if (fieldOrExpression.isExpression()) { analyzer.analyze(fieldOrExpression.getExpression()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/DependencyExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/DependencyExtractor.java index a9eec2cb0f8a..659612a6ffb3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/DependencyExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/DependencyExtractor.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.QualifiedNameReference; @@ -23,6 +24,8 @@ import java.util.List; import java.util.Set; +import static java.util.Objects.requireNonNull; + public final class DependencyExtractor { private DependencyExtractor() {} @@ -49,10 +52,10 @@ public static List extractAll(Expression expression) } // to extract qualified name with prefix - public static Set extractNames(Expression expression) + public static Set extractNames(Expression expression, Set columnReferences) { ImmutableSet.Builder builder = ImmutableSet.builder(); - new QualifiedNameBuilderVisitor().process(expression, builder); + new QualifiedNameBuilderVisitor(columnReferences).process(expression, builder); return builder.build(); } @@ -70,6 +73,25 @@ protected Void visitQualifiedNameReference(QualifiedNameReference node, Immutabl private static class QualifiedNameBuilderVisitor extends DefaultExpressionTraversalVisitor> { + private final Set columnReferences; + + private QualifiedNameBuilderVisitor(Set columnReferences) + { + this.columnReferences = requireNonNull(columnReferences, "columnReferences is null"); + } + + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, ImmutableSet.Builder builder) + { + if (columnReferences.contains(node)) { + builder.add(DereferenceExpression.getQualifiedName(node)); + } + else { + process(node.getBase(), builder); + } + return null; + } + @Override protected Void visitQualifiedNameReference(QualifiedNameReference node, ImmutableSet.Builder builder) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java index 98c5af7872cc..ae3fc3fe3b26 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java @@ -43,6 +43,7 @@ import com.facebook.presto.sql.tree.CoalesceExpression; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.DefaultTraversalVisitor; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; @@ -137,8 +138,10 @@ public static ExpressionInterpreter expressionOptimizer(Expression expression, M return new ExpressionInterpreter(expression, metadata, session, expressionTypes, true); } - public static Object evaluateConstantExpression(Expression expression, Type expectedType, Metadata metadata, Session session) + public static Object evaluateConstantExpression(Expression expression, Type expectedType, Metadata metadata, Session session, Set columnReferences) { + requireNonNull(columnReferences, "columnReferences is null"); + ExpressionAnalyzer analyzer = createConstantAnalyzer(metadata, session); analyzer.analyze(expression, new TupleDescriptor(), new AnalysisContext()); @@ -152,14 +155,27 @@ public static Object evaluateConstantExpression(Expression expression, Type expe IdentityHashMap coercions = new IdentityHashMap<>(); coercions.putAll(analyzer.getExpressionCoercions()); coercions.put(expression, expectedType); - return evaluateConstantExpression(expression, coercions, metadata, session); + return evaluateConstantExpression(expression, coercions, metadata, session, columnReferences); } - public static Object evaluateConstantExpression(Expression expression, IdentityHashMap coercions, Metadata metadata, Session session) + public static Object evaluateConstantExpression(Expression expression, IdentityHashMap coercions, Metadata metadata, Session session, Set columnReferences) { + requireNonNull(columnReferences, "columnReferences is null"); + // verify expression is constant expression.accept(new DefaultTraversalVisitor() { + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, Void context) + { + if (columnReferences.contains(node)) { + throw new SemanticException(EXPRESSION_NOT_CONSTANT, expression, "Constant expression cannot contain column references"); + } + + process(node.getBase(), context); + return null; + } + @Override protected Void visitQualifiedNameReference(QualifiedNameReference node, Void context) { @@ -297,6 +313,13 @@ else if (javaType == Slice.class) { throw new UnsupportedOperationException("Inputs or cursor myst be set"); } + @Override + protected Object visitDereferenceExpression(DereferenceExpression node, Object context) + { + // Dereference is never a Symbol + return node; + } + @Override protected Object visitQualifiedNameReference(QualifiedNameReference node, Object context) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 87f5a2370a3a..54a38252997f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -361,6 +361,8 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) // 2.a. Rewrite aggregates in terms of pre-projected inputs TranslationMap translations = new TranslationMap(subPlan.getRelationPlan(), analysis); + // Copy the TranslationMap to keep the expressionMappings we have so far + translations.copyMappingsFrom(subPlan.getTranslations()); boolean needPostProjectionCoercion = false; for (FunctionCall aggregate : analysis.getAggregates(node)) { Expression rewritten = subPlan.rewrite(aggregate); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index d6c5a72ab51a..25bfe785801c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -239,8 +239,8 @@ protected RelationPlan visitJoin(Join node, Void context) if (comparison.getType() != EQUAL && node.getType() != INNER) { throw new SemanticException(NOT_SUPPORTED, node, "Non-equi joins only supported for inner join: %s", conjunct); } - Set firstDependencies = DependencyExtractor.extractNames(comparison.getLeft()); - Set secondDependencies = DependencyExtractor.extractNames(comparison.getRight()); + Set firstDependencies = DependencyExtractor.extractNames(comparison.getLeft(), analysis.getColumnReferences()); + Set secondDependencies = DependencyExtractor.extractNames(comparison.getRight(), analysis.getColumnReferences()); Expression leftExpression; Expression rightExpression; @@ -418,12 +418,12 @@ protected RelationPlan visitValues(Values node, Void context) List items = ((Row) row).getItems(); for (int i = 0; i < items.size(); i++) { Expression expression = items.get(i); - Object constantValue = evaluateConstantExpression(expression, analysis.getCoercions(), metadata, session); + Object constantValue = evaluateConstantExpression(expression, analysis.getCoercions(), metadata, session, analysis.getColumnReferences()); values.add(LiteralInterpreter.toExpression(constantValue, descriptor.getFieldByIndex(i).getType())); } } else { - Object constantValue = evaluateConstantExpression(row, analysis.getCoercions(), metadata, session); + Object constantValue = evaluateConstantExpression(row, analysis.getCoercions(), metadata, session, analysis.getColumnReferences()); values.add(LiteralInterpreter.toExpression(constantValue, descriptor.getFieldByIndex(0).getType())); } @@ -451,7 +451,7 @@ protected RelationPlan visitUnnest(Unnest node, Void context) ImmutableMap.Builder> unnestSymbols = ImmutableMap.builder(); Iterator unnestedSymbolsIterator = unnestedSymbols.iterator(); for (Expression expression : node.getExpressions()) { - Object constantValue = evaluateConstantExpression(expression, analysis.getCoercions(), metadata, session); + Object constantValue = evaluateConstantExpression(expression, analysis.getCoercions(), metadata, session, analysis.getColumnReferences()); Type type = analysis.getType(expression); values.add(LiteralInterpreter.toExpression(constantValue, type)); Symbol inputSymbol = symbolAllocator.newSymbol(expression, type); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java index 57db91e9617f..a171dbed92a9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java @@ -17,6 +17,7 @@ import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.analyzer.FieldOrExpression; import com.facebook.presto.sql.tree.Cast; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; @@ -31,6 +32,7 @@ import java.util.Map; import static com.facebook.presto.sql.QueryUtil.mangleFieldReference; +import static com.google.common.base.Preconditions.checkState; /** * Keeps track of fields and expressions and their mapping to symbols in the current plan @@ -111,7 +113,7 @@ public Expression rewrite(FieldOrExpression fieldOrExpression) if (fieldOrExpression.isFieldReference()) { int fieldIndex = fieldOrExpression.getFieldIndex(); Symbol symbol = fieldSymbols[fieldIndex]; - Preconditions.checkState(symbol != null, "No mapping for field '%s'", fieldIndex); + checkState(symbol != null, "No mapping for field '%s'", fieldIndex); return new QualifiedNameReference(symbol.toQualifiedName()); } @@ -127,7 +129,7 @@ public void put(Expression expression, Symbol symbol) // also update the field mappings if this expression is a simple field reference if (expression instanceof QualifiedNameReference) { - int fieldIndex = analysis.getResolvedNames(expression).get(((QualifiedNameReference) expression).getName()); + int fieldIndex = analysis.getResolvedNames(expression).get(expression); fieldSymbols[fieldIndex] = symbol; } } @@ -166,7 +168,7 @@ public Symbol get(FieldOrExpression fieldOrExpression) private Expression translateNamesToSymbols(Expression expression) { - final Map resolvedNames = analysis.getResolvedNames(expression); + final Map resolvedNames = analysis.getResolvedNames(expression); Preconditions.checkArgument(resolvedNames != null, "No resolved names for expression %s", expression); return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() @@ -188,21 +190,38 @@ public Expression rewriteExpression(Expression node, Void context, ExpressionTre @Override public Expression rewriteQualifiedNameReference(QualifiedNameReference node, Void context, ExpressionTreeRewriter treeRewriter) { - QualifiedName name = node.getName(); + return rewriteExpressionWithResolvedName(node); + } - Integer fieldIndex = resolvedNames.get(name); - Preconditions.checkState(fieldIndex != null, "No field mapping for name '%s'", name); + private Expression rewriteExpressionWithResolvedName(Expression node) + { + Integer fieldIndex = resolvedNames.get(node); + checkState(fieldIndex != null, "No field mapping for node '%s'", node); Symbol symbol = rewriteBase.getSymbol(fieldIndex); - Preconditions.checkState(symbol != null, "No symbol mapping for name '%s' (%s)", name, fieldIndex); - + checkState(symbol != null, "No symbol mapping for node '%s' (%s)", node, fieldIndex); Expression rewrittenExpression = new QualifiedNameReference(symbol.toQualifiedName()); - if (analysis.isRowFieldReference(node)) { - QualifiedName mangledName = QualifiedName.of(mangleFieldReference(node.getName().getSuffix())); - rewrittenExpression = new FunctionCall(mangledName, ImmutableList.of(rewrittenExpression)); + // cast expression if coercion is registered + Type coercion = analysis.getCoercion(node); + if (coercion != null) { + rewrittenExpression = new Cast(rewrittenExpression, coercion.getTypeSignature().toString()); + } + return rewrittenExpression; + } + + @Override + public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (analysis.getColumnReferences().contains(node)) { + return rewriteExpressionWithResolvedName(node); } + // Rewrite all row field reference to function call. + Expression rewrittenBaseExpression = rewrite(node.getBase()); + QualifiedName mangledName = QualifiedName.of(mangleFieldReference(node.getFieldName())); + Expression rewrittenExpression = new FunctionCall(mangledName, ImmutableList.of(rewrittenBaseExpression)); + // cast expression if coercion is registered Type coercion = analysis.getCoercion(node); if (coercion != null) { diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java index 0ef5ded25c0a..df3be1b9020f 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java @@ -55,9 +55,12 @@ import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.DefaultTraversalVisitor; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.QualifiedNameReference; import com.facebook.presto.testing.LocalQueryRunner; import com.facebook.presto.testing.MaterializedResult; @@ -92,6 +95,7 @@ import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.QueryUtil.mangleFieldReference; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.analyzeExpressionsWithSymbols; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; import static com.facebook.presto.sql.planner.LocalExecutionPlanner.toTypes; @@ -417,6 +421,27 @@ public Expression rewriteExpression(Expression node, Void context, ExpressionTre return rewrittenExpression; } + + @Override + public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (analysis.getColumnReferences().contains(node)) { + return rewriteExpression(node, context, treeRewriter); + } + + // Otherwise rewrite to FunctionCall + Expression rewrittenBase = ExpressionTreeRewriter.rewriteWith(this, node.getBase()); + QualifiedName mangledName = QualifiedName.of(mangleFieldReference(node.getFieldName())); + Expression rewrittenExpression = new FunctionCall(mangledName, ImmutableList.of(rewrittenBase)); + + // cast expression if coercion is registered + Type coercion = analysis.getCoercion(node); + if (coercion != null) { + rewrittenExpression = new Cast(rewrittenExpression, coercion.getTypeSignature().toString()); + } + + return rewrittenExpression; + } }, parsedExpression); return canonicalizeExpression(rewrittenExpression); @@ -486,6 +511,7 @@ protected Void visitQualifiedNameReference(QualifiedNameReference node, Void con return null; } }, null); + return hasQualifiedNameReference.get(); } diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestRowOperators.java b/presto-main/src/test/java/com/facebook/presto/type/TestRowOperators.java index b242b7b37760..6e8e9be8caab 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestRowOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestRowOperators.java @@ -81,6 +81,7 @@ public void testFieldAccessor() assertFunction("array[test_row(1, 2)][1].col1", BIGINT, 2); assertFunction("test_row(FALSE, ARRAY [1, 2], MAP(ARRAY[1, 3], ARRAY[2.0, 4.0])).col1", new ArrayType(BIGINT), ImmutableList.of(1L, 2L)); assertFunction("test_row(FALSE, ARRAY [1, 2], MAP(ARRAY[1, 3], ARRAY[2.0, 4.0])).col2", new MapType(BIGINT, DOUBLE), ImmutableMap.of(1L, 2.0, 3L, 4.0)); + assertFunction("test_row(1.0, ARRAY[test_row(31, 4.1), test_row(32, 4.2)], test_row(3, 4.0)).col1[2].col0", BIGINT, 32); } @Test diff --git a/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 b/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 index d78df175f045..ea15a09b3435 100644 --- a/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 +++ b/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 @@ -227,7 +227,6 @@ primaryExpression | STRING #stringLiteral | '(' expression (',' expression)+ ')' #rowConstructor | ROW '(' expression (',' expression)* ')' #rowConstructor - | qualifiedName #columnReference | qualifiedName '(' ASTERISK ')' over? #functionCall | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' over? #functionCall | '(' query ')' #subqueryExpression @@ -237,7 +236,8 @@ primaryExpression | TRY_CAST '(' expression AS type ')' #cast | ARRAY '[' (expression (',' expression)*)? ']' #arrayConstructor | value=primaryExpression '[' index=valueExpression ']' #subscript - | value=primaryExpression '.' fieldName=identifier #fieldReference + | identifier #columnReference + | base=primaryExpression '.' fieldName=identifier #dereference | name=CURRENT_DATE #specialDateTimeFunction | name=CURRENT_TIME ('(' precision=INTEGER_VALUE ')')? #specialDateTimeFunction | name=CURRENT_TIMESTAMP ('(' precision=INTEGER_VALUE ')')? #specialDateTimeFunction diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/ExpressionFormatter.java b/presto-parser/src/main/java/com/facebook/presto/sql/ExpressionFormatter.java index dfe324ff7cf7..3f6d5122375f 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/ExpressionFormatter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/ExpressionFormatter.java @@ -24,6 +24,7 @@ import com.facebook.presto.sql.tree.CoalesceExpression; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.CurrentTime; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.DoubleLiteral; import com.facebook.presto.sql.tree.ExistsPredicate; import com.facebook.presto.sql.tree.Expression; @@ -227,6 +228,13 @@ protected String visitQualifiedNameReference(QualifiedNameReference node, Boolea return formatQualifiedName(node.getName()); } + @Override + protected String visitDereferenceExpression(DereferenceExpression node, Boolean unmangleNames) + { + String baseString = process(node.getBase(), unmangleNames); + return baseString + "." + formatIdentifier(node.getFieldName()); + } + private static String formatQualifiedName(QualifiedName name) { List parts = new ArrayList<>(); diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/TreePrinter.java b/presto-parser/src/main/java/com/facebook/presto/sql/TreePrinter.java index 491c233a69d1..e4e4bd5425ef 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/TreePrinter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/TreePrinter.java @@ -20,6 +20,7 @@ import com.facebook.presto.sql.tree.BooleanLiteral; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.DefaultTraversalVisitor; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.InPredicate; @@ -51,10 +52,10 @@ public class TreePrinter { private static final String INDENT = " "; - private final IdentityHashMap resolvedNameReferences; + private final IdentityHashMap resolvedNameReferences; private final PrintStream out; - public TreePrinter(IdentityHashMap resolvedNameReferences, PrintStream out) + public TreePrinter(IdentityHashMap resolvedNameReferences, PrintStream out) { this.resolvedNameReferences = new IdentityHashMap<>(resolvedNameReferences); this.out = out; @@ -251,6 +252,18 @@ protected Void visitQualifiedNameReference(QualifiedNameReference node, Integer return null; } + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, Integer indentLevel) + { + QualifiedName resolved = resolvedNameReferences.get(node); + String resolvedName = ""; + if (resolved != null) { + resolvedName = "=>" + resolved.toString(); + } + print(indentLevel, "DereferenceExpression[" + node + resolvedName + "]"); + return null; + } + @Override protected Void visitFunctionCall(FunctionCall node, Integer indentLevel) { diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java index c995cd0ac993..2eefb0b109d3 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java @@ -32,6 +32,7 @@ import com.facebook.presto.sql.tree.CreateView; import com.facebook.presto.sql.tree.CurrentTime; import com.facebook.presto.sql.tree.Delete; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.DoubleLiteral; import com.facebook.presto.sql.tree.DropTable; import com.facebook.presto.sql.tree.DropView; @@ -123,8 +124,6 @@ import java.util.Optional; import java.util.stream.Collectors; -import static com.facebook.presto.sql.QueryUtil.mangleFieldReference; - class AstBuilder extends SqlBaseBaseVisitor { @@ -834,22 +833,21 @@ public Node visitSubscript(SqlBaseParser.SubscriptContext context) } @Override - public Node visitFieldReference(SqlBaseParser.FieldReferenceContext context) + public Node visitSubqueryExpression(SqlBaseParser.SubqueryExpressionContext context) { - // TODO: This should be done during the conversion to RowExpression - return new FunctionCall(new QualifiedName(mangleFieldReference(context.fieldName.getText())), ImmutableList.of((Expression) visit(context.value))); + return new SubqueryExpression((Query) visit(context.query())); } @Override - public Node visitSubqueryExpression(SqlBaseParser.SubqueryExpressionContext context) + public Node visitDereference(SqlBaseParser.DereferenceContext context) { - return new SubqueryExpression((Query) visit(context.query())); + return new DereferenceExpression((Expression) visit(context.base), context.fieldName.getText()); } @Override public Node visitColumnReference(SqlBaseParser.ColumnReferenceContext context) { - return new QualifiedNameReference(getQualifiedName(context.qualifiedName())); + return new QualifiedNameReference(new QualifiedName(context.getText())); } @Override diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java index 5c7113208a07..21206fe4c285 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java @@ -257,6 +257,11 @@ protected R visitQualifiedNameReference(QualifiedNameReference node, C context) return visitExpression(node, context); } + protected R visitDereferenceExpression(DereferenceExpression node, C context) + { + return visitExpression(node, context); + } + protected R visitNullIfExpression(NullIfExpression node, C context) { return visitExpression(node, context); diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java index af98f7250965..48785e395c50 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java @@ -165,6 +165,13 @@ protected R visitFunctionCall(FunctionCall node, C context) return null; } + @Override + protected R visitDereferenceExpression(DereferenceExpression node, C context) + { + process(node.getBase(), context); + return null; + } + @Override public R visitWindow(Window node, C context) { diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/DereferenceExpression.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DereferenceExpression.java new file mode 100644 index 000000000000..67656aa6e89e --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DereferenceExpression.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import com.google.common.collect.Lists; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.Preconditions.checkArgument; + +public class DereferenceExpression + extends Expression +{ + private final Expression base; + private final String fieldName; + + public DereferenceExpression(Expression base, String fieldName) + { + checkArgument(base != null, "base is null"); + checkArgument(fieldName != null, "fieldName is null"); + this.base = base; + this.fieldName = fieldName.toLowerCase(); + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitDereferenceExpression(this, context); + } + + public Expression getBase() + { + return base; + } + + public String getFieldName() + { + return fieldName; + } + + /** + * If this DereferenceExpression looks like a QualifiedName, return QualifiedName. + * Otherwise return null + */ + public static QualifiedName getQualifiedName(DereferenceExpression expression) + { + List parts = tryParseParts(expression.base, expression.fieldName); + return parts == null ? null : new QualifiedName(parts); + } + + private static List tryParseParts(Expression base, String fieldName) + { + if (base instanceof QualifiedNameReference) { + List newList = Lists.newArrayList(((QualifiedNameReference) base).getName().getParts()); + newList.add(fieldName); + return newList; + } + else if (base instanceof DereferenceExpression) { + QualifiedName baseQualifiedName = getQualifiedName((DereferenceExpression) base); + if (baseQualifiedName != null) { + List newList = Lists.newArrayList(baseQualifiedName.getParts()); + newList.add(fieldName); + return newList; + } + } + return null; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DereferenceExpression that = (DereferenceExpression) o; + return Objects.equals(base, that.base) && + Objects.equals(fieldName, that.fieldName); + } + + @Override + public int hashCode() + { + return Objects.hash(base, fieldName); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionRewriter.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionRewriter.java index e58e30b23473..33cec08abc6a 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionRewriter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionRewriter.java @@ -140,6 +140,11 @@ public Expression rewriteQualifiedNameReference(QualifiedNameReference node, C c return rewriteExpression(node, context, treeRewriter); } + public Expression rewriteDereferenceExpression(DereferenceExpression node, C context, ExpressionTreeRewriter treeRewriter) + { + return rewriteExpression(node, context, treeRewriter); + } + public Expression rewriteExtract(Extract node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionTreeRewriter.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionTreeRewriter.java index 9dbc84ea842c..4e32e9dd0c07 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionTreeRewriter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionTreeRewriter.java @@ -612,6 +612,24 @@ public Expression visitQualifiedNameReference(QualifiedNameReference node, Conte return node; } + @Override + public Expression visitDereferenceExpression(DereferenceExpression node, Context context) + { + if (!context.isDefaultRewrite()) { + Expression result = rewriter.rewriteDereferenceExpression(node, context.get(), ExpressionTreeRewriter.this); + if (result != null) { + return result; + } + } + + Expression base = rewrite(node.getBase(), context.get()); + if (base != node.getBase()) { + return new DereferenceExpression(base, node.getFieldName()); + } + + return node; + } + @Override protected Expression visitExtract(Extract node, Context context) { diff --git a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java index fcbebe21ec7b..9b08c5ff609b 100644 --- a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java +++ b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java @@ -27,6 +27,7 @@ import com.facebook.presto.sql.tree.CreateView; import com.facebook.presto.sql.tree.CurrentTime; import com.facebook.presto.sql.tree.Delete; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.DoubleLiteral; import com.facebook.presto.sql.tree.DropTable; import com.facebook.presto.sql.tree.DropView; @@ -78,6 +79,7 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; import org.testng.annotations.Test; import java.util.Optional; @@ -691,6 +693,67 @@ public void testShowPartitions() Optional.of("ALL"))); } + @Test + public void testSelectWithRowType() + throws Exception + { + assertStatement("SELECT col1.f1, col2, col3.f1.f2.f3 FROM table1", + new Query( + Optional.empty(), + new QuerySpecification( + selectList( + new DereferenceExpression(new QualifiedNameReference(QualifiedName.of("col1")), "f1"), + new QualifiedNameReference(QualifiedName.of("col2")), + new DereferenceExpression( + new DereferenceExpression(new DereferenceExpression(new QualifiedNameReference(QualifiedName.of("col3")), "f1"), "f2"), "f3")), + Optional.of(new Table(QualifiedName.of("table1"))), + Optional.empty(), + ImmutableList.of(), + Optional.empty(), + ImmutableList.of(), + Optional.empty()), + ImmutableList.of(), + Optional.empty(), + Optional.empty())); + + assertStatement("SELECT col1.f1[0], col2, col3[2].f2.f3, col4[4] FROM table1", + new Query( + Optional.empty(), + new QuerySpecification( + selectList( + new SubscriptExpression(new DereferenceExpression(new QualifiedNameReference(QualifiedName.of("col1")), "f1"), new LongLiteral("0")), + new QualifiedNameReference(QualifiedName.of("col2")), + new DereferenceExpression(new DereferenceExpression(new SubscriptExpression(new QualifiedNameReference(QualifiedName.of("col3")), new LongLiteral("2")), "f2"), "f3"), + new SubscriptExpression(new QualifiedNameReference(QualifiedName.of("col4")), new LongLiteral("4")) + ), + Optional.of(new Table(QualifiedName.of("table1"))), + Optional.empty(), + ImmutableList.of(), + Optional.empty(), + ImmutableList.of(), + Optional.empty()), + ImmutableList.of(), + Optional.empty(), + Optional.empty())); + + assertStatement("SELECT test_row(11, 12).col0", + new Query( + Optional.empty(), + new QuerySpecification( + selectList( + new DereferenceExpression(new FunctionCall(QualifiedName.of("test_row"), Lists.newArrayList(new LongLiteral("11"), new LongLiteral("12"))), "col0") + ), + Optional.empty(), + Optional.empty(), + ImmutableList.of(), + Optional.empty(), + ImmutableList.of(), + Optional.empty()), + ImmutableList.of(), + Optional.empty(), + Optional.empty())); + } + @Test public void testCreateTable() throws Exception diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index bcc816fed7e7..7b04c6a9eef8 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -162,10 +162,145 @@ public void testVarbinary() public void testRowFieldAccessor() throws Exception { + //Dereference only assertQuery("SELECT a.col0 FROM (VALUES ROW (test_row(1, 2))) AS t (a)", "SELECT 1"); assertQuery("SELECT a.col0 FROM (VALUES ROW (test_row(1.0, 2.0))) AS t (a)", "SELECT 1.0"); assertQuery("SELECT a.col0 FROM (VALUES ROW (test_row(TRUE, FALSE))) AS t (a)", "SELECT TRUE"); assertQuery("SELECT a.col1 FROM (VALUES ROW (test_row(1.0, 'kittens'))) AS t (a)", "SELECT 'kittens'"); + assertQuery("SELECT a.col2.col1 FROM (VALUES ROW(test_row(1.0, ARRAY[2], test_row(3, 4.0)))) t(a)", "SELECT 4.0"); + + // Subscript + Dereference + assertQuery("SELECT a.col1[2] FROM (VALUES ROW(test_row(1.0, ARRAY[22, 33, 44, 55], test_row(3, 4.0)))) t(a)", "SELECT 33"); + assertQuery("SELECT a.col1[2].col0, a.col1[2].col1 FROM (VALUES ROW(test_row(1.0, ARRAY[test_row(31, 4.1), test_row(32, 4.2)], test_row(3, 4.0)))) t(a)", "SELECT 32, 4.2"); + + assertQuery("SELECT test_row(11, 12).col0", "SELECT 11"); + } + + @Test + public void testRowFieldAccessorInAggregate() + throws Exception + { + assertQuery("SELECT a.col0, SUM(a.col1[2]), SUM(a.col2.col0), SUM(a.col2.col1) FROM " + + "(VALUES " + + "(ROW(test_row(1.0, ARRAY[2, 13, 4], test_row(11, 4.1)))), " + + "(ROW(test_row(2.0, ARRAY[2, 23, 4], test_row(12, 14.0)))), " + + "(ROW(test_row(1.0, ARRAY[22, 33, 44], test_row(13, 5.0))))) t(a) " + + "GROUP BY a.col0", + "SELECT * FROM VALUES (1.0, 46, 24, 9.1), (2.0, 23, 12, 14.0)"); + + assertQuery("SELECT a.col2.col0, SUM(a.col0), SUM(a.col1[2]), SUM(a.col2.col1) FROM " + + "(VALUES " + + "(ROW(test_row(1.0, ARRAY[2, 13, 4], test_row(11, 4.1)))), " + + "(ROW(test_row(2.0, ARRAY[2, 23, 4], test_row(11, 14.0)))), " + + "(ROW(test_row(7.0, ARRAY[22, 33, 44], test_row(13, 5.0))))) t(a) " + + "GROUP BY a.col2.col0", + "SELECT * FROM VALUES (11, 3.0, 36, 18.1), (13, 7.0, 33, 5.0)"); + + assertQuery("SELECT a.col1[1].col0, SUM(a.col0), SUM(a.col1[1].col1), SUM(a.col1[2].col0), SUM(a.col2.col1) FROM " + + "(VALUES " + + "(ROW(test_row(1.0, ARRAY[test_row(31, 4.5), test_row(12, 4.2)], test_row(3, 4.0)))), " + + "(ROW(test_row(3.1, ARRAY[test_row(41, 3.1), test_row(32, 4.2)], test_row(6, 6.0)))), " + + "(ROW(test_row(2.2, ARRAY[test_row(31, 4.2), test_row(22, 4.2)], test_row(5, 4.0))))) t(a) " + + "GROUP BY a.col1[1].col0", + "SELECT * FROM VALUES (31, 3.2, 8.7, 34, 8.0), (41, 3.1, 3.1, 32, 6.0)"); + + assertQuery("SELECT a.col1[1].col0, SUM(a.col0), SUM(a.col1[1].col1), SUM(a.col1[2].col0), SUM(a.col2.col1) FROM " + + "(VALUES " + + "(ROW(test_row(2.2, ARRAY[test_row(31, 4.2), test_row(22, 4.2)], test_row(5, 4.0)))), " + + "(ROW(test_row(1.0, ARRAY[test_row(31, 4.5), test_row(12, 4.2)], test_row(3, 4.1)))), " + + "(ROW(test_row(3.1, ARRAY[test_row(41, 3.1), test_row(32, 4.2)], test_row(6, 6.0)))), " + + "(ROW(test_row(3.3, ARRAY[test_row(41, 3.1), test_row(32, 4.2)], test_row(6, 6.0)))) " + + ") t(a) " + + "GROUP BY a.col1[1]", + "SELECT * FROM VALUES (31, 2.2, 4.2, 22, 4.0), (31, 1.0, 4.5, 12, 4.1), (41, 6.4, 6.2, 64, 12.0)"); + + assertQuery("SELECT a.col1[2], SUM(a.col0), SUM(a.col1[1]), SUM(a.col2.col1) FROM " + + "(VALUES " + + "(ROW(test_row(1.0, ARRAY[2, 13, 4], test_row(11, 4.1)))), " + + "(ROW(test_row(2.0, ARRAY[2, 13, 4], test_row(12, 14.0)))), " + + "(ROW(test_row(7.0, ARRAY[22, 33, 44], test_row(13, 5.0))))) t(a) " + + "GROUP BY a.col1[2]", + "SELECT * FROM VALUES (13, 3.0, 4, 18.1), (33, 7.0, 22, 5.0)"); + + assertQuery("SELECT a.col2.col0, SUM(a.col2.col1) FROM " + + "(VALUES " + + "(ROW(test_row(2.2, ARRAY[test_row(31, 4.2), test_row(22, 4.2)], test_row(5, 4.0)))), " + + "(ROW(test_row(1.0, ARRAY[test_row(31, 4.5), test_row(12, 4.2)], test_row(3, 4.1)))), " + + "(ROW(test_row(3.1, ARRAY[test_row(41, 3.1), test_row(32, 4.2)], test_row(6, 6.0)))), " + + "(ROW(test_row(3.3, ARRAY[test_row(41, 3.1), test_row(32, 4.2)], test_row(6, 6.0)))) " + + ") t(a) " + + "GROUP BY a.col2", + "SELECT * FROM VALUES (5, 4.0), (3, 4.1), (6, 12.0)"); + + assertQuery("SELECT a.col2.col0, a.col0, SUM(a.col2.col1) FROM " + + "(VALUES " + + "(ROW(test_row(1.0, ARRAY[2, 13, 4], test_row(11, 4.1)))), " + + "(ROW(test_row(2.0, ARRAY[2, 23, 4], test_row(11, 14.0)))), " + + "(ROW(test_row(1.5, ARRAY[2, 13, 4], test_row(11, 4.1)))), " + + "(ROW(test_row(1.5, ARRAY[2, 13, 4], test_row(11, 4.1)))), " + + "(ROW(test_row(7.0, ARRAY[22, 33, 44], test_row(13, 5.0))))) t(a) " + + "GROUP BY 1, 2 ORDER BY 1", + "SELECT * FROM VALUES (11, 1.0, 4.1), (11, 1.5, 8.2), (11, 2.0, 14.0), (13, 7.0, 5.0)"); + } + + @Test + public void testRowFieldAccessorInWindowFunction() + throws Exception + { + assertQuery("SELECT a.col0, " + + "SUM(a.col1[1].col1) OVER(PARTITION BY a.col2.col0), " + + "SUM(a.col2.col1) OVER(PARTITION BY a.col2.col0) FROM " + + "(VALUES " + + "(ROW(test_row(1.0, ARRAY[test_row(31, 14.5), test_row(12, 4.2)], test_row(3, 4.0)))), " + + "(ROW(test_row(2.2, ARRAY[test_row(41, 13.1), test_row(32, 4.2)], test_row(6, 6.0)))), " + + "(ROW(test_row(2.2, ARRAY[test_row(41, 17.1), test_row(45, 4.2)], test_row(7, 16.0)))), " + + "(ROW(test_row(2.2, ARRAY[test_row(41, 13.1), test_row(32, 4.2)], test_row(6, 6.0)))), " + + "(ROW(test_row(3.1, ARRAY[test_row(41, 13.1), test_row(32, 4.2)], test_row(6, 6.0))))) t(a) ", + "SELECT * FROM VALUES (1.0, 14.5, 4.0), (2.2, 39.3, 18.0), (2.2, 39.3, 18.0), (2.2, 17.1, 16.0), (3.1, 39.3, 18.0)"); + + assertQuery("SELECT a.col1[1].col0, " + + "SUM(a.col0) OVER(PARTITION BY a.col1[1].col0), " + + "SUM(a.col1[1].col1) OVER(PARTITION BY a.col1[1].col0), " + + "SUM(a.col2.col1) OVER(PARTITION BY a.col1[1].col0) FROM " + + "(VALUES " + + "(ROW(test_row(1.0, ARRAY[test_row(31, 14.5), test_row(12, 4.2)], test_row(3, 4.0)))), " + + "(ROW(test_row(3.1, ARRAY[test_row(41, 13.1), test_row(32, 4.2)], test_row(6, 6.0)))), " + + "(ROW(test_row(2.2, ARRAY[test_row(31, 14.2), test_row(22, 4.2)], test_row(5, 4.0))))) t(a) ", + "SELECT * FROM VALUES (31, 3.2, 28.7, 8.0), (31, 3.2, 28.7, 8.0), (41, 3.1, 13.1, 6.0)"); + } + + @Test + public void testRowFieldAccessorInJoin() + throws Exception + { + assertQuery("" + + "SELECT t.a.col1, custkey, orderkey FROM " + + "(VALUES " + + "(ROW(test_row(1, 11))), " + + "(ROW(test_row(2, 22))), " + + "(ROW(test_row(3, 33)))) t(a) " + + "INNER JOIN orders " + + "ON t.a.col0 = orders.orderkey", + "SELECT * FROM VALUES (11, 370, 1), (22, 781, 2), (33, 1234, 3)"); + } + + @Test(expectedExceptions = RuntimeException.class, expectedExceptionsMessageRegExp = "'\"a\".\"col0\"' must be an aggregate expression or appear in GROUP BY clause") + public void testMissingRowFieldInGroupBy() + throws Exception + { + assertQuery("SELECT a.col0, count(*) FROM (VALUES ROW(test_row(1, 1))) t(a)"); + } + + @Test + public void testWhereWithRowField() + throws Exception + { + assertQuery("SELECT a.col0 FROM (VALUES ROW (test_row(1, 2))) AS t (a) WHERE a.col0 > 0", "SELECT 1"); + assertQuery("SELECT SUM(a.col0) FROM (VALUES ROW (test_row(1, 2))) AS t (a) WHERE a.col0 <= 0", "SELECT null"); + + assertQuery("SELECT a.col0 FROM (VALUES ROW (test_row(1, 2))) AS t (a) WHERE a.col0 < a.col1", "SELECT 1"); + assertQuery("SELECT SUM(a.col0) FROM (VALUES ROW (test_row(1, 2))) AS t (a) WHERE a.col0 < a.col1", "SELECT 1"); + assertQuery("SELECT SUM(a.col0) FROM (VALUES ROW (test_row(1, 2))) AS t (a) WHERE a.col0 > a.col1", "SELECT null"); } @Test