diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index 35711e57d0b72..e66e2143dfa0f 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -106,4 +106,8 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { testStddevSamp(true) testCovarPop() testCovarSamp() + testRegrIntercept() + testRegrSlope() + testRegrR2() + testRegrSXY() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index b38f2675243e6..6d03b73766975 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -107,4 +107,8 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes testCovarPop() testCovarSamp() testCorr() + testRegrIntercept() + testRegrSlope() + testRegrR2() + testRegrSXY() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index d76e13c1cd421..c0822ab87d00c 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -104,4 +104,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT testCovarSamp(true) testCorr() testCorr(true) + testRegrIntercept() + testRegrIntercept(true) + testRegrSlope() + testRegrSlope(true) + testRegrR2() + testRegrR2(true) + testRegrSXY() + testRegrSXY(true) } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 7cab8cd77df66..d11d35f3ef8b7 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -386,9 +386,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu protected def caseConvert(tableName: String): String = tableName + private def withOrWithout(isDistinct: Boolean): String = if (isDistinct) "with" else "without" + protected def testVarPop(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: VAR_POP with distinct: $isDistinct") { + test(s"scan with aggregate push-down: VAR_POP ${withOrWithout(isDistinct)} DISTINCT") { val df = sql(s"SELECT VAR_POP(${distinct}bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) @@ -396,15 +398,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "VAR_POP") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 10000d) - assert(row(1).getDouble(0) === 2500d) - assert(row(2).getDouble(0) === 0d) + assert(row(0).getDouble(0) === 10000.0) + assert(row(1).getDouble(0) === 2500.0) + assert(row(2).getDouble(0) === 0.0) } } protected def testVarSamp(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: VAR_SAMP with distinct: $isDistinct") { + test(s"scan with aggregate push-down: VAR_SAMP ${withOrWithout(isDistinct)} DISTINCT") { val df = sql( s"SELECT VAR_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -413,15 +415,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "VAR_SAMP") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 20000d) - assert(row(1).getDouble(0) === 5000d) + assert(row(0).getDouble(0) === 20000.0) + assert(row(1).getDouble(0) === 5000.0) assert(row(2).isNullAt(0)) } } protected def testStddevPop(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: STDDEV_POP with distinct: $isDistinct") { + test(s"scan with aggregate push-down: STDDEV_POP ${withOrWithout(isDistinct)} DISTINCT") { val df = sql( s"SELECT STDDEV_POP(${distinct}bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -430,15 +432,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "STDDEV_POP") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 100d) - assert(row(1).getDouble(0) === 50d) - assert(row(2).getDouble(0) === 0d) + assert(row(0).getDouble(0) === 100.0) + assert(row(1).getDouble(0) === 50.0) + assert(row(2).getDouble(0) === 0.0) } } protected def testStddevSamp(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: STDDEV_SAMP with distinct: $isDistinct") { + test(s"scan with aggregate push-down: STDDEV_SAMP ${withOrWithout(isDistinct)} DISTINCT") { val df = sql( s"SELECT STDDEV_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -447,15 +449,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "STDDEV_SAMP") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 141.4213562373095d) - assert(row(1).getDouble(0) === 70.71067811865476d) + assert(row(0).getDouble(0) === 141.4213562373095) + assert(row(1).getDouble(0) === 70.71067811865476) assert(row(2).isNullAt(0)) } } protected def testCovarPop(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: COVAR_POP with distinct: $isDistinct") { + test(s"scan with aggregate push-down: COVAR_POP ${withOrWithout(isDistinct)} DISTINCT") { val df = sql( s"SELECT COVAR_POP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -464,15 +466,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "COVAR_POP") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 10000d) - assert(row(1).getDouble(0) === 2500d) - assert(row(2).getDouble(0) === 0d) + assert(row(0).getDouble(0) === 10000.0) + assert(row(1).getDouble(0) === 2500.0) + assert(row(2).getDouble(0) === 0.0) } } protected def testCovarSamp(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: COVAR_SAMP with distinct: $isDistinct") { + test(s"scan with aggregate push-down: COVAR_SAMP ${withOrWithout(isDistinct)} DISTINCT") { val df = sql( s"SELECT COVAR_SAMP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -481,15 +483,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "COVAR_SAMP") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 20000d) - assert(row(1).getDouble(0) === 5000d) + assert(row(0).getDouble(0) === 20000.0) + assert(row(1).getDouble(0) === 5000.0) assert(row(2).isNullAt(0)) } } protected def testCorr(isDistinct: Boolean = false): Unit = { val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: CORR with distinct: $isDistinct") { + test(s"scan with aggregate push-down: CORR ${withOrWithout(isDistinct)} DISTINCT") { val df = sql( s"SELECT CORR(${distinct}bonus, bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -498,9 +500,77 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu checkAggregatePushed(df, "CORR") val row = df.collect() assert(row.length === 3) - assert(row(0).getDouble(0) === 1d) - assert(row(1).getDouble(0) === 1d) + assert(row(0).getDouble(0) === 1.0) + assert(row(1).getDouble(0) === 1.0) + assert(row(2).isNullAt(0)) + } + } + + protected def testRegrIntercept(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: REGR_INTERCEPT ${withOrWithout(isDistinct)} DISTINCT") { + val df = sql( + s"SELECT REGR_INTERCEPT(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "REGR_INTERCEPT") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 0.0) + assert(row(1).getDouble(0) === 0.0) + assert(row(2).isNullAt(0)) + } + } + + protected def testRegrSlope(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: REGR_SLOPE ${withOrWithout(isDistinct)} DISTINCT") { + val df = sql( + s"SELECT REGR_SLOPE(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "REGR_SLOPE") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 1.0) + assert(row(1).getDouble(0) === 1.0) + assert(row(2).isNullAt(0)) + } + } + + protected def testRegrR2(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: REGR_R2 ${withOrWithout(isDistinct)} DISTINCT") { + val df = sql( + s"SELECT REGR_R2(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "REGR_R2") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 1.0) + assert(row(1).getDouble(0) === 1.0) assert(row(2).isNullAt(0)) } } + + protected def testRegrSXY(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: REGR_SXY ${withOrWithout(isDistinct)} DISTINCT") { + val df = sql( + s"SELECT REGR_SXY(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "REGR_SXY") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 20000.0) + assert(row(1).getDouble(0) === 5000.0) + assert(row(2).getDouble(0) === 0.0) + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index 17e24fa7ad8da..6dfaad0d26eb4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.filter.Predicate; -import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder; +import org.apache.spark.sql.internal.connector.ToStringSQLBuilder; /** * The general representation of SQL scalar expressions, which contains the upper-cased @@ -398,12 +398,7 @@ public int hashCode() { @Override public String toString() { - V2ExpressionSQLBuilder builder = new V2ExpressionSQLBuilder(); - try { - return builder.build(this); - } catch (Throwable e) { - return name + "(" + - Arrays.stream(children).map(child -> child.toString()).reduce((a,b) -> a + "," + b) + ")"; - } + ToStringSQLBuilder builder = new ToStringSQLBuilder(); + return builder.build(this); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java new file mode 100644 index 0000000000000..8e4155f81b87b --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Objects; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.internal.connector.ToStringSQLBuilder; + +/** + * The general representation of user defined scalar function, which contains the upper-cased + * function name, canonical function name and all the children expressions. + * + * @since 3.4.0 + */ +@Evolving +public class UserDefinedScalarFunc implements Expression, Serializable { + private String name; + private String canonicalName; + private Expression[] children; + + public UserDefinedScalarFunc(String name, String canonicalName, Expression[] children) { + this.name = name; + this.canonicalName = canonicalName; + this.children = children; + } + + public String name() { return name; } + public String canonicalName() { return canonicalName; } + + @Override + public Expression[] children() { return children; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + UserDefinedScalarFunc that = (UserDefinedScalarFunc) o; + return Objects.equals(name, that.name) && Objects.equals(canonicalName, that.canonicalName) && + Arrays.equals(children, that.children); + } + + @Override + public int hashCode() { + return Objects.hash(name, canonicalName, children); + } + + @Override + public String toString() { + ToStringSQLBuilder builder = new ToStringSQLBuilder(); + return builder.build(this); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java index 7016644543447..81838074fb136 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -17,11 +17,9 @@ package org.apache.spark.sql.connector.expressions.aggregate; -import java.util.Arrays; -import java.util.stream.Collectors; - import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.internal.connector.ToStringSQLBuilder; /** * The general implementation of {@link AggregateFunc}, which contains the upper-cased function @@ -47,27 +45,21 @@ public final class GeneralAggregateFunc implements AggregateFunc { private final boolean isDistinct; private final Expression[] children; - public String name() { return name; } - public boolean isDistinct() { return isDistinct; } - public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) { this.name = name; this.isDistinct = isDistinct; this.children = children; } + public String name() { return name; } + public boolean isDistinct() { return isDistinct; } + @Override public Expression[] children() { return children; } @Override public String toString() { - String inputsString = Arrays.stream(children) - .map(Expression::describe) - .collect(Collectors.joining(", ")); - if (isDistinct) { - return name + "(DISTINCT " + inputsString + ")"; - } else { - return name + "(" + inputsString + ")"; - } + ToStringSQLBuilder builder = new ToStringSQLBuilder(); + return builder.build(this); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java new file mode 100644 index 0000000000000..9a89e7a89c9f9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.internal.connector.ToStringSQLBuilder; + +/** + * The general representation of user defined aggregate function, which implements + * {@link AggregateFunc}, contains the upper-cased function name, the canonical function name, + * the `isDistinct` flag and all the inputs. Note that Spark cannot push down aggregate with + * this function partially to the source, but can only push down the entire aggregate. + * + * @since 3.4.0 + */ +@Evolving +public class UserDefinedAggregateFunc implements AggregateFunc { + private final String name; + private String canonicalName; + private final boolean isDistinct; + private final Expression[] children; + + public UserDefinedAggregateFunc( + String name, String canonicalName, boolean isDistinct, Expression[] children) { + this.name = name; + this.canonicalName = canonicalName; + this.isDistinct = isDistinct; + this.children = children; + } + + public String name() { return name; } + public String canonicalName() { return canonicalName; } + public boolean isDistinct() { return isDistinct; } + + @Override + public Expression[] children() { return children; } + + @Override + public String toString() { + ToStringSQLBuilder builder = new ToStringSQLBuilder(); + return builder.build(this); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 528d7ea4cdb1d..60708ede19c8f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -27,6 +27,15 @@ import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.expressions.GeneralScalarExpression; import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Avg; +import org.apache.spark.sql.connector.expressions.aggregate.Max; +import org.apache.spark.sql.connector.expressions.aggregate.Min; +import org.apache.spark.sql.connector.expressions.aggregate.Count; +import org.apache.spark.sql.connector.expressions.aggregate.CountStar; +import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Sum; +import org.apache.spark.sql.connector.expressions.aggregate.UserDefinedAggregateFunc; import org.apache.spark.sql.types.DataType; /** @@ -161,6 +170,40 @@ public String build(Expression expr) { default: return visitUnexpectedExpr(expr); } + } else if (expr instanceof Min) { + Min min = (Min) expr; + return visitAggregateFunction("MIN", false, + Arrays.stream(min.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof Max) { + Max max = (Max) expr; + return visitAggregateFunction("MAX", false, + Arrays.stream(max.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof Count) { + Count count = (Count) expr; + return visitAggregateFunction("COUNT", count.isDistinct(), + Arrays.stream(count.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof Sum) { + Sum sum = (Sum) expr; + return visitAggregateFunction("SUM", sum.isDistinct(), + Arrays.stream(sum.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof CountStar) { + return visitAggregateFunction("COUNT", false, new String[]{"*"}); + } else if (expr instanceof Avg) { + Avg avg = (Avg) expr; + return visitAggregateFunction("AVG", avg.isDistinct(), + Arrays.stream(avg.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof GeneralAggregateFunc) { + GeneralAggregateFunc f = (GeneralAggregateFunc) expr; + return visitAggregateFunction(f.name(), f.isDistinct(), + Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof UserDefinedScalarFunc) { + UserDefinedScalarFunc f = (UserDefinedScalarFunc) expr; + return visitUserDefinedScalarFunction(f.name(), f.canonicalName(), + Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new)); + } else if (expr instanceof UserDefinedAggregateFunc) { + UserDefinedAggregateFunc f = (UserDefinedAggregateFunc) expr; + return visitUserDefinedAggregateFunction(f.name(), f.canonicalName(), f.isDistinct(), + Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new)); } else { return visitUnexpectedExpr(expr); } @@ -273,6 +316,28 @@ protected String visitSQLFunction(String funcName, String[] inputs) { return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")"; } + protected String visitAggregateFunction( + String funcName, boolean isDistinct, String[] inputs) { + if (isDistinct) { + return funcName + + "(DISTINCT " + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")"; + } else { + return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")"; + } + } + + protected String visitUserDefinedScalarFunction( + String funcName, String canonicalName, String[] inputs) { + throw new UnsupportedOperationException( + this.getClass().getSimpleName() + " does not support user defined function: " + funcName); + } + + protected String visitUserDefinedAggregateFunction( + String funcName, String canonicalName, boolean isDistinct, String[] inputs) { + throw new UnsupportedOperationException(this.getClass().getSimpleName() + + " does not support user defined aggregate function: " + funcName); + } + protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException { throw new IllegalArgumentException("Unexpected V2 expression: " + expr); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala new file mode 100644 index 0000000000000..889fdd4ebf291 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal.connector + +import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder + +/** + * The builder to generate `toString` information of V2 expressions. + */ +class ToStringSQLBuilder extends V2ExpressionSQLBuilder { + override protected def visitUserDefinedScalarFunction( + funcName: String, canonicalName: String, inputs: Array[String]) = + s"""$funcName(${inputs.mkString(", ")})""" + + override protected def visitUserDefinedAggregateFunction( + funcName: String, + canonicalName: String, + isDistinct: Boolean, + inputs: Array[String]): String = { + val distinct = if (isDistinct) "DISTINCT " else "" + s"""$funcName($distinct${inputs.mkString(", ")})""" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 041e7f369c513..8bb65a8804471 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue} +import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.types.{BooleanType, IntegerType} @@ -398,6 +398,14 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { case YearOfWeek(child) => generateExpression(child).map(v => new V2Extract("YEAR_OF_WEEK", v)) // TODO supports other expressions + case ApplyFunctionExpression(function, children) => + val childrenExpressions = children.flatMap(generateExpression(_)) + if (childrenExpressions.length == children.length) { + Some(new UserDefinedScalarFunc( + function.name(), function.canonicalName(), childrenExpressions.toArray[V2Expression])) + } else { + None + } case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 8d8e2c26e279e..5709e2e1484df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -738,6 +738,14 @@ object DataSourceStrategy PushableColumnWithoutNestedColumn(right), _) => Some(new GeneralAggregateFunc("CORR", agg.isDistinct, Array(FieldReference(left), FieldReference(right)))) + case aggregate.V2Aggregator(aggrFunc, children, _, _) => + val translatedExprs = children.flatMap(PushableExpression.unapply(_)) + if (translatedExprs.length == children.length) { + Some(new UserDefinedAggregateFunc(aggrFunc.name(), + aggrFunc.canonicalName(), agg.isDistinct, translatedExprs.toArray[V2Expression])) + } else { + None + } case _ => None } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index f69cf937b099b..e90f59f310fcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -26,12 +26,12 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder} -import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructType} +import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, StructType} import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper { @@ -44,6 +44,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit pushDownFilters, pushDownAggregates, pushDownLimitAndOffset, + buildScanWithPushedAggregate, pruneColumns) pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) => @@ -92,189 +93,201 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform { // update the scan builder with agg pushdown and return a new plan with agg pushed - case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => - child match { - case ScanOperation(project, filters, sHolder: ScanBuilderHolder) - if filters.isEmpty && CollapseProject.canCollapseExpressions( - resultExpressions, project, alwaysInline = true) => - sHolder.builder match { - case r: SupportsPushDownAggregates => - val aliasMap = getAliasMap(project) - val actualResultExprs = resultExpressions.map(replaceAliasButKeepName(_, aliasMap)) - val actualGroupExprs = groupingExpressions.map(replaceAlias(_, aliasMap)) - - val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal) - val normalizedAggregates = DataSourceStrategy.normalizeExprs( - aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] - val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs( - actualGroupExprs, sHolder.relation.output) - val translatedAggregates = DataSourceStrategy.translateAggregation( - normalizedAggregates, normalizedGroupingExpressions) - val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = { - if (translatedAggregates.isEmpty || - r.supportCompletePushDown(translatedAggregates.get) || - translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) { - (actualResultExprs, aggregates, translatedAggregates) - } else { - // scalastyle:off - // The data source doesn't support the complete push-down of this aggregation. - // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be - // pushed, completely or partially. - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT avg(c1) FROM t GROUP BY c2; - // The original logical plan is - // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19] - // +- ScanOperation[...] - // - // After convert avg(c1#9) to sum(c1#9)/count(c1#9) - // we have the following - // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19] - // +- ScanOperation[...] - // scalastyle:on - val newResultExpressions = actualResultExprs.map { expr => - expr.transform { - case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => - val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) - val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct) - avg.evaluateExpression transform { - case a: Attribute if a.semanticEquals(avg.sum) => - addCastIfNeeded(sum, avg.sum.dataType) - case a: Attribute if a.semanticEquals(avg.count) => - addCastIfNeeded(count, avg.count.dataType) - } - } - }.asInstanceOf[Seq[NamedExpression]] - // Because aggregate expressions changed, translate them again. - aggExprToOutputOrdinal.clear() - val newAggregates = - collectAggregates(newResultExpressions, aggExprToOutputOrdinal) - val newNormalizedAggregates = DataSourceStrategy.normalizeExprs( - newAggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] - (newResultExpressions, newAggregates, DataSourceStrategy.translateAggregation( - newNormalizedAggregates, normalizedGroupingExpressions)) + case agg: Aggregate => rewriteAggregate(agg) + } + + private def rewriteAggregate(agg: Aggregate): LogicalPlan = agg.child match { + case ScanOperation(project, Nil, holder @ ScanBuilderHolder(_, _, + r: SupportsPushDownAggregates)) if CollapseProject.canCollapseExpressions( + agg.aggregateExpressions, project, alwaysInline = true) => + val aliasMap = getAliasMap(project) + val actualResultExprs = agg.aggregateExpressions.map(replaceAliasButKeepName(_, aliasMap)) + val actualGroupExprs = agg.groupingExpressions.map(replaceAlias(_, aliasMap)) + + val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] + val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal) + val normalizedAggExprs = DataSourceStrategy.normalizeExprs( + aggregates, holder.relation.output).asInstanceOf[Seq[AggregateExpression]] + val normalizedGroupingExpr = DataSourceStrategy.normalizeExprs( + actualGroupExprs, holder.relation.output) + val translatedAggOpt = DataSourceStrategy.translateAggregation( + normalizedAggExprs, normalizedGroupingExpr) + if (translatedAggOpt.isEmpty) { + // Cannot translate the catalyst aggregate, return the query plan unchanged. + return agg + } + + val (finalResultExprs, finalAggExprs, translatedAgg, canCompletePushDown) = { + if (r.supportCompletePushDown(translatedAggOpt.get)) { + (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, true) + } else if (!translatedAggOpt.get.aggregateExpressions().exists(_.isInstanceOf[Avg])) { + (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, false) + } else { + // scalastyle:off + // The data source doesn't support the complete push-down of this aggregation. + // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be + // pushed, completely or partially. + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT avg(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // + // After convert avg(c1#9) to sum(c1#9)/count(c1#9) + // we have the following + // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // scalastyle:on + val newResultExpressions = actualResultExprs.map { expr => + expr.transform { + case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => + val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) + val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct) + avg.evaluateExpression transform { + case a: Attribute if a.semanticEquals(avg.sum) => + addCastIfNeeded(sum, avg.sum.dataType) + case a: Attribute if a.semanticEquals(avg.count) => + addCastIfNeeded(count, avg.count.dataType) } - } + } + }.asInstanceOf[Seq[NamedExpression]] + // Because aggregate expressions changed, translate them again. + aggExprToOutputOrdinal.clear() + val newAggregates = + collectAggregates(newResultExpressions, aggExprToOutputOrdinal) + val newNormalizedAggExprs = DataSourceStrategy.normalizeExprs( + newAggregates, holder.relation.output).asInstanceOf[Seq[AggregateExpression]] + val newTranslatedAggOpt = DataSourceStrategy.translateAggregation( + newNormalizedAggExprs, normalizedGroupingExpr) + if (newTranslatedAggOpt.isEmpty) { + // Ideally we should never reach here. But if we end up with not able to translate + // new aggregate with AVG replaced by SUM/COUNT, revert to the original one. + (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, false) + } else { + (newResultExpressions, newNormalizedAggExprs, newTranslatedAggOpt.get, + r.supportCompletePushDown(newTranslatedAggOpt.get)) + } + } + } - if (finalTranslatedAggregates.isEmpty) { - aggNode // return original plan node - } else if (!r.supportCompletePushDown(finalTranslatedAggregates.get) && - !supportPartialAggPushDown(finalTranslatedAggregates.get)) { - aggNode // return original plan node - } else { - val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation) - if (pushedAggregates.isEmpty) { - aggNode // return original plan node - } else { - // No need to do column pruning because only the aggregate columns are used as - // DataSourceV2ScanRelation output columns. All the other columns are not - // included in the output. - val scan = sHolder.builder.build() - - // scalastyle:off - // use the group by columns and aggregate columns as the output columns - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t GROUP BY c2; - // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation - // We want to have the following logical plan: - // == Optimized Logical Plan == - // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] - // scalastyle:on - val newOutput = scan.readSchema().toAttributes - assert(newOutput.length == groupingExpressions.length + finalAggregates.length) - val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - val groupAttrs = normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map { - case ((a: Attribute, b: Attribute), _) => b.withExprId(a.exprId) - case ((expr, attr), ordinal) => - if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) { - groupByExprToOutputOrdinal(expr.canonicalized) = ordinal - } - attr - } - val aggOutput = newOutput.drop(groupAttrs.length) - val output = groupAttrs ++ aggOutput - - logInfo( - s""" - |Pushing operators to ${sHolder.relation.name} - |Pushed Aggregate Functions: - | ${pushedAggregates.get.aggregateExpressions.mkString(", ")} - |Pushed Group by: - | ${pushedAggregates.get.groupByExpressions.mkString(", ")} - |Output: ${output.mkString(", ")} - """.stripMargin) - - val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) - val scanRelation = - DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) - if (r.supportCompletePushDown(pushedAggregates.get)) { - val projectExpressions = finalResultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression => - val ordinal = aggExprToOutputOrdinal(agg.canonicalized) - val child = - addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType) - Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId) - case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => - val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) - addCastIfNeeded(groupAttrs(ordinal), expr.dataType) - } - }.asInstanceOf[Seq[NamedExpression]] - Project(projectExpressions, scanRelation) + if (!canCompletePushDown && !supportPartialAggPushDown(translatedAgg)) { + return agg + } + if (!r.pushAggregation(translatedAgg)) { + return agg + } + + // scalastyle:off + // We name the output columns of group expressions and aggregate functions by + // ordinal: `group_col_0`, `group_col_1`, ..., `agg_func_0`, `agg_func_1`, ... + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // Use group_col_0, agg_func_0, agg_func_1 as output for ScanBuilderHolder. + // We want to have the following logical plan: + // == Optimized Logical Plan == + // Aggregate [group_col_0#10], [min(agg_func_0#21) AS min(c1)#17, max(agg_func_1#22) AS max(c1)#18] + // +- ScanBuilderHolder[group_col_0#10, agg_func_0#21, agg_func_1#22] + // Later, we build the `Scan` instance and convert ScanBuilderHolder to DataSourceV2ScanRelation. + // scalastyle:on + val groupOutput = normalizedGroupingExpr.zipWithIndex.map { case (e, i) => + AttributeReference(s"group_col_$i", e.dataType)() + } + val aggOutput = finalAggExprs.zipWithIndex.map { case (e, i) => + AttributeReference(s"agg_func_$i", e.dataType)() + } + val newOutput = groupOutput ++ aggOutput + val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] + normalizedGroupingExpr.zipWithIndex.foreach { case (expr, ordinal) => + if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) { + groupByExprToOutputOrdinal(expr.canonicalized) = ordinal + } + } + + holder.pushedAggregate = Some(translatedAgg) + holder.output = newOutput + logInfo( + s""" + |Pushing operators to ${holder.relation.name} + |Pushed Aggregate Functions: + | ${translatedAgg.aggregateExpressions().mkString(", ")} + |Pushed Group by: + | ${translatedAgg.groupByExpressions.mkString(", ")} + """.stripMargin) + + if (canCompletePushDown) { + val projectExpressions = finalResultExprs.map { expr => + expr.transformDown { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + Alias(aggOutput(ordinal), agg.resultAttribute.name)(agg.resultAttribute.exprId) + case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => + val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) + expr match { + case ne: NamedExpression => Alias(groupOutput(ordinal), ne.name)(ne.exprId) + case _ => groupOutput(ordinal) + } + } + }.asInstanceOf[Seq[NamedExpression]] + Project(projectExpressions, holder) + } else { + // scalastyle:off + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9, c2#10] ... + // + // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] + // we have the following + // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // scalastyle:on + val aggExprs = finalResultExprs.map(_.transform { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + val aggAttribute = aggOutput(ordinal) + val aggFunction: aggregate.AggregateFunction = + agg.aggregateFunction match { + case max: aggregate.Max => + max.copy(child = aggAttribute) + case min: aggregate.Min => + min.copy(child = aggAttribute) + case sum: aggregate.Sum => + // To keep the dataType of `Sum` unchanged, we need to cast the + // data-source-aggregated result to `Sum.child.dataType` if it's decimal. + // See `SumBase.resultType` + val newChild = if (sum.dataType.isInstanceOf[DecimalType]) { + addCastIfNeeded(aggAttribute, sum.child.dataType) } else { - val plan = Aggregate(output.take(groupingExpressions.length), - finalResultExpressions, scanRelation) - - // scalastyle:off - // Change the optimized logical plan to reflect the pushed down aggregate - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t GROUP BY c2; - // The original logical plan is - // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c1#9, c2#10] ... - // - // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] - // we have the following - // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... - // - // We want to change it to - // == Optimized Logical Plan == - // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... - // scalastyle:on - plan.transformExpressions { - case agg: AggregateExpression => - val ordinal = aggExprToOutputOrdinal(agg.canonicalized) - val aggAttribute = aggOutput(ordinal) - val aggFunction: aggregate.AggregateFunction = - agg.aggregateFunction match { - case max: aggregate.Max => - max.copy(child = addCastIfNeeded(aggAttribute, max.child.dataType)) - case min: aggregate.Min => - min.copy(child = addCastIfNeeded(aggAttribute, min.child.dataType)) - case sum: aggregate.Sum => - sum.copy(child = addCastIfNeeded(aggAttribute, sum.child.dataType)) - case _: aggregate.Count => - aggregate.Sum(addCastIfNeeded(aggAttribute, LongType)) - case other => other - } - agg.copy(aggregateFunction = aggFunction) - case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => - val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) - addCastIfNeeded(groupAttrs(ordinal), expr.dataType) - } + aggAttribute } - } + sum.copy(child = newChild) + case _: aggregate.Count => + aggregate.Sum(aggAttribute) + case other => other } - case _ => aggNode - } - case _ => aggNode + agg.copy(aggregateFunction = aggFunction) + case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => + val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) + expr match { + case ne: NamedExpression => Alias(groupOutput(ordinal), ne.name)(ne.exprId) + case _ => groupOutput(ordinal) + } + }).asInstanceOf[Seq[NamedExpression]] + Aggregate(groupOutput, aggExprs, holder) } + + case _ => agg } - private def collectAggregates(resultExpressions: Seq[NamedExpression], + private def collectAggregates( + resultExpressions: Seq[NamedExpression], aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = { var ordinal = 0 resultExpressions.flatMap { expr => @@ -292,14 +305,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } private def supportPartialAggPushDown(agg: Aggregation): Boolean = { - // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. - // If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down. - agg.aggregateExpressions().isEmpty || agg.aggregateExpressions().exists { + // We can only partially push down min/max/sum/count without DISTINCT. + agg.aggregateExpressions().isEmpty || agg.aggregateExpressions().forall { case sum: Sum => !sum.isDistinct case count: Count => !count.isDistinct - case avg: Avg => !avg.isDistinct - case _: GeneralAggregateFunc => false - case _ => true + case _: Min | _: Max | _: CountStar => true + case _ => false } } @@ -310,6 +321,26 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit Cast(expression, expectedDataType) } + def buildScanWithPushedAggregate(plan: LogicalPlan): LogicalPlan = plan.transform { + case holder: ScanBuilderHolder if holder.pushedAggregate.isDefined => + // No need to do column pruning because only the aggregate columns are used as + // DataSourceV2ScanRelation output columns. All the other columns are not + // included in the output. + val scan = holder.builder.build() + val realOutput = scan.readSchema().toAttributes + assert(realOutput.length == holder.output.length, + "The data source returns unexpected number of columns") + val wrappedScan = getWrappedScan(scan, holder) + val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput) + val projectList = realOutput.zip(holder.output).map { case (a1, a2) => + // The data source may return columns with arbitrary data types and it's safer to cast them + // to the expected data type. + assert(Cast.canCast(a1.dataType, a2.dataType)) + Alias(addCastIfNeeded(a1, a2.dataType), a2.name)(a2.exprId) + } + Project(projectList, scanRelation) + } + def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => // column pruning @@ -324,7 +355,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit |Output: ${output.mkString(", ")} """.stripMargin) - val wrappedScan = getWrappedScan(scan, sHolder, Option.empty[Aggregation]) + val wrappedScan = getWrappedScan(scan, sHolder) val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) @@ -376,8 +407,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } (operation, isPushed && !isPartiallyPushed) case s @ Sort(order, _, operation @ ScanOperation(project, filter, sHolder: ScanBuilderHolder)) - if filter.isEmpty && CollapseProject.canCollapseExpressions( - order, project, alwaysInline = true) => + // Without building the Scan, we do not know the resulting column names after aggregate + // push-down, and thus can't push down Top-N which needs to know the ordering column names. + // TODO: we can support simple cases like GROUP BY columns directly and ORDER BY the same + // columns, which we know the resulting column names: the original table columns. + if sHolder.pushedAggregate.isEmpty && filter.isEmpty && + CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) => val aliasMap = getAliasMap(project) val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]] val normalizedOrders = DataSourceStrategy.normalizeExprs( @@ -478,10 +513,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } } - private def getWrappedScan( - scan: Scan, - sHolder: ScanBuilderHolder, - aggregation: Option[Aggregation]): Scan = { + private def getWrappedScan(scan: Scan, sHolder: ScanBuilderHolder): Scan = { scan match { case v1: V1Scan => val pushedFilters = sHolder.builder match { @@ -489,7 +521,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit f.pushedFilters() case _ => Array.empty[sources.Filter] } - val pushedDownOperators = PushedDownOperators(aggregation, sHolder.pushedSample, + val pushedDownOperators = PushedDownOperators(sHolder.pushedAggregate, sHolder.pushedSample, sHolder.pushedLimit, sHolder.pushedOffset, sHolder.sortOrders, sHolder.pushedPredicates) V1ScanWrapper(v1, pushedFilters, pushedDownOperators) case _ => scan @@ -498,7 +530,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } case class ScanBuilderHolder( - output: Seq[AttributeReference], + var output: Seq[AttributeReference], relation: DataSourceV2Relation, builder: ScanBuilder) extends LeafNode { var pushedLimit: Option[Int] = None @@ -510,6 +542,8 @@ case class ScanBuilderHolder( var pushedSample: Option[TableSampleInfo] = None var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate] + + var pushedAggregate: Option[Aggregation] = None } // A wrapper for v1 scan to carry the translated filters and the handled ones, along with diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 9bf25aa0d633f..32a0b4639a83e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.jdbc import java.sql.{SQLException, Types} import java.util.Locale +import scala.util.control.NonFatal + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.types._ private object DB2Dialect extends JdbcDialect { @@ -30,35 +32,47 @@ private object DB2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2") + private val distinctUnsupportedAggregateFunctions = + Set("COVAR_POP", "COVAR_SAMP", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY") + // See https://www.ibm.com/docs/en/db2/11.5?topic=functions-aggregate - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VARIANCE($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VARIANCE_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVARIANCE(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVARIANCE_SAMP(${f.children().head}, ${f.children().last})") - case _ => None + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) + + class DB2SQLBuilder extends JDBCSQLBuilder { + override def visitAggregateFunction( + funcName: String, isDistinct: Boolean, inputs: Array[String]): String = + if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { + throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " + + s"support aggregate function: $funcName with DISTINCT"); + } else { + super.visitAggregateFunction(funcName, isDistinct, inputs) } - ) + + override def dialectFunctionName(funcName: String): String = funcName match { + case "VAR_POP" => "VARIANCE" + case "VAR_SAMP" => "VARIANCE_SAMP" + case "STDDEV_POP" => "STDDEV" + case "STDDEV_SAMP" => "STDDEV_SAMP" + case "COVAR_POP" => "COVARIANCE" + case "COVAR_SAMP" => "COVARIANCE_SAMP" + case _ => super.dialectFunctionName(funcName) + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val db2SQLBuilder = new DB2SQLBuilder() + try { + Some(db2SQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } } override def getCatalystType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 36c3c6be4a05c..439e0697d9f3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.jdbc import java.sql.Types import java.util.Locale -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -31,25 +30,12 @@ private object DerbyDialect extends JdbcDialect { url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby") // See https://db.apache.org/derby/docs/10.15/ref/index.html - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_SAMP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_SAMP(${f.children().head})") - case _ => None - } - ) - } + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 945c25cad56b7..5a909b704e24c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.expressions.Expression -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, ShortType, StringType} @@ -36,7 +35,13 @@ private[sql] object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") - private val supportedFunctions = + private val distinctUnsupportedAggregateFunctions = + Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY") + + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions + + private val supportedFunctions = supportedAggregateFunctions ++ Set("ABS", "COALESCE", "GREATEST", "LEAST", "RAND", "LOG", "LOG10", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL", "ROUND", "SIN", "SINH", "COS", "COSH", "TAN", "TANH", "COT", "ASIN", "ACOS", "ATAN", "ATAN2", "DEGREES", "RADIANS", "SIGN", @@ -45,42 +50,6 @@ private[sql] object H2Dialect extends JdbcDialect { override def isSupportedFunction(funcName: String): Boolean = supportedFunctions.contains(funcName) - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_POP($distinct${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_SAMP($distinct${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"CORR($distinct${f.children().head}, ${f.children().last})") - case _ => None - } - ) - } - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Option(JdbcType("CLOB", Types.CLOB)) case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) @@ -125,9 +94,9 @@ private[sql] object H2Dialect extends JdbcDialect { } override def compileExpression(expr: Expression): Option[String] = { - val jdbcSQLBuilder = new H2JDBCSQLBuilder() + val h2SQLBuilder = new H2SQLBuilder() try { - Some(jdbcSQLBuilder.build(expr)) + Some(h2SQLBuilder.build(expr)) } catch { case NonFatal(e) => logWarning("Error occurs while compiling V2 expression", e) @@ -135,7 +104,15 @@ private[sql] object H2Dialect extends JdbcDialect { } } - class H2JDBCSQLBuilder extends JDBCSQLBuilder { + class H2SQLBuilder extends JDBCSQLBuilder { + override def visitAggregateFunction( + funcName: String, isDistinct: Boolean, inputs: Array[String]): String = + if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { + throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " + + s"support aggregate function: $funcName with DISTINCT"); + } else { + super.visitAggregateFunction(funcName, isDistinct, inputs) + } override def visitExtract(field: String, source: String): String = { val newField = field match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index d42d4e8fc0bac..3af0e0de89444 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcUtils} @@ -244,7 +244,7 @@ abstract class JdbcDialect extends Serializable with Logging{ override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { if (isSupportedFunction(funcName)) { - s"""$funcName(${inputs.mkString(", ")})""" + s"""${dialectFunctionName(funcName)}(${inputs.mkString(", ")})""" } else { // The framework will catch the error and give up the push-down. // Please see `JdbcDialect.compileExpression(expr: Expression)` for more details. @@ -253,6 +253,18 @@ abstract class JdbcDialect extends Serializable with Logging{ } } + override def visitAggregateFunction( + funcName: String, isDistinct: Boolean, inputs: Array[String]): String = { + if (isSupportedFunction(funcName)) { + super.visitAggregateFunction(dialectFunctionName(funcName), isDistinct, inputs) + } else { + throw new UnsupportedOperationException( + s"${this.getClass.getSimpleName} does not support aggregate function: $funcName"); + } + } + + protected def dialectFunctionName(funcName: String): String = funcName + override def visitOverlay(inputs: Array[String]): String = { if (isSupportedFunction("OVERLAY")) { super.visitOverlay(inputs) @@ -303,26 +315,8 @@ abstract class JdbcDialect extends Serializable with Logging{ * @return Converted value. */ @Since("3.3.0") - def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - aggFunction match { - case min: Min => - compileExpression(min.column).map(v => s"MIN($v)") - case max: Max => - compileExpression(max.column).map(v => s"MAX($v)") - case count: Count => - val distinct = if (count.isDistinct) "DISTINCT " else "" - compileExpression(count.column).map(v => s"COUNT($distinct$v)") - case sum: Sum => - val distinct = if (sum.isDistinct) "DISTINCT " else "" - compileExpression(sum.column).map(v => s"SUM($distinct$v)") - case _: CountStar => - Some("COUNT(*)") - case avg: Avg => - val distinct = if (avg.isDistinct) "DISTINCT " else "" - compileExpression(avg.column).map(v => s"AVG($distinct$v)") - case _ => None - } - } + @deprecated("use org.apache.spark.sql.jdbc.JdbcDialect.compileExpression instead.", "3.4.0") + def compileAggregate(aggFunction: AggregateFunc): Option[String] = compileExpression(aggFunction) /** * List the user-defined functions in jdbc dialect. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 8d2fbec55f919..af35a0575ac8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.jdbc import java.sql.SQLException import java.util.Locale +import scala.util.control.NonFatal + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -43,28 +45,32 @@ private object MsSqlServerDialect extends JdbcDialect { // scalastyle:off line.size.limit // See https://docs.microsoft.com/en-us/sql/t-sql/functions/aggregate-functions-transact-sql?view=sql-server-ver15 // scalastyle:on line.size.limit - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VARP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDEVP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDEV($distinct${f.children().head})") - case _ => None - } - ) + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) + + class MsSqlServerSQLBuilder extends JDBCSQLBuilder { + override def dialectFunctionName(funcName: String): String = funcName match { + case "VAR_POP" => "VARP" + case "VAR_SAMP" => "VAR" + case "STDDEV_POP" => "STDEVP" + case "STDDEV_SAMP" => "STDEV" + case _ => super.dialectFunctionName(funcName) + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val msSqlServerSQLBuilder = new MsSqlServerSQLBuilder() + try { + Some(msSqlServerSQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } } override def getCatalystType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index 24f9bac74f86d..d850b61383520 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -22,13 +22,13 @@ import java.util import java.util.Locale import scala.collection.mutable.ArrayBuilder +import scala.util.control.NonFatal import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} import org.apache.spark.sql.connector.catalog.index.TableIndex -import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, MetadataBuilder} @@ -38,25 +38,37 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url : String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql") + private val distinctUnsupportedAggregateFunctions = + Set("VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") + // See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_SAMP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_SAMP(${f.children().head})") - case _ => None + private val supportedAggregateFunctions = + Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ distinctUnsupportedAggregateFunctions + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) + + class MySQLSQLBuilder extends JDBCSQLBuilder { + override def visitAggregateFunction( + funcName: String, isDistinct: Boolean, inputs: Array[String]): String = + if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { + throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " + + s"support aggregate function: $funcName with DISTINCT"); + } else { + super.visitAggregateFunction(funcName, isDistinct, inputs) } - ) + } + + override def compileExpression(expr: Expression): Option[String] = { + val mysqlSQLBuilder = new MySQLSQLBuilder() + try { + Some(mysqlSQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } } override def getCatalystType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 40333c1757c4a..79ac248d723e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.jdbc import java.sql.{Date, Timestamp, Types} import java.util.{Locale, TimeZone} +import scala.util.control.NonFatal + import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -34,36 +36,40 @@ private case object OracleDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle") + private val distinctUnsupportedAggregateFunctions = + Set("VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP", "CORR", + "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY") + // scalastyle:off line.size.limit // https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/Aggregate-Functions.html#GUID-62BE676B-AF18-4E63-BD14-25206FEA0848 // scalastyle:on line.size.limit - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"VAR_SAMP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_POP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => - assert(f.children().length == 1) - Some(s"STDDEV_SAMP(${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVAR_POP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"CORR(${f.children().head}, ${f.children().last})") - case _ => None + private val supportedAggregateFunctions = + Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ distinctUnsupportedAggregateFunctions + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) + + class OracleSQLBuilder extends JDBCSQLBuilder { + override def visitAggregateFunction( + funcName: String, isDistinct: Boolean, inputs: Array[String]): String = + if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) { + throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " + + s"support aggregate function: $funcName with DISTINCT"); + } else { + super.visitAggregateFunction(funcName, isDistinct, inputs) } - ) + } + + override def compileExpression(expr: Expression): Option[String] = { + val oracleSQLBuilder = new OracleSQLBuilder() + try { + Some(oracleSQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } } private def supportTimeZoneTypes: Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index a668d66ee2f9a..6800b5f298c1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NonEmptyNamespaceException, NoSuchIndexException} import org.apache.spark.sql.connector.expressions.NamedReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ @@ -37,41 +36,13 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql") // See https://www.postgresql.org/docs/8.4/functions-aggregate.html - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_POP($distinct${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_SAMP($distinct${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" => - assert(f.children().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"CORR($distinct${f.children().head}, ${f.children().last})") - case _ => None - } - ) - } + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP", "CORR", + "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 79fb710cf03b3..fba5316c478a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.jdbc import java.util.Locale -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.types._ @@ -31,38 +30,12 @@ private case object TeradataDialect extends JdbcDialect { // scalastyle:off line.size.limit // See https://docs.teradata.com/r/Teradata-VantageTM-SQL-Functions-Expressions-and-Predicates/March-2019/Aggregate-Functions // scalastyle:on line.size.limit - override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { - super.compileAggregate(aggFunction).orElse( - aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.children().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.children().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVAR_POP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => - assert(f.children().length == 2) - Some(s"CORR(${f.children().head}, ${f.children().last})") - case _ => None - } - ) - } + private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", + "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP", "CORR") + private val supportedFunctions = supportedAggregateFunctions + + override def isSupportedFunction(funcName: String): Boolean = + supportedFunctions.contains(funcName) override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index 63e43a3f46cd7..7a2f2700c582d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.unsafe.types.UTF8String object IntAverage extends AggregateFunction[(Int, Int), Int] { override def name(): String = "iavg" + override def canonicalName(): String = "h2.iavg" override def inputTypes(): Array[DataType] = Array(IntegerType) override def resultType(): DataType = IntegerType @@ -63,6 +64,7 @@ object IntAverage extends AggregateFunction[(Int, Int), Int] { object LongAverage extends AggregateFunction[(Long, Long), Long] { override def name(): String = "iavg" + override def canonicalName(): String = "h2.iavg" override def inputTypes(): Array[DataType] = Array(LongType) override def resultType(): DataType = LongType @@ -111,6 +113,24 @@ object IntegralAverage extends UnboundFunction { | iavg(bigint) -> bigint""".stripMargin } +case class StrLen(impl: BoundFunction) extends UnboundFunction { + override def name(): String = "strlen" + + override def bind(inputType: StructType): BoundFunction = { + if (inputType.fields.length != 1) { + throw new UnsupportedOperationException("Expect exactly one argument"); + } + inputType.fields(0).dataType match { + case StringType => impl + case _ => + throw new UnsupportedOperationException("Expect StringType") + } + } + + override def description(): String = + "strlen: returns the length of the input string strlen(string) -> int" +} + class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { private val emptyProps: java.util.Map[String, String] = Collections.emptyMap[String, String] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 4c2f67166b9bb..be2c8ce057559 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -20,16 +20,22 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, DriverManager} import java.util.Properties +import scala.util.control.NonFatal + import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.sql.{AnalysisException, DataFrame, ExplainSuiteHelper, QueryTest, Row} +import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Offset, Sort} -import org.apache.spark.sql.connector.IntegralAverage +import org.apache.spark.sql.connector.{IntegralAverage, StrLen} +import org.apache.spark.sql.connector.catalog.functions.{ScalarFunction, UnboundFunction} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{abs, acos, asin, atan, atan2, avg, ceil, coalesce, cos, cosh, count, count_distinct, degrees, exp, floor, lit, log => logarithm, log10, not, pow, radians, round, signum, sin, sinh, sqrt, sum, tan, tanh, udf, when} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} import org.apache.spark.util.Utils class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHelper { @@ -39,6 +45,63 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass" var conn: java.sql.Connection = null + val testH2Dialect = new JdbcDialect { + override def canHandle(url: String): Boolean = H2Dialect.canHandle(url) + + class H2SQLBuilder extends JDBCSQLBuilder { + override def visitUserDefinedScalarFunction( + funcName: String, canonicalName: String, inputs: Array[String]): String = { + canonicalName match { + case "h2.char_length" => + s"$funcName(${inputs.mkString(", ")})" + case _ => super.visitUserDefinedScalarFunction(funcName, canonicalName, inputs) + } + } + + override def visitUserDefinedAggregateFunction( + funcName: String, + canonicalName: String, + isDistinct: Boolean, + inputs: Array[String]): String = { + canonicalName match { + case "h2.iavg" => + if (isDistinct) { + s"AVG(DISTINCT ${inputs.mkString(", ")})" + } else { + s"AVG(${inputs.mkString(", ")})" + } + case _ => + super.visitUserDefinedAggregateFunction(funcName, canonicalName, isDistinct, inputs) + } + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val h2SQLBuilder = new H2SQLBuilder() + try { + Some(h2SQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } + } + + override def functions: Seq[(String, UnboundFunction)] = H2Dialect.functions + } + + case object CharLength extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "CHAR_LENGTH" + override def canonicalName(): String = "h2.char_length" + + override def produceResult(input: InternalRow): Int = { + val s = input.getString(0) + s.length + } + } + override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.h2.url", url) @@ -116,6 +179,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate() } H2Dialect.registerFunction("my_avg", IntegralAverage) + H2Dialect.registerFunction("my_strlen", StrLen(CharLength)) } override def afterAll(): Unit = { @@ -200,9 +264,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .limit(1) - checkLimitRemoved(df4, false) + checkAggregateRemoved(df4) + checkLimitRemoved(df4) checkPushedInfo(df4, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedLimit: LIMIT 1") checkAnswer(df4, Seq(Row(1, 19000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -275,9 +343,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .offset(1) - checkOffsetRemoved(df5, false) + checkAggregateRemoved(df5) + checkLimitRemoved(df5) checkPushedInfo(df5, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1") checkAnswer(df5, Seq(Row(2, 22000.00), Row(6, 12000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -412,10 +484,15 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .groupBy("DEPT").sum("SALARY") .limit(2) .offset(1) - checkLimitRemoved(df10, false) - checkOffsetRemoved(df10, false) + checkAggregateRemoved(df10) + checkLimitRemoved(df10) + checkOffsetRemoved(df10) checkPushedInfo(df10, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedLimit: LIMIT 2", + "PushedOffset: OFFSET 1") checkAnswer(df10, Seq(Row(2, 22000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -547,10 +624,15 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df9, Seq(Row(2, "david", 10000.00, 1300.0, true))) val df10 = sql("SELECT dept, sum(salary) FROM h2.test.employee group by dept LIMIT 1 OFFSET 1") - checkLimitRemoved(df10, false) - checkOffsetRemoved(df10, false) + checkAggregateRemoved(df10) + checkLimitRemoved(df10) + checkOffsetRemoved(df10) checkPushedInfo(df10, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedLimit: LIMIT 2", + "PushedOffset: OFFSET 1") checkAnswer(df10, Seq(Row(2, 22000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -791,11 +873,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df11 = sql( """ |SELECT * FROM h2.test.employee - |WHERE GREATEST(bonus, 1100) > 1200 AND LEAST(salary, 10000) > 9000 AND RAND(1) < 1 + |WHERE GREATEST(bonus, 1100) > 1200 AND RAND(1) < bonus |""".stripMargin) checkFiltersRemoved(df11) checkPushedInfo(df11, "PushedFilters: " + - "[(GREATEST(BONUS, 1100.0)) > 1200.0, (LEAST(SALARY, 10000.00)) > 9000.00, RAND(1) < 1.0]") + "[BONUS IS NOT NULL, (GREATEST(BONUS, 1100.0)) > 1200.0, RAND(1) < BONUS]") checkAnswer(df11, Row(2, "david", 10000, 1300, true)) val df12 = sql( @@ -1051,6 +1133,33 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df8, Seq(Row("alex"))) } + test("scan with filter push-down with UDF") { + JdbcDialects.unregisterDialect(H2Dialect) + try { + JdbcDialects.registerDialect(testH2Dialect) + val df1 = sql("SELECT * FROM h2.test.people where h2.my_strlen(name) > 2") + checkFiltersRemoved(df1) + checkPushedInfo(df1, "PushedFilters: [CHAR_LENGTH(NAME) > 2],") + checkAnswer(df1, Seq(Row("fred", 1), Row("mary", 2))) + + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val df2 = sql( + """ + |SELECT * + |FROM h2.test.people + |WHERE h2.my_strlen(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2 + """.stripMargin) + checkFiltersRemoved(df2) + checkPushedInfo(df2, + "PushedFilters: [CHAR_LENGTH(CASE WHEN NAME = 'fred' THEN NAME ELSE 'abc' END) > 2],") + checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) + } + } finally { + JdbcDialects.unregisterDialect(testH2Dialect) + JdbcDialects.registerDialect(H2Dialect) + } + } + test("scan with column pruning") { val df = spark.table("h2.test.people").select("id") checkSchemaNames(df, Seq("ID")) @@ -1620,43 +1729,81 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } test("scan with aggregate push-down: VAR_POP VAR_SAMP with filter and group by") { - val df = sql("SELECT VAR_POP(bonus), VAR_SAMP(bonus) FROM h2.test.employee WHERE dept > 0" + - " GROUP BY DePt") + val df = sql( + """ + |SELECT + | VAR_POP(bonus), + | VAR_POP(DISTINCT bonus), + | VAR_SAMP(bonus), + | VAR_SAMP(DISTINCT bonus) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) checkFiltersRemoved(df) checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") - checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + checkPushedInfo(df, + """ + |PushedAggregates: [VAR_POP(BONUS), VAR_POP(DISTINCT BONUS), + |VAR_SAMP(BONUS), VAR_SAMP(DISTINCT BONUS)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df, Seq(Row(10000d, 10000d, 20000d, 20000d), + Row(2500d, 2500d, 5000d, 5000d), Row(0d, 0d, null, null))) } test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP with filter and group by") { - val df = sql("SELECT STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" + - " WHERE dept > 0 GROUP BY DePt") + val df = sql( + """ + |SELECT + | STDDEV_POP(bonus), + | STDDEV_POP(DISTINCT bonus), + | STDDEV_SAMP(bonus), + | STDDEV_SAMP(DISTINCT bonus) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) checkFiltersRemoved(df) checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") - checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null))) + checkPushedInfo(df, + """ + |PushedAggregates: [STDDEV_POP(BONUS), STDDEV_POP(DISTINCT BONUS), + |STDDEV_SAMP(BONUS), STDDEV_SAMP(DISTINCT BONUS)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df, Seq(Row(100d, 100d, 141.4213562373095d, 141.4213562373095d), + Row(50d, 50d, 70.71067811865476d, 70.71067811865476d), Row(0d, 0d, null, null))) } test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") { - val df = sql("SELECT COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" + + val df1 = sql("SELECT COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" + " FROM h2.test.employee WHERE dept > 0 GROUP BY DePt") - checkFiltersRemoved(df) - checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + + checkFiltersRemoved(df1) + checkAggregateRemoved(df1) + checkPushedInfo(df1, "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") - checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + checkAnswer(df1, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + + val df2 = sql("SELECT COVAR_POP(DISTINCT bonus, bonus), COVAR_SAMP(DISTINCT bonus, bonus)" + + " FROM h2.test.employee WHERE dept > 0 GROUP BY DePt") + checkFiltersRemoved(df2) + checkAggregateRemoved(df2, false) + checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]") + checkAnswer(df2, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) } test("scan with aggregate push-down: CORR with filter and group by") { - val df = sql("SELECT CORR(bonus, bonus) FROM h2.test.employee WHERE dept > 0" + + val df1 = sql("SELECT CORR(bonus, bonus) FROM h2.test.employee WHERE dept > 0" + " GROUP BY DePt") - checkFiltersRemoved(df) - checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [CORR(BONUS, BONUS)], " + + checkFiltersRemoved(df1) + checkAggregateRemoved(df1) + checkPushedInfo(df1, "PushedAggregates: [CORR(BONUS, BONUS)], " + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") - checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) + checkAnswer(df1, Seq(Row(1d), Row(1d), Row(null))) + + val df2 = sql("SELECT CORR(DISTINCT bonus, bonus) FROM h2.test.employee WHERE dept > 0" + + " GROUP BY DePt") + checkFiltersRemoved(df2) + checkAggregateRemoved(df2, false) + checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT > 0]") + checkAnswer(df2, Seq(Row(1d), Row(1d), Row(null))) } test("scan with aggregate push-down: aggregate over alias push down") { @@ -1984,17 +2131,51 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } - test("register dialect specific functions") { - val df = sql("SELECT h2.my_avg(id) FROM h2.test.people") - checkAggregateRemoved(df, false) - checkAnswer(df, Row(1) :: Nil) - val e1 = intercept[AnalysisException] { - checkAnswer(sql("SELECT h2.test.my_avg2(id) FROM h2.test.people"), Seq.empty) - } - assert(e1.getMessage.contains("Undefined function: 'my_avg2'")) - val e2 = intercept[AnalysisException] { - checkAnswer(sql("SELECT h2.my_avg2(id) FROM h2.test.people"), Seq.empty) + test("scan with aggregate push-down: complete push-down UDAF") { + JdbcDialects.unregisterDialect(H2Dialect) + try { + JdbcDialects.registerDialect(testH2Dialect) + val df1 = sql("SELECT h2.my_avg(id) FROM h2.test.people") + checkAggregateRemoved(df1) + checkPushedInfo(df1, + "PushedAggregates: [iavg(ID)], PushedFilters: [], PushedGroupByExpressions: []") + checkAnswer(df1, Seq(Row(1))) + + val df2 = sql("SELECT name, h2.my_avg(id) FROM h2.test.people group by name") + checkAggregateRemoved(df2) + checkPushedInfo(df2, + "PushedAggregates: [iavg(ID)], PushedFilters: [], PushedGroupByExpressions: [NAME]") + checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val df3 = sql( + """ + |SELECT + | h2.my_avg(CASE WHEN NAME = 'fred' THEN id + 1 ELSE id END) + |FROM h2.test.people + """.stripMargin) + checkAggregateRemoved(df3) + checkPushedInfo(df3, + "PushedAggregates: [iavg(CASE WHEN NAME = 'fred' THEN ID + 1 ELSE ID END)]," + + " PushedFilters: [], PushedGroupByExpressions: []") + checkAnswer(df3, Seq(Row(2))) + + val df4 = sql( + """ + |SELECT + | name, + | h2.my_avg(CASE WHEN NAME = 'fred' THEN id + 1 ELSE id END) + |FROM h2.test.people + |GROUP BY name + """.stripMargin) + checkAggregateRemoved(df4) + checkPushedInfo(df4, + "PushedAggregates: [iavg(CASE WHEN NAME = 'fred' THEN ID + 1 ELSE ID END)]," + + " PushedFilters: [], PushedGroupByExpressions: [NAME]") + checkAnswer(df4, Seq(Row("fred", 2), Row("mary", 2))) + } + } finally { + JdbcDialects.unregisterDialect(testH2Dialect) + JdbcDialects.registerDialect(H2Dialect) } - assert(e2.getMessage.contains("Undefined function: my_avg2")) } }