From 263994ad2832e3be79fab8fa91f34c3c89b8e7e5 Mon Sep 17 00:00:00 2001 From: "praveenkrishna.d" Date: Fri, 17 May 2024 18:31:47 +0530 Subject: [PATCH 1/3] Use QueryBuilder injected by Guice in H2 connector --- .../io/trino/plugin/jdbc/TestJdbcConnectionCreation.java | 4 ++-- .../java/io/trino/plugin/jdbc/TestJdbcTableProperties.java | 2 +- .../test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java | 6 +++--- .../test/java/io/trino/plugin/jdbc/TestingH2JdbcModule.java | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcConnectionCreation.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcConnectionCreation.java index cb4c090352743..9629bf8a9a44f 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcConnectionCreation.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcConnectionCreation.java @@ -93,9 +93,9 @@ public void configure(Binder binder) {} @Provides @Singleton @ForBaseJdbc - public static JdbcClient provideJdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping) + public static JdbcClient provideJdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping) { - return new TestingH2JdbcClient(config, connectionFactory, identifierMapping); + return new TestingH2JdbcClient(config, connectionFactory, queryBuilder, identifierMapping); } @Provides diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcTableProperties.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcTableProperties.java index 42283f8cf5025..12d20d1daafd4 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcTableProperties.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcTableProperties.java @@ -41,7 +41,7 @@ public class TestJdbcTableProperties protected QueryRunner createQueryRunner() throws Exception { - TestingH2JdbcModule module = new TestingH2JdbcModule((config, connectionFactory, identifierMapping) -> new TestingH2JdbcClient(config, connectionFactory, identifierMapping) + TestingH2JdbcModule module = new TestingH2JdbcModule((config, connectionFactory, queryBuilder, identifierMapping) -> new TestingH2JdbcClient(config, connectionFactory, queryBuilder, identifierMapping) { @Override public Map getTableProperties(ConnectorSession session, JdbcTableHandle tableHandle) diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java index 6749eea8c367a..3fd55790ffd79 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java @@ -94,12 +94,12 @@ class TestingH2JdbcClient public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory) { - this(config, connectionFactory, new DefaultIdentifierMapping()); + this(config, connectionFactory, new DefaultQueryBuilder(RemoteQueryModifier.NONE), new DefaultIdentifierMapping()); } - public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping) + public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping) { - super("\"", connectionFactory, new DefaultQueryBuilder(RemoteQueryModifier.NONE), config.getJdbcTypesMappedToVarchar(), identifierMapping, RemoteQueryModifier.NONE, false); + super("\"", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, RemoteQueryModifier.NONE, false); } @Override diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcModule.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcModule.java index 246c20cc5886b..babaae7e733ae 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcModule.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcModule.java @@ -56,9 +56,9 @@ public void configure(Binder binder) @Provides @Singleton @ForBaseJdbc - public JdbcClient provideJdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping) + public JdbcClient provideJdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping) { - return testingH2JdbcClientFactory.create(config, connectionFactory, identifierMapping); + return testingH2JdbcClientFactory.create(config, connectionFactory, queryBuilder, identifierMapping); } @Provides @@ -83,6 +83,6 @@ public static String createH2ConnectionUrl() public interface TestingH2JdbcClientFactory { - TestingH2JdbcClient create(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping); + TestingH2JdbcClient create(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping); } } From 17ab0f5aeda2a63670416f6241572062dc36f4f4 Mon Sep 17 00:00:00 2001 From: "praveenkrishna.d" Date: Fri, 17 May 2024 20:38:06 +0530 Subject: [PATCH 2/3] Remove redundant usage of RewriteComparison rule For equals and not equals rewrite we capture them as GenericRewrite expression and usage of RewriteComparison rule is redundant here. --- .../src/main/java/io/trino/plugin/ignite/IgniteClient.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java index 417dae74d08b4..0ead923826d5c 100644 --- a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java +++ b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java @@ -45,10 +45,8 @@ import io.trino.plugin.jdbc.aggregation.ImplementCountDistinct; import io.trino.plugin.jdbc.aggregation.ImplementMinMax; import io.trino.plugin.jdbc.aggregation.ImplementSum; -import io.trino.plugin.jdbc.expression.ComparisonOperator; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.ParameterizedExpression; -import io.trino.plugin.jdbc.expression.RewriteComparison; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; @@ -164,7 +162,6 @@ public IgniteClient( JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() - .add(new RewriteComparison(ImmutableSet.of(ComparisonOperator.EQUAL, ComparisonOperator.NOT_EQUAL))) .addStandardRules(this::quoted) .map("$equal(left, right)").to("left = right") .map("$not_equal(left, right)").to("left <> right") From 1caffaff4cb2b0445b2b811f6859c24352f9cd47 Mon Sep 17 00:00:00 2001 From: "praveenkrishna.d" Date: Fri, 10 May 2024 20:42:46 +0530 Subject: [PATCH 3/3] Handle NaN when pushing filter for JDBC connector Some JDBC based sources like PotgreSql, Oracle, Ignite etc treats NaN values as equal, and greater than all non-NaN values. This is different from Trino's behaviour where NaN values are not equal, and they are not greater than or lesser than non-NaN values which results in in-correct results when they are pushed down to the underlying datasource. So we push down additional condition to ensure that NaN values are not considered in comparison operations. --- .../plugin/jdbc/NaNSpecificQueryBuilder.java | 47 +++++++ .../plugin/jdbc/BaseJdbcConnectorTest.java | 129 ++++++++++++++++++ .../plugin/jdbc/TestingH2JdbcClient.java | 2 +- .../plugin/jdbc/TestingH2JdbcModule.java | 2 + .../TestClickHouseConnectorTest.java | 98 +++++++++++++ .../plugin/druid/TestDruidConnectorTest.java | 7 + .../io/trino/plugin/ignite/IgniteClient.java | 22 ++- .../plugin/ignite/IgniteClientModule.java | 3 + .../ignite/TestIgniteConnectorTest.java | 2 + .../mariadb/BaseMariaDbConnectorTest.java | 7 + .../plugin/mysql/BaseMySqlConnectorTest.java | 7 + .../io/trino/plugin/oracle/OracleClient.java | 21 ++- .../plugin/oracle/OracleClientModule.java | 3 + .../oracle/BaseOracleConnectorTest.java | 82 +++++++++++ .../phoenix5/TestPhoenixConnectorTest.java | 7 + .../CollationAwareQueryBuilder.java | 4 +- .../plugin/postgresql/PostgreSqlClient.java | 22 ++- .../sqlserver/BaseSqlServerConnectorTest.java | 9 ++ .../testing/TestingConnectorBehavior.java | 1 + 19 files changed, 452 insertions(+), 23 deletions(-) create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/NaNSpecificQueryBuilder.java diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/NaNSpecificQueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/NaNSpecificQueryBuilder.java new file mode 100644 index 0000000000000..1da2cfea35a01 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/NaNSpecificQueryBuilder.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import com.google.inject.Inject; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.type.Type; + +import java.util.Optional; +import java.util.function.Consumer; + +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.RealType.REAL; +import static java.lang.String.format; + +public class NaNSpecificQueryBuilder + extends DefaultQueryBuilder +{ + @Inject + public NaNSpecificQueryBuilder(RemoteQueryModifier queryModifier) + { + super(queryModifier); + } + + @Override + protected String toPredicate(JdbcClient client, ConnectorSession session, JdbcColumnHandle column, JdbcTypeHandle jdbcType, Type type, WriteFunction writeFunction, String operator, Object value, Consumer accumulator) + { + if ((type == REAL || type == DOUBLE) && (operator.equals(">") || operator.equals(">="))) { + accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value))); + return format("((%s %s %s) AND (%s <> 'NaN'))", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression(), client.quoted(column.getColumnName())); + } + + return super.toPredicate(client, session, column, jdbcType, type, writeFunction, operator, value, accumulator); + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 6684a6e11f21e..2d60ba77a1a18 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -98,9 +98,11 @@ import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_INEQUALITY; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_LIMIT_PUSHDOWN; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_MERGE; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_NAN_INFINITY; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_NATIVE_QUERY; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_NOT_NULL_CONSTRAINT; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_ARITHMETIC_EXPRESSION_PUSHDOWN; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN_WITH_LIKE; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_PUSHDOWN; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY; @@ -193,6 +195,127 @@ public void testCharTrailingSpace() } } + @Test + public void testSpecialValueOnApproximateNumericColumn() + { + if (!hasBehavior(SUPPORTS_NAN_INFINITY)) { + assertThatThrownBy(() -> assertUpdate("CREATE TABLE spl_values_nan AS SELECT CAST(Nan() AS REAL) real_col, Nan() AS double_col")) + .satisfies(this::verifyApproximateNumericSpecialValueFailure); + assertThatThrownBy(() -> assertUpdate("CREATE TABLE spl_values_infinity AS SELECT CAST(Infinity() AS REAL) real_col, Infinity() AS double_col")) + .satisfies(this::verifyApproximateNumericSpecialValueFailure); + assertThatThrownBy(() -> assertUpdate("CREATE TABLE spl_values_negative_infinity AS SELECT CAST(-Infinity() AS REAL) real_col, -Infinity() AS double_col")) + .satisfies(this::verifyApproximateNumericSpecialValueFailure); + return; + } + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "spl_approx_numeric", + "(c_varchar VARCHAR, c_real REAL, c_real_2 REAL, c_double DOUBLE, c_double_2 DOUBLE)", + List.of( + "'1', NaN(), REAL '1', Nan(), DOUBLE '1'", + "'2', -Infinity(), REAL '1', -Infinity(), DOUBLE '1'", + "'3', Infinity(), REAL '1', Infinity(), DOUBLE '1'", + "'4', Nan(), Nan(), NaN(), Nan()"))) { + String tableName = table.getName(); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > 1 AND c_double > 1")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < 1 AND c_double < 1")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '2'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < Infinity() AND c_double < Infinity()")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '2'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > -Infinity() AND c_double > -Infinity()")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '3'"); + + if (hasBehavior(SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN)) { + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 AND c_double > c_double_2")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real IS DISTINCT FROM c_real_2 AND c_double IS DISTINCT FROM c_double_2")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '1', '2', '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real_2 IS DISTINCT FROM c_real AND c_double_2 IS DISTINCT FROM c_double")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '1', '2', '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 OR c_double > c_double_2")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real = c_real_2 OR c_double = c_double_2")) + .isFullyPushedDown() + .returnsEmptyResult(); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real <> c_real_2 OR c_double <> c_double_2")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '1', '2', '3', '4'"); + } + else { + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 AND c_double > c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real IS DISTINCT FROM c_real_2 AND c_double IS DISTINCT FROM c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '1', '2', '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real_2 IS DISTINCT FROM c_real AND c_double_2 IS DISTINCT FROM c_double")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '1', '2', '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 OR c_double > c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real = c_real_2 OR c_double = c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .returnsEmptyResult(); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real <> c_real_2 OR c_double <> c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '1', '2', '3', '4'"); + } + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < 1 OR c_double < 1")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '2'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < Infinity() OR c_double < Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '2'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > -Infinity() OR c_double > -Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '3'"); + } + } + @Test public void testAggregationPushdown() { @@ -2204,6 +2327,12 @@ protected void assertDynamicFiltering(@Language("SQL") String sql, JoinDistribut assertDynamicFiltering(sql, joinDistributionType, true); } + protected void verifyApproximateNumericSpecialValueFailure(Throwable e) + { + throw new AssertionError("Unexpected special value", e); + } + + private void assertNoDynamicFiltering(@Language("SQL") String sql) { assertDynamicFiltering(sql, PARTITIONED, false); diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java index 3fd55790ffd79..cbf09abb58496 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java @@ -94,7 +94,7 @@ class TestingH2JdbcClient public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory) { - this(config, connectionFactory, new DefaultQueryBuilder(RemoteQueryModifier.NONE), new DefaultIdentifierMapping()); + this(config, connectionFactory, new NaNSpecificQueryBuilder(RemoteQueryModifier.NONE), new DefaultIdentifierMapping()); } public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping) diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcModule.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcModule.java index babaae7e733ae..8c6b62c44588a 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcModule.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcModule.java @@ -29,6 +29,7 @@ import java.util.concurrent.ThreadLocalRandom; import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -51,6 +52,7 @@ public TestingH2JdbcModule(TestingH2JdbcClientFactory testingH2JdbcClientFactory public void configure(Binder binder) { newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON); + newOptionalBinder(binder, QueryBuilder.class).setBinding().to(NaNSpecificQueryBuilder.class).in(Scopes.SINGLETON); } @Provides diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java index 2e9e297d3884f..4af4d105d929b 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableMap; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; import io.trino.sql.planner.plan.AggregationNode; +import io.trino.sql.planner.plan.FilterNode; import io.trino.testing.MaterializedResult; import io.trino.testing.QueryRunner; import io.trino.testing.TestingConnectorBehavior; @@ -161,6 +162,103 @@ public void testRenameColumnName() { } + @Test + @Override + // Clickhouse doesn't support push down on real columns + public void testSpecialValueOnApproximateNumericColumn() + { + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "spl_approx_numeric", + "(c_varchar VARCHAR, c_real REAL, c_real_2 REAL, c_double DOUBLE, c_double_2 DOUBLE)", + List.of( + "'1', NaN(), REAL '1', Nan(), DOUBLE '1'", + "'2', -Infinity(), REAL '1', -Infinity(), DOUBLE '1'", + "'3', Infinity(), REAL '1', Infinity(), DOUBLE '1'", + "'4', Nan(), Nan(), NaN(), Nan()"))) { + String tableName = table.getName(); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_double > 1")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > 1 AND c_double > 1")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_double < 1")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '2'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < 1 AND c_double < 1")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '2'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_double < Infinity()")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '2'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < Infinity() AND c_double < Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '2'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > -Infinity() AND c_double > -Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_double > -Infinity()")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 AND c_double > c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real IS DISTINCT FROM c_real_2 AND c_double IS DISTINCT FROM c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '1', '2', '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real_2 IS DISTINCT FROM c_real AND c_double_2 IS DISTINCT FROM c_double")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '1', '2', '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 OR c_double > c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real = c_real_2 OR c_double = c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .returnsEmptyResult(); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real <> c_real_2 OR c_double <> c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '1', '2', '3', '4'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < Infinity() OR c_double < Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '2'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > -Infinity() OR c_double > -Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '3'"); + } + } + @Override protected Optional filterColumnNameTestData(String columnName) { diff --git a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidConnectorTest.java b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidConnectorTest.java index 027745779b8b2..fc0b8a24066f1 100644 --- a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidConnectorTest.java +++ b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidConnectorTest.java @@ -210,6 +210,13 @@ public void testSelectAll() assertQuery("SELECT orderkey, custkey, orderstatus, totalprice, orderdate, orderpriority, clerk, shippriority, comment FROM orders"); } + @Test + @Override + public void testSpecialValueOnApproximateNumericColumn() + { + abort("Druid connector does not support NaN and Infinity"); + } + /** * This test verifies that the filtering we have in place to overcome Druid's limitation of * not handling the escaping of search characters like % and _, works correctly. diff --git a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java index 0ead923826d5c..719ec8bdf6d13 100644 --- a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java +++ b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java @@ -163,13 +163,23 @@ public IgniteClient( JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) - .map("$equal(left, right)").to("left = right") - .map("$not_equal(left, right)").to("left <> right") + .withTypeClass("non_approx_numeric_type", ImmutableSet.of("boolean", "tinyint", "smallint", "integer", "bigint", "decimal", "varchar", "date")) + .withTypeClass("approx_numeric_type", ImmutableSet.of("real", "double")) + .map("$equal(left: non_approx_numeric_type, right: non_approx_numeric_type)").to("left = right") + .map("$not_equal(left: non_approx_numeric_type, right: non_approx_numeric_type)").to("left <> right") + .map("$equal(left: approx_numeric_type, right: approx_numeric_type)").to("left = right AND left <> 'NaN' AND right <> 'NaN'") + .map("$not_equal(left: approx_numeric_type, right: approx_numeric_type)").to("left <> right OR (left = 'NaN' AND right = 'NaN')") + .map("$equal(left: approx_numeric_type, right: approx_numeric_type)").to("left = right") + .map("$not_equal(left: approx_numeric_type, right: approx_numeric_type)").to("left <> right OR (left = 'NaN' AND right = 'NaN')") .map("$identical(left, right)").to("left IS NOT DISTINCT FROM right") - .map("$less_than(left, right)").to("left < right") - .map("$less_than_or_equal(left, right)").to("left <= right") - .map("$greater_than(left, right)").to("left > right") - .map("$greater_than_or_equal(left, right)").to("left >= right") + .map("$less_than(left: non_approx_numeric_type, right: non_approx_numeric_type)").to("left < right") + .map("$less_than_or_equal(left: non_approx_numeric_type, right: non_approx_numeric_type)").to("left <= right") + .map("$greater_than(left: non_approx_numeric_type, right: non_approx_numeric_type)").to("left > right") + .map("$greater_than_or_equal(left: non_approx_numeric_type, right: non_approx_numeric_type)").to("left >= right") + .map("$less_than(left: approx_numeric_type, right: approx_numeric_type)").to("((left < right) AND right <> 'NaN')") + .map("$less_than_or_equal(left: approx_numeric_type, right: approx_numeric_type)").to("((left <= right) AND right <> 'NaN')") + .map("$greater_than(left: approx_numeric_type, right: approx_numeric_type)").to("((left > right) AND left <> 'NaN')") + .map("$greater_than_or_equal(left: approx_numeric_type, right: approx_numeric_type)").to("((left >= right) AND left <> 'NaN')") .map("$like(value: varchar, pattern: varchar): boolean").to("value LIKE pattern") .map("$not($is_null(value))").to("value IS NOT NULL") .map("$not(value: boolean)").to("NOT value") diff --git a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClientModule.java b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClientModule.java index 14e41f61754db..14afeaefcd005 100644 --- a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClientModule.java +++ b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClientModule.java @@ -26,6 +26,8 @@ import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcMetadataFactory; +import io.trino.plugin.jdbc.NaNSpecificQueryBuilder; +import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.credential.CredentialProvider; import org.apache.ignite.IgniteJdbcThinDriver; @@ -41,6 +43,7 @@ public void configure(Binder binder) { binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(IgniteClient.class).in(Scopes.SINGLETON); newOptionalBinder(binder, JdbcMetadataFactory.class).setBinding().to(IgniteJdbcMetadataFactory.class).in(Scopes.SINGLETON); + newOptionalBinder(binder, QueryBuilder.class).setBinding().to(NaNSpecificQueryBuilder.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(IgniteJdbcConfig.class); bindTablePropertiesProvider(binder, IgniteTableProperties.class); binder.install(new DecimalModule()); diff --git a/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteConnectorTest.java b/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteConnectorTest.java index b07d98755f7f0..159499fc60e4a 100644 --- a/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteConnectorTest.java +++ b/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteConnectorTest.java @@ -67,6 +67,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) return switch (connectorBehavior) { case SUPPORTS_AGGREGATION_PUSHDOWN, SUPPORTS_JOIN_PUSHDOWN, + SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN, SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN_WITH_LIKE, SUPPORTS_TOPN_PUSHDOWN_WITH_VARCHAR -> true; case SUPPORTS_ADD_COLUMN_NOT_NULL_CONSTRAINT, @@ -77,6 +78,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV, SUPPORTS_AGGREGATION_PUSHDOWN_VARIANCE, SUPPORTS_ARRAY, + SUPPORTS_PREDICATE_ARITHMETIC_EXPRESSION_PUSHDOWN, SUPPORTS_COMMENT_ON_COLUMN, SUPPORTS_COMMENT_ON_TABLE, SUPPORTS_DROP_NOT_NULL_CONSTRAINT, diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java index be59ee6229a69..0be2d5050e058 100644 --- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java @@ -47,6 +47,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_AGGREGATION_PUSHDOWN_COVARIANCE, SUPPORTS_AGGREGATION_PUSHDOWN_REGRESSION, SUPPORTS_ARRAY, + SUPPORTS_NAN_INFINITY, SUPPORTS_COMMENT_ON_COLUMN, SUPPORTS_CREATE_TABLE_WITH_COLUMN_COMMENT, SUPPORTS_DROP_NOT_NULL_CONSTRAINT, @@ -368,4 +369,10 @@ protected void verifyColumnNameLengthFailurePermissible(Throwable e) { assertThat(e).hasMessageMatching("(.*Identifier name '.*' is too long|.*Incorrect column name.*)"); } + + @Override + protected void verifyApproximateNumericSpecialValueFailure(Throwable e) + { + assertThat(e).hasMessageMatching("Failed to insert data: .* Unknown column '.*' in 'field list'"); + } } diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlConnectorTest.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlConnectorTest.java index bfb8eb2029cf2..f4cacf12a812e 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlConnectorTest.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlConnectorTest.java @@ -63,6 +63,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_AGGREGATION_PUSHDOWN_COVARIANCE, SUPPORTS_AGGREGATION_PUSHDOWN_REGRESSION, SUPPORTS_ARRAY, + SUPPORTS_NAN_INFINITY, SUPPORTS_COMMENT_ON_COLUMN, SUPPORTS_CREATE_TABLE_WITH_COLUMN_COMMENT, SUPPORTS_DROP_NOT_NULL_CONSTRAINT, @@ -627,4 +628,10 @@ public void verifyMySqlJdbcDriverNegativeDateHandling() } } } + + @Override + protected void verifyApproximateNumericSpecialValueFailure(Throwable e) + { + assertThat(e).hasMessageMatching(".* is not a valid numeric or approximate numeric value"); + } } diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java index 60cefaf68c418..9c246f1457b8d 100644 --- a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java @@ -232,13 +232,20 @@ public OracleClient( this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) - .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "real", "double")) - .map("$equal(left: numeric_type, right: numeric_type)").to("left = right") - .map("$not_equal(left: numeric_type, right: numeric_type)").to("left <> right") - .map("$less_than(left: numeric_type, right: numeric_type)").to("left < right") - .map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right") - .map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right") - .map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right") + .withTypeClass("exact_numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal")) + .withTypeClass("approx_numeric_type", ImmutableSet.of("real", "double")) + .map("$equal(left: exact_numeric_type, right: exact_numeric_type)").to("left = right") + .map("$not_equal(left: exact_numeric_type, right: exact_numeric_type)").to("left <> right") + .map("$equal(left: approx_numeric_type, right: approx_numeric_type)").to("left = right AND left <> 'NaN' AND right <> 'NaN'") + .map("$not_equal(left: approx_numeric_type, right: approx_numeric_type)").to("left <> right OR (left = 'NaN' AND right = 'NaN')") + .map("$less_than(left: exact_numeric_type, right: exact_numeric_type)").to("left < right") + .map("$less_than_or_equal(left: exact_numeric_type, right: exact_numeric_type)").to("left <= right") + .map("$greater_than(left: exact_numeric_type, right: exact_numeric_type)").to("left > right") + .map("$greater_than_or_equal(left: exact_numeric_type, right: exact_numeric_type)").to("left >= right") + .map("$less_than(left: approx_numeric_type, right: approx_numeric_type)").to("left < right AND right <> 'NaN'") + .map("$less_than_or_equal(left: approx_numeric_type, right: approx_numeric_type)").to("left <= right AND right <> 'NaN'") + .map("$greater_than(left: approx_numeric_type, right: approx_numeric_type)").to("left > right AND left <> 'NaN'") + .map("$greater_than_or_equal(left: approx_numeric_type, right: approx_numeric_type)").to("left >= right AND left <> 'NaN'") .add(new RewriteStringComparison()) .build(); diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClientModule.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClientModule.java index 219d5d25f2970..7476d4c090bf2 100644 --- a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClientModule.java +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClientModule.java @@ -27,6 +27,8 @@ import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.MaxDomainCompactionThreshold; +import io.trino.plugin.jdbc.NaNSpecificQueryBuilder; +import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.RetryingConnectionFactory.RetryStrategy; import io.trino.plugin.jdbc.TimestampTimeZoneDomain; import io.trino.plugin.jdbc.credential.CredentialProvider; @@ -56,6 +58,7 @@ public void configure(Binder binder) bindSessionPropertiesProvider(binder, OracleSessionProperties.class); configBinder(binder).bindConfig(OracleConfig.class); newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)).setBinding().toInstance(ORACLE_MAX_LIST_EXPRESSIONS); + newOptionalBinder(binder, QueryBuilder.class).setBinding().to(NaNSpecificQueryBuilder.class).in(Scopes.SINGLETON); newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON); newSetBinder(binder, RetryStrategy.class).addBinding().to(OracleRetryStrategy.class).in(Scopes.SINGLETON); } diff --git a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java index 21ceb79f87392..81605f6b3285c 100644 --- a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java +++ b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java @@ -416,6 +416,88 @@ public void testNativeQueryIncorrectSyntax() .failure().hasMessageContaining("Query not supported: ResultSetMetaData not available for query: some wrong syntax"); } + @Test + @Override + // Oracle doesn't support `IS DISTINCT FROM` expression pushdown + public void testSpecialValueOnApproximateNumericColumn() + { + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "spl_approx_numeric", + "(c_varchar VARCHAR, c_real REAL, c_real_2 REAL, c_double DOUBLE, c_double_2 DOUBLE)", + List.of( + "'1', NaN(), REAL '1', Nan(), DOUBLE '1'", + "'2', -Infinity(), REAL '1', -Infinity(), DOUBLE '1'", + "'3', Infinity(), REAL '1', Infinity(), DOUBLE '1'", + "'4', Nan(), Nan(), NaN(), Nan()"))) { + String tableName = table.getName(); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > 1 AND c_double > 1")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < 1 AND c_double < 1")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '2'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < Infinity() AND c_double < Infinity()")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '2'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > -Infinity() AND c_double > -Infinity()")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 AND c_double > c_double_2")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real IS DISTINCT FROM c_real_2 AND c_double IS DISTINCT FROM c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '1', '2', '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real_2 IS DISTINCT FROM c_real AND c_double_2 IS DISTINCT FROM c_double")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '1', '2', '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 OR c_double > c_double_2")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '3'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real = c_real_2 OR c_double = c_double_2")) + .isFullyPushedDown() + .returnsEmptyResult(); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real <> c_real_2 OR c_double <> c_double_2")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES '1', '2', '3', '4'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < 1 OR c_double < 1")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '2'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < Infinity() OR c_double < Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '2'"); + + assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > -Infinity() OR c_double > -Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES '3'"); + } + } + @Override protected TestTable simpleTable() { diff --git a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java index 7e1dd2962b45c..22860bb7c5856 100644 --- a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java +++ b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java @@ -90,6 +90,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_UPDATE -> true; case SUPPORTS_ADD_COLUMN_WITH_COMMENT, SUPPORTS_AGGREGATION_PUSHDOWN, + SUPPORTS_NAN_INFINITY, SUPPORTS_COMMENT_ON_COLUMN, SUPPORTS_COMMENT_ON_TABLE, SUPPORTS_CREATE_TABLE_WITH_COLUMN_COMMENT, @@ -831,4 +832,10 @@ protected SqlExecutor onRemoteDatabase() } }; } + + @Override + protected void verifyApproximateNumericSpecialValueFailure(Throwable e) + { + assertThat(e).hasMessageMatching("Character .* is neither a decimal digit number, decimal point, nor \"e\" notation exponential mark\\."); + } } diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java index 06f8d666bde23..4bd217c1e1af2 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java @@ -14,11 +14,11 @@ package io.trino.plugin.postgresql; import com.google.inject.Inject; -import io.trino.plugin.jdbc.DefaultQueryBuilder; import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.NaNSpecificQueryBuilder; import io.trino.plugin.jdbc.QueryParameter; import io.trino.plugin.jdbc.WriteFunction; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; @@ -34,7 +34,7 @@ import static java.lang.String.format; public class CollationAwareQueryBuilder - extends DefaultQueryBuilder + extends NaNSpecificQueryBuilder { @Inject public CollationAwareQueryBuilder(RemoteQueryModifier queryModifier) diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index b1fa936a5fb88..46b5e406c3a49 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -308,14 +308,22 @@ public PostgreSqlClient( .addStandardRules(this::quoted) .add(new RewriteIn()) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) - .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "real", "double")) - .map("$equal(left, right)").to("left = right") - .map("$not_equal(left, right)").to("left <> right") + .withTypeClass("exact_numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal")) + .withTypeClass("approx_numeric_type", ImmutableSet.of("real", "double")) + .withTypeClass("non_approx_numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "varchar", "char")) + .map("$equal(left: non_approx_numeric_type, right: non_approx_numeric_type)").to("left = right") + .map("$not_equal(left: non_approx_numeric_type, right: non_approx_numeric_type)").to("left <> right") + .map("$equal(left: approx_numeric_type, right: approx_numeric_type)").to("left = right AND left <> 'NaN' AND right <> 'NaN'") + .map("$not_equal(left: approx_numeric_type, right: approx_numeric_type)").to("left <> right OR (left = 'NaN' AND right = 'NaN')") .map("$identical(left, right)").to("left IS NOT DISTINCT FROM right") - .map("$less_than(left: numeric_type, right: numeric_type)").to("left < right") - .map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right") - .map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right") - .map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right") + .map("$less_than(left: exact_numeric_type, right: exact_numeric_type)").to("left < right") + .map("$less_than_or_equal(left: exact_numeric_type, right: exact_numeric_type)").to("left <= right") + .map("$greater_than(left: exact_numeric_type, right: exact_numeric_type)").to("left > right") + .map("$greater_than_or_equal(left: exact_numeric_type, right: exact_numeric_type)").to("left >= right") + .map("$less_than(left: approx_numeric_type, right: approx_numeric_type)").to("left < right AND right <> 'NaN'") + .map("$less_than_or_equal(left: approx_numeric_type, right: approx_numeric_type)").to("left <= right AND right <> 'NaN'") + .map("$greater_than(left: approx_numeric_type, right: approx_numeric_type)").to("left > right AND left <> 'NaN'") + .map("$greater_than_or_equal(left: approx_numeric_type, right: approx_numeric_type)").to("left >= right AND left <> 'NaN'") .map("$add(left: integer_type, right: integer_type)").to("left + right") .map("$subtract(left: integer_type, right: integer_type)").to("left - right") .map("$multiply(left: integer_type, right: integer_type)").to("left * right") diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java index 731403b0dc1e4..9e0a4ea537d5f 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java @@ -59,6 +59,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_AGGREGATION_PUSHDOWN_COVARIANCE, SUPPORTS_AGGREGATION_PUSHDOWN_REGRESSION, SUPPORTS_ARRAY, + SUPPORTS_NAN_INFINITY, SUPPORTS_COMMENT_ON_COLUMN, SUPPORTS_COMMENT_ON_TABLE, SUPPORTS_CREATE_TABLE_WITH_COLUMN_COMMENT, @@ -873,4 +874,12 @@ protected Session joinPushdownEnabled(Session session) .setCatalogSessionProperty(session.getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER") .build(); } + + @Override + protected void verifyApproximateNumericSpecialValueFailure(Throwable e) + { + assertThat(e).hasMessageMatching("Failed to insert data: The incoming tabular data stream \\(TDS\\) remote procedure call \\(RPC\\) protocol stream is incorrect\\. " + + "Parameter 3 \\(\"\"\\): The supplied value is not a valid instance of data type real\\. Check the source data for invalid values\\. " + + "An example of an invalid value is data of numeric type with scale greater than precision\\."); + } } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java b/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java index a233c8c6d244c..f5a50a1a420b5 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java @@ -34,6 +34,7 @@ public enum TestingConnectorBehavior SUPPORTS_TRUNCATE(SUPPORTS_DELETE), SUPPORTS_ARRAY, + SUPPORTS_NAN_INFINITY, SUPPORTS_ROW_TYPE, SUPPORTS_PREDICATE_PUSHDOWN,