From 8e5001761741486bec217e616efb6b1b6ac9e5c8 Mon Sep 17 00:00:00 2001 From: "praveenkrishna.d" Date: Fri, 10 May 2024 20:42:46 +0530 Subject: [PATCH] Handle NaN when pushing filter for JDBC connector --- .../plugin/jdbc/NaNSpecificQueryBuilder.java | 47 ++++++++++++++++ .../plugin/jdbc/BaseJdbcConnectorTest.java | 53 ++++++++++++++++++ .../plugin/jdbc/TestJdbcConnectorTest.java | 56 +++++++++++++++++++ .../plugin/jdbc/TestingH2JdbcClient.java | 2 +- .../plugin/jdbc/TestingH2JdbcModule.java | 2 + .../TestClickHouseConnectorTest.java | 55 ++++++++++++++++++ .../plugin/druid/TestDruidConnectorTest.java | 7 +++ .../plugin/ignite/IgniteClientModule.java | 3 + .../mariadb/BaseMariaDbConnectorTest.java | 8 +++ .../plugin/mysql/BaseMySqlConnectorTest.java | 8 +++ .../plugin/oracle/OracleClientModule.java | 3 + .../phoenix5/TestPhoenixConnectorTest.java | 8 +++ .../plugin/postgresql/PostgreSqlClient.java | 15 +++-- .../postgresql/PostgresAwareQueryBuilder.java | 7 +++ .../sqlserver/BaseSqlServerConnectorTest.java | 9 +++ 15 files changed, 277 insertions(+), 6 deletions(-) create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/NaNSpecificQueryBuilder.java diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/NaNSpecificQueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/NaNSpecificQueryBuilder.java new file mode 100644 index 0000000000000..1da2cfea35a01 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/NaNSpecificQueryBuilder.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import com.google.inject.Inject; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.type.Type; + +import java.util.Optional; +import java.util.function.Consumer; + +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.RealType.REAL; +import static java.lang.String.format; + +public class NaNSpecificQueryBuilder + extends DefaultQueryBuilder +{ + @Inject + public NaNSpecificQueryBuilder(RemoteQueryModifier queryModifier) + { + super(queryModifier); + } + + @Override + protected String toPredicate(JdbcClient client, ConnectorSession session, JdbcColumnHandle column, JdbcTypeHandle jdbcType, Type type, WriteFunction writeFunction, String operator, Object value, Consumer accumulator) + { + if ((type == REAL || type == DOUBLE) && (operator.equals(">") || operator.equals(">="))) { + accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value))); + return format("((%s %s %s) AND (%s <> 'NaN'))", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression(), client.quoted(column.getColumnName())); + } + + return super.toPredicate(client, session, column, jdbcType, type, writeFunction, operator, value, accumulator); + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 0d11fa0dc1bd0..b9696728b6acc 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -193,6 +193,59 @@ public void testCharTrailingSpace() } } + @Test + public void testSpecialValueOnApproximateNumericColumn() + { + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "spl_approx_numeric", + "(c_int INTEGER, c_real REAL, c_real_2 REAL, c_double DOUBLE, c_double_2 DOUBLE)", + List.of( + "1, NaN(), REAL '1', Nan(), DOUBLE '1'", + "2, -Infinity(), REAL '1', -Infinity(), DOUBLE '1'", + "3, Infinity(), REAL '1', Infinity(), DOUBLE '1'"))) { + String tableName = table.getName(); + + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > 1 AND c_double > 1")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES 3"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > c_real_2 AND c_double > c_double_2")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES 3"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < 1 AND c_double < 1")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES 2"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < Infinity() AND c_double < Infinity()")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES 2"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > -Infinity() AND c_double > -Infinity()")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES 3"); + + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > c_real_2 OR c_double > c_double_2")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES 3"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < 1 OR c_double < 1")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 2"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < Infinity() OR c_double < Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 2"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > -Infinity() OR c_double > -Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 3"); + } + } + @Test public void testAggregationPushdown() { diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcConnectorTest.java index 6751c1160dc0d..afcb99e7cd5f8 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcConnectorTest.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.Session; +import io.trino.sql.planner.plan.FilterNode; import io.trino.testing.QueryRunner; import io.trino.testing.TestingConnectorBehavior; import io.trino.testing.sql.JdbcSqlExecutor; @@ -25,6 +26,7 @@ import org.junit.jupiter.api.parallel.Execution; import org.junit.jupiter.api.parallel.ExecutionMode; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.OptionalInt; @@ -278,6 +280,60 @@ public void testAddColumnConcurrently() abort("TODO: Enable this test after finding the failure cause"); } + @Test + @Override + public void testSpecialValueOnApproximateNumericColumn() + { + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "spl_approx_numeric", + "(c_int INTEGER, c_real REAL, c_real_2 REAL, c_double DOUBLE, c_double_2 DOUBLE)", + List.of( + "1, NaN(), REAL '1', Nan(), DOUBLE '1'", + "2, -Infinity(), REAL '1', -Infinity(), DOUBLE '1'", + "3, Infinity(), REAL '1', Infinity(), DOUBLE '1'"))) { + String tableName = table.getName(); + + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > 1 AND c_double > 1")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES 3"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > c_real_2 AND c_double > c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 3"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < 1 AND c_double < 1")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES 2"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < Infinity() AND c_double < Infinity()")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES 2"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > -Infinity() AND c_double > -Infinity()")) + .isFullyPushedDown() + .skippingTypesCheck() + .matches("VALUES 3"); + + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > c_real_2 OR c_double > c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 3"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < 1 OR c_double < 1")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 2"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < Infinity() OR c_double < Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 2"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > -Infinity() OR c_double > -Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 3"); + } + } + @Override protected void verifySetColumnTypeFailurePermissible(Throwable e) { diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java index 6749eea8c367a..db07c64b705b8 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java @@ -99,7 +99,7 @@ public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFa public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping) { - super("\"", connectionFactory, new DefaultQueryBuilder(RemoteQueryModifier.NONE), config.getJdbcTypesMappedToVarchar(), identifierMapping, RemoteQueryModifier.NONE, false); + super("\"", connectionFactory, new NaNSpecificQueryBuilder(RemoteQueryModifier.NONE), config.getJdbcTypesMappedToVarchar(), identifierMapping, RemoteQueryModifier.NONE, false); } @Override diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcModule.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcModule.java index 246c20cc5886b..3be187ec920df 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcModule.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcModule.java @@ -29,6 +29,7 @@ import java.util.concurrent.ThreadLocalRandom; import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -51,6 +52,7 @@ public TestingH2JdbcModule(TestingH2JdbcClientFactory testingH2JdbcClientFactory public void configure(Binder binder) { newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON); + newOptionalBinder(binder, QueryBuilder.class).setBinding().to(NaNSpecificQueryBuilder.class).in(Scopes.SINGLETON); } @Provides diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java index 793b243cb07ee..49414b2de11a0 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableMap; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; import io.trino.sql.planner.plan.AggregationNode; +import io.trino.sql.planner.plan.FilterNode; import io.trino.testing.MaterializedResult; import io.trino.testing.QueryRunner; import io.trino.testing.TestingConnectorBehavior; @@ -161,6 +162,60 @@ public void testRenameColumnName() { } + @Test + @Override + public void testSpecialValueOnApproximateNumericColumn() + { + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "spl_approx_numeric", + "(c_int INTEGER, c_real REAL, c_real_2 REAL, c_double DOUBLE, c_double_2 DOUBLE)", + List.of( + "1, NaN(), REAL '1', Nan(), DOUBLE '1'", + "2, -Infinity(), REAL '1', -Infinity(), DOUBLE '1'", + "3, Infinity(), REAL '1', Infinity(), DOUBLE '1'"))) { + String tableName = table.getName(); + + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > 1 AND c_double > 1")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 3"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > c_real_2 AND c_double > c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 3"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < 1 AND c_double < 1")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 2"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < Infinity() AND c_double < Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 2"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > -Infinity() AND c_double > -Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 3"); + + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > c_real_2 OR c_double > c_double_2")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 3"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < 1 OR c_double < 1")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 2"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < Infinity() OR c_double < Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 2"); + assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > -Infinity() OR c_double > -Infinity()")) + .isNotFullyPushedDown(FilterNode.class) + .skippingTypesCheck() + .matches("VALUES 3"); + } + } + @Override protected Optional filterColumnNameTestData(String columnName) { diff --git a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidConnectorTest.java b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidConnectorTest.java index caa22c92e7379..661b9dfc23ea4 100644 --- a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidConnectorTest.java +++ b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidConnectorTest.java @@ -210,6 +210,13 @@ public void testSelectAll() assertQuery("SELECT orderkey, custkey, orderstatus, totalprice, orderdate, orderpriority, clerk, shippriority, comment FROM orders"); } + @Test + @Override + public void testSpecialValueOnApproximateNumericColumn() + { + // Druid doesn't support writes and it also doesn't support NaN and Infinity + } + /** * This test verifies that the filtering we have in place to overcome Druid's limitation of * not handling the escaping of search characters like % and _, works correctly. diff --git a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClientModule.java b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClientModule.java index 14e41f61754db..14afeaefcd005 100644 --- a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClientModule.java +++ b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClientModule.java @@ -26,6 +26,8 @@ import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcMetadataFactory; +import io.trino.plugin.jdbc.NaNSpecificQueryBuilder; +import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.credential.CredentialProvider; import org.apache.ignite.IgniteJdbcThinDriver; @@ -41,6 +43,7 @@ public void configure(Binder binder) { binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(IgniteClient.class).in(Scopes.SINGLETON); newOptionalBinder(binder, JdbcMetadataFactory.class).setBinding().to(IgniteJdbcMetadataFactory.class).in(Scopes.SINGLETON); + newOptionalBinder(binder, QueryBuilder.class).setBinding().to(NaNSpecificQueryBuilder.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(IgniteJdbcConfig.class); bindTablePropertiesProvider(binder, IgniteTableProperties.class); binder.install(new DecimalModule()); diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java index 768b12f8ba9fa..980d1dcf9dbc3 100644 --- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/BaseMariaDbConnectorTest.java @@ -91,6 +91,14 @@ public void testShowColumns() assertThat(query("SHOW COLUMNS FROM orders")).result().matches(getDescribeOrdersResult()); } + @Test + @Override + public void testSpecialValueOnApproximateNumericColumn() + { + assertThatThrownBy(super::testSpecialValueOnApproximateNumericColumn) + .hasMessageContaining("Out of range value for column 'c_real' at row 1"); + } + @Override protected boolean isColumnNameRejected(Exception exception, String columnName, boolean delimited) { diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlConnectorTest.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlConnectorTest.java index 7a81b1da97701..d69ab4c3b0df5 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlConnectorTest.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlConnectorTest.java @@ -107,6 +107,14 @@ public void testShowColumns() assertThat(query("SHOW COLUMNS FROM orders")).result().matches(getDescribeOrdersResult()); } + @Test + @Override + public void testSpecialValueOnApproximateNumericColumn() + { + assertThatThrownBy(super::testSpecialValueOnApproximateNumericColumn) + .hasMessageContaining("'NaN' is not a valid numeric or approximate numeric value"); + } + @Override protected boolean isColumnNameRejected(Exception exception, String columnName, boolean delimited) { diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClientModule.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClientModule.java index 219d5d25f2970..7476d4c090bf2 100644 --- a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClientModule.java +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClientModule.java @@ -27,6 +27,8 @@ import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.MaxDomainCompactionThreshold; +import io.trino.plugin.jdbc.NaNSpecificQueryBuilder; +import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.RetryingConnectionFactory.RetryStrategy; import io.trino.plugin.jdbc.TimestampTimeZoneDomain; import io.trino.plugin.jdbc.credential.CredentialProvider; @@ -56,6 +58,7 @@ public void configure(Binder binder) bindSessionPropertiesProvider(binder, OracleSessionProperties.class); configBinder(binder).bindConfig(OracleConfig.class); newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)).setBinding().toInstance(ORACLE_MAX_LIST_EXPRESSIONS); + newOptionalBinder(binder, QueryBuilder.class).setBinding().to(NaNSpecificQueryBuilder.class).in(Scopes.SINGLETON); newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON); newSetBinder(binder, RetryStrategy.class).addBinding().to(OracleRetryStrategy.class).in(Scopes.SINGLETON); } diff --git a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java index fa7ec81c67a33..ccd919e8b5066 100644 --- a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java +++ b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java @@ -254,6 +254,14 @@ public void testCreateSchema() abort("test disabled until issue fixed"); // TODO https://github.com/trinodb/trino/issues/2348 } + @Test + @Override + public void testSpecialValueOnApproximateNumericColumn() + { + assertThatThrownBy(super::testSpecialValueOnApproximateNumericColumn) + .hasMessageContaining("Character N is neither a decimal digit number, decimal point, nor \"e\" notation exponential mark"); + } + @Override protected boolean isColumnNameRejected(Exception exception, String columnName, boolean delimited) { diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index b1fa936a5fb88..7ee65abb73aa7 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -308,14 +308,19 @@ public PostgreSqlClient( .addStandardRules(this::quoted) .add(new RewriteIn()) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) - .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "real", "double")) + .withTypeClass("exact_numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal")) + .withTypeClass("approx_numeric_type", ImmutableSet.of("real", "double")) .map("$equal(left, right)").to("left = right") .map("$not_equal(left, right)").to("left <> right") .map("$identical(left, right)").to("left IS NOT DISTINCT FROM right") - .map("$less_than(left: numeric_type, right: numeric_type)").to("left < right") - .map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right") - .map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right") - .map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right") + .map("$less_than(left: exact_numeric_type, right: exact_numeric_type)").to("left < right") + .map("$less_than_or_equal(left: exact_numeric_type, right: exact_numeric_type)").to("left <= right") + .map("$greater_than(left: exact_numeric_type, right: exact_numeric_type)").to("left > right") + .map("$greater_than_or_equal(left: exact_numeric_type, right: exact_numeric_type)").to("left >= right") + .map("$less_than(left: approx_numeric_type, right: approx_numeric_type)").to("((left < right) AND right <> 'NaN')") + .map("$less_than_or_equal(left: approx_numeric_type, right: approx_numeric_type)").to("((left <= right) AND right <> 'NaN')") + .map("$greater_than(left: approx_numeric_type, right: approx_numeric_type)").to("((left > right) AND left <> 'NaN')") + .map("$greater_than_or_equal(left: approx_numeric_type, right: approx_numeric_type)").to("((left >= right) AND left <> 'NaN')") .map("$add(left: integer_type, right: integer_type)").to("left + right") .map("$subtract(left: integer_type, right: integer_type)").to("left - right") .map("$multiply(left: integer_type, right: integer_type)").to("left * right") diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgresAwareQueryBuilder.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgresAwareQueryBuilder.java index 8dbc3e098d5d4..b2a0e79d7a16f 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgresAwareQueryBuilder.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgresAwareQueryBuilder.java @@ -31,6 +31,8 @@ import static io.trino.plugin.postgresql.PostgreSqlClient.isCollatable; import static io.trino.plugin.postgresql.PostgreSqlSessionProperties.isEnableStringPushdownWithCollate; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.RealType.REAL; import static java.lang.String.format; public class PostgresAwareQueryBuilder @@ -69,6 +71,11 @@ protected String toPredicate(JdbcClient client, ConnectorSession session, JdbcCo return format("%s %s %s COLLATE \"C\"", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression()); } + if ((type == REAL || type == DOUBLE) && (operator.equals(">") || operator.equals(">="))) { + accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value))); + return format("((%s %s %s) AND (%s <> 'NaN'))", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression(), client.quoted(column.getColumnName())); + } + return super.toPredicate(client, session, column, jdbcType, type, writeFunction, operator, value, accumulator); } } diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java index 1ae62d92aabb0..04cada4e9b637 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/BaseSqlServerConnectorTest.java @@ -41,6 +41,7 @@ import static java.util.stream.Collectors.joining; import static java.util.stream.IntStream.range; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; public abstract class BaseSqlServerConnectorTest extends BaseJdbcConnectorTest @@ -119,6 +120,14 @@ public void testReadFromView() } } + @Test + @Override + public void testSpecialValueOnApproximateNumericColumn() + { + assertThatThrownBy(super::testSpecialValueOnApproximateNumericColumn) + .hasMessageContaining("Failed to insert data: The incoming tabular data stream (TDS) remote procedure call (RPC) protocol stream is incorrect."); + } + @Override protected void verifyAddNotNullColumnToNonEmptyTableFailurePermissible(Throwable e) {