From aee8f9d6381f8dba0959da17e1af10f670285471 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Thu, 1 Feb 2024 23:08:21 +0530 Subject: [PATCH] GH-37841: [Java] Dictionary decoding not using the compression factory from the ArrowReader (#38371) ### Rationale for this change This PR addresses https://github.com/apache/arrow/issues/37841. ### What changes are included in this PR? Adding compression-based write and read for Dictionary data. ### Are these changes tested? Yes. ### Are there any user-facing changes? No * Closes: #37841 Lead-authored-by: Vibhatha Lakmal Abeykoon Co-authored-by: vibhatha Signed-off-by: David Li --- .../TestArrowReaderWriterWithCompression.java | 206 ++++++++++++++++-- .../apache/arrow/vector/ipc/ArrowReader.java | 2 +- .../apache/arrow/vector/ipc/ArrowWriter.java | 23 +- 3 files changed, 201 insertions(+), 30 deletions(-) diff --git a/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java b/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java index 6104cb1a132e4..af28333746290 100644 --- a/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java +++ b/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java @@ -18,7 +18,9 @@ package org.apache.arrow.compression; import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.nio.channels.Channels; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -27,63 +29,223 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.GenerateSampleData; +import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.compression.CompressionUtil; import org.apache.arrow.vector.compression.NoCompressionCodec; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.ArrowFileReader; import org.apache.arrow.vector.ipc.ArrowFileWriter; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel; +import org.junit.After; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; public class TestArrowReaderWriterWithCompression { - @Test - public void testArrowFileZstdRoundTrip() throws Exception { - // Prepare sample data - final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + private BufferAllocator allocator; + private ByteArrayOutputStream out; + private VectorSchemaRoot root; + + @BeforeEach + public void setup() { + if (allocator == null) { + allocator = new RootAllocator(Integer.MAX_VALUE); + } + out = new ByteArrayOutputStream(); + root = null; + } + + @After + public void tearDown() { + if (root != null) { + root.close(); + } + if (allocator != null) { + allocator.close(); + } + if (out != null) { + out.reset(); + } + + } + + private void createAndWriteArrowFile(DictionaryProvider provider, + CompressionUtil.CodecType codecType) throws IOException { List fields = new ArrayList<>(); fields.add(new Field("col", FieldType.notNullable(new ArrowType.Utf8()), new ArrayList<>())); - VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(fields), allocator); + root = VectorSchemaRoot.create(new Schema(fields), allocator); + final int rowCount = 10; GenerateSampleData.generateTestData(root.getVector(0), rowCount); root.setRowCount(rowCount); - // Write an in-memory compressed arrow file - ByteArrayOutputStream out = new ByteArrayOutputStream(); - try (final ArrowFileWriter writer = - new ArrowFileWriter(root, null, Channels.newChannel(out), new HashMap<>(), - IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, CompressionUtil.CodecType.ZSTD, Optional.of(7))) { + try (final ArrowFileWriter writer = new ArrowFileWriter(root, provider, Channels.newChannel(out), + new HashMap<>(), IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, codecType, Optional.of(7))) { writer.start(); writer.writeBatch(); writer.end(); } + } + + private void createAndWriteArrowStream(DictionaryProvider provider, + CompressionUtil.CodecType codecType) throws IOException { + List fields = new ArrayList<>(); + fields.add(new Field("col", FieldType.notNullable(new ArrowType.Utf8()), new ArrayList<>())); + root = VectorSchemaRoot.create(new Schema(fields), allocator); + + final int rowCount = 10; + GenerateSampleData.generateTestData(root.getVector(0), rowCount); + root.setRowCount(rowCount); + + try (final ArrowStreamWriter writer = new ArrowStreamWriter(root, provider, Channels.newChannel(out), + IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, codecType, Optional.of(7))) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + } - // Read the in-memory compressed arrow file with CommonsCompressionFactory provided + private Dictionary createDictionary(VarCharVector dictionaryVector) { + setVector(dictionaryVector, + "foo".getBytes(StandardCharsets.UTF_8), + "bar".getBytes(StandardCharsets.UTF_8), + "baz".getBytes(StandardCharsets.UTF_8)); + + return new Dictionary(dictionaryVector, + new DictionaryEncoding(/*id=*/1L, /*ordered=*/false, /*indexType=*/null)); + } + + @Test + public void testArrowFileZstdRoundTrip() throws Exception { + createAndWriteArrowFile(null, CompressionUtil.CodecType.ZSTD); + // with compression + try (ArrowFileReader reader = + new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator, + CommonsCompressionFactory.INSTANCE)) { + Assertions.assertEquals(1, reader.getRecordBlocks().size()); + Assertions.assertTrue(reader.loadNextBatch()); + Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot())); + Assertions.assertFalse(reader.loadNextBatch()); + } + // without compression try (ArrowFileReader reader = - new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), - allocator, CommonsCompressionFactory.INSTANCE)) { - Assert.assertEquals(1, reader.getRecordBlocks().size()); + new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator, + NoCompressionCodec.Factory.INSTANCE)) { + Assertions.assertEquals(1, reader.getRecordBlocks().size()); + Exception exception = Assert.assertThrows(IllegalArgumentException.class, + reader::loadNextBatch); + Assertions.assertEquals("Please add arrow-compression module to use CommonsCompressionFactory for ZSTD", + exception.getMessage()); + } + } + + @Test + public void testArrowStreamZstdRoundTrip() throws Exception { + createAndWriteArrowStream(null, CompressionUtil.CodecType.ZSTD); + // with compression + try (ArrowStreamReader reader = + new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator, + CommonsCompressionFactory.INSTANCE)) { Assert.assertTrue(reader.loadNextBatch()); Assert.assertTrue(root.equals(reader.getVectorSchemaRoot())); Assert.assertFalse(reader.loadNextBatch()); } + // without compression + try (ArrowStreamReader reader = + new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator, + NoCompressionCodec.Factory.INSTANCE)) { + Exception exception = Assert.assertThrows(IllegalArgumentException.class, + reader::loadNextBatch); + Assert.assertEquals( + "Please add arrow-compression module to use CommonsCompressionFactory for ZSTD", + exception.getMessage() + ); + } + } - // Read the in-memory compressed arrow file without CompressionFactory provided + @Test + public void testArrowFileZstdRoundTripWithDictionary() throws Exception { + VarCharVector dictionaryVector = (VarCharVector) + FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector("f1_file", allocator, null); + Dictionary dictionary = createDictionary(dictionaryVector); + DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + provider.put(dictionary); + + createAndWriteArrowFile(provider, CompressionUtil.CodecType.ZSTD); + + // with compression + try (ArrowFileReader reader = + new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator, + CommonsCompressionFactory.INSTANCE)) { + Assertions.assertEquals(1, reader.getRecordBlocks().size()); + Assertions.assertTrue(reader.loadNextBatch()); + Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot())); + Assertions.assertFalse(reader.loadNextBatch()); + } + // without compression try (ArrowFileReader reader = - new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), - allocator, NoCompressionCodec.Factory.INSTANCE)) { - Assert.assertEquals(1, reader.getRecordBlocks().size()); + new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator, + NoCompressionCodec.Factory.INSTANCE)) { + Assertions.assertEquals(1, reader.getRecordBlocks().size()); + Exception exception = Assert.assertThrows(IllegalArgumentException.class, + reader::loadNextBatch); + Assertions.assertEquals("Please add arrow-compression module to use CommonsCompressionFactory for ZSTD", + exception.getMessage()); + } + dictionaryVector.close(); + } + + @Test + public void testArrowStreamZstdRoundTripWithDictionary() throws Exception { + VarCharVector dictionaryVector = (VarCharVector) + FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector("f1_stream", allocator, null); + Dictionary dictionary = createDictionary(dictionaryVector); + DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + provider.put(dictionary); + + createAndWriteArrowStream(provider, CompressionUtil.CodecType.ZSTD); + + // with compression + try (ArrowStreamReader reader = + new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator, + CommonsCompressionFactory.INSTANCE)) { + Assertions.assertTrue(reader.loadNextBatch()); + Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot())); + Assertions.assertFalse(reader.loadNextBatch()); + } + // without compression + try (ArrowStreamReader reader = + new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator, + NoCompressionCodec.Factory.INSTANCE)) { + Exception exception = Assert.assertThrows(IllegalArgumentException.class, + reader::loadNextBatch); + Assertions.assertEquals("Please add arrow-compression module to use CommonsCompressionFactory for ZSTD", + exception.getMessage()); + } + dictionaryVector.close(); + } - Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> reader.loadNextBatch()); - String expectedMessage = "Please add arrow-compression module to use CommonsCompressionFactory for ZSTD"; - Assert.assertEquals(expectedMessage, exception.getMessage()); + public static void setVector(VarCharVector vector, byte[]... values) { + final int length = values.length; + vector.allocateNewSafe(); + for (int i = 0; i < length; i++) { + if (values[i] != null) { + vector.set(i, values[i]); + } } + vector.setValueCount(length); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java index 04c57d7e82fef..01f4e925c69b3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java @@ -251,7 +251,7 @@ private void load(ArrowDictionaryBatch dictionaryBatch, FieldVector vector) { VectorSchemaRoot root = new VectorSchemaRoot( Collections.singletonList(vector.getField()), Collections.singletonList(vector), 0); - VectorLoader loader = new VectorLoader(root); + VectorLoader loader = new VectorLoader(root, this.compressionFactory); try { loader.load(dictionaryBatch.getDictionary()); } finally { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java index a33c55de53f23..1cc201ae56f4b 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java @@ -61,9 +61,14 @@ public abstract class ArrowWriter implements AutoCloseable { private final DictionaryProvider dictionaryProvider; private final Set dictionaryIdsUsed = new HashSet<>(); + private final CompressionCodec.Factory compressionFactory; + private final CompressionUtil.CodecType codecType; + private final Optional compressionLevel; private boolean started = false; private boolean ended = false; + private final CompressionCodec codec; + protected IpcOption option; protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { @@ -89,16 +94,19 @@ protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, Writab protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, IpcOption option, CompressionCodec.Factory compressionFactory, CompressionUtil.CodecType codecType, Optional compressionLevel) { - this.unloader = new VectorUnloader( - root, /*includeNullCount*/ true, - compressionLevel.isPresent() ? - compressionFactory.createCodec(codecType, compressionLevel.get()) : - compressionFactory.createCodec(codecType), - /*alignBuffers*/ true); this.out = new WriteChannel(out); this.option = option; this.dictionaryProvider = provider; + this.compressionFactory = compressionFactory; + this.codecType = codecType; + this.compressionLevel = compressionLevel; + this.codec = this.compressionLevel.isPresent() ? + this.compressionFactory.createCodec(this.codecType, this.compressionLevel.get()) : + this.compressionFactory.createCodec(this.codecType); + this.unloader = new VectorUnloader(root, /*includeNullCount*/ true, codec, + /*alignBuffers*/ true); + List fields = new ArrayList<>(root.getSchema().getFields().size()); MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), option.metadataVersion); @@ -133,7 +141,8 @@ protected void writeDictionaryBatch(Dictionary dictionary) throws IOException { Collections.singletonList(vector.getField()), Collections.singletonList(vector), count); - VectorUnloader unloader = new VectorUnloader(dictRoot); + VectorUnloader unloader = new VectorUnloader(dictRoot, /*includeNullCount*/ true, this.codec, + /*alignBuffers*/ true); ArrowRecordBatch batch = unloader.getRecordBatch(); ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, false); try {