Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle NaN when pushing filter for JDBC connector #21923

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
findepi marked this conversation as resolved.
Show resolved Hide resolved
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc;

import com.google.inject.Inject;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.type.Type;

import java.util.Optional;
import java.util.function.Consumer;

import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
import static java.lang.String.format;

public class NaNSpecificQueryBuilder
extends DefaultQueryBuilder
{
@Inject
public NaNSpecificQueryBuilder(RemoteQueryModifier queryModifier)
{
super(queryModifier);
}

@Override
protected String toPredicate(JdbcClient client, ConnectorSession session, JdbcColumnHandle column, JdbcTypeHandle jdbcType, Type type, WriteFunction writeFunction, String operator, Object value, Consumer<QueryParameter> accumulator)
{
if ((type == REAL || type == DOUBLE) && (operator.equals(">") || operator.equals(">="))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about <, <=, =, !=?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this check of other operators as they work as expected - in case of Postgres NaN is placed about Infinity so we need to add them only for > or >= operator.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a code comment.

when you write it down say "in PostgreSQL ..." in a generic class, you'll realize that you're exploiting a (common) implementation detail.

  • what does the spec say about double and NaN ordering? (i think the spec omits existence of NaNs, but i may be wrong)
  • even if the spec said something definitive, implementation vary (for example, Trino and PostgreSQL have different behavior). Which means the behavior should be implemented in connector-specific manner.

accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value)));
return format("((%s %s %s) AND (%s <> 'NaN'))", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression(), client.quoted(column.getColumnName()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'NaN' as a varchar?
are you sure this is portable SQL?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the tests 'NaN' is not parsed as a Varchar but considered as a NaN representation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is used along with double/real columns I think it is coerced automatically and the tests also proves the same.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may work for PostgreSQL and may also work for some other databases, but I doubt this to be a really portable behavior (for example, Trino does not support comparing double values with 'NaN' literal, does it?) and this class is meant to be portable.

}

return super.toPredicate(client, session, column, jdbcType, type, writeFunction, operator, value, accumulator);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,11 @@
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_INEQUALITY;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_LIMIT_PUSHDOWN;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_MERGE;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_NAN_INFINITY;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_NATIVE_QUERY;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_NOT_NULL_CONSTRAINT;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_ARITHMETIC_EXPRESSION_PUSHDOWN;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN_WITH_LIKE;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_PUSHDOWN;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY;
Expand Down Expand Up @@ -193,6 +195,127 @@ public void testCharTrailingSpace()
}
}

@Test
public void testSpecialValueOnApproximateNumericColumn()
{
if (!hasBehavior(SUPPORTS_NAN_INFINITY)) {
assertThatThrownBy(() -> assertUpdate("CREATE TABLE spl_values_nan AS SELECT CAST(Nan() AS REAL) real_col, Nan() AS double_col"))
.satisfies(this::verifyApproximateNumericSpecialValueFailure);
assertThatThrownBy(() -> assertUpdate("CREATE TABLE spl_values_infinity AS SELECT CAST(Infinity() AS REAL) real_col, Infinity() AS double_col"))
.satisfies(this::verifyApproximateNumericSpecialValueFailure);
assertThatThrownBy(() -> assertUpdate("CREATE TABLE spl_values_negative_infinity AS SELECT CAST(-Infinity() AS REAL) real_col, -Infinity() AS double_col"))
.satisfies(this::verifyApproximateNumericSpecialValueFailure);
return;
}
try (TestTable table = new TestTable(
getQueryRunner()::execute,
"spl_approx_numeric",
"(c_varchar VARCHAR, c_real REAL, c_real_2 REAL, c_double DOUBLE, c_double_2 DOUBLE)",
List.of(
"'1', NaN(), REAL '1', Nan(), DOUBLE '1'",
"'2', -Infinity(), REAL '1', -Infinity(), DOUBLE '1'",
"'3', Infinity(), REAL '1', Infinity(), DOUBLE '1'",
"'4', Nan(), Nan(), NaN(), Nan()"))) {
String tableName = table.getName();

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > 1 AND c_double > 1"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < 1 AND c_double < 1"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES '2'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < Infinity() AND c_double < Infinity()"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES '2'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > -Infinity() AND c_double > -Infinity()"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES '3'");

if (hasBehavior(SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN)) {
assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 AND c_double > c_double_2"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real IS DISTINCT FROM c_real_2 AND c_double IS DISTINCT FROM c_double_2"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES '1', '2', '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real_2 IS DISTINCT FROM c_real AND c_double_2 IS DISTINCT FROM c_double"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES '1', '2', '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 OR c_double > c_double_2"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real = c_real_2 OR c_double = c_double_2"))
.isFullyPushedDown()
.returnsEmptyResult();

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real <> c_real_2 OR c_double <> c_double_2"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES '1', '2', '3', '4'");
}
else {
assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 AND c_double > c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real IS DISTINCT FROM c_real_2 AND c_double IS DISTINCT FROM c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '1', '2', '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real_2 IS DISTINCT FROM c_real AND c_double_2 IS DISTINCT FROM c_double"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '1', '2', '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 OR c_double > c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real = c_real_2 OR c_double = c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.returnsEmptyResult();

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real <> c_real_2 OR c_double <> c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '1', '2', '3', '4'");
}

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < 1 OR c_double < 1"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '2'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < Infinity() OR c_double < Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '2'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > -Infinity() OR c_double > -Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '3'");
}
}

@Test
public void testAggregationPushdown()
{
Expand Down Expand Up @@ -2204,6 +2327,12 @@ protected void assertDynamicFiltering(@Language("SQL") String sql, JoinDistribut
assertDynamicFiltering(sql, joinDistributionType, true);
}

protected void verifyApproximateNumericSpecialValueFailure(Throwable e)
{
throw new AssertionError("Unexpected special value", e);
}


private void assertNoDynamicFiltering(@Language("SQL") String sql)
{
assertDynamicFiltering(sql, PARTITIONED, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ public void configure(Binder binder) {}
@Provides
@Singleton
@ForBaseJdbc
public static JdbcClient provideJdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping)
public static JdbcClient provideJdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping)
{
return new TestingH2JdbcClient(config, connectionFactory, identifierMapping);
return new TestingH2JdbcClient(config, connectionFactory, queryBuilder, identifierMapping);
}

@Provides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class TestJdbcTableProperties
protected QueryRunner createQueryRunner()
throws Exception
{
TestingH2JdbcModule module = new TestingH2JdbcModule((config, connectionFactory, identifierMapping) -> new TestingH2JdbcClient(config, connectionFactory, identifierMapping)
TestingH2JdbcModule module = new TestingH2JdbcModule((config, connectionFactory, queryBuilder, identifierMapping) -> new TestingH2JdbcClient(config, connectionFactory, queryBuilder, identifierMapping)
{
@Override
public Map<String, Object> getTableProperties(ConnectorSession session, JdbcTableHandle tableHandle)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ class TestingH2JdbcClient

public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory)
{
this(config, connectionFactory, new DefaultIdentifierMapping());
this(config, connectionFactory, new NaNSpecificQueryBuilder(RemoteQueryModifier.NONE), new DefaultIdentifierMapping());
}

public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping)
public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping)
{
super("\"", connectionFactory, new DefaultQueryBuilder(RemoteQueryModifier.NONE), config.getJdbcTypesMappedToVarchar(), identifierMapping, RemoteQueryModifier.NONE, false);
findepi marked this conversation as resolved.
Show resolved Hide resolved
super("\"", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, RemoteQueryModifier.NONE, false);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.concurrent.ThreadLocalRandom;

import static com.google.inject.multibindings.Multibinder.newSetBinder;
import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

Expand All @@ -51,14 +52,15 @@ public TestingH2JdbcModule(TestingH2JdbcClientFactory testingH2JdbcClientFactory
public void configure(Binder binder)
{
newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON);
newOptionalBinder(binder, QueryBuilder.class).setBinding().to(NaNSpecificQueryBuilder.class).in(Scopes.SINGLETON);
}

@Provides
@Singleton
@ForBaseJdbc
public JdbcClient provideJdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping)
public JdbcClient provideJdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping)
{
return testingH2JdbcClientFactory.create(config, connectionFactory, identifierMapping);
return testingH2JdbcClientFactory.create(config, connectionFactory, queryBuilder, identifierMapping);
}

@Provides
Expand All @@ -83,6 +85,6 @@ public static String createH2ConnectionUrl()

public interface TestingH2JdbcClientFactory
{
TestingH2JdbcClient create(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping);
TestingH2JdbcClient create(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableMap;
import io.trino.plugin.jdbc.BaseJdbcConnectorTest;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.testing.MaterializedResult;
import io.trino.testing.QueryRunner;
import io.trino.testing.TestingConnectorBehavior;
Expand Down Expand Up @@ -161,6 +162,103 @@ public void testRenameColumnName()
{
}

@Test
@Override
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why override? document

// Clickhouse doesn't support push down on real columns
public void testSpecialValueOnApproximateNumericColumn()
{
try (TestTable table = new TestTable(
getQueryRunner()::execute,
"spl_approx_numeric",
"(c_varchar VARCHAR, c_real REAL, c_real_2 REAL, c_double DOUBLE, c_double_2 DOUBLE)",
List.of(
"'1', NaN(), REAL '1', Nan(), DOUBLE '1'",
"'2', -Infinity(), REAL '1', -Infinity(), DOUBLE '1'",
"'3', Infinity(), REAL '1', Infinity(), DOUBLE '1'",
"'4', Nan(), Nan(), NaN(), Nan()"))) {
String tableName = table.getName();

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_double > 1"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > 1 AND c_double > 1"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_double < 1"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES '2'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < 1 AND c_double < 1"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '2'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_double < Infinity()"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES '2'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < Infinity() AND c_double < Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '2'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > -Infinity() AND c_double > -Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_double > -Infinity()"))
.isFullyPushedDown()
.skippingTypesCheck()
.matches("VALUES '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 AND c_double > c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real IS DISTINCT FROM c_real_2 AND c_double IS DISTINCT FROM c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '1', '2', '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real_2 IS DISTINCT FROM c_real AND c_double_2 IS DISTINCT FROM c_double"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '1', '2', '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > c_real_2 OR c_double > c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '3'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real = c_real_2 OR c_double = c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.returnsEmptyResult();

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real <> c_real_2 OR c_double <> c_double_2"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '1', '2', '3', '4'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real < Infinity() OR c_double < Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '2'");

assertThat(query("SELECT c_varchar FROM " + tableName + " WHERE c_real > -Infinity() OR c_double > -Infinity()"))
.isNotFullyPushedDown(FilterNode.class)
.skippingTypesCheck()
.matches("VALUES '3'");
}
}

@Override
protected Optional<String> filterColumnNameTestData(String columnName)
{
Expand Down
Loading
Loading