Skip to content

Commit

Permalink
Support backwards compatible reads for unnanotated repeated primitive
Browse files Browse the repository at this point in the history
fields in Parquet

Per the Parquet Spec `A repeated field that is neither contained by a
LIST- or MAP-annotated group nor annotated by LIST or MAP should be
interpreted as a required list of required elements where the element
type is the type of the field`, however Trino currently throws an error:
`class org.apache.parquet.io.PrimitiveColumnIO cannot be cast to class
org.apache.parquet.io.GroupColumnIO
(org.apache.parquet.io.PrimitiveColumnIO and
org.apache.parquet.io.GroupColumnIO`. This commit adds support for these
backwards compatible reads.
  • Loading branch information
mxmarkovics committed Mar 6, 2024
1 parent dc73c85 commit e48dcb7
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH;
import static java.lang.String.format;
import static org.apache.parquet.schema.Type.Repetition.OPTIONAL;
import static org.apache.parquet.schema.Type.Repetition.REPEATED;
Expand Down Expand Up @@ -354,14 +355,38 @@ public static Optional<Field> constructField(Type type, ColumnIO columnIO)
return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(keyField, valueField)));
}
if (type instanceof ArrayType arrayType) {
GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO;
if (groupColumnIO.getChildrenCount() != 1) {
return Optional.empty();
// Per the parquet spec (https://github.com/apache/parquet-format/blob/master/LogicalTypes.md):
// `A repeated field that is neither contained by a LIST- or MAP-annotated group nor annotated by LIST or MAP should be interpreted as a required list of required elements
// where the element type is the type of the field.`
//
// A parquet encoding for a required list of strings can be expressed in two ways, however for backwards compatibility they should be handled the same, so here we need
// to adjust repetition and definition levels when converting ColumnIOs to Fields.
// 1. required group colors (LIST) {
// repeated group list {
// required string element;
// }
// }
// 2. repeated binary colors (STRING);
if (columnIO.getType().getRepetition() == REPEATED && columnIO instanceof PrimitiveColumnIO && repetitionLevel > 0 && definitionLevel > 0) {
PrimitiveColumnIO primitiveColumnIO = (PrimitiveColumnIO) columnIO;
PrimitiveField primitiveFieldElement = new PrimitiveField(((ArrayType) type).getElementType(), true, primitiveColumnIO.getColumnDescriptor(), primitiveColumnIO.getId());
return Optional.of(new GroupField(type, repetitionLevel - 1, definitionLevel - 1, true, ImmutableList.of(Optional.of(primitiveFieldElement))));
}
else {
GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO;
if (groupColumnIO.getChildrenCount() != 1) {
return Optional.empty();
}
Optional<Field> field = constructField(arrayType.getElementType(), getArrayElementColumn(groupColumnIO.getChild(0)));
return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(field)));
}
Optional<Field> field = constructField(arrayType.getElementType(), getArrayElementColumn(groupColumnIO.getChild(0)));
return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(field)));
}
PrimitiveColumnIO primitiveColumnIO = (PrimitiveColumnIO) columnIO;
if (primitiveColumnIO.getType().getRepetition() == REPEATED && primitiveColumnIO.getParent() instanceof MessageColumnIO) {
ColumnDescriptor columnDescriptor = primitiveColumnIO.getColumnDescriptor();
throw new TrinoException(TYPE_MISMATCH, format("Repeated field %s's type %s in a parquet file is incompatible with type %s defined in table schema",
columnDescriptor.getPrimitiveType().getName(), columnDescriptor.getPrimitiveType().getPrimitiveTypeName(), type.getDisplayName()));
}
return Optional.of(new PrimitiveField(type, required, primitiveColumnIO.getColumnDescriptor(), primitiveColumnIO.getId()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ private static <T> ColumnReader createColumnReader(

private static boolean isFlatColumn(PrimitiveField field)
{
return field.getDescriptor().getPath().length == 1;
return field.getDescriptor().getPath().length == 1 && field.getRepetitionLevel() == 0;
}

private static boolean isLogicalUuid(LogicalTypeAnnotation annotation)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.parquet.reader;

import io.trino.parquet.PrimitiveField;
import io.trino.parquet.reader.flat.FlatColumnReader;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.schema.PrimitiveType;
import org.testng.annotations.Test;


import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static io.trino.spi.type.IntegerType.INTEGER;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32;
import static org.apache.parquet.schema.Type.Repetition.OPTIONAL;
import static org.assertj.core.api.Assertions.assertThat;
import static org.joda.time.DateTimeZone.UTC;

public class TestColumnReaderFactory
{
@Test
public void testTopLevelPrimitiveFields()
{
ColumnReaderFactory columnReaderFactory = new ColumnReaderFactory(UTC);
PrimitiveType primitiveType = new PrimitiveType(OPTIONAL, INT32, "test");

PrimitiveField topLevelRepeatedPrimitiveField = new PrimitiveField(
INTEGER,
true,
new ColumnDescriptor(new String[] {"topLevelRepeatedPrimitiveField test"}, primitiveType, 1, 1),
0);
assertThat(columnReaderFactory.create(topLevelRepeatedPrimitiveField, newSimpleAggregatedMemoryContext())).isInstanceOf(NestedColumnReader.class);

PrimitiveField topLevelOptionalPrimitiveField = new PrimitiveField(
INTEGER,
false,
new ColumnDescriptor(new String[] {"topLevelRequiredPrimitiveField test"}, primitiveType, 0, 1),
0);
assertThat(columnReaderFactory.create(topLevelOptionalPrimitiveField, newSimpleAggregatedMemoryContext())).isInstanceOf(FlatColumnReader.class);

PrimitiveField topLevelRequiredPrimitiveField = new PrimitiveField(
INTEGER,
true,
new ColumnDescriptor(new String[] {"topLevelRequiredPrimitiveField test"}, primitiveType, 0, 0),
0);
assertThat(columnReaderFactory.create(topLevelRequiredPrimitiveField, newSimpleAggregatedMemoryContext())).isInstanceOf(FlatColumnReader.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import com.google.common.collect.DiscreteDomain;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Range;
import com.google.common.io.Resources;
import com.google.common.primitives.Shorts;
import io.airlift.units.DataSize;
import io.trino.plugin.hive.HiveQueryRunner;
import io.trino.plugin.hive.HiveTimestampPrecision;
import io.trino.spi.TrinoException;
import io.trino.spi.type.ArrayType;
Expand All @@ -32,6 +34,9 @@
import io.trino.spi.type.SqlTimestamp;
import io.trino.spi.type.SqlVarbinary;
import io.trino.spi.type.Type;
import io.trino.sql.query.QueryAssertions;
import io.trino.testing.DistributedQueryRunner;
import io.trino.testing.QueryFailedException;
import org.apache.hadoop.hive.common.type.Date;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.common.type.Timestamp;
Expand All @@ -43,6 +48,7 @@
import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Test;

import java.io.File;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -95,6 +101,7 @@
import static io.trino.spi.type.VarcharType.createUnboundedVarcharType;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.testing.DateTimeTestingUtils.sqlTimestampOf;
import static io.trino.testing.MaterializedResult.resultBuilder;
import static io.trino.testing.StructuralTestUtil.mapType;
import static java.lang.Math.floorDiv;
import static java.lang.Math.floorMod;
Expand Down Expand Up @@ -1988,6 +1995,85 @@ public void testMapMaxReadBytes()
tester.testMaxReadBytes(getStandardMapObjectInspector(javaStringObjectInspector, javaLongObjectInspector), values, values, mapType(VARCHAR, BIGINT), maxReadBlockSize);
}

@Test
public void testBackwardsCompatibleRepeatedStringField()
throws Exception
{
File parquetFile = new File(Resources.getResource("parquet_repeated_primitives/string/").toURI());
try (DistributedQueryRunner queryRunner = HiveQueryRunner.builder().build();
QueryAssertions assertions = new QueryAssertions(queryRunner)) {
queryRunner.execute(format("""
CREATE TABLE table_with_repeated_primitive_string (
myString array<varchar>)
WITH (
external_location = '%s',
format = 'PARQUET')
""",
parquetFile.getAbsolutePath()));

assertThat(assertions.query("SELECT myString FROM table_with_repeated_primitive_string"))
.result().matches(resultBuilder(queryRunner.getDefaultSession(), new ArrayType(VARCHAR))
.row(Arrays.asList("hello", "world"))
.row(Arrays.asList("good", "bye"))
.row(Arrays.asList("one", "two", "three"))
.build());
queryRunner.execute("DROP TABLE table_with_repeated_primitive_string");
}
}

@Test
public void testBackwardsCompatibleRepeatedIntegerField()
throws Exception
{
File parquetFile = new File(Resources.getResource("parquet_repeated_primitives/int/").toURI());
try (DistributedQueryRunner queryRunner = HiveQueryRunner.builder().build();
QueryAssertions assertions = new QueryAssertions(queryRunner)) {
queryRunner.execute(format("""
CREATE TABLE table_with_repeated_primitive_int (
repeatedInt array<int>)
WITH (
external_location = '%s',
format = 'PARQUET')
""",
parquetFile.getAbsolutePath()));

assertThat(assertions.query("SELECT repeatedInt FROM table_with_repeated_primitive_int"))
.result().matches(resultBuilder(queryRunner.getDefaultSession(), new ArrayType(INTEGER))
.row(Arrays.asList(1, 2, 3))
.build());
queryRunner.execute("DROP TABLE table_with_repeated_primitive_int");
}
}

@Test
public void testBackwardsCompatibleRepeatedPrimitiveFieldDefinedAsPrimitive()
throws Exception
{
DistributedQueryRunner queryRunner = HiveQueryRunner.builder().build();
File parquetFile = new File(Resources.getResource("parquet_repeated_primitives/int/").toURI());
boolean errorThrown = false;
try {
queryRunner.execute(format("""
CREATE TABLE table_with_repeated_primitive_int_bad_table_schema (
repeatedInt int)
WITH (
external_location = '%s',
format = 'PARQUET')
""",
parquetFile.getAbsolutePath()));
queryRunner.execute("SELECT repeatedInt FROM table_with_repeated_primitive_int_bad_table_schema");
}
catch (QueryFailedException e) {
if (e.getMessage().contains("Repeated field repeatedint's type INT32 in a parquet file is incompatible with type integer defined in table schema")) {
errorThrown = true;
}
}
finally {
queryRunner.execute("DROP TABLE IF EXISTS table_with_repeated_primitive_int_bad_table_schema");
}
assert(errorThrown);

Check failure on line 2074 in plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/AbstractTestParquetReader.java

View workflow job for this annotation

GitHub Actions / error-prone-checks

These grouping parentheses are unnecessary; it is unlikely the code will be misinterpreted without them

Check failure on line 2074 in plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/AbstractTestParquetReader.java

View workflow job for this annotation

GitHub Actions / error-prone-checks

These grouping parentheses are unnecessary; it is unlikely the code will be misinterpreted without them
}

private static <T> Iterable<T> repeatEach(int n, Iterable<T> iterable)
{
return () -> new AbstractIterator<>()
Expand Down
Binary file not shown.
Binary file not shown.

0 comments on commit e48dcb7

Please sign in to comment.