Skip to content

Commit

Permalink
Account for memory usage of dictionary in SliceDictionaryColumnReader
Browse files Browse the repository at this point in the history
SliceDictionaryColumnReader#dictionaryData could retain hundreds
MB memory, sometime can cause worker OOM, especially in the
situation of small worker and large concurrency.
  • Loading branch information
XuPengfei-1020 authored and sopel39 committed Jan 11, 2022
1 parent fcdf74d commit f722686
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 4 deletions.
Expand Up @@ -185,7 +185,7 @@ private Block readNullBlock(boolean[] isNull, int nonNullCount)
if (nonNullValueTemp.length < minNonNullValueSize) {
nonNullValueTemp = new int[minNonNullValueSize];
nonNullPositionList = new int[minNonNullValueSize];
systemMemoryContext.setBytes(sizeOf(nonNullValueTemp) + sizeOf(nonNullPositionList));
systemMemoryContext.setBytes(getRetainedSizeInBytes());
}

dataStream.next(nonNullValueTemp, nonNullCount);
Expand Down Expand Up @@ -219,6 +219,7 @@ private void setDictionaryBlockData(byte[] dictionaryData, int[] dictionaryOffse
dictionaryOffsets[positionCount] = dictionaryOffsets[positionCount - 1];
dictionaryBlock = new VariableWidthBlock(positionCount, wrappedBuffer(dictionaryData), dictionaryOffsets, Optional.of(isNullVector));
currentDictionaryData = dictionaryData;
systemMemoryContext.setBytes(getRetainedSizeInBytes());
}
}

Expand Down Expand Up @@ -363,6 +364,8 @@ public void close()
@Override
public long getRetainedSizeInBytes()
{
return INSTANCE_SIZE;
return INSTANCE_SIZE + sizeOf(nonNullValueTemp) + sizeOf(nonNullPositionList) + sizeOf(dictionaryData)
+ sizeOf(dictionaryLength) + sizeOf(dictionaryOffsetVector)
+ (currentDictionaryData == dictionaryData ? 0 : sizeOf(currentDictionaryData));
}
}
Expand Up @@ -32,6 +32,7 @@
import java.util.HashMap;

import static io.airlift.testing.Assertions.assertGreaterThan;
import static io.airlift.testing.Assertions.assertGreaterThanOrEqual;
import static io.trino.orc.OrcReader.INITIAL_BATCH_SIZE;
import static io.trino.orc.OrcReader.MAX_BATCH_SIZE;
import static io.trino.orc.OrcTester.Format.ORC_12;
Expand Down Expand Up @@ -75,8 +76,8 @@ public void testVarcharTypeWithoutNulls()

// StripeReader memory should increase after reading a block.
assertGreaterThan(reader.getCurrentStripeRetainedSizeInBytes(), stripeReaderRetainedSize);
// There are no local buffers needed.
assertEquals(reader.getStreamReaderRetainedSizeInBytes() - streamReaderRetainedSize, 0L);
// There may be some extra local buffers needed for dictionary data.
assertGreaterThanOrEqual(reader.getStreamReaderRetainedSizeInBytes(), streamReaderRetainedSize);
// The total retained size and system memory usage should be greater than 0 byte because of the instance sizes.
assertGreaterThan(reader.getRetainedSizeInBytes() - readerRetainedSize, 0L);
assertGreaterThan(reader.getSystemMemoryUsage() - readerSystemMemoryUsage, 0L);
Expand Down
@@ -0,0 +1,133 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.orc;

import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slices;
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.orc.metadata.Footer;
import io.trino.orc.metadata.OrcMetadataReader;
import io.trino.orc.metadata.StripeInformation;
import io.trino.orc.reader.SliceDictionaryColumnReader;
import org.testng.annotations.Test;

import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Random;

import static com.google.common.io.Files.createTempDir;
import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static io.trino.orc.OrcTester.writeOrcColumnTrino;
import static io.trino.orc.metadata.CompressionKind.NONE;
import static io.trino.orc.metadata.PostScript.HiveWriterVersion.ORIGINAL;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static java.nio.file.Files.readAllBytes;
import static java.time.ZoneOffset.UTC;
import static java.util.UUID.randomUUID;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;

public class TestSliceDictionaryColumnReader
{
public static final int ROWS = 100_000;
private static final int DICTIONARY = 22;
private static final int MAX_STRING = 19;

@Test
public void testDictionaryReaderUpdatesRetainedSize()
throws Exception
{
// create orc file
List<String> values = createValues();
File temporaryDirectory = createTempDir();
File orcFile = new File(temporaryDirectory, randomUUID().toString());
writeOrcColumnTrino(orcFile, NONE, VARCHAR, values.iterator(), new OrcWriterStats());

// prepare for read
OrcDataSource dataSource = new MemoryOrcDataSource(new OrcDataSourceId(orcFile.getPath()), Slices.wrappedBuffer(readAllBytes(orcFile.toPath())));
OrcReader orcReader = OrcReader.createOrcReader(dataSource, new OrcReaderOptions())
.orElseThrow(() -> new RuntimeException("File is empty"));
Footer footer = orcReader.getFooter();
List<OrcColumn> columns = orcReader.getRootColumn().getNestedColumns();
assertTrue(columns.size() == 1);
StripeReader stripeReader = new StripeReader(
dataSource,
UTC,
Optional.empty(),
footer.getTypes(),
ImmutableSet.copyOf(columns),
footer.getRowsInRowGroup(),
OrcPredicate.TRUE,
ORIGINAL,
new OrcMetadataReader(),
Optional.empty());
AggregatedMemoryContext memoryContext = newSimpleAggregatedMemoryContext();
SliceDictionaryColumnReader columnReader = new SliceDictionaryColumnReader(columns.get(0), memoryContext.newLocalMemoryContext(TestSliceDictionaryColumnReader.class.getSimpleName()), -1, false);

List<StripeInformation> stripeInformations = footer.getStripes();
for (StripeInformation stripeInformation : stripeInformations) {
Stripe stripe = stripeReader.readStripe(stripeInformation, newSimpleAggregatedMemoryContext());
List<RowGroup> rowGroups = stripe.getRowGroups();
columnReader.startStripe(stripe.getFileTimeZone(), stripe.getDictionaryStreamSources(), stripe.getColumnEncodings());

for (RowGroup rowGroup : rowGroups) {
columnReader.startRowGroup(rowGroup.getStreamSources());
columnReader.prepareNextRead(1000);
columnReader.readBlock();
// memory usage check
assertEquals(memoryContext.getBytes(), columnReader.getRetainedSizeInBytes());
}
}

columnReader.close();
assertTrue(memoryContext.getBytes() == 0);
}

private List<String> createValues()
{
Random random = new Random();
List<String> dictionary = createDictionary(random);

List<String> values = new ArrayList<>();
for (int i = 0; i < ROWS; ++i) {
if (random.nextBoolean()) {
values.add(dictionary.get(random.nextInt(dictionary.size())));
}
else {
values.add(null);
}
}
return values;
}

private List<String> createDictionary(Random random)
{
List<String> dictionary = new ArrayList<>();
for (int dictionaryIndex = 0; dictionaryIndex < DICTIONARY; dictionaryIndex++) {
dictionary.add(randomAsciiString(random));
}
return dictionary;
}

private String randomAsciiString(Random random)
{
char[] value = new char[random.nextInt(MAX_STRING)];
for (int i = 0; i < value.length; i++) {
value[i] = (char) random.nextInt(Byte.MAX_VALUE);
}
return new String(value);
}
}

0 comments on commit f722686

Please sign in to comment.