From db393ef861b970dcc52210812ef96673c8222d9d Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 11 Oct 2023 17:44:48 +0200 Subject: [PATCH] GH-37910: [Java][Integration] Implement C Data Interface integration testing --- ci/scripts/integration_arrow.sh | 4 +- dev/archery/archery/integration/datagen.py | 1 - .../archery/integration/tester_csharp.py | 9 + .../archery/integration/tester_java.py | 177 +++++++++++++++++- docker-compose.yml | 11 +- .../arrow/c/BufferImportTypeVisitor.java | 4 +- .../main/java/org/apache/arrow/c/Format.java | 4 + .../org/apache/arrow/c/SchemaImporter.java | 2 +- .../org/apache/arrow/vector/NullVector.java | 1 + .../vector/compare/RangeEqualsVisitor.java | 6 +- .../vector/dictionary/DictionaryProvider.java | 26 ++- .../arrow/vector/ipc/JsonFileReader.java | 37 +++- .../vector/ipc/message/ArrowRecordBatch.java | 4 +- .../apache/arrow/vector/util/Validator.java | 26 +++ 14 files changed, 281 insertions(+), 31 deletions(-) diff --git a/ci/scripts/integration_arrow.sh b/ci/scripts/integration_arrow.sh index 289d376a4db9b..2861b1c09d479 100755 --- a/ci/scripts/integration_arrow.sh +++ b/ci/scripts/integration_arrow.sh @@ -23,8 +23,8 @@ arrow_dir=${1} gold_dir=$arrow_dir/testing/data/arrow-ipc-stream/integration pip install -e $arrow_dir/dev/archery[integration] -# For C# C Data Interface testing -pip install pythonnet +# For C Data Interface testing +pip install jpype1 pythonnet # Get more detailed context on crashes export PYTHONFAULTHANDLER=1 diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index 01672fbe7488a..cecf4e386b800 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -1722,7 +1722,6 @@ def generate_dictionary_unsigned_case(): # TODO: JavaScript does not support uint64 dictionary indices, so disabled # for now - # dict3 = Dictionary(3, StringField('dictionary3'), size=5, name='DICT3') fields = [ DictionaryField('f0', get_field('', 'uint8'), dict0), diff --git a/dev/archery/archery/integration/tester_csharp.py b/dev/archery/archery/integration/tester_csharp.py index 83b07495f9907..ea585b19284fb 100644 --- a/dev/archery/archery/integration/tester_csharp.py +++ b/dev/archery/archery/integration/tester_csharp.py @@ -18,6 +18,7 @@ from contextlib import contextmanager import gc import os +import weakref from . import cdata from .tester import Tester, CDataExporter, CDataImporter @@ -72,6 +73,14 @@ def __init__(self, debug, args): self.ffi = cdata.ffi() _load_clr() + def _finalize(): + # Collect GC handles so as to call release functions from other + # exporters before it gets too late. + # TODO make this a run_gc() function? + from Apache.Arrow.IntegrationTest import CDataInterface + CDataInterface.RunGC() + weakref.finalize(self, _finalize) + def _pointer_to_int(self, c_ptr): return int(self.ffi.cast('uintptr_t', c_ptr)) diff --git a/dev/archery/archery/integration/tester_java.py b/dev/archery/archery/integration/tester_java.py index 45855079eb72e..cb994d273c81b 100644 --- a/dev/archery/archery/integration/tester_java.py +++ b/dev/archery/archery/integration/tester_java.py @@ -16,10 +16,12 @@ # under the License. import contextlib +import functools import os import subprocess -from .tester import Tester +from . import cdata +from .tester import Tester, CDataExporter, CDataImporter from .util import run_cmd, log from ..utils.source import ARROW_ROOT_DEFAULT @@ -42,18 +44,25 @@ def load_version_from_pom(): "ARROW_JAVA_INTEGRATION_JAR", os.path.join( ARROW_ROOT_DEFAULT, - "java/tools/target/arrow-tools-{}-" - "jar-with-dependencies.jar".format(_arrow_version), - ), + "java/tools/target", + f"arrow-tools-{_arrow_version}-jar-with-dependencies.jar" + ) +) +_ARROW_C_DATA_JAR = os.environ.get( + "ARROW_C_DATA_JAVA_INTEGRATION_JAR", + os.path.join( + ARROW_ROOT_DEFAULT, + "java/c/target", + f"arrow-c-data-{_arrow_version}.jar" + ) ) _ARROW_FLIGHT_JAR = os.environ.get( "ARROW_FLIGHT_JAVA_INTEGRATION_JAR", os.path.join( ARROW_ROOT_DEFAULT, - "java/flight/flight-integration-tests/target/" - "flight-integration-tests-{}-jar-with-dependencies.jar".format( - _arrow_version), - ), + "java/flight/flight-integration-tests/target", + f"flight-integration-tests-{_arrow_version}-jar-with-dependencies.jar" + ) ) _ARROW_FLIGHT_SERVER = ( "org.apache.arrow.flight.integration.tests.IntegrationTestServer" @@ -63,11 +72,157 @@ def load_version_from_pom(): ) +@functools.lru_cache +def setup_jpype(): + import jpype + jar_path = f"{_ARROW_TOOLS_JAR}:{_ARROW_C_DATA_JAR}" + # XXX Didn't manage to tone down the logging level here (DEBUG -> INFO) + jpype.startJVM(jpype.getDefaultJVMPath(), + "-Djava.class.path=" + jar_path, *_JAVA_OPTS) + + +class _CDataBase: + + def __init__(self, debug, args): + import jpype + self.debug = debug + self.args = args + self.ffi = cdata.ffi() + setup_jpype() + # JPype pointers to java.io, org.apache.arrow... + self.java_io = jpype.JPackage("java").io + self.java_arrow = jpype.JPackage("org").apache.arrow + self.java_allocator = self._make_java_allocator() + + def _pointer_to_int(self, c_ptr): + return int(self.ffi.cast('uintptr_t', c_ptr)) + + def _wrap_c_schema_ptr(self, c_schema_ptr): + return self.java_arrow.c.ArrowSchema.wrap( + self._pointer_to_int(c_schema_ptr)) + + def _wrap_c_array_ptr(self, c_array_ptr): + return self.java_arrow.c.ArrowArray.wrap( + self._pointer_to_int(c_array_ptr)) + + def _make_java_allocator(self): + # Return a new allocator + return self.java_arrow.memory.RootAllocator() + + def _assert_schemas_equal(self, expected, actual): + # XXX This is fragile for dictionaries, as Schema.equals compares + # dictionary ids. + self.java_arrow.vector.util.Validator.compareSchemas( + expected, actual) + + def _assert_batches_equal(self, expected, actual): + self.java_arrow.vector.util.Validator.compareVectorSchemaRoot( + expected, actual) + + def _assert_dict_providers_equal(self, expected, actual): + self.java_arrow.vector.util.Validator.compareDictionaryProviders( + expected, actual) + + +class JavaCDataExporter(CDataExporter, _CDataBase): + + def export_schema_from_json(self, json_path, c_schema_ptr): + json_file = self.java_io.File(json_path) + with self.java_arrow.vector.ipc.JsonFileReader( + json_file, self.java_allocator) as json_reader: + schema = json_reader.start() + dict_provider = json_reader + self.java_arrow.c.Data.exportSchema( + self.java_allocator, schema, dict_provider, + self._wrap_c_schema_ptr(c_schema_ptr) + ) + + def export_batch_from_json(self, json_path, num_batch, c_array_ptr): + json_file = self.java_io.File(json_path) + with self.java_arrow.vector.ipc.JsonFileReader( + json_file, self.java_allocator) as json_reader: + json_reader.start() + if num_batch > 0: + actually_skipped = json_reader.skip(num_batch) + assert actually_skipped == num_batch + with json_reader.read() as batch: + dict_provider = json_reader + self.java_arrow.c.Data.exportVectorSchemaRoot( + self.java_allocator, batch, dict_provider, + self._wrap_c_array_ptr(c_array_ptr)) + + @property + def supports_releasing_memory(self): + return True + + def record_allocation_state(self): + return self.java_allocator.getAllocatedMemory() + + def compare_allocation_state(self, recorded, gc_until): + def pred(): + return self.java_allocator.getAllocatedMemory() == recorded + + return gc_until(pred) + + +class JavaCDataImporter(CDataImporter, _CDataBase): + + def import_schema_and_compare_to_json(self, json_path, c_schema_ptr): + json_file = self.java_io.File(json_path) + with self.java_arrow.vector.ipc.JsonFileReader( + json_file, self.java_allocator) as json_reader: + json_schema = json_reader.start() + with self.java_arrow.c.CDataDictionaryProvider() as dict_provider: + imported_schema = self.java_arrow.c.Data.importSchema( + self.java_allocator, + self._wrap_c_schema_ptr(c_schema_ptr), + dict_provider) + self._assert_schemas_equal(json_schema, imported_schema) + + def import_batch_and_compare_to_json(self, json_path, num_batch, + c_array_ptr): + json_file = self.java_io.File(json_path) + with self.java_arrow.vector.ipc.JsonFileReader( + json_file, self.java_allocator) as json_reader: + schema = json_reader.start() + if num_batch > 0: + actually_skipped = json_reader.skip(num_batch) + assert actually_skipped == num_batch + with (json_reader.read() as batch, + self.java_arrow.vector.VectorSchemaRoot.create( + schema, self.java_allocator) as imported_batch): + # We need to pass a dict provider primed with dictionary ids + # matching those in the schema, hence an empty + # CDataDictionaryProvider would not work here! + dict_provider = (self.java_arrow.vector.dictionary + .DictionaryProvider.MapDictionaryProvider()) + dict_provider.copyStructureFrom(json_reader, self.java_allocator) + with dict_provider: + self.java_arrow.c.Data.importIntoVectorSchemaRoot( + self.java_allocator, + self._wrap_c_array_ptr(c_array_ptr), + imported_batch, dict_provider) + self._assert_batches_equal(batch, imported_batch) + self._assert_dict_providers_equal(json_reader, dict_provider) + + @property + def supports_releasing_memory(self): + return True + + def gc_until(self, predicate): + # No need to call the Java GC thanks to AutoCloseable (?) + return predicate() + + class JavaTester(Tester): PRODUCER = True CONSUMER = True FLIGHT_SERVER = True FLIGHT_CLIENT = True + C_DATA_SCHEMA_EXPORTER = True + C_DATA_SCHEMA_IMPORTER = True + C_DATA_ARRAY_EXPORTER = True + C_DATA_ARRAY_IMPORTER = True name = 'Java' @@ -186,3 +341,9 @@ def flight_server(self, scenario_name=None): finally: server.kill() server.wait(5) + + def make_c_data_exporter(self): + return JavaCDataExporter(self.debug, self.args) + + def make_c_data_importer(self): + return JavaCDataImporter(self.debug, self.args) diff --git a/docker-compose.yml b/docker-compose.yml index 10e2b9fa8e205..75b405d09d511 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1730,16 +1730,21 @@ services: volumes: *conda-volumes environment: <<: [*common, *ccache] - # tell archery where the arrow binaries are located + ARCHERY_INTEGRATION_WITH_RUST: 0 + # Tell Archery where the arrow C++ binaries are located ARROW_CPP_EXE_PATH: /build/cpp/debug ARROW_GO_INTEGRATION: 1 - ARCHERY_INTEGRATION_WITH_RUST: 0 + ARROW_JAVA_CDATA: "ON" + JAVA_JNI_CMAKE_ARGS: >- + -DARROW_JAVA_JNI_ENABLE_DEFAULT=OFF + -DARROW_JAVA_JNI_ENABLE_C=ON command: ["/arrow/ci/scripts/rust_build.sh /arrow /build && /arrow/ci/scripts/cpp_build.sh /arrow /build && /arrow/ci/scripts/csharp_build.sh /arrow /build && /arrow/ci/scripts/go_build.sh /arrow && - /arrow/ci/scripts/java_build.sh /arrow /build && + /arrow/ci/scripts/java_jni_build.sh /arrow $${ARROW_HOME} /build /tmp/dist/java/$$(arch) && + /arrow/ci/scripts/java_build.sh /arrow /build /tmp/dist/java && /arrow/ci/scripts/js_build.sh /arrow /build && /arrow/ci/scripts/integration_arrow.sh /arrow /build"] diff --git a/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java b/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java index 7408bf71136fa..cd2a464f4fa17 100644 --- a/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java +++ b/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java @@ -165,9 +165,9 @@ public List visit(ArrowType.Union type) { return Collections.singletonList(importFixedBytes(type, 0, UnionVector.TYPE_WIDTH)); case Dense: return Arrays.asList(importFixedBytes(type, 0, DenseUnionVector.TYPE_WIDTH), - importFixedBytes(type, 0, DenseUnionVector.OFFSET_WIDTH)); + importFixedBytes(type, 1, DenseUnionVector.OFFSET_WIDTH)); default: - throw new UnsupportedOperationException("Importing buffers for type: " + type); + throw new UnsupportedOperationException("Importing buffers for union type: " + type); } } diff --git a/java/c/src/main/java/org/apache/arrow/c/Format.java b/java/c/src/main/java/org/apache/arrow/c/Format.java index 315d3caad7da2..2875e46f749c4 100644 --- a/java/c/src/main/java/org/apache/arrow/c/Format.java +++ b/java/c/src/main/java/org/apache/arrow/c/Format.java @@ -138,6 +138,8 @@ static String asString(ArrowType arrowType) { return "tiD"; case YEAR_MONTH: return "tiM"; + case MONTH_DAY_NANO: + return "tin"; default: throw new UnsupportedOperationException( String.format("Interval type with unit %s is unsupported", type.getUnit())); @@ -277,6 +279,8 @@ static ArrowType asType(String format, long flags) return new ArrowType.Interval(IntervalUnit.YEAR_MONTH); case "tiD": return new ArrowType.Interval(IntervalUnit.DAY_TIME); + case "tin": + return new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO); case "+l": return new ArrowType.List(); case "+L": diff --git a/java/c/src/main/java/org/apache/arrow/c/SchemaImporter.java b/java/c/src/main/java/org/apache/arrow/c/SchemaImporter.java index 21d88f6cd4ba5..09a6afafa0a46 100644 --- a/java/c/src/main/java/org/apache/arrow/c/SchemaImporter.java +++ b/java/c/src/main/java/org/apache/arrow/c/SchemaImporter.java @@ -44,7 +44,7 @@ final class SchemaImporter { private static final Logger logger = LoggerFactory.getLogger(SchemaImporter.class); private static final int MAX_IMPORT_RECURSION_LEVEL = 64; - private long nextDictionaryID = 1L; + private long nextDictionaryID = 0L; private final BufferAllocator allocator; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/NullVector.java b/java/vector/src/main/java/org/apache/arrow/vector/NullVector.java index 6e4c2764bdcc4..1badf4b4ca808 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/NullVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/NullVector.java @@ -192,6 +192,7 @@ public List getChildrenFromFields() { @Override public void loadFieldBuffers(ArrowFieldNode fieldNode, List ownBuffers) { Preconditions.checkArgument(ownBuffers.isEmpty(), "Null vector has no buffers"); + valueCount = fieldNode.getLength(); } @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java index 698ddac466041..5323ddda838c8 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java @@ -121,9 +121,11 @@ public boolean rangeEquals(Range range) { "rightStart %s must be non negative.", range.getRightStart()); Preconditions.checkArgument(range.getRightStart() + range.getLength() <= right.getValueCount(), - "(rightStart + length) %s out of range[0, %s].", 0, right.getValueCount()); + "(rightStart + length) %s out of range[0, %s].", + range.getRightStart() + range.getLength(), right.getValueCount()); Preconditions.checkArgument(range.getLeftStart() + range.getLength() <= left.getValueCount(), - "(leftStart + length) %s out of range[0, %s].", 0, left.getValueCount()); + "(leftStart + length) %s out of range[0, %s].", + range.getLeftStart() + range.getLength(), left.getValueCount()); return left.accept(this, range); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java index 76e1eb9f66d25..06019904fc6de 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java @@ -21,6 +21,8 @@ import java.util.Map; import java.util.Set; +import org.apache.arrow.memory.BufferAllocator; + /** * A manager for association of dictionary IDs to their corresponding {@link Dictionary}. */ @@ -35,7 +37,7 @@ public interface DictionaryProvider { /** * Implementation of {@link DictionaryProvider} that is backed by a hash-map. */ - class MapDictionaryProvider implements DictionaryProvider { + class MapDictionaryProvider implements AutoCloseable, DictionaryProvider { private final Map map; @@ -49,6 +51,21 @@ public MapDictionaryProvider(Dictionary... dictionaries) { } } + /** + * Initialize the map structure from another provider, but with empty vectors. + * + * @param other the {@link DictionaryProvider} to copy the ids and fields from + * @param allocator allocator to create the empty vectors + */ + public void copyStructureFrom(DictionaryProvider other, BufferAllocator allocator) { + for (Long id : other.getDictionaryIds()) { + Dictionary otherDict = other.lookup(id); + Dictionary newDict = new Dictionary(otherDict.getVector().getField().createVector(allocator), + otherDict.getEncoding()); + put(newDict); + } + } + public void put(Dictionary dictionary) { map.put(dictionary.getEncoding().getId(), dictionary); } @@ -62,5 +79,12 @@ public final Set getDictionaryIds() { public Dictionary lookup(long id) { return map.get(id); } + + @Override + public void close() { + for (Dictionary dictionary : map.values()) { + dictionary.getVector().close(); + } + } } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java index 742daeef255f8..06b53580a6927 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java @@ -237,6 +237,27 @@ public VectorSchemaRoot read() throws IOException { } } + /** + * Skips a number of record batches in the file. + * + * @param numBatches the number of batches to skip + * @return the actual number of skipped batches. + */ + public int skip(int numBatches) throws IOException { + for (int i = 0; i < numBatches; ++i) { + JsonToken t = parser.nextToken(); + if (t == START_OBJECT) { + parser.skipChildren(); + assert parser.getCurrentToken() == END_OBJECT; + } else if (t == END_ARRAY) { + return i; + } else { + throw new IllegalArgumentException("Invalid token: " + t); + } + } + return numBatches; + } + private abstract class BufferReader { protected abstract ArrowBuf read(BufferAllocator allocator, int count) throws IOException; @@ -692,7 +713,8 @@ private ArrowBuf readIntoBuffer(BufferAllocator allocator, BufferType bufferType } private void readFromJsonIntoVector(Field field, FieldVector vector) throws JsonParseException, IOException { - TypeLayout typeLayout = TypeLayout.getTypeLayout(field.getType()); + ArrowType type = field.getType(); + TypeLayout typeLayout = TypeLayout.getTypeLayout(type); List vectorTypes = typeLayout.getBufferTypes(); ArrowBuf[] vectorBuffers = new ArrowBuf[vectorTypes.size()]; /* @@ -728,21 +750,18 @@ private void readFromJsonIntoVector(Field field, FieldVector vector) throws Json BufferType bufferType = vectorTypes.get(v); nextFieldIs(bufferType.getName()); int innerBufferValueCount = valueCount; - if (bufferType.equals(OFFSET) && !field.getType().getTypeID().equals(ArrowType.ArrowTypeID.Union)) { - /* offset buffer has 1 additional value capacity */ + if (bufferType.equals(OFFSET) && !(type instanceof ArrowType.Union)) { + /* offset buffer has 1 additional value capacity except for dense unions */ innerBufferValueCount = valueCount + 1; } vectorBuffers[v] = readIntoBuffer(allocator, bufferType, vector.getMinorType(), innerBufferValueCount); } - if (vectorBuffers.length == 0) { - readToken(END_OBJECT); - return; - } - int nullCount = 0; - if (!(vector.getField().getFieldType().getType() instanceof ArrowType.Union)) { + if (type instanceof ArrowType.Null) { + nullCount = valueCount; + } else if (!(type instanceof ArrowType.Union)) { nullCount = BitVectorHelper.getNullCount(vectorBuffers[0], valueCount); } final ArrowFieldNode fieldNode = new ArrowFieldNode(valueCount, nullCount); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowRecordBatch.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowRecordBatch.java index 83a8ece0bfb06..f81d049a9257f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowRecordBatch.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/ArrowRecordBatch.java @@ -112,8 +112,8 @@ public ArrowRecordBatch( } long size = arrowBuf.readableBytes(); arrowBuffers.add(new ArrowBuffer(offset, size)); - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Buffer in RecordBatch at {}, length: {}", offset, size); + if (LOGGER.isTraceEnabled()) { + LOGGER.trace("Buffer in RecordBatch at {}, length: {}", offset, size); } offset += size; if (alignBuffers) { // align on 8 byte boundaries diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java b/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java index 741972b4ad2a8..0c9ad1e2753f1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java @@ -17,6 +17,7 @@ package org.apache.arrow.vector.util; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -85,6 +86,31 @@ public static void compareDictionaries( } } + /** + * Validate two dictionary providers are equal in structure and contents. + */ + public static void compareDictionaryProviders( + DictionaryProvider provider1, + DictionaryProvider provider2) { + List ids1 = new ArrayList(provider1.getDictionaryIds()); + List ids2 = new ArrayList(provider2.getDictionaryIds()); + java.util.Collections.sort(ids1); + java.util.Collections.sort(ids2); + if (!ids1.equals(ids2)) { + throw new IllegalArgumentException("Different ids in dictionary providers:\n" + + ids1 + "\n" + ids2); + } + for (long id : ids1) { + Dictionary dict1 = provider1.lookup(id); + Dictionary dict2 = provider2.lookup(id); + try { + compareFieldVectors(dict1.getVector(), dict2.getVector()); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Different dictionaries:\n" + dict1 + "\n" + dict2, e); + } + } + } + /** * Validate two arrow vectorSchemaRoot are equal. *