Skip to content

Commit

Permalink
Adding double to float coercer
Browse files Browse the repository at this point in the history
  • Loading branch information
hustnn authored and dain committed Apr 30, 2019
1 parent f0eb4f3 commit 4a456a8
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 8 deletions.
Expand Up @@ -71,7 +71,7 @@ public boolean canCoerce(HiveType fromHiveType, HiveType toHiveType)
return toHiveType.equals(HIVE_DOUBLE) || toType instanceof DecimalType;
}
if (fromHiveType.equals(HIVE_DOUBLE)) {
return toType instanceof DecimalType;
return toHiveType.equals(HIVE_FLOAT) || toType instanceof DecimalType;
}
if (fromType instanceof DecimalType) {
return toType instanceof DecimalType || toHiveType.equals(HIVE_FLOAT) || toHiveType.equals(HIVE_DOUBLE);
Expand Down
Expand Up @@ -15,6 +15,7 @@

import io.prestosql.plugin.hive.HivePageSourceProvider.BucketAdaptation;
import io.prestosql.plugin.hive.HivePageSourceProvider.ColumnMapping;
import io.prestosql.plugin.hive.coercions.DoubleToFloatCoercer;
import io.prestosql.plugin.hive.coercions.FloatToDoubleCoercer;
import io.prestosql.plugin.hive.coercions.IntegerNumberToVarcharCoercer;
import io.prestosql.plugin.hive.coercions.IntegerNumberUpscaleCoercer;
Expand Down Expand Up @@ -345,6 +346,9 @@ private static Function<Block, Block> createCoercer(TypeManager typeManager, Hiv
if (fromHiveType.equals(HIVE_FLOAT) && toHiveType.equals(HIVE_DOUBLE)) {
return new FloatToDoubleCoercer();
}
if (fromHiveType.equals(HIVE_DOUBLE) && toHiveType.equals(HIVE_FLOAT)) {
return new DoubleToFloatCoercer();
}
if (fromType instanceof DecimalType && toType instanceof DecimalType) {
return createDecimalToDecimalCoercer((DecimalType) fromType, (DecimalType) toType);
}
Expand Down
@@ -0,0 +1,37 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.prestosql.plugin.hive.coercions;

import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;

import static io.prestosql.spi.type.DoubleType.DOUBLE;
import static io.prestosql.spi.type.RealType.REAL;
import static java.lang.Float.floatToRawIntBits;

public class DoubleToFloatCoercer
extends TypeCoercer
{
public DoubleToFloatCoercer()
{
super(DOUBLE, REAL);
}

@Override
protected void applyCoercedValue(BlockBuilder blockBuilder, Block block, int position)
{
REAL.writeLong(blockBuilder, floatToRawIntBits((float) DOUBLE.getDouble(block, position)));
}
}
Expand Up @@ -110,6 +110,7 @@ private static HiveTableDefinition.HiveTableDefinitionBuilder tableDefinitionBui
" int_to_bigint INT," +
" bigint_to_varchar BIGINT," +
" float_to_double " + floatType + "," +
" double_to_float DOUBLE," +
" shortdecimal_to_shortdecimal DECIMAL(10,2)," +
" shortdecimal_to_longdecimal DECIMAL(10,2)," +
" longdecimal_to_shortdecimal DECIMAL(20,12)," +
Expand Down Expand Up @@ -275,12 +276,12 @@ private void doTestHiveCoercion(HiveTableDefinition tableDefinition)
query(format(
"INSERT INTO %s\n" +
"VALUES\n" +
" (TINYINT '-1', TINYINT '2', TINYINT '-3', SMALLINT '100', SMALLINT '-101', INTEGER '2323', 12345, REAL '0.5', DECIMAL '12345678.12', DECIMAL '12345678.12', DECIMAL '12345678.123456123456', DECIMAL '12345678.123456123456', %s '12345.12345', DOUBLE '12345.12345', DECIMAL '12345.12345', DECIMAL '12345.12345',\n" +
" (TINYINT '-1', TINYINT '2', TINYINT '-3', SMALLINT '100', SMALLINT '-101', INTEGER '2323', 12345, REAL '0.5', DOUBLE '0.5', DECIMAL '12345678.12', DECIMAL '12345678.12', DECIMAL '12345678.123456123456', DECIMAL '12345678.123456123456', %s '12345.12345', DOUBLE '12345.12345', DECIMAL '12345.12345', DECIMAL '12345.12345',\n" +
" CAST(ROW ('as is', -1, 100, 2323, 12345) AS ROW(keep VARCHAR, ti2si TINYINT, si2int SMALLINT, int2bi INTEGER, bi2vc BIGINT)),\n" +
" ARRAY [CAST(ROW (2, -101, 12345, 'removed') AS ROW (ti2int TINYINT, si2bi SMALLINT, bi2vc BIGINT, remove VARCHAR))],\n" +
" MAP (ARRAY [TINYINT '2'], ARRAY [CAST(ROW (-3, 2323, REAL '0.5') AS ROW (ti2bi TINYINT, int2bi INTEGER, float2double %s))]),\n" +
" 1),\n" +
" (TINYINT '1', TINYINT '-2', NULL, SMALLINT '-100', SMALLINT '101', INTEGER '-2323', -12345, REAL '-1.5', DECIMAL '-12345678.12', DECIMAL '-12345678.12', DECIMAL '-12345678.123456123456', DECIMAL '-12345678.123456123456', %s '-12345.12345', DOUBLE '-12345.12345', DECIMAL '-12345.12345', DECIMAL '-12345.12345',\n" +
" (TINYINT '1', TINYINT '-2', NULL, SMALLINT '-100', SMALLINT '101', INTEGER '-2323', -12345, REAL '-1.5', DOUBLE '-1.5', DECIMAL '-12345678.12', DECIMAL '-12345678.12', DECIMAL '-12345678.123456123456', DECIMAL '-12345678.123456123456', %s '-12345.12345', DOUBLE '-12345.12345', DECIMAL '-12345.12345', DECIMAL '-12345.12345',\n" +
" CAST(ROW (NULL, 1, -100, -2323, -12345) AS ROW(keep VARCHAR, ti2si TINYINT, si2int SMALLINT, int2bi INTEGER, bi2vc BIGINT)),\n" +
" ARRAY [CAST(ROW (-2, 101, -12345, NULL) AS ROW (ti2int TINYINT, si2bi SMALLINT, bi2vc BIGINT, remove VARCHAR))],\n" +
" MAP (ARRAY [TINYINT '-2'], ARRAY [CAST(ROW (null, -2323, REAL '-1.5') AS ROW (ti2bi TINYINT, int2bi INTEGER, float2double %s))]),\n" +
Expand All @@ -305,6 +306,7 @@ private void doTestHiveCoercion(HiveTableDefinition tableDefinition)
2323L,
"12345",
0.5,
0.5,
new BigDecimal("12345678.1200"),
new BigDecimal("12345678.1200"),
new BigDecimal("12345678.12"),
Expand All @@ -326,6 +328,7 @@ private void doTestHiveCoercion(HiveTableDefinition tableDefinition)
-2323L,
"-12345",
-1.5,
-1.5,
new BigDecimal("-12345678.1200"),
new BigDecimal("-12345678.1200"),
new BigDecimal("-12345678.12"),
Expand All @@ -350,6 +353,7 @@ else if (usingTeradataJdbcDriver(connection)) {
2323L,
"12345",
0.5,
0.5,
12345678.1200,
12345678.1200,
12345678.12,
Expand All @@ -371,6 +375,7 @@ else if (usingTeradataJdbcDriver(connection)) {
-2323L,
"-12345",
-1.5,
-1.5,
-12345678.1200,
-12345678.1200,
-12345678.12,
Expand All @@ -388,19 +393,19 @@ else if (usingTeradataJdbcDriver(connection)) {
throw new IllegalStateException();
}
// test primitive values
assertThat(queryResult.project(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20)).containsOnly(project(expectedRows, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20));
assertThat(queryResult.project(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 21)).containsOnly(project(expectedRows, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 21));
// test structural values (tempto can't handle map and row)
assertEqualsIgnoreOrder(queryResult.column(17), column(expectedRows, 17), "row_to_row field is not equal");
assertEqualsIgnoreOrder(queryResult.column(18), column(expectedRows, 18), "row_to_row field is not equal");
if (usingPrestoJdbcDriver(connection)) {
assertEqualsIgnoreOrder(extract(queryResult.column(18)), column(expectedRows, 18), "list_to_list field is not equal");
assertEqualsIgnoreOrder(extract(queryResult.column(19)), column(expectedRows, 19), "list_to_list field is not equal");
}
else if (usingTeradataJdbcDriver(connection)) {
assertEqualsIgnoreOrder(queryResult.column(18), column(expectedRows, 18), "list_to_list field is not equal");
assertEqualsIgnoreOrder(queryResult.column(19), column(expectedRows, 19), "list_to_list field is not equal");
}
else {
throw new IllegalStateException();
}
assertEqualsIgnoreOrder(queryResult.column(19), column(expectedRows, 19), "map_to_map field is not equal");
assertEqualsIgnoreOrder(queryResult.column(20), column(expectedRows, 20), "map_to_map field is not equal");
}

private void assertProperAlteredTableSchema(String tableName)
Expand All @@ -416,6 +421,7 @@ private void assertProperAlteredTableSchema(String tableName)
row("int_to_bigint", "bigint"),
row("bigint_to_varchar", "varchar"),
row("float_to_double", "double"),
row("double_to_float", floatType),
row("shortdecimal_to_shortdecimal", "decimal(18,4)"),
row("shortdecimal_to_longdecimal", "decimal(20,4)"),
row("longdecimal_to_shortdecimal", "decimal(12,2)"),
Expand Down Expand Up @@ -445,6 +451,7 @@ private void assertColumnTypes(QueryResult queryResult, String tableName)
BIGINT,
VARCHAR,
DOUBLE,
floatType,
DECIMAL,
DECIMAL,
DECIMAL,
Expand All @@ -468,6 +475,7 @@ else if (usingTeradataJdbcDriver(connection)) {
BIGINT,
VARCHAR,
DOUBLE,
floatType,
DECIMAL,
DECIMAL,
DECIMAL,
Expand Down Expand Up @@ -498,6 +506,7 @@ private static void alterTableColumnTypes(String tableName)
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN int_to_bigint int_to_bigint bigint", tableName));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN bigint_to_varchar bigint_to_varchar string", tableName));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN float_to_double float_to_double double", tableName));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN double_to_float double_to_float %s", tableName, floatType));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN shortdecimal_to_shortdecimal shortdecimal_to_shortdecimal DECIMAL(18,4)", tableName));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN shortdecimal_to_longdecimal shortdecimal_to_longdecimal DECIMAL(20,4)", tableName));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN longdecimal_to_shortdecimal longdecimal_to_shortdecimal DECIMAL(12,2)", tableName));
Expand Down

0 comments on commit 4a456a8

Please sign in to comment.