Skip to content

Commit

Permalink
Handle NaN when pushing filter for JDBC connector
Browse files Browse the repository at this point in the history
  • Loading branch information
Praveen2112 committed May 11, 2024
1 parent 4b9fbb7 commit 8e50017
Show file tree
Hide file tree
Showing 15 changed files with 277 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc;

import com.google.inject.Inject;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.type.Type;

import java.util.Optional;
import java.util.function.Consumer;

import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
import static java.lang.String.format;

public class NaNSpecificQueryBuilder
extends DefaultQueryBuilder
{
@Inject
public NaNSpecificQueryBuilder(RemoteQueryModifier queryModifier)
{
super(queryModifier);
}

@Override
protected String toPredicate(JdbcClient client, ConnectorSession session, JdbcColumnHandle column, JdbcTypeHandle jdbcType, Type type, WriteFunction writeFunction, String operator, Object value, Consumer<QueryParameter> accumulator)
{
if ((type == REAL || type == DOUBLE) && (operator.equals(">") || operator.equals(">="))) {
accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value)));
return format("((%s %s %s) AND (%s <> 'NaN'))", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression(), client.quoted(column.getColumnName()));
}

return super.toPredicate(client, session, column, jdbcType, type, writeFunction, operator, value, accumulator);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,59 @@ public void testCharTrailingSpace()
}
}

@Test
public void testSpecialValueOnApproximateNumericColumn()
{
try (TestTable table = new TestTable(
getQueryRunner()::execute,
"spl_approx_numeric",
"(c_int INTEGER, c_real REAL, c_real_2 REAL, c_double DOUBLE, c_double_2 DOUBLE)",
List.of(
"1, NaN(), REAL '1', Nan(), DOUBLE '1'",
"2, -Infinity(), REAL '1', -Infinity(), DOUBLE '1'",
"3, Infinity(), REAL '1', Infinity(), DOUBLE '1'"))) {
String tableName = table.getName();

assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > 1 AND c_double > 1"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES 3");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > c_real_2 AND c_double > c_double_2"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES 3");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < 1 AND c_double < 1"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES 2");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < Infinity() AND c_double < Infinity()"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES 2");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > -Infinity() AND c_double > -Infinity()"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES 3");

assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > c_real_2 OR c_double > c_double_2"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES 3");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < 1 OR c_double < 1"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 2");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < Infinity() OR c_double < Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 2");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > -Infinity() OR c_double > -Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 3");
}
}

@Test
public void testAggregationPushdown()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.testing.QueryRunner;
import io.trino.testing.TestingConnectorBehavior;
import io.trino.testing.sql.JdbcSqlExecutor;
Expand All @@ -25,6 +26,7 @@
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
Expand Down Expand Up @@ -278,6 +280,60 @@ public void testAddColumnConcurrently()
abort("TODO: Enable this test after finding the failure cause");
}

@Test
@Override
public void testSpecialValueOnApproximateNumericColumn()
{
try (TestTable table = new TestTable(
getQueryRunner()::execute,
"spl_approx_numeric",
"(c_int INTEGER, c_real REAL, c_real_2 REAL, c_double DOUBLE, c_double_2 DOUBLE)",
List.of(
"1, NaN(), REAL '1', Nan(), DOUBLE '1'",
"2, -Infinity(), REAL '1', -Infinity(), DOUBLE '1'",
"3, Infinity(), REAL '1', Infinity(), DOUBLE '1'"))) {
String tableName = table.getName();

assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > 1 AND c_double > 1"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES 3");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > c_real_2 AND c_double > c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 3");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < 1 AND c_double < 1"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES 2");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < Infinity() AND c_double < Infinity()"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES 2");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > -Infinity() AND c_double > -Infinity()"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES 3");

assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > c_real_2 OR c_double > c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 3");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < 1 OR c_double < 1"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 2");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < Infinity() OR c_double < Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 2");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > -Infinity() OR c_double > -Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 3");
}
}

@Override
protected void verifySetColumnTypeFailurePermissible(Throwable e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFa

public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping)
{
super("\"", connectionFactory, new DefaultQueryBuilder(RemoteQueryModifier.NONE), config.getJdbcTypesMappedToVarchar(), identifierMapping, RemoteQueryModifier.NONE, false);
super("\"", connectionFactory, new NaNSpecificQueryBuilder(RemoteQueryModifier.NONE), config.getJdbcTypesMappedToVarchar(), identifierMapping, RemoteQueryModifier.NONE, false);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.concurrent.ThreadLocalRandom;

import static com.google.inject.multibindings.Multibinder.newSetBinder;
import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

Expand All @@ -51,6 +52,7 @@ public TestingH2JdbcModule(TestingH2JdbcClientFactory testingH2JdbcClientFactory
public void configure(Binder binder)
{
newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON);
newOptionalBinder(binder, QueryBuilder.class).setBinding().to(NaNSpecificQueryBuilder.class).in(Scopes.SINGLETON);
}

@Provides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableMap;
import io.trino.plugin.jdbc.BaseJdbcConnectorTest;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.testing.MaterializedResult;
import io.trino.testing.QueryRunner;
import io.trino.testing.TestingConnectorBehavior;
Expand Down Expand Up @@ -161,6 +162,60 @@ public void testRenameColumnName()
{
}

@Test
@Override
public void testSpecialValueOnApproximateNumericColumn()
{
try (TestTable table = new TestTable(
getQueryRunner()::execute,
"spl_approx_numeric",
"(c_int INTEGER, c_real REAL, c_real_2 REAL, c_double DOUBLE, c_double_2 DOUBLE)",
List.of(
"1, NaN(), REAL '1', Nan(), DOUBLE '1'",
"2, -Infinity(), REAL '1', -Infinity(), DOUBLE '1'",
"3, Infinity(), REAL '1', Infinity(), DOUBLE '1'"))) {
String tableName = table.getName();

assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > 1 AND c_double > 1"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 3");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > c_real_2 AND c_double > c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 3");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < 1 AND c_double < 1"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 2");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < Infinity() AND c_double < Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 2");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > -Infinity() AND c_double > -Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 3");

assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > c_real_2 OR c_double > c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 3");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < 1 OR c_double < 1"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 2");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real < Infinity() OR c_double < Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 2");
assertThat(query("SELECT c_int FROM " + tableName + " WHERE c_real > -Infinity() OR c_double > -Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES 3");
}
}

@Override
protected Optional<String> filterColumnNameTestData(String columnName)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ public void testSelectAll()
assertQuery("SELECT orderkey, custkey, orderstatus, totalprice, orderdate, orderpriority, clerk, shippriority, comment FROM orders");
}

@Test
@Override
public void testSpecialValueOnApproximateNumericColumn()
{
// Druid doesn't support writes and it also doesn't support NaN and Infinity
}

/**
* This test verifies that the filtering we have in place to overcome Druid's limitation of
* not handling the escaping of search characters like % and _, works correctly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import io.trino.plugin.jdbc.ForBaseJdbc;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcMetadataFactory;
import io.trino.plugin.jdbc.NaNSpecificQueryBuilder;
import io.trino.plugin.jdbc.QueryBuilder;
import io.trino.plugin.jdbc.credential.CredentialProvider;
import org.apache.ignite.IgniteJdbcThinDriver;

Expand All @@ -41,6 +43,7 @@ public void configure(Binder binder)
{
binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(IgniteClient.class).in(Scopes.SINGLETON);
newOptionalBinder(binder, JdbcMetadataFactory.class).setBinding().to(IgniteJdbcMetadataFactory.class).in(Scopes.SINGLETON);
newOptionalBinder(binder, QueryBuilder.class).setBinding().to(NaNSpecificQueryBuilder.class).in(Scopes.SINGLETON);
configBinder(binder).bindConfig(IgniteJdbcConfig.class);
bindTablePropertiesProvider(binder, IgniteTableProperties.class);
binder.install(new DecimalModule());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ public void testShowColumns()
assertThat(query("SHOW COLUMNS FROM orders")).result().matches(getDescribeOrdersResult());
}

@Test
@Override
public void testSpecialValueOnApproximateNumericColumn()
{
assertThatThrownBy(super::testSpecialValueOnApproximateNumericColumn)
.hasMessageContaining("Out of range value for column 'c_real' at row 1");
}

@Override
protected boolean isColumnNameRejected(Exception exception, String columnName, boolean delimited)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ public void testShowColumns()
assertThat(query("SHOW COLUMNS FROM orders")).result().matches(getDescribeOrdersResult());
}

@Test
@Override
public void testSpecialValueOnApproximateNumericColumn()
{
assertThatThrownBy(super::testSpecialValueOnApproximateNumericColumn)
.hasMessageContaining("'NaN' is not a valid numeric or approximate numeric value");
}

@Override
protected boolean isColumnNameRejected(Exception exception, String columnName, boolean delimited)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import io.trino.plugin.jdbc.ForBaseJdbc;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.MaxDomainCompactionThreshold;
import io.trino.plugin.jdbc.NaNSpecificQueryBuilder;
import io.trino.plugin.jdbc.QueryBuilder;
import io.trino.plugin.jdbc.RetryingConnectionFactory.RetryStrategy;
import io.trino.plugin.jdbc.TimestampTimeZoneDomain;
import io.trino.plugin.jdbc.credential.CredentialProvider;
Expand Down Expand Up @@ -56,6 +58,7 @@ public void configure(Binder binder)
bindSessionPropertiesProvider(binder, OracleSessionProperties.class);
configBinder(binder).bindConfig(OracleConfig.class);
newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)).setBinding().toInstance(ORACLE_MAX_LIST_EXPRESSIONS);
newOptionalBinder(binder, QueryBuilder.class).setBinding().to(NaNSpecificQueryBuilder.class).in(Scopes.SINGLETON);
newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON);
newSetBinder(binder, RetryStrategy.class).addBinding().to(OracleRetryStrategy.class).in(Scopes.SINGLETON);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,14 @@ public void testCreateSchema()
abort("test disabled until issue fixed"); // TODO https://github.com/trinodb/trino/issues/2348
}

@Test
@Override
public void testSpecialValueOnApproximateNumericColumn()
{
assertThatThrownBy(super::testSpecialValueOnApproximateNumericColumn)
.hasMessageContaining("Character N is neither a decimal digit number, decimal point, nor \"e\" notation exponential mark");
}

@Override
protected boolean isColumnNameRejected(Exception exception, String columnName, boolean delimited)
{
Expand Down

0 comments on commit 8e50017

Please sign in to comment.