From 4c5d98ae19360add00aa303076a6e3d32112f607 Mon Sep 17 00:00:00 2001 From: Zhixiong Chen Date: Wed, 27 Jul 2022 10:09:18 +0800 Subject: [PATCH] [SPARK-39148][SQL] DS V2 aggregate push down can work with OFFSET or LIMIT (#505) * [SPARK-39139][SQL] DS V2 supports push down DS V2 UDF ### What changes were proposed in this pull request? Currently, Spark DS V2 push-down framework supports push down SQL to data sources. But the DS V2 push-down framework only support push down the built-in functions to data sources. Each database have a lot very useful functions which not supported by Spark. If we can push down these functions into data source, it will reduce disk I/O and network I/O and improve the performance when query databases. ### Why are the changes needed? 1. Spark can leverage the functions supported by databases 2. Improve the query performance. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests. Closes #36593 from beliefer/SPARK-39139. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan * [SPARK-39453][SQL][TESTS][FOLLOWUP] Let `RAND` in filter is more meaningful ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/36830 makes DS V2 supports push down misc non-aggregate functions(non ANSI). But he `Rand` in test case looks no meaningful. ### Why are the changes needed? Let `Rand` in filter is more meaningful. ### Does this PR introduce _any_ user-facing change? 'No'. Just update test case. ### How was this patch tested? Just update test case. Closes #37033 from beliefer/SPARK-39453_followup. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan * [SPARK-37527][SQL][FOLLOWUP] Cannot compile COVAR_POP, COVAR_SAMP and CORR in `H2Dialect` if them with `DISTINCT` ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/35145 compile COVAR_POP, COVAR_SAMP and CORR in H2Dialect. Because H2 does't support COVAR_POP, COVAR_SAMP and CORR works with DISTINCT. So https://github.com/apache/spark/pull/35145 introduces a bug that compile COVAR_POP, COVAR_SAMP and CORR if these aggregate functions with DISTINCT. ### Why are the changes needed? Fix bug that compile COVAR_POP, COVAR_SAMP and CORR if these aggregate functions with DISTINCT. ### Does this PR introduce _any_ user-facing change? 'Yes'. Bug will be fix. ### How was this patch tested? New test cases. Closes #37090 from beliefer/SPARK-37527_followup2. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan * [SPARK-39627][SQL] DS V2 pushdown should unify the compile API ### What changes were proposed in this pull request? Currently, `JdbcDialect` have two API `compileAggregate` and `compileExpression`, we can unify them. ### Why are the changes needed? Improve ease of use. ### Does this PR introduce _any_ user-facing change? 'No'. The two API `compileAggregate` call `compileExpression` not changed. ### How was this patch tested? N/A Closes #37047 from beliefer/SPARK-39627. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan * [SPARK-39384][SQL] Compile built-in linear regression aggregate functions for JDBC dialect ### What changes were proposed in this pull request? Recently, Spark DS V2 pushdown framework translate a lot of standard linear regression aggregate functions. Currently, only H2Dialect compile these standard linear regression aggregate functions. This PR compile these standard linear regression aggregate functions for other build-in JDBC dialect. ### Why are the changes needed? Make build-in JDBC dialect support compile linear regression aggregate push-down. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New test cases. Closes #37188 from beliefer/SPARK-39384. Authored-by: Jiaan Geng Signed-off-by: Sean Owen * [SPARK-39148][SQL] DS V2 aggregate push down can work with OFFSET or LIMIT ### What changes were proposed in this pull request? This PR refactors the v2 agg pushdown code. The main change is, now we don't build the `Scan` immediately when pushing agg. We did it so before because we want to know the data schema with agg pushed, then we can add cast when rewriting the query plan after pushdown. But the problem is, we build `Scan` too early and can't push down any more operators, while it's common to see LIMIT/OFFSET after agg. The idea of the refactor is, we don't need to know the data schema with agg pushed. We just give an expectation (the data type should be the same of Spark agg functions), use it to define the output of `ScanBuilderHolder`, and then rewrite the query plan. Later on, when we build the `Scan` and replace `ScanBuilderHolder` with `DataSourceV2ScanRelation`, we check the actual data schema and add a `Project` to do type cast if necessary. ### Why are the changes needed? support pushing down LIMIT/OFFSET after agg. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? updated tests Closes #37195 from cloud-fan/agg. Lead-authored-by: Wenchen Fan Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan Co-authored-by: Jiaan Geng Co-authored-by: Wenchen Fan Co-authored-by: Wenchen Fan --- .../sql/jdbc/v2/DB2IntegrationSuite.scala | 4 + .../sql/jdbc/v2/OracleIntegrationSuite.scala | 4 + .../jdbc/v2/PostgresIntegrationSuite.scala | 8 + .../apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 118 ++++- .../expressions/GeneralScalarExpression.java | 11 +- .../expressions/UserDefinedScalarFunc.java | 70 +++ .../aggregate/GeneralAggregateFunc.java | 20 +- .../aggregate/UserDefinedAggregateFunc.java | 59 +++ .../util/V2ExpressionSQLBuilder.java | 65 +++ .../connector/ToStringSQLBuilder.scala | 38 ++ .../catalyst/util/V2ExpressionBuilder.scala | 10 +- .../datasources/DataSourceStrategy.scala | 10 +- .../v2/V2ScanRelationPushDown.scala | 418 ++++++++++-------- .../apache/spark/sql/jdbc/DB2Dialect.scala | 70 +-- .../apache/spark/sql/jdbc/DerbyDialect.scala | 26 +- .../org/apache/spark/sql/jdbc/H2Dialect.scala | 59 +-- .../apache/spark/sql/jdbc/JdbcDialects.scala | 38 +- .../spark/sql/jdbc/MsSqlServerDialect.scala | 52 ++- .../apache/spark/sql/jdbc/MySQLDialect.scala | 50 ++- .../apache/spark/sql/jdbc/OracleDialect.scala | 60 +-- .../spark/sql/jdbc/PostgresDialect.scala | 43 +- .../spark/sql/jdbc/TeradataDialect.scala | 39 +- .../connector/DataSourceV2FunctionSuite.scala | 20 + .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 271 ++++++++++-- 24 files changed, 1029 insertions(+), 534 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index 35711e57d0b72..e66e2143dfa0f 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -106,4 +106,8 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { testStddevSamp(true) testCovarPop() testCovarSamp() + testRegrIntercept() + testRegrSlope() + testRegrR2() + testRegrSXY() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index b38f2675243e6..6d03b73766975 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -107,4 +107,8 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes testCovarPop() testCovarSamp() testCorr() + testRegrIntercept() + testRegrSlope() + testRegrR2() + testRegrSXY() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index d76e13c1cd421..c0822ab87d00c 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -104,4 +104,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT testCovarSamp(true) testCorr() testCorr(true) + testRegrIntercept() + testRegrIntercept(true) + testRegrSlope() + testRegrSlope(true) + testRegrR2() + testRegrR2(true) + testRegrSXY() + testRegrSXY(true) } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 7cab8cd77df66..d11d35f3ef8b7 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -386,9 +386,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu protected def caseConvert(tableName: String): String = tableName + private def withOrWithout(isDistinct: Boolean): String = if (isDistinct) "with" else "without" + protected def testVarPop(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: VAR_POP with distinct: $isDistinct") { + test(s"scan with aggregate push-down: VAR_POP ${withOrWithout(isDistinct)} DISTINCT") { val df = sql(s"SELECT VAR_POP(${distinct}bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) @@ -396,15 +398,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "VAR_POP") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 10000d) - assert(row(1).getDouble(0) === 2500d) - assert(row(2).getDouble(0) === 0d) + assert(row(0).getDouble(0) === 10000.0) + assert(row(1).getDouble(0) === 2500.0) + assert(row(2).getDouble(0) === 0.0) } } protected def testVarSamp(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: VAR_SAMP with distinct: $isDistinct") { + test(s"scan with aggregate push-down: VAR_SAMP ${withOrWithout(isDistinct)} DISTINCT") { val df = sql( s"SELECT VAR_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -413,15 +415,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "VAR_SAMP") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 20000d) - assert(row(1).getDouble(0) === 5000d) + assert(row(0).getDouble(0) === 20000.0) + assert(row(1).getDouble(0) === 5000.0) assert(row(2).isNullAt(0)) } } protected def testStddevPop(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: STDDEV_POP with distinct: $isDistinct") { + test(s"scan with aggregate push-down: STDDEV_POP ${withOrWithout(isDistinct)} DISTINCT") { val df = sql( s"SELECT STDDEV_POP(${distinct}bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -430,15 +432,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "STDDEV_POP") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 100d) - assert(row(1).getDouble(0) === 50d) - assert(row(2).getDouble(0) === 0d) + assert(row(0).getDouble(0) === 100.0) + assert(row(1).getDouble(0) === 50.0) + assert(row(2).getDouble(0) === 0.0) } } protected def testStddevSamp(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: STDDEV_SAMP with distinct: $isDistinct") { + test(s"scan with aggregate push-down: STDDEV_SAMP ${withOrWithout(isDistinct)} DISTINCT") { val df = sql( s"SELECT STDDEV_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -447,15 +449,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "STDDEV_SAMP") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 141.4213562373095d) - assert(row(1).getDouble(0) === 70.71067811865476d) + assert(row(0).getDouble(0) === 141.4213562373095) + assert(row(1).getDouble(0) === 70.71067811865476) assert(row(2).isNullAt(0)) } } protected def testCovarPop(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: COVAR_POP with distinct: $isDistinct") { + test(s"scan with aggregate push-down: COVAR_POP ${withOrWithout(isDistinct)} DISTINCT") { val df = sql( s"SELECT COVAR_POP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -464,15 +466,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "COVAR_POP") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 10000d) - assert(row(1).getDouble(0) === 2500d) - assert(row(2).getDouble(0) === 0d) + assert(row(0).getDouble(0) === 10000.0) + assert(row(1).getDouble(0) === 2500.0) + assert(row(2).getDouble(0) === 0.0) } } protected def testCovarSamp(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: COVAR_SAMP with distinct: $isDistinct") { + test(s"scan with aggregate push-down: COVAR_SAMP ${withOrWithout(isDistinct)} DISTINCT") { val df = sql( s"SELECT COVAR_SAMP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -481,15 +483,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "COVAR_SAMP") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 20000d) - assert(row(1).getDouble(0) === 5000d) + assert(row(0).getDouble(0) === 20000.0) + assert(row(1).getDouble(0) === 5000.0) assert(row(2).isNullAt(0)) } } protected def testCorr(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: CORR with distinct: $isDistinct") { + test(s"scan with aggregate push-down: CORR ${withOrWithout(isDistinct)} DISTINCT") { val df = sql( s"SELECT CORR(${distinct}bonus, bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -498,9 +500,77 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "CORR") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 1d) - assert(row(1).getDouble(0) === 1d) + assert(row(0).getDouble(0) === 1.0) + assert(row(1).getDouble(0) === 1.0) + assert(row(2).isNullAt(0)) + } + } + + protected def testRegrIntercept(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: REGR_INTERCEPT ${withOrWithout(isDistinct)} DISTINCT") { + val df = sql( + s"SELECT REGR_INTERCEPT(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "REGR_INTERCEPT") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 0.0) + assert(row(1).getDouble(0) === 0.0) + assert(row(2).isNullAt(0)) + } + } + + protected def testRegrSlope(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: REGR_SLOPE ${withOrWithout(isDistinct)} DISTINCT") { + val df = sql( + s"SELECT REGR_SLOPE(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "REGR_SLOPE") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 1.0) + assert(row(1).getDouble(0) === 1.0) + assert(row(2).isNullAt(0)) + } + } + + protected def testRegrR2(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: REGR_R2 ${withOrWithout(isDistinct)} DISTINCT") { + val df = sql( + s"SELECT REGR_R2(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "REGR_R2") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 1.0) + assert(row(1).getDouble(0) === 1.0) assert(row(2).isNullAt(0)) } } + + protected def testRegrSXY(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: REGR_SXY ${withOrWithout(isDistinct)} DISTINCT") { + val df = sql( + s"SELECT REGR_SXY(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "REGR_SXY") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 20000.0) + assert(row(1).getDouble(0) === 5000.0) + assert(row(2).getDouble(0) === 0.0) + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index 17e24fa7ad8da..6dfaad0d26eb4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.filter.Predicate; -import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder; +import org.apache.spark.sql.internal.connector.ToStringSQLBuilder; /** * The general representation of SQL scalar expressions, which contains the upper-cased @@ -398,12 +398,7 @@ public int hashCode() { @Override public String toString() { - V2ExpressionSQLBuilder builder = new V2ExpressionSQLBuilder(); - try { - return builder.build(this); - } catch (Throwable e) { - return name + "(" + - Arrays.stream(children).map(child -> child.toString()).reduce((a,b) -> a + "," + b) + ")"; - } + ToStringSQLBuilder builder = new ToStringSQLBuilder(); + return builder.build(this); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java new file mode 100644 index 0000000000000..8e4155f81b87b --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Objects; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.internal.connector.ToStringSQLBuilder; + +/** + * The general representation of user defined scalar function, which contains the upper-cased + * function name, canonical function name and all the children expressions. + * + * @since 3.4.0 + */ +@Evolving +public class UserDefinedScalarFunc implements Expression, Serializable { + private String name; + private String canonicalName; + private Expression[] children; + + public UserDefinedScalarFunc(String name, String canonicalName, Expression[] children) { + this.name = name; + this.canonicalName = canonicalName; + this.children = children; + } + + public String name() { return name; } + public String canonicalName() { return canonicalName; } + + @Override + public Expression[] children() { return children; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + UserDefinedScalarFunc that = (UserDefinedScalarFunc) o; + return Objects.equals(name, that.name) && Objects.equals(canonicalName, that.canonicalName) && + Arrays.equals(children, that.children); + } + + @Override + public int hashCode() { + return Objects.hash(name, canonicalName, children); + } + + @Override + public String toString() { + ToStringSQLBuilder builder = new ToStringSQLBuilder(); + return builder.build(this); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java index 7016644543447..81838074fb136 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -17,11 +17,9 @@ package org.apache.spark.sql.connector.expressions.aggregate; -import java.util.Arrays; -import java.util.stream.Collectors; - import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.internal.connector.ToStringSQLBuilder; /** * The general implementation of {@link AggregateFunc}, which contains the upper-cased function @@ -47,27 +45,21 @@ public final class GeneralAggregateFunc implements AggregateFunc { private final boolean isDistinct; private final Expression[] children; - public String name() { return name; } - public boolean isDistinct() { return isDistinct; } - public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) { this.name = name; this.isDistinct = isDistinct; this.children = children; } + public String name() { return name; } + public boolean isDistinct() { return isDistinct; } + @Override public Expression[] children() { return children; } @Override public String toString() { - String inputsString = Arrays.stream(children) - .map(Expression::describe) - .collect(Collectors.joining(", ")); - if (isDistinct) { - return name + "(DISTINCT " + inputsString + ")"; - } else { - return name + "(" + inputsString + ")"; - } + ToStringSQLBuilder builder = new ToStringSQLBuilder(); + return builder.build(this); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java new file mode 100644 index 0000000000000..9a89e7a89c9f9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.internal.connector.ToStringSQLBuilder; + +/** + * The general representation of user defined aggregate function, which implements + * {@link AggregateFunc}, contains the upper-cased function name, the canonical function name, + * the `isDistinct` flag and all the inputs. Note that Spark cannot push down aggregate with + * this function partially to the source, but can only push down the entire aggregate. + * + * @since 3.4.0 + */ +@Evolving +public class UserDefinedAggregateFunc implements AggregateFunc { + private final String name; + private String canonicalName; + private final boolean isDistinct; + private final Expression[] children; + + public UserDefinedAggregateFunc( + String name, String canonicalName, boolean isDistinct, Expression[] children) { + this.name = name; + this.canonicalName = canonicalName; + this.isDistinct = isDistinct; + this.children = children; + } + + public String name() { return name; } + public String canonicalName() { return canonicalName; } + public boolean isDistinct() { return isDistinct; } + + @Override + public Expression[] children() { return children; } + + @Override + public String toString() { + ToStringSQLBuilder builder = new ToStringSQLBuilder(); + return builder.build(this); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 528d7ea4cdb1d..60708ede19c8f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -27,6 +27,15 @@ import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.expressions.GeneralScalarExpression; import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Avg; +import org.apache.spark.sql.connector.expressions.aggregate.Max; +import org.apache.spark.sql.connector.expressions.aggregate.Min; +import org.apache.spark.sql.connector.expressions.aggregate.Count; +import org.apache.spark.sql.connector.expressions.aggregate.CountStar; +import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Sum; +import org.apache.spark.sql.connector.expressions.aggregate.UserDefinedAggregateFunc; import org.apache.spark.sql.types.DataType; /** @@ -161,6 +170,40 @@ public String build(Expression expr) { default: return visitUnexpectedExpr(expr); } + } else if (expr instanceof Min) { + Min min = (Min) expr; + return visitAggregateFunction("MIN", false, + Arrays.stream(min.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof Max) { + Max max = (Max) expr; + return visitAggregateFunction("MAX", false, + Arrays.stream(max.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof Count) { + Count count = (Count) expr; + return visitAggregateFunction("COUNT", count.isDistinct(), + Arrays.stream(count.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof Sum) { + Sum sum = (Sum) expr; + return visitAggregateFunction("SUM", sum.isDistinct(), + Arrays.stream(sum.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof CountStar) { + return visitAggregateFunction("COUNT", false, new String[]{"*"}); + } else if (expr instanceof Avg) { + Avg avg = (Avg) expr; + return visitAggregateFunction("AVG", avg.isDistinct(), + Arrays.stream(avg.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof GeneralAggregateFunc) { + GeneralAggregateFunc f = (GeneralAggregateFunc) expr; + return visitAggregateFunction(f.name(), f.isDistinct(), + Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof UserDefinedScalarFunc) { + UserDefinedScalarFunc f = (UserDefinedScalarFunc) expr; + return visitUserDefinedScalarFunction(f.name(), f.canonicalName(), + Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof UserDefinedAggregateFunc) { + UserDefinedAggregateFunc f = (UserDefinedAggregateFunc) expr; + return visitUserDefinedAggregateFunction(f.name(), f.canonicalName(), f.isDistinct(), + Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new)); } else { return visitUnexpectedExpr(expr); } @@ -273,6 +316,28 @@ protected String visitSQLFunction(String funcName, String[] inputs) { return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")"; } + protected String visitAggregateFunction( + String funcName, boolean isDistinct, String[] inputs) { + if (isDistinct) { + return funcName + + "(DISTINCT " + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")"; + } else { + return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")"; + } + } + + protected String visitUserDefinedScalarFunction( + String funcName, String canonicalName, String[] inputs) { + throw new UnsupportedOperationException( + this.getClass().getSimpleName() + " does not support user defined function: " + funcName); + } + + protected String visitUserDefinedAggregateFunction( + String funcName, String canonicalName, boolean isDistinct, String[] inputs) { + throw new UnsupportedOperationException(this.getClass().getSimpleName() + + " does not support user defined aggregate function: " + funcName); + } + protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException { throw new IllegalArgumentException("Unexpected V2 expression: " + expr); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala new file mode 100644 index 0000000000000..889fdd4ebf291 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal.connector + +import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder + +/** + * The builder to generate `toString` information of V2 expressions. + */ +class ToStringSQLBuilder extends V2ExpressionSQLBuilder { + override protected def visitUserDefinedScalarFunction( + funcName: String, canonicalName: String, inputs: Array[String]) = + s"""$funcName(${inputs.mkString(", ")})""" + + override protected def visitUserDefinedAggregateFunction( + funcName: String, + canonicalName: String, + isDistinct: Boolean, + inputs: Array[String]): String = { + val distinct = if (isDistinct) "DISTINCT " else "" + s"""$funcName($distinct${inputs.mkString(", ")})""" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 041e7f369c513..8bb65a8804471 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue} +import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.types.{BooleanType, IntegerType} @@ -398,6 +398,14 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { case YearOfWeek(child) => generateExpression(child).map(v => new V2Extract("YEAR_OF_WEEK", v)) // TODO supports other expressions + case ApplyFunctionExpression(function, children) => + val childrenExpressions = children.flatMap(generateExpression(_)) + if (childrenExpressions.length == children.length) { + Some(new UserDefinedScalarFunc( + function.name(), function.canonicalName(), childrenExpressions.toArray[V2Expression])) + } else { + None + } case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 8d8e2c26e279e..5709e2e1484df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -738,6 +738,14 @@ object DataSourceStrategy PushableColumnWithoutNestedColumn(right), _) => Some(new GeneralAggregateFunc("CORR", agg.isDistinct, Array(FieldReference(left), FieldReference(right)))) + case aggregate.V2Aggregator(aggrFunc, children, _, _) => + val translatedExprs = children.flatMap(PushableExpression.unapply(_)) + if (translatedExprs.length == children.length) { + Some(new UserDefinedAggregateFunc(aggrFunc.name(), + aggrFunc.canonicalName(), agg.isDistinct, translatedExprs.toArray[V2Expression])) + } else { + None + } case _ => None } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index f69cf937b099b..e90f59f310fcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -26,12 +26,12 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder} -import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructType} +import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, StructType} import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper { @@ -44,6 +44,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit pushDownFilters, pushDownAggregates, pushDownLimitAndOffset, + buildScanWithPushedAggregate, pruneColumns) pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) => @@ -92,189 +93,201 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform { // update the scan builder with agg pushdown and return a new plan with agg pushed - case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => - child match { - case ScanOperation(project, filters, sHolder: ScanBuilderHolder) - if filters.isEmpty && CollapseProject.canCollapseExpressions( - resultExpressions, project, alwaysInline = true) => - sHolder.builder match { - case r: SupportsPushDownAggregates => - val aliasMap = getAliasMap(project) - val actualResultExprs = resultExpressions.map(replaceAliasButKeepName(_, aliasMap)) - val actualGroupExprs = groupingExpressions.map(replaceAlias(_, aliasMap)) - - val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal) - val normalizedAggregates = DataSourceStrategy.normalizeExprs( - aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] - val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs( - actualGroupExprs, sHolder.relation.output) - val translatedAggregates = DataSourceStrategy.translateAggregation( - normalizedAggregates, normalizedGroupingExpressions) - val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = { - if (translatedAggregates.isEmpty || - r.supportCompletePushDown(translatedAggregates.get) || - translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) { - (actualResultExprs, aggregates, translatedAggregates) - } else { - // scalastyle:off - // The data source doesn't support the complete push-down of this aggregation. - // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be - // pushed, completely or partially. - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT avg(c1) FROM t GROUP BY c2; - // The original logical plan is - // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19] - // +- ScanOperation[...] - // - // After convert avg(c1#9) to sum(c1#9)/count(c1#9) - // we have the following - // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19] - // +- ScanOperation[...] - // scalastyle:on - val newResultExpressions = actualResultExprs.map { expr => - expr.transform { - case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => - val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) - val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct) - avg.evaluateExpression transform { - case a: Attribute if a.semanticEquals(avg.sum) => - addCastIfNeeded(sum, avg.sum.dataType) - case a: Attribute if a.semanticEquals(avg.count) => - addCastIfNeeded(count, avg.count.dataType) - } - } - }.asInstanceOf[Seq[NamedExpression]] - // Because aggregate expressions changed, translate them again. - aggExprToOutputOrdinal.clear() - val newAggregates = - collectAggregates(newResultExpressions, aggExprToOutputOrdinal) - val newNormalizedAggregates = DataSourceStrategy.normalizeExprs( - newAggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] - (newResultExpressions, newAggregates, DataSourceStrategy.translateAggregation( - newNormalizedAggregates, normalizedGroupingExpressions)) + case agg: Aggregate => rewriteAggregate(agg) + } + + private def rewriteAggregate(agg: Aggregate): LogicalPlan = agg.child match { + case ScanOperation(project, Nil, holder @ ScanBuilderHolder(_, _, + r: SupportsPushDownAggregates)) if CollapseProject.canCollapseExpressions( + agg.aggregateExpressions, project, alwaysInline = true) => + val aliasMap = getAliasMap(project) + val actualResultExprs = agg.aggregateExpressions.map(replaceAliasButKeepName(_, aliasMap)) + val actualGroupExprs = agg.groupingExpressions.map(replaceAlias(_, aliasMap)) + + val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] + val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal) + val normalizedAggExprs = DataSourceStrategy.normalizeExprs( + aggregates, holder.relation.output).asInstanceOf[Seq[AggregateExpression]] + val normalizedGroupingExpr = DataSourceStrategy.normalizeExprs( + actualGroupExprs, holder.relation.output) + val translatedAggOpt = DataSourceStrategy.translateAggregation( + normalizedAggExprs, normalizedGroupingExpr) + if (translatedAggOpt.isEmpty) { + // Cannot translate the catalyst aggregate, return the query plan unchanged. + return agg + } + + val (finalResultExprs, finalAggExprs, translatedAgg, canCompletePushDown) = { + if (r.supportCompletePushDown(translatedAggOpt.get)) { + (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, true) + } else if (!translatedAggOpt.get.aggregateExpressions().exists(_.isInstanceOf[Avg])) { + (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, false) + } else { + // scalastyle:off + // The data source doesn't support the complete push-down of this aggregation. + // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be + // pushed, completely or partially. + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT avg(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // + // After convert avg(c1#9) to sum(c1#9)/count(c1#9) + // we have the following + // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // scalastyle:on + val newResultExpressions = actualResultExprs.map { expr => + expr.transform { + case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => + val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) + val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct) + avg.evaluateExpression transform { + case a: Attribute if a.semanticEquals(avg.sum) => + addCastIfNeeded(sum, avg.sum.dataType) + case a: Attribute if a.semanticEquals(avg.count) => + addCastIfNeeded(count, avg.count.dataType) } - } + } + }.asInstanceOf[Seq[NamedExpression]] + // Because aggregate expressions changed, translate them again. + aggExprToOutputOrdinal.clear() + val newAggregates = + collectAggregates(newResultExpressions, aggExprToOutputOrdinal) + val newNormalizedAggExprs = DataSourceStrategy.normalizeExprs( + newAggregates, holder.relation.output).asInstanceOf[Seq[AggregateExpression]] + val newTranslatedAggOpt = DataSourceStrategy.translateAggregation( + newNormalizedAggExprs, normalizedGroupingExpr) + if (newTranslatedAggOpt.isEmpty) { + // Ideally we should never reach here. But if we end up with not able to translate + // new aggregate with AVG replaced by SUM/COUNT, revert to the original one. + (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, false) + } else { + (newResultExpressions, newNormalizedAggExprs, newTranslatedAggOpt.get, + r.supportCompletePushDown(newTranslatedAggOpt.get)) + } + } + } - if (finalTranslatedAggregates.isEmpty) { - aggNode // return original plan node - } else if (!r.supportCompletePushDown(finalTranslatedAggregates.get) && - !supportPartialAggPushDown(finalTranslatedAggregates.get)) { - aggNode // return original plan node - } else { - val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation) - if (pushedAggregates.isEmpty) { - aggNode // return original plan node - } else { - // No need to do column pruning because only the aggregate columns are used as - // DataSourceV2ScanRelation output columns. All the other columns are not - // included in the output. - val scan = sHolder.builder.build() - - // scalastyle:off - // use the group by columns and aggregate columns as the output columns - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t GROUP BY c2; - // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation - // We want to have the following logical plan: - // == Optimized Logical Plan == - // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] - // scalastyle:on - val newOutput = scan.readSchema().toAttributes - assert(newOutput.length == groupingExpressions.length + finalAggregates.length) - val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - val groupAttrs = normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map { - case ((a: Attribute, b: Attribute), _) => b.withExprId(a.exprId) - case ((expr, attr), ordinal) => - if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) { - groupByExprToOutputOrdinal(expr.canonicalized) = ordinal - } - attr - } - val aggOutput = newOutput.drop(groupAttrs.length) - val output = groupAttrs ++ aggOutput - - logInfo( - s""" - |Pushing operators to ${sHolder.relation.name} - |Pushed Aggregate Functions: - | ${pushedAggregates.get.aggregateExpressions.mkString(", ")} - |Pushed Group by: - | ${pushedAggregates.get.groupByExpressions.mkString(", ")} - |Output: ${output.mkString(", ")} - """.stripMargin) - - val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) - val scanRelation = - DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) - if (r.supportCompletePushDown(pushedAggregates.get)) { - val projectExpressions = finalResultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression => - val ordinal = aggExprToOutputOrdinal(agg.canonicalized) - val child = - addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType) - Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId) - case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => - val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) - addCastIfNeeded(groupAttrs(ordinal), expr.dataType) - } - }.asInstanceOf[Seq[NamedExpression]] - Project(projectExpressions, scanRelation) + if (!canCompletePushDown && !supportPartialAggPushDown(translatedAgg)) { + return agg + } + if (!r.pushAggregation(translatedAgg)) { + return agg + } + + // scalastyle:off + // We name the output columns of group expressions and aggregate functions by + // ordinal: `group_col_0`, `group_col_1`, ..., `agg_func_0`, `agg_func_1`, ... + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // Use group_col_0, agg_func_0, agg_func_1 as output for ScanBuilderHolder. + // We want to have the following logical plan: + // == Optimized Logical Plan == + // Aggregate [group_col_0#10], [min(agg_func_0#21) AS min(c1)#17, max(agg_func_1#22) AS max(c1)#18] + // +- ScanBuilderHolder[group_col_0#10, agg_func_0#21, agg_func_1#22] + // Later, we build the `Scan` instance and convert ScanBuilderHolder to DataSourceV2ScanRelation. + // scalastyle:on + val groupOutput = normalizedGroupingExpr.zipWithIndex.map { case (e, i) => + AttributeReference(s"group_col_$i", e.dataType)() + } + val aggOutput = finalAggExprs.zipWithIndex.map { case (e, i) => + AttributeReference(s"agg_func_$i", e.dataType)() + } + val newOutput = groupOutput ++ aggOutput + val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] + normalizedGroupingExpr.zipWithIndex.foreach { case (expr, ordinal) => + if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) { + groupByExprToOutputOrdinal(expr.canonicalized) = ordinal + } + } + + holder.pushedAggregate = Some(translatedAgg) + holder.output = newOutput + logInfo( + s""" + |Pushing operators to ${holder.relation.name} + |Pushed Aggregate Functions: + | ${translatedAgg.aggregateExpressions().mkString(", ")} + |Pushed Group by: + | ${translatedAgg.groupByExpressions.mkString(", ")} + """.stripMargin) + + if (canCompletePushDown) { + val projectExpressions = finalResultExprs.map { expr => + expr.transformDown { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + Alias(aggOutput(ordinal), agg.resultAttribute.name)(agg.resultAttribute.exprId) + case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => + val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) + expr match { + case ne: NamedExpression => Alias(groupOutput(ordinal), ne.name)(ne.exprId) + case _ => groupOutput(ordinal) + } + } + }.asInstanceOf[Seq[NamedExpression]] + Project(projectExpressions, holder) + } else { + // scalastyle:off + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9, c2#10] ... + // + // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] + // we have the following + // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // scalastyle:on + val aggExprs = finalResultExprs.map(_.transform { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + val aggAttribute = aggOutput(ordinal) + val aggFunction: aggregate.AggregateFunction = + agg.aggregateFunction match { + case max: aggregate.Max => + max.copy(child = aggAttribute) + case min: aggregate.Min => + min.copy(child = aggAttribute) + case sum: aggregate.Sum => + // To keep the dataType of `Sum` unchanged, we need to cast the + // data-source-aggregated result to `Sum.child.dataType` if it's decimal. + // See `SumBase.resultType` + val newChild = if (sum.dataType.isInstanceOf[DecimalType]) { + addCastIfNeeded(aggAttribute, sum.child.dataType) } else { - val plan = Aggregate(output.take(groupingExpressions.length), - finalResultExpressions, scanRelation) - - // scalastyle:off - // Change the optimized logical plan to reflect the pushed down aggregate - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t GROUP BY c2; - // The original logical plan is - // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c1#9, c2#10] ... - // - // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] - // we have the following - // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... - // - // We want to change it to - // == Optimized Logical Plan == - // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... - // scalastyle:on - plan.transformExpressions { - case agg: AggregateExpression => - val ordinal = aggExprToOutputOrdinal(agg.canonicalized) - val aggAttribute = aggOutput(ordinal) - val aggFunction: aggregate.AggregateFunction = - agg.aggregateFunction match { - case max: aggregate.Max => - max.copy(child = addCastIfNeeded(aggAttribute, max.child.dataType)) - case min: aggregate.Min => - min.copy(child = addCastIfNeeded(aggAttribute, min.child.dataType)) - case sum: aggregate.Sum => - sum.copy(child = addCastIfNeeded(aggAttribute, sum.child.dataType)) - case _: aggregate.Count => - aggregate.Sum(addCastIfNeeded(aggAttribute, LongType)) - case other => other - } - agg.copy(aggregateFunction = aggFunction) - case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => - val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) - addCastIfNeeded(groupAttrs(ordinal), expr.dataType) - } + aggAttribute } - } + sum.copy(child = newChild) + case _: aggregate.Count => + aggregate.Sum(aggAttribute) + case other => other } - case _ => aggNode - } - case _ => aggNode + agg.copy(aggregateFunction = aggFunction) + case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => + val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) + expr match { + case ne: NamedExpression => Alias(groupOutput(ordinal), ne.name)(ne.exprId) + case _ => groupOutput(ordinal) + } + }).asInstanceOf[Seq[NamedExpression]] + Aggregate(groupOutput, aggExprs, holder) } + + case _ => agg } - private def collectAggregates(resultExpressions: Seq[NamedExpression], + private def collectAggregates( + resultExpressions: Seq[NamedExpression], aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = { var ordinal = 0 resultExpressions.flatMap { expr => @@ -292,14 +305,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } private def supportPartialAggPushDown(agg: Aggregation): Boolean = { - // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. - // If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down. - agg.aggregateExpressions().isEmpty || agg.aggregateExpressions().exists { + // We can only partially push down min/max/sum/count without DISTINCT. + agg.aggregateExpressions().isEmpty || agg.aggregateExpressions().forall { case sum: Sum => !sum.isDistinct case count: Count => !count.isDistinct - case avg: Avg => !avg.isDistinct - case _: GeneralAggregateFunc => false - case _ => true + case _: Min | _: Max | _: CountStar => true + case _ => false } } @@ -310,6 +321,26 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit Cast(expression, expectedDataType) } + def buildScanWithPushedAggregate(plan: LogicalPlan): LogicalPlan = plan.transform { + case holder: ScanBuilderHolder if holder.pushedAggregate.isDefined => + // No need to do column pruning because only the aggregate columns are used as + // DataSourceV2ScanRelation output columns. All the other columns are not + // included in the output. + val scan = holder.builder.build() + val realOutput = scan.readSchema().toAttributes + assert(realOutput.length == holder.output.length, + "The data source returns unexpected number of columns") + val wrappedScan = getWrappedScan(scan, holder) + val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput) + val projectList = realOutput.zip(holder.output).map { case (a1, a2) => + // The data source may return columns with arbitrary data types and it's safer to cast them + // to the expected data type. + assert(Cast.canCast(a1.dataType, a2.dataType)) + Alias(addCastIfNeeded(a1, a2.dataType), a2.name)(a2.exprId) + } + Project(projectList, scanRelation) + } + def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => // column pruning @@ -324,7 +355,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit |Output: ${output.mkString(", ")} """.stripMargin) - val wrappedScan = getWrappedScan(scan, sHolder, Option.empty[Aggregation]) + val wrappedScan = getWrappedScan(scan, sHolder) val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) @@ -376,8 +407,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } (operation, isPushed && !isPartiallyPushed) case s @ Sort(order, _, operation @ ScanOperation(project, filter, sHolder: ScanBuilderHolder)) - if filter.isEmpty && CollapseProject.canCollapseExpressions( - order, project, alwaysInline = true) => + // Without building the Scan, we do not know the resulting column names after aggregate + // push-down, and thus can't push down Top-N which needs to know the ordering column names. + // TODO: we can support simple cases like GROUP BY columns directly and ORDER BY the same + // columns, which we know the resulting column names: the original table columns. + if sHolder.pushedAggregate.isEmpty && filter.isEmpty && + CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) => val aliasMap = getAliasMap(project) val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]] val normalizedOrders = DataSourceStrategy.normalizeExprs( @@ -478,10 +513,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } } - private def getWrappedScan( - scan: Scan, - sHolder: ScanBuilderHolder, - aggregation: Option[Aggregation]): Scan = { + private def getWrappedScan(scan: Scan, sHolder: ScanBuilderHolder): Scan = { scan match { case v1: V1Scan => val pushedFilters = sHolder.builder match { @@ -489,7 +521,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit f.pushedFilters() case _ => Array.empty[sources.Filter] } - val pushedDownOperators = PushedDownOperators(aggregation, sHolder.pushedSample, + val pushedDownOperators = PushedDownOperators(sHolder.pushedAggregate, sHolder.pushedSample, sHolder.pushedLimit, sHolder.pushedOffset, sHolder.sortOrders, sHolder.pushedPredicates) V1ScanWrapper(v1, pushedFilters, pushedDownOperators) case _ => scan @@ -498,7 +530,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } case class ScanBuilderHolder( - output: Seq[AttributeReference], + var output: Seq[AttributeReference], relation: DataSourceV2Relation, builder: ScanBuilder) extends LeafNode { var pushedLimit: Option[Int] = None @@ -510,6 +542,8 @@ case class ScanBuilderHolder( var pushedSample: Option[TableSampleInfo] = None var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate] + + var pushedAggregate: Option[Aggregation] = None } // A wrapper for v1 scan to carry the translated filters and the handled ones, along with diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 9bf25aa0d633f..32a0b4639a83e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.jdbc import java.sql.{SQLException, Types} import java.util.Locale +import scala.util.control.NonFatal + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.types._ private object DB2Dialect extends JdbcDialect { @@ -30,35 +32,47 @@ private object DB2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2") + private val distinctUnsupportedAggregateFunctions = + Set("COVAR_POP", "COVAR_SAMP", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY") + // See https://www.ibm.com/docs/en/db2/11.5?topic=functions-aggregate - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VARIANCE($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VARIANCE_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVARIANCE(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVARIANCE_SAMP(${f.children().head}, ${f.children().last})") - case _ => None + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) + + class DB2SQLBuilder extends JDBCSQLBuilder { + override def visitAggregateFunction( + funcName: String, isDistinct: Boolean, inputs: Array[String]): String = + if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { + throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " + + s"support aggregate function: $funcName with DISTINCT"); + } else { + super.visitAggregateFunction(funcName, isDistinct, inputs) } - ) + + override def dialectFunctionName(funcName: String): String = funcName match { + case "VAR_POP" => "VARIANCE" + case "VAR_SAMP" => "VARIANCE_SAMP" + case "STDDEV_POP" => "STDDEV" + case "STDDEV_SAMP" => "STDDEV_SAMP" + case "COVAR_POP" => "COVARIANCE" + case "COVAR_SAMP" => "COVARIANCE_SAMP" + case _ => super.dialectFunctionName(funcName) + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val db2SQLBuilder = new DB2SQLBuilder() + try { + Some(db2SQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } } override def getCatalystType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 36c3c6be4a05c..439e0697d9f3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.jdbc import java.sql.Types import java.util.Locale -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -31,25 +30,12 @@ private object DerbyDialect extends JdbcDialect { url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby") // See https://db.apache.org/derby/docs/10.15/ref/index.html - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_SAMP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_SAMP(${f.children().head})") - case _ => None - } - ) - } + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 945c25cad56b7..5a909b704e24c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.expressions.Expression -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, ShortType, StringType} @@ -36,7 +35,13 @@ private[sql] object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") - private val supportedFunctions = + private val distinctUnsupportedAggregateFunctions = + Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY") + + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions + + private val supportedFunctions = supportedAggregateFunctions ++ Set("ABS", "COALESCE", "GREATEST", "LEAST", "RAND", "LOG", "LOG10", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL", "ROUND", "SIN", "SINH", "COS", "COSH", "TAN", "TANH", "COT", "ASIN", "ACOS", "ATAN", "ATAN2", "DEGREES", "RADIANS", "SIGN", @@ -45,42 +50,6 @@ private[sql] object H2Dialect extends JdbcDialect { override def isSupportedFunction(funcName: String): Boolean = supportedFunctions.contains(funcName) - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_POP($distinct${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_SAMP($distinct${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"CORR($distinct${f.children().head}, ${f.children().last})") - case _ => None - } - ) - } - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Option(JdbcType("CLOB", Types.CLOB)) case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) @@ -125,9 +94,9 @@ private[sql] object H2Dialect extends JdbcDialect { } override def compileExpression(expr: Expression): Option[String] = { - val jdbcSQLBuilder = new H2JDBCSQLBuilder() + val h2SQLBuilder = new H2SQLBuilder() try { - Some(jdbcSQLBuilder.build(expr)) + Some(h2SQLBuilder.build(expr)) } catch { case NonFatal(e) => logWarning("Error occurs while compiling V2 expression", e) @@ -135,7 +104,15 @@ private[sql] object H2Dialect extends JdbcDialect { } } - class H2JDBCSQLBuilder extends JDBCSQLBuilder { + class H2SQLBuilder extends JDBCSQLBuilder { + override def visitAggregateFunction( + funcName: String, isDistinct: Boolean, inputs: Array[String]): String = + if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { + throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " + + s"support aggregate function: $funcName with DISTINCT"); + } else { + super.visitAggregateFunction(funcName, isDistinct, inputs) + } override def visitExtract(field: String, source: String): String = { val newField = field match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index d42d4e8fc0bac..3af0e0de89444 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcUtils} @@ -244,7 +244,7 @@ abstract class JdbcDialect extends Serializable with Logging{ override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { if (isSupportedFunction(funcName)) { - s"""$funcName(${inputs.mkString(", ")})""" + s"""${dialectFunctionName(funcName)}(${inputs.mkString(", ")})""" } else { // The framework will catch the error and give up the push-down. // Please see `JdbcDialect.compileExpression(expr: Expression)` for more details. @@ -253,6 +253,18 @@ abstract class JdbcDialect extends Serializable with Logging{ } } + override def visitAggregateFunction( + funcName: String, isDistinct: Boolean, inputs: Array[String]): String = { + if (isSupportedFunction(funcName)) { + super.visitAggregateFunction(dialectFunctionName(funcName), isDistinct, inputs) + } else { + throw new UnsupportedOperationException( + s"${this.getClass.getSimpleName} does not support aggregate function: $funcName"); + } + } + + protected def dialectFunctionName(funcName: String): String = funcName + override def visitOverlay(inputs: Array[String]): String = { if (isSupportedFunction("OVERLAY")) { super.visitOverlay(inputs) @@ -303,26 +315,8 @@ abstract class JdbcDialect extends Serializable with Logging{ * @return Converted value. */ @Since("3.3.0") - def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - aggFunction match { - case min: Min => - compileExpression(min.column).map(v => s"MIN($v)") - case max: Max => - compileExpression(max.column).map(v => s"MAX($v)") - case count: Count => - val distinct = if (count.isDistinct) "DISTINCT " else "" - compileExpression(count.column).map(v => s"COUNT($distinct$v)") - case sum: Sum => - val distinct = if (sum.isDistinct) "DISTINCT " else "" - compileExpression(sum.column).map(v => s"SUM($distinct$v)") - case _: CountStar => - Some("COUNT(*)") - case avg: Avg => - val distinct = if (avg.isDistinct) "DISTINCT " else "" - compileExpression(avg.column).map(v => s"AVG($distinct$v)") - case _ => None - } - } + @deprecated("use org.apache.spark.sql.jdbc.JdbcDialect.compileExpression instead.", "3.4.0") + def compileAggregate(aggFunction: AggregateFunc): Option[String] = compileExpression(aggFunction) /** * List the user-defined functions in jdbc dialect. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 8d2fbec55f919..af35a0575ac8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.jdbc import java.sql.SQLException import java.util.Locale +import scala.util.control.NonFatal + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -43,28 +45,32 @@ private object MsSqlServerDialect extends JdbcDialect { // scalastyle:off line.size.limit // See https://docs.microsoft.com/en-us/sql/t-sql/functions/aggregate-functions-transact-sql?view=sql-server-ver15 // scalastyle:on line.size.limit - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VARP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDEVP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDEV($distinct${f.children().head})") - case _ => None - } - ) + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) + + class MsSqlServerSQLBuilder extends JDBCSQLBuilder { + override def dialectFunctionName(funcName: String): String = funcName match { + case "VAR_POP" => "VARP" + case "VAR_SAMP" => "VAR" + case "STDDEV_POP" => "STDEVP" + case "STDDEV_SAMP" => "STDEV" + case _ => super.dialectFunctionName(funcName) + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val msSqlServerSQLBuilder = new MsSqlServerSQLBuilder() + try { + Some(msSqlServerSQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } } override def getCatalystType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index 24f9bac74f86d..d850b61383520 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -22,13 +22,13 @@ import java.util import java.util.Locale import scala.collection.mutable.ArrayBuilder +import scala.util.control.NonFatal import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} import org.apache.spark.sql.connector.catalog.index.TableIndex -import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, MetadataBuilder} @@ -38,25 +38,37 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url : String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql") + private val distinctUnsupportedAggregateFunctions = + Set("VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") + // See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_SAMP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_SAMP(${f.children().head})") - case _ => None + private val supportedAggregateFunctions = + Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ distinctUnsupportedAggregateFunctions + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) + + class MySQLSQLBuilder extends JDBCSQLBuilder { + override def visitAggregateFunction( + funcName: String, isDistinct: Boolean, inputs: Array[String]): String = + if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { + throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " + + s"support aggregate function: $funcName with DISTINCT"); + } else { + super.visitAggregateFunction(funcName, isDistinct, inputs) } - ) + } + + override def compileExpression(expr: Expression): Option[String] = { + val mysqlSQLBuilder = new MySQLSQLBuilder() + try { + Some(mysqlSQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } } override def getCatalystType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 40333c1757c4a..79ac248d723e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.jdbc import java.sql.{Date, Timestamp, Types} import java.util.{Locale, TimeZone} +import scala.util.control.NonFatal + import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -34,36 +36,40 @@ private case object OracleDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle") + private val distinctUnsupportedAggregateFunctions = + Set("VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP", "CORR", + "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY") + // scalastyle:off line.size.limit // https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/Aggregate-Functions.html#GUID-62BE676B-AF18-4E63-BD14-25206FEA0848 // scalastyle:on line.size.limit - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_SAMP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_SAMP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVAR_POP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"CORR(${f.children().head}, ${f.children().last})") - case _ => None + private val supportedAggregateFunctions = + Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ distinctUnsupportedAggregateFunctions + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) + + class OracleSQLBuilder extends JDBCSQLBuilder { + override def visitAggregateFunction( + funcName: String, isDistinct: Boolean, inputs: Array[String]): String = + if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { + throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " + + s"support aggregate function: $funcName with DISTINCT"); + } else { + super.visitAggregateFunction(funcName, isDistinct, inputs) } - ) + } + + override def compileExpression(expr: Expression): Option[String] = { + val oracleSQLBuilder = new OracleSQLBuilder() + try { + Some(oracleSQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } } private def supportTimeZoneTypes: Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index a668d66ee2f9a..6800b5f298c1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NonEmptyNamespaceException, NoSuchIndexException} import org.apache.spark.sql.connector.expressions.NamedReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ @@ -37,41 +36,13 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql") // See https://www.postgresql.org/docs/8.4/functions-aggregate.html - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_POP($distinct${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_SAMP($distinct${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"CORR($distinct${f.children().head}, ${f.children().last})") - case _ => None - } - ) - } + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP", "CORR", + "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 79fb710cf03b3..fba5316c478a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.jdbc import java.util.Locale -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.types._ @@ -31,38 +30,12 @@ private case object TeradataDialect extends JdbcDialect { // scalastyle:off line.size.limit // See https://docs.teradata.com/r/Teradata-VantageTM-SQL-Functions-Expressions-and-Predicates/March-2019/Aggregate-Functions // scalastyle:on line.size.limit - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVAR_POP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"CORR(${f.children().head}, ${f.children().last})") - case _ => None - } - ) - } + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP", "CORR") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index 63e43a3f46cd7..7a2f2700c582d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.unsafe.types.UTF8String object IntAverage extends AggregateFunction[(Int, Int), Int] { override def name(): String = "iavg" + override def canonicalName(): String = "h2.iavg" override def inputTypes(): Array[DataType] = Array(IntegerType) override def resultType(): DataType = IntegerType @@ -63,6 +64,7 @@ object IntAverage extends AggregateFunction[(Int, Int), Int] { object LongAverage extends AggregateFunction[(Long, Long), Long] { override def name(): String = "iavg" + override def canonicalName(): String = "h2.iavg" override def inputTypes(): Array[DataType] = Array(LongType) override def resultType(): DataType = LongType @@ -111,6 +113,24 @@ object IntegralAverage extends UnboundFunction { | iavg(bigint) -> bigint""".stripMargin } +case class StrLen(impl: BoundFunction) extends UnboundFunction { + override def name(): String = "strlen" + + override def bind(inputType: StructType): BoundFunction = { + if (inputType.fields.length != 1) { + throw new UnsupportedOperationException("Expect exactly one argument"); + } + inputType.fields(0).dataType match { + case StringType => impl + case _ => + throw new UnsupportedOperationException("Expect StringType") + } + } + + override def description(): String = + "strlen: returns the length of the input string strlen(string) -> int" +} + class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { private val emptyProps: java.util.Map[String, String] = Collections.emptyMap[String, String] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 4c2f67166b9bb..be2c8ce057559 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -20,16 +20,22 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, DriverManager} import java.util.Properties +import scala.util.control.NonFatal + import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.sql.{AnalysisException, DataFrame, ExplainSuiteHelper, QueryTest, Row} +import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Offset, Sort} -import org.apache.spark.sql.connector.IntegralAverage +import org.apache.spark.sql.connector.{IntegralAverage, StrLen} +import org.apache.spark.sql.connector.catalog.functions.{ScalarFunction, UnboundFunction} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{abs, acos, asin, atan, atan2, avg, ceil, coalesce, cos, cosh, count, count_distinct, degrees, exp, floor, lit, log => logarithm, log10, not, pow, radians, round, signum, sin, sinh, sqrt, sum, tan, tanh, udf, when} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} import org.apache.spark.util.Utils class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHelper { @@ -39,6 +45,63 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass" var conn: java.sql.Connection = null + val testH2Dialect = new JdbcDialect { + override def canHandle(url: String): Boolean = H2Dialect.canHandle(url) + + class H2SQLBuilder extends JDBCSQLBuilder { + override def visitUserDefinedScalarFunction( + funcName: String, canonicalName: String, inputs: Array[String]): String = { + canonicalName match { + case "h2.char_length" => + s"$funcName(${inputs.mkString(", ")})" + case _ => super.visitUserDefinedScalarFunction(funcName, canonicalName, inputs) + } + } + + override def visitUserDefinedAggregateFunction( + funcName: String, + canonicalName: String, + isDistinct: Boolean, + inputs: Array[String]): String = { + canonicalName match { + case "h2.iavg" => + if (isDistinct) { + s"AVG(DISTINCT ${inputs.mkString(", ")})" + } else { + s"AVG(${inputs.mkString(", ")})" + } + case _ => + super.visitUserDefinedAggregateFunction(funcName, canonicalName, isDistinct, inputs) + } + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val h2SQLBuilder = new H2SQLBuilder() + try { + Some(h2SQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } + } + + override def functions: Seq[(String, UnboundFunction)] = H2Dialect.functions + } + + case object CharLength extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "CHAR_LENGTH" + override def canonicalName(): String = "h2.char_length" + + override def produceResult(input: InternalRow): Int = { + val s = input.getString(0) + s.length + } + } + override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.h2.url", url) @@ -116,6 +179,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate() } H2Dialect.registerFunction("my_avg", IntegralAverage) + H2Dialect.registerFunction("my_strlen", StrLen(CharLength)) } override def afterAll(): Unit = { @@ -200,9 +264,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .limit(1) - checkLimitRemoved(df4, false) + checkAggregateRemoved(df4) + checkLimitRemoved(df4) checkPushedInfo(df4, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedLimit: LIMIT 1") checkAnswer(df4, Seq(Row(1, 19000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -275,9 +343,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .offset(1) - checkOffsetRemoved(df5, false) + checkAggregateRemoved(df5) + checkLimitRemoved(df5) checkPushedInfo(df5, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1") checkAnswer(df5, Seq(Row(2, 22000.00), Row(6, 12000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -412,10 +484,15 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .groupBy("DEPT").sum("SALARY") .limit(2) .offset(1) - checkLimitRemoved(df10, false) - checkOffsetRemoved(df10, false) + checkAggregateRemoved(df10) + checkLimitRemoved(df10) + checkOffsetRemoved(df10) checkPushedInfo(df10, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedLimit: LIMIT 2", + "PushedOffset: OFFSET 1") checkAnswer(df10, Seq(Row(2, 22000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -547,10 +624,15 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df9, Seq(Row(2, "david", 10000.00, 1300.0, true))) val df10 = sql("SELECT dept, sum(salary) FROM h2.test.employee group by dept LIMIT 1 OFFSET 1") - checkLimitRemoved(df10, false) - checkOffsetRemoved(df10, false) + checkAggregateRemoved(df10) + checkLimitRemoved(df10) + checkOffsetRemoved(df10) checkPushedInfo(df10, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedLimit: LIMIT 2", + "PushedOffset: OFFSET 1") checkAnswer(df10, Seq(Row(2, 22000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -791,11 +873,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df11 = sql( """ |SELECT * FROM h2.test.employee - |WHERE GREATEST(bonus, 1100) > 1200 AND LEAST(salary, 10000) > 9000 AND RAND(1) < 1 + |WHERE GREATEST(bonus, 1100) > 1200 AND RAND(1) < bonus |""".stripMargin) checkFiltersRemoved(df11) checkPushedInfo(df11, "PushedFilters: " + - "[(GREATEST(BONUS, 1100.0)) > 1200.0, (LEAST(SALARY, 10000.00)) > 9000.00, RAND(1) < 1.0]") + "[BONUS IS NOT NULL, (GREATEST(BONUS, 1100.0)) > 1200.0, RAND(1) < BONUS]") checkAnswer(df11, Row(2, "david", 10000, 1300, true)) val df12 = sql( @@ -1051,6 +1133,33 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df8, Seq(Row("alex"))) } + test("scan with filter push-down with UDF") { + JdbcDialects.unregisterDialect(H2Dialect) + try { + JdbcDialects.registerDialect(testH2Dialect) + val df1 = sql("SELECT * FROM h2.test.people where h2.my_strlen(name) > 2") + checkFiltersRemoved(df1) + checkPushedInfo(df1, "PushedFilters: [CHAR_LENGTH(NAME) > 2],") + checkAnswer(df1, Seq(Row("fred", 1), Row("mary", 2))) + + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val df2 = sql( + """ + |SELECT * + |FROM h2.test.people + |WHERE h2.my_strlen(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2 + """.stripMargin) + checkFiltersRemoved(df2) + checkPushedInfo(df2, + "PushedFilters: [CHAR_LENGTH(CASE WHEN NAME = 'fred' THEN NAME ELSE 'abc' END) > 2],") + checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) + } + } finally { + JdbcDialects.unregisterDialect(testH2Dialect) + JdbcDialects.registerDialect(H2Dialect) + } + } + test("scan with column pruning") { val df = spark.table("h2.test.people").select("id") checkSchemaNames(df, Seq("ID")) @@ -1620,43 +1729,81 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } test("scan with aggregate push-down: VAR_POP VAR_SAMP with filter and group by") { - val df = sql("SELECT VAR_POP(bonus), VAR_SAMP(bonus) FROM h2.test.employee WHERE dept > 0" + - " GROUP BY DePt") + val df = sql( + """ + |SELECT + | VAR_POP(bonus), + | VAR_POP(DISTINCT bonus), + | VAR_SAMP(bonus), + | VAR_SAMP(DISTINCT bonus) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) checkFiltersRemoved(df) checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") - checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + checkPushedInfo(df, + """ + |PushedAggregates: [VAR_POP(BONUS), VAR_POP(DISTINCT BONUS), + |VAR_SAMP(BONUS), VAR_SAMP(DISTINCT BONUS)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df, Seq(Row(10000d, 10000d, 20000d, 20000d), + Row(2500d, 2500d, 5000d, 5000d), Row(0d, 0d, null, null))) } test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP with filter and group by") { - val df = sql("SELECT STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" + - " WHERE dept > 0 GROUP BY DePt") + val df = sql( + """ + |SELECT + | STDDEV_POP(bonus), + | STDDEV_POP(DISTINCT bonus), + | STDDEV_SAMP(bonus), + | STDDEV_SAMP(DISTINCT bonus) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) checkFiltersRemoved(df) checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") - checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null))) + checkPushedInfo(df, + """ + |PushedAggregates: [STDDEV_POP(BONUS), STDDEV_POP(DISTINCT BONUS), + |STDDEV_SAMP(BONUS), STDDEV_SAMP(DISTINCT BONUS)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df, Seq(Row(100d, 100d, 141.4213562373095d, 141.4213562373095d), + Row(50d, 50d, 70.71067811865476d, 70.71067811865476d), Row(0d, 0d, null, null))) } test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") { - val df = sql("SELECT COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" + + val df1 = sql("SELECT COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" + " FROM h2.test.employee WHERE dept > 0 GROUP BY DePt") - checkFiltersRemoved(df) - checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + + checkFiltersRemoved(df1) + checkAggregateRemoved(df1) + checkPushedInfo(df1, "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") - checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + checkAnswer(df1, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + + val df2 = sql("SELECT COVAR_POP(DISTINCT bonus, bonus), COVAR_SAMP(DISTINCT bonus, bonus)" + + " FROM h2.test.employee WHERE dept > 0 GROUP BY DePt") + checkFiltersRemoved(df2) + checkAggregateRemoved(df2, false) + checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]") + checkAnswer(df2, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) } test("scan with aggregate push-down: CORR with filter and group by") { - val df = sql("SELECT CORR(bonus, bonus) FROM h2.test.employee WHERE dept > 0" + + val df1 = sql("SELECT CORR(bonus, bonus) FROM h2.test.employee WHERE dept > 0" + " GROUP BY DePt") - checkFiltersRemoved(df) - checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [CORR(BONUS, BONUS)], " + + checkFiltersRemoved(df1) + checkAggregateRemoved(df1) + checkPushedInfo(df1, "PushedAggregates: [CORR(BONUS, BONUS)], " + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") - checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) + checkAnswer(df1, Seq(Row(1d), Row(1d), Row(null))) + + val df2 = sql("SELECT CORR(DISTINCT bonus, bonus) FROM h2.test.employee WHERE dept > 0" + + " GROUP BY DePt") + checkFiltersRemoved(df2) + checkAggregateRemoved(df2, false) + checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]") + checkAnswer(df2, Seq(Row(1d), Row(1d), Row(null))) } test("scan with aggregate push-down: aggregate over alias push down") { @@ -1984,17 +2131,51 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } - test("register dialect specific functions") { - val df = sql("SELECT h2.my_avg(id) FROM h2.test.people") - checkAggregateRemoved(df, false) - checkAnswer(df, Row(1) :: Nil) - val e1 = intercept[AnalysisException] { - checkAnswer(sql("SELECT h2.test.my_avg2(id) FROM h2.test.people"), Seq.empty) - } - assert(e1.getMessage.contains("Undefined function: 'my_avg2'")) - val e2 = intercept[AnalysisException] { - checkAnswer(sql("SELECT h2.my_avg2(id) FROM h2.test.people"), Seq.empty) + test("scan with aggregate push-down: complete push-down UDAF") { + JdbcDialects.unregisterDialect(H2Dialect) + try { + JdbcDialects.registerDialect(testH2Dialect) + val df1 = sql("SELECT h2.my_avg(id) FROM h2.test.people") + checkAggregateRemoved(df1) + checkPushedInfo(df1, + "PushedAggregates: [iavg(ID)], PushedFilters: [], PushedGroupByExpressions: []") + checkAnswer(df1, Seq(Row(1))) + + val df2 = sql("SELECT name, h2.my_avg(id) FROM h2.test.people group by name") + checkAggregateRemoved(df2) + checkPushedInfo(df2, + "PushedAggregates: [iavg(ID)], PushedFilters: [], PushedGroupByExpressions: [NAME]") + checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val df3 = sql( + """ + |SELECT + | h2.my_avg(CASE WHEN NAME = 'fred' THEN id + 1 ELSE id END) + |FROM h2.test.people + """.stripMargin) + checkAggregateRemoved(df3) + checkPushedInfo(df3, + "PushedAggregates: [iavg(CASE WHEN NAME = 'fred' THEN ID + 1 ELSE ID END)]," + + " PushedFilters: [], PushedGroupByExpressions: []") + checkAnswer(df3, Seq(Row(2))) + + val df4 = sql( + """ + |SELECT + | name, + | h2.my_avg(CASE WHEN NAME = 'fred' THEN id + 1 ELSE id END) + |FROM h2.test.people + |GROUP BY name + """.stripMargin) + checkAggregateRemoved(df4) + checkPushedInfo(df4, + "PushedAggregates: [iavg(CASE WHEN NAME = 'fred' THEN ID + 1 ELSE ID END)]," + + " PushedFilters: [], PushedGroupByExpressions: [NAME]") + checkAnswer(df4, Seq(Row("fred", 2), Row("mary", 2))) + } + } finally { + JdbcDialects.unregisterDialect(testH2Dialect) + JdbcDialects.registerDialect(H2Dialect) } - assert(e2.getMessage.contains("Undefined function: my_avg2")) } }