Skip to content

Commit

Permalink
Fix TABLESAMPLE with DECIMAL
Browse files Browse the repository at this point in the history
  • Loading branch information
findepi committed Sep 4, 2020
1 parent 07eb8b2 commit e19d0f1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
Expand Up @@ -56,6 +56,7 @@
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeNotFoundException;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.sql.InterpretedFunctionInvoker;
import io.prestosql.sql.SqlPath;
import io.prestosql.sql.analyzer.Analysis.GroupingSetAnalysis;
import io.prestosql.sql.analyzer.Analysis.SelectExpression;
Expand Down Expand Up @@ -1473,7 +1474,8 @@ protected Scope visitSampledRelation(SampledRelation relation, Optional<Scope> s
analysis.isDescribe())
.getExpressionTypes();

if (!typeCoercion.canCoerce(expressionTypes.get(NodeRef.of(samplePercentage)), DOUBLE)) {
Type samplePercentageType = expressionTypes.get(NodeRef.of(samplePercentage));
if (!typeCoercion.canCoerce(samplePercentageType, DOUBLE)) {
throw semanticException(TYPE_MISMATCH, samplePercentage, "Sample percentage should be a numeric expression");
}

Expand All @@ -1487,8 +1489,14 @@ protected Scope visitSampledRelation(SampledRelation relation, Optional<Scope> s
throw semanticException(INVALID_ARGUMENTS, samplePercentage, "Sample percentage cannot be NULL");
}

verify(samplePercentageObject instanceof Number, "Sample percentage should evaluate to a Number");
double samplePercentageValue = ((Number) samplePercentageObject).doubleValue();
if (samplePercentageType != DOUBLE) {
ResolvedFunction coercion = metadata.getCoercion(samplePercentageType, DOUBLE);
InterpretedFunctionInvoker functionInvoker = new InterpretedFunctionInvoker(metadata);
samplePercentageObject = functionInvoker.invoke(coercion, session.toConnectorSession(), samplePercentageObject);
verify(samplePercentageObject != null, "Coercion from %s to %s returned null", samplePercentageType, DOUBLE);
}

double samplePercentageValue = (double) samplePercentageObject;

if (samplePercentageValue < 0.0) {
throw semanticException(NUMERIC_VALUE_OUT_OF_RANGE, samplePercentage, "Sample percentage must be greater than or equal to 0");
Expand Down
12 changes: 12 additions & 0 deletions presto-tests/src/test/java/io/prestosql/tests/TestTablesample.java
Expand Up @@ -65,6 +65,18 @@ public void testTablesample()
// 1%
assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (1)"))
.satisfies(result -> assertThat((Long) result.getOnlyValue()).isBetween(50L, 450L));

// 0.1%
assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (1e-1)"))
.satisfies(result -> assertThat((Long) result.getOnlyValue()).isBetween(5L, 45L));

// 0.1% as decimal
assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (0.1)"))
.satisfies(result -> assertThat((Long) result.getOnlyValue()).isBetween(5L, 45L));

// fraction as long decimal
assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (0.000000000000000000001)"))
.satisfies(result -> assertThat((Long) result.getOnlyValue()).isBetween(0L, 5L));
}

@Test
Expand Down

0 comments on commit e19d0f1

Please sign in to comment.