Skip to content

Commit

Permalink
Support reading DWRF flat map
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinwilfong authored and wenleix committed Nov 7, 2018
1 parent 0b93d49 commit 87f7e51
Show file tree
Hide file tree
Showing 49 changed files with 1,646 additions and 291 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Expand Up @@ -412,7 +412,7 @@
<dependency>
<groupId>com.facebook.presto.orc</groupId>
<artifactId>orc-protobuf</artifactId>
<version>6</version>
<version>7</version>
</dependency>

<dependency>
Expand Down
Expand Up @@ -56,8 +56,10 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.SortedMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.facebook.presto.orc.OrcWriteValidation.OrcWriteValidationMode.BOTH;
Expand Down Expand Up @@ -217,7 +219,7 @@ public void validateStripeStatistics(OrcDataSourceId orcDataSourceId, long strip
validateColumnStatisticsEquivalent(orcDataSourceId, "Stripe at " + stripeOffset, actual, expected.getColumnStatistics());
}

public void validateRowGroupStatistics(OrcDataSourceId orcDataSourceId, long stripeOffset, Map<Integer, List<RowGroupIndex>> actualRowGroupStatistics)
public void validateRowGroupStatistics(OrcDataSourceId orcDataSourceId, long stripeOffset, Map<StreamId, List<RowGroupIndex>> actualRowGroupStatistics)
throws OrcCorruptionException
{
requireNonNull(actualRowGroupStatistics, "actualRowGroupStatistics is null");
Expand All @@ -227,7 +229,11 @@ public void validateRowGroupStatistics(OrcDataSourceId orcDataSourceId, long str
}

int rowGroupCount = expectedRowGroupStatistics.size();
for (Entry<Integer, List<RowGroupIndex>> entry : actualRowGroupStatistics.entrySet()) {
for (Entry<StreamId, List<RowGroupIndex>> entry : actualRowGroupStatistics.entrySet()) {
// TODO: Remove once the Presto writer supports flat map
if (entry.getKey().getSequence() > 0) {
throw new OrcCorruptionException(orcDataSourceId, "Unexpected sequence ID for column %s at offset %s", entry.getKey().getColumn(), stripeOffset);
}
if (entry.getValue().size() != rowGroupCount) {
throw new OrcCorruptionException(orcDataSourceId, "Unexpected row group count stripe in at offset %s", stripeOffset);
}
Expand All @@ -237,14 +243,15 @@ public void validateRowGroupStatistics(OrcDataSourceId orcDataSourceId, long str
RowGroupStatistics expectedRowGroup = expectedRowGroupStatistics.get(rowGroupIndex);
if (expectedRowGroup.getValidationMode() != HASHED) {
Map<Integer, ColumnStatistics> expectedStatistics = expectedRowGroup.getColumnStatistics();
if (!expectedStatistics.keySet().equals(actualRowGroupStatistics.keySet())) {
Set<Integer> actualColumns = actualRowGroupStatistics.keySet().stream()
.map(StreamId::getColumn)
.collect(Collectors.toSet());
if (!expectedStatistics.keySet().equals(actualColumns)) {
throw new OrcCorruptionException(orcDataSourceId, "Unexpected column in row group %s in stripe at offset %s", rowGroupIndex, stripeOffset);
}
for (Entry<Integer, ColumnStatistics> entry : expectedStatistics.entrySet()) {
int columnIndex = entry.getKey();
List<RowGroupIndex> actualRowGroup = actualRowGroupStatistics.get(columnIndex);
ColumnStatistics actual = actualRowGroup.get(rowGroupIndex).getColumnStatistics();
ColumnStatistics expected = entry.getValue();
for (Entry<StreamId, List<RowGroupIndex>> entry : actualRowGroupStatistics.entrySet()) {
ColumnStatistics actual = entry.getValue().get(rowGroupIndex).getColumnStatistics();
ColumnStatistics expected = expectedStatistics.get(entry.getKey().getColumn());
validateColumnStatisticsEquivalent(orcDataSourceId, "Row group " + rowGroupIndex + " in stripe at offset " + stripeOffset, actual, expected);
}
}
Expand All @@ -258,13 +265,13 @@ public void validateRowGroupStatistics(OrcDataSourceId orcDataSourceId, long str
}
}

private static RowGroupStatistics buildActualRowGroupStatistics(int rowGroupIndex, Map<Integer, List<RowGroupIndex>> actualRowGroupStatistics)
private static RowGroupStatistics buildActualRowGroupStatistics(int rowGroupIndex, Map<StreamId, List<RowGroupIndex>> actualRowGroupStatistics)
{
return new RowGroupStatistics(
BOTH,
IntStream.range(1, actualRowGroupStatistics.size() + 1)
.boxed()
.collect(toImmutableMap(identity(), columnIndex -> actualRowGroupStatistics.get(columnIndex).get(rowGroupIndex).getColumnStatistics())));
actualRowGroupStatistics.entrySet()
.stream()
.collect(Collectors.toMap(entry -> entry.getKey().getColumn(), entry -> entry.getValue().get(rowGroupIndex).getColumnStatistics())));
}

public void validateRowGroupStatistics(
Expand Down
Expand Up @@ -18,22 +18,30 @@

import java.util.List;

import static com.facebook.presto.orc.metadata.ColumnEncoding.DEFAULT_SEQUENCE_ID;
import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;

public final class StreamDescriptor
{
private final String streamName;
private final int streamId;
private final int sequence;
private final OrcTypeKind streamType;
private final String fieldName;
private final OrcDataSource orcDataSource;
private final List<StreamDescriptor> nestedStreams;

public StreamDescriptor(String streamName, int streamId, String fieldName, OrcTypeKind streamType, OrcDataSource orcDataSource, List<StreamDescriptor> nestedStreams)
{
this(streamName, streamId, fieldName, streamType, orcDataSource, nestedStreams, DEFAULT_SEQUENCE_ID);
}

public StreamDescriptor(String streamName, int streamId, String fieldName, OrcTypeKind streamType, OrcDataSource orcDataSource, List<StreamDescriptor> nestedStreams, int sequence)
{
this.streamName = requireNonNull(streamName, "streamName is null");
this.streamId = streamId;
this.sequence = sequence;
this.fieldName = requireNonNull(fieldName, "fieldName is null");
this.streamType = requireNonNull(streamType, "type is null");
this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null");
Expand All @@ -50,6 +58,11 @@ public int getStreamId()
return streamId;
}

public int getSequence()
{
return sequence;
}

public OrcTypeKind getStreamType()
{
return streamType;
Expand Down Expand Up @@ -81,6 +94,7 @@ public String toString()
return toStringHelper(this)
.add("streamName", streamName)
.add("streamId", streamId)
.add("sequence", sequence)
.add("streamType", streamType)
.add("dataSource", orcDataSource.getId())
.toString();
Expand Down
17 changes: 14 additions & 3 deletions presto-orc/src/main/java/com/facebook/presto/orc/StreamId.java
Expand Up @@ -16,22 +16,27 @@
import com.facebook.presto.orc.metadata.Stream;
import com.facebook.presto.orc.metadata.Stream.StreamKind;

import java.util.Objects;

import static com.google.common.base.MoreObjects.toStringHelper;

public final class StreamId
{
private final int column;
private final int sequence;
private final StreamKind streamKind;

public StreamId(Stream stream)
{
this.column = stream.getColumn();
this.sequence = stream.getSequence();
this.streamKind = stream.getStreamKind();
}

public StreamId(int column, StreamKind streamKind)
public StreamId(int column, int sequence, StreamKind streamKind)
{
this.column = column;
this.sequence = sequence;
this.streamKind = streamKind;
}

Expand All @@ -40,6 +45,11 @@ public int getColumn()
return column;
}

public int getSequence()
{
return sequence;
}

public StreamKind getStreamKind()
{
return streamKind;
Expand All @@ -48,7 +58,7 @@ public StreamKind getStreamKind()
@Override
public int hashCode()
{
return 31 * column + streamKind.hashCode();
return Objects.hash(column, sequence, streamKind);
}

@Override
Expand All @@ -62,14 +72,15 @@ public boolean equals(Object obj)
}

StreamId other = (StreamId) obj;
return column == other.column && streamKind == other.streamKind;
return column == other.column && sequence == other.sequence && streamKind == other.streamKind;
}

@Override
public String toString()
{
return toStringHelper(this)
.add("column", column)
.add("sequence", sequence)
.add("streamKind", streamKind)
.toString();
}
Expand Down
63 changes: 43 additions & 20 deletions presto-orc/src/main/java/com/facebook/presto/orc/StripeReader.java
Expand Up @@ -18,6 +18,7 @@
import com.facebook.presto.orc.checkpoint.StreamCheckpoint;
import com.facebook.presto.orc.metadata.ColumnEncoding;
import com.facebook.presto.orc.metadata.ColumnEncoding.ColumnEncodingKind;
import com.facebook.presto.orc.metadata.DwrfSequenceEncoding;
import com.facebook.presto.orc.metadata.MetadataReader;
import com.facebook.presto.orc.metadata.OrcType;
import com.facebook.presto.orc.metadata.OrcType.OrcTypeKind;
Expand All @@ -44,6 +45,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
Expand Down Expand Up @@ -115,9 +117,20 @@ public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext syste
if (includedOrcColumns.contains(stream.getColumn())) {
streams.put(new StreamId(stream), stream);

ColumnEncodingKind columnEncoding = columnEncodings.get(stream.getColumn()).getColumnEncodingKind();
if (columnEncoding == DICTIONARY && stream.getStreamKind() == StreamKind.IN_DICTIONARY) {
hasRowGroupDictionary = true;
if (stream.getStreamKind() == StreamKind.IN_DICTIONARY) {
ColumnEncoding columnEncoding = columnEncodings.get(stream.getColumn());

if (columnEncoding.getColumnEncodingKind() == DICTIONARY) {
hasRowGroupDictionary = true;
}

Optional<List<DwrfSequenceEncoding>> additionalSequenceEncodings = columnEncoding.getAdditionalSequenceEncodings();
if (additionalSequenceEncodings.isPresent()
&& additionalSequenceEncodings.get().stream()
.map(DwrfSequenceEncoding::getValueEncoding)
.anyMatch(encoding -> encoding.getColumnEncodingKind() == DICTIONARY)) {
hasRowGroupDictionary = true;
}
}
}
}
Expand All @@ -133,10 +146,10 @@ public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext syste
Map<StreamId, OrcInputStream> streamsData = readDiskRanges(stripe.getOffset(), diskRanges, systemMemoryUsage);

// read the bloom filter for each column
Map<Integer, List<HiveBloomFilter>> bloomFilterIndexes = readBloomFilterIndexes(streams, streamsData);
Map<StreamId, List<HiveBloomFilter>> bloomFilterIndexes = readBloomFilterIndexes(streams, streamsData);

// read the row index for each column
Map<Integer, List<RowGroupIndex>> columnIndexes = readColumnIndexes(streams, streamsData, bloomFilterIndexes);
Map<StreamId, List<RowGroupIndex>> columnIndexes = readColumnIndexes(streams, streamsData, bloomFilterIndexes);
if (writeValidation.isPresent()) {
writeValidation.get().validateRowGroupStatistics(orcDataSource.getId(), stripe.getOffset(), columnIndexes);
}
Expand Down Expand Up @@ -262,7 +275,9 @@ private Map<StreamId, ValueInputStream<?>> createValueStreams(Map<StreamId, Stre
for (Entry<StreamId, Stream> entry : streams.entrySet()) {
StreamId streamId = entry.getKey();
Stream stream = entry.getValue();
ColumnEncodingKind columnEncoding = columnEncodings.get(stream.getColumn()).getColumnEncodingKind();
ColumnEncodingKind columnEncoding = columnEncodings.get(stream.getColumn())
.getColumnEncoding(stream.getSequence())
.getColumnEncodingKind();

// skip index and empty streams
if (isIndexStream(stream) || stream.getLength() == 0) {
Expand All @@ -286,7 +301,9 @@ public InputStreamSources createDictionaryStreamSources(Map<StreamId, Stream> st
int column = stream.getColumn();

// only process dictionary streams
ColumnEncodingKind columnEncoding = columnEncodings.get(column).getColumnEncodingKind();
ColumnEncodingKind columnEncoding = columnEncodings.get(column)
.getColumnEncoding(stream.getSequence())
.getColumnEncodingKind();
if (!isDictionary(stream, columnEncoding)) {
continue;
}
Expand All @@ -310,7 +327,7 @@ private List<RowGroup> createRowGroups(
int rowsInStripe,
Map<StreamId, Stream> streams,
Map<StreamId, ValueInputStream<?>> valueStreams,
Map<Integer, List<RowGroupIndex>> columnIndexes,
Map<StreamId, List<RowGroupIndex>> columnIndexes,
Set<Integer> selectedRowGroups,
List<ColumnEncoding> encodings)
throws InvalidCheckpointException
Expand Down Expand Up @@ -373,30 +390,30 @@ static boolean isIndexStream(Stream stream)
return stream.getStreamKind() == ROW_INDEX || stream.getStreamKind() == DICTIONARY_COUNT || stream.getStreamKind() == BLOOM_FILTER || stream.getStreamKind() == BLOOM_FILTER_UTF8;
}

private Map<Integer, List<HiveBloomFilter>> readBloomFilterIndexes(Map<StreamId, Stream> streams, Map<StreamId, OrcInputStream> streamsData)
private Map<StreamId, List<HiveBloomFilter>> readBloomFilterIndexes(Map<StreamId, Stream> streams, Map<StreamId, OrcInputStream> streamsData)
throws IOException
{
ImmutableMap.Builder<Integer, List<HiveBloomFilter>> bloomFilters = ImmutableMap.builder();
ImmutableMap.Builder<StreamId, List<HiveBloomFilter>> bloomFilters = ImmutableMap.builder();
for (Entry<StreamId, Stream> entry : streams.entrySet()) {
Stream stream = entry.getValue();
if (stream.getStreamKind() == BLOOM_FILTER) {
OrcInputStream inputStream = streamsData.get(entry.getKey());
bloomFilters.put(stream.getColumn(), metadataReader.readBloomFilterIndexes(inputStream));
bloomFilters.put(entry.getKey(), metadataReader.readBloomFilterIndexes(inputStream));
}
// TODO: add support for BLOOM_FILTER_UTF8
}
return bloomFilters.build();
}

private Map<Integer, List<RowGroupIndex>> readColumnIndexes(Map<StreamId, Stream> streams, Map<StreamId, OrcInputStream> streamsData, Map<Integer, List<HiveBloomFilter>> bloomFilterIndexes)
private Map<StreamId, List<RowGroupIndex>> readColumnIndexes(Map<StreamId, Stream> streams, Map<StreamId, OrcInputStream> streamsData, Map<StreamId, List<HiveBloomFilter>> bloomFilterIndexes)
throws IOException
{
ImmutableMap.Builder<Integer, List<RowGroupIndex>> columnIndexes = ImmutableMap.builder();
ImmutableMap.Builder<StreamId, List<RowGroupIndex>> columnIndexes = ImmutableMap.builder();
for (Entry<StreamId, Stream> entry : streams.entrySet()) {
Stream stream = entry.getValue();
if (stream.getStreamKind() == ROW_INDEX) {
OrcInputStream inputStream = streamsData.get(entry.getKey());
List<HiveBloomFilter> bloomFilters = bloomFilterIndexes.get(stream.getColumn());
List<HiveBloomFilter> bloomFilters = bloomFilterIndexes.get(entry.getKey());
List<RowGroupIndex> rowGroupIndexes = metadataReader.readRowIndexes(hiveWriterVersion, inputStream);
if (bloomFilters != null && !bloomFilters.isEmpty()) {
ImmutableList.Builder<RowGroupIndex> newRowGroupIndexes = ImmutableList.builder();
Expand All @@ -408,13 +425,13 @@ private Map<Integer, List<RowGroupIndex>> readColumnIndexes(Map<StreamId, Stream
}
rowGroupIndexes = newRowGroupIndexes.build();
}
columnIndexes.put(stream.getColumn(), rowGroupIndexes);
columnIndexes.put(entry.getKey(), rowGroupIndexes);
}
}
return columnIndexes.build();
}

private Set<Integer> selectRowGroups(StripeInformation stripe, Map<Integer, List<RowGroupIndex>> columnIndexes)
private Set<Integer> selectRowGroups(StripeInformation stripe, Map<StreamId, List<RowGroupIndex>> columnIndexes)
{
int rowsInStripe = toIntExact(stripe.getNumberOfRows());
int groupsInStripe = ceil(rowsInStripe, rowsInRowGroup);
Expand All @@ -432,18 +449,24 @@ private Set<Integer> selectRowGroups(StripeInformation stripe, Map<Integer, List
return selectedRowGroups.build();
}

private static Map<Integer, ColumnStatistics> getRowGroupStatistics(OrcType rootStructType, Map<Integer, List<RowGroupIndex>> columnIndexes, int rowGroup)
private static Map<Integer, ColumnStatistics> getRowGroupStatistics(OrcType rootStructType, Map<StreamId, List<RowGroupIndex>> columnIndexes, int rowGroup)
{
requireNonNull(rootStructType, "rootStructType is null");
checkArgument(rootStructType.getOrcTypeKind() == OrcTypeKind.STRUCT);
requireNonNull(columnIndexes, "columnIndexes is null");
checkArgument(rowGroup >= 0, "rowGroup is negative");

Map<Integer, List<ColumnStatistics>> groupedColumnStatistics = new HashMap<>();
for (Entry<StreamId, List<RowGroupIndex>> entry : columnIndexes.entrySet()) {
groupedColumnStatistics.computeIfAbsent(entry.getKey().getColumn(), key -> new ArrayList<>())
.add(entry.getValue().get(rowGroup).getColumnStatistics());
}

ImmutableMap.Builder<Integer, ColumnStatistics> statistics = ImmutableMap.builder();
for (int ordinal = 0; ordinal < rootStructType.getFieldCount(); ordinal++) {
List<RowGroupIndex> rowGroupIndexes = columnIndexes.get(rootStructType.getFieldTypeIndex(ordinal));
if (rowGroupIndexes != null) {
statistics.put(ordinal, rowGroupIndexes.get(rowGroup).getColumnStatistics());
List<ColumnStatistics> columnStatistics = groupedColumnStatistics.get(rootStructType.getFieldTypeIndex(ordinal));
if (columnStatistics != null) {
statistics.put(ordinal, ColumnStatistics.mergeColumnStatistics(columnStatistics));
}
}
return statistics.build();
Expand Down

0 comments on commit 87f7e51

Please sign in to comment.