diff --git a/connector/src/main/java/tech/ydb/spark/connector/YdbContext.java b/connector/src/main/java/tech/ydb/spark/connector/YdbContext.java index cbad50b..748dacc 100644 --- a/connector/src/main/java/tech/ydb/spark/connector/YdbContext.java +++ b/connector/src/main/java/tech/ydb/spark/connector/YdbContext.java @@ -13,7 +13,6 @@ import java.util.Objects; import com.google.common.io.ByteStreams; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,7 +34,7 @@ * * @author Aleksandr Gorshenin */ -public class YdbContext implements Serializable { +public class YdbContext implements Serializable, AutoCloseable { private static final Logger logger = LoggerFactory.getLogger(YdbContext.class); private static final long serialVersionUID = 6522842483896983993L; @@ -62,7 +61,7 @@ public class YdbContext implements Serializable { private final int sessionPoolSize; - public YdbContext(CaseInsensitiveStringMap options) { + public YdbContext(Map options) { this.connectionString = ConnectionOption.URL.read(options); if (connectionString == null || this.connectionString .trim().isEmpty()) { throw new IllegalArgumentException("Incorrect value for property " + ConnectionOption.URL); @@ -121,6 +120,11 @@ public boolean equals(Object other) { sessionPoolSize == sessionPoolSize; } + @Override + public void close() { + YdbRegistry.closeExecutor(this); + } + public YdbExecutor getExecutor() { return YdbRegistry.getOrCreate(this, YdbContext::createExecutor); } diff --git a/connector/src/main/java/tech/ydb/spark/connector/YdbRegistry.java b/connector/src/main/java/tech/ydb/spark/connector/YdbRegistry.java index ff2095c..b1a211a 100644 --- a/connector/src/main/java/tech/ydb/spark/connector/YdbRegistry.java +++ b/connector/src/main/java/tech/ydb/spark/connector/YdbRegistry.java @@ -32,4 +32,11 @@ public static void closeAll() { } } } + + public static void closeExecutor(YdbContext ctx) { + YdbExecutor executor = EXECUTORS.remove(ctx); + if (executor != null) { + executor.close(); + } + } } diff --git a/connector/src/main/java/tech/ydb/spark/connector/YdbTableProvider.java b/connector/src/main/java/tech/ydb/spark/connector/YdbTableProvider.java index dee2226..dcfd5cd 100644 --- a/connector/src/main/java/tech/ydb/spark/connector/YdbTableProvider.java +++ b/connector/src/main/java/tech/ydb/spark/connector/YdbTableProvider.java @@ -21,6 +21,8 @@ * @author zinal */ public class YdbTableProvider implements TableProvider, DataSourceRegister { + private static final String SPARK_PATH_OPTION = "path"; + @Override public String shortName() { return "ydb"; @@ -34,10 +36,16 @@ public boolean supportsExternalMetadata() { private String exractTableName(CaseInsensitiveStringMap options) { // Check that table path is provided String table = OperationOption.DBTABLE.read(options); - if (table == null || table.trim().length() == 0) { - throw new IllegalArgumentException("Missing property: " + OperationOption.DBTABLE); + if (table != null && !table.trim().isEmpty()) { + return table.trim(); + } + + String path = options.get(SPARK_PATH_OPTION); + if (path != null && !path.trim().isEmpty()) { + return path.trim(); } - return table.trim(); + + throw new IllegalArgumentException("Missing property: " + OperationOption.DBTABLE); } @Override diff --git a/connector/src/main/java/tech/ydb/spark/connector/common/KeysRange.java b/connector/src/main/java/tech/ydb/spark/connector/common/KeysRange.java index 4c73772..4ee73e7 100644 --- a/connector/src/main/java/tech/ydb/spark/connector/common/KeysRange.java +++ b/connector/src/main/java/tech/ydb/spark/connector/common/KeysRange.java @@ -1,13 +1,19 @@ package tech.ydb.spark.connector.common; import java.io.Serializable; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; +import java.util.Objects; import java.util.Optional; +import org.sparkproject.guava.annotations.VisibleForTesting; + import tech.ydb.spark.connector.YdbTypes; import tech.ydb.table.description.KeyBound; import tech.ydb.table.description.KeyRange; import tech.ydb.table.values.TupleValue; +import tech.ydb.table.values.Type; import tech.ydb.table.values.Value; /** @@ -17,232 +23,242 @@ */ public class KeysRange implements Serializable { private static final long serialVersionUID = 5756661733369903758L; - - public static final Limit NO_LIMIT = new Limit(new Serializable[0], true); - public static final KeysRange UNRESTRICTED = new KeysRange(NO_LIMIT, NO_LIMIT); + public static final KeysRange UNRESTRICTED = new KeysRange(Limit.UNSTRICTED, Limit.UNSTRICTED); + public static final KeysRange EMPTY = new KeysRange((Limit) null, (Limit) null); private final Limit from; private final Limit to; - public KeysRange(Limit from, Limit to) { - this.from = validated(from); - this.to = validated(to); + public KeysRange(KeyRange kr, YdbTypes types) { + this(new Limit(kr.getFrom(), types), new Limit(kr.getTo(), types)); } - public KeysRange(KeyRange kr, YdbTypes types) { - this(convert(kr.getFrom(), types), convert(kr.getTo(), types)); + public KeysRange(Serializable[] fromValue, boolean fromInclusive, Serializable[] toValue, boolean toInclusive) { + this(new Limit(fromValue, fromInclusive), new Limit(toValue, toInclusive)); } -// public KeysRange(ArrayList from, ArrayList to) { -// this(new Limit(from, true), new Limit(to, false)); -// } + private KeysRange(Limit left, Limit right) { + if (isValidRange(left, right)) { + this.from = left; + this.to = right; + } else { + this.from = null; + this.to = null; + } + } -// public KeysRange(List from, List to) { -// this(new Limit(from, true), new Limit(to, false)); -// } + public boolean isEmpty() { + return from == null || to == null; + } - public KeysRange(Serializable[] from, boolean fromInclusive, Serializable[] to, boolean toInclusive) { - this(new Limit(from, fromInclusive), new Limit(to, toInclusive)); + public boolean hasFromValue() { + return from != null && from.values != null; } - public Limit getFrom() { - return from; + public boolean hasToValue() { + return to != null && to.values != null; } - public Limit getTo() { - return to; + public boolean isUnrestricted() { + return from != null && to != null && from.values == null && to.values == null; } - /** - * Empty range means left is greater than right. Missing values on left means MIN, on right - - * MAX. - * - * @return true for empty range, false otherwise - */ - public boolean isEmpty() { - for (int idx = 0; idx < from.values.length && idx < to.values.length; idx += 1) { - final Serializable o1 = from.values[idx]; - final Serializable o2 = to.values[idx]; - if (o1 == o2) { - continue; - } - if (o1 == null) { - return false; - } - if (o2 == null) { - return false; - } - if (o1.getClass() != o2.getClass()) { - throw new IllegalArgumentException("Incompatible data types " + o1.getClass().toString() - + " and " + o2.getClass().toString()); - } - if (!(o1 instanceof Comparable)) { - throw new IllegalArgumentException("Uncomparable data type " + o1.getClass().toString()); - } - @SuppressWarnings("unchecked") - int cmp = ((Comparable) o1).compareTo(o2); - if (cmp > 0) { - return true; - } - if (cmp < 0) { - return false; - } - } - if (!from.inclusive) { - if ((from.values.length < to.values.length) || to.inclusive) { - return true; - } - } - if (!to.inclusive) { - if ((from.values.length > to.values.length) || from.inclusive) { - return true; - } - } - return false; + public boolean includesFromValue() { + return from != null && from.inclusive; } - public boolean isUnrestricted() { - return from.isUnrestricted() && to.isUnrestricted(); + public boolean includesToValue() { + return to != null && to.inclusive; } - public static Limit convert(Optional v, YdbTypes types) { - if (!v.isPresent()) { - return null; - } + public TupleValue readFromValue(YdbTypes types, FieldInfo[] columns) { + return from.writeTuple(types, columns); + } - KeyBound kb = v.get(); - Value tx = kb.getValue(); - if (!(tx instanceof TupleValue)) { - throw new IllegalArgumentException(); + public TupleValue readToValue(YdbTypes types, FieldInfo[] columns) { + return to.writeTuple(types, columns); + } + + public KeysRange intersect(KeysRange other) { + if (isEmpty() || other == null || other.isUnrestricted()) { + return this; } - TupleValue t = (TupleValue) tx; - final int sz = t.size(); - Serializable[] out = new Serializable[sz]; - for (int i = 0; i < sz; ++i) { - out[i] = types.convertFromYdb(t.get(i)); + if (other.isEmpty() || isUnrestricted()) { + return other; } - return new Limit(out, kb.isInclusive()); + return new KeysRange( + from.leftMerge(other.from), + to.rightMerge(other.to) + ); } @Override public String toString() { - return (from.isInclusive() ? "[* " : "(* ") - + Arrays.toString(from.values) + " -> " + Arrays.toString(to.values) - + (to.isInclusive() ? " *]" : " *)"); + if (isEmpty()) { + return "(None)"; + } + char l = from.inclusive ? '[' : '('; + char r = to.inclusive ? ']' : ')'; + return l + toString(from.values, "-Inf") + " - " + toString(to.values, "+Inf") + r; } - public KeysRange intersect(KeysRange other) { - if (other == null || (other.from.isUnrestricted() && other.to.isUnrestricted())) { - return this; + @Override + public int hashCode() { + return Objects.hash(from, to); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; } - if (from.isUnrestricted() && to.isUnrestricted()) { - return other; + if (obj == null) { + return false; } - Limit outFrom = (from.compareTo(other.from, true) > 0) ? from : other.from; - Limit outTo = (to.compareTo(other.to, false) > 0) ? other.to : to; - KeysRange retval = new KeysRange(outFrom, outTo); - return retval; + if (getClass() != obj.getClass()) { + return false; + } + KeysRange other = (KeysRange) obj; + return Objects.equals(this.to, other.to) && Objects.equals(this.from, other.from); } - public static int compare(Serializable[] v1, Serializable[] v2, boolean nullsFirst) { - if (v1 == v2 || (v1 != null && v2 != null && v1.length == 0 && v2.length == 0)) { - return 0; + private static String toString(Serializable[] value, String nullValue) { + if (value == null || value.length == 0) { + return nullValue; } - if (v1 == null || v1.length == 0) { - return nullsFirst ? -1 : 1; + if (value.length == 1) { + return value[0].toString(); } - if (v2 == null || v2.length == 0) { - return nullsFirst ? 1 : -1; + + String[] ss = new String[value.length]; + for (int idx = 0; idx < value.length; idx += 1) { + ss[idx] = value[idx].toString(); + } + + return "(" + String.join(",", ss) + ")"; + } + + @VisibleForTesting + static int compareValues(Serializable[] v1, Serializable[] v2) { + if (v1 == v2) { + return 0; // the same values + } + + if (v1 == null) { + return -1; + } + + if (v2 == null) { + return 1; } for (int idx = 0; idx < v1.length && idx < v2.length; idx += 1) { Serializable o1 = v1[idx]; Serializable o2 = v2[idx]; - if (o1 == o2) { - continue; - } if (o1 == null) { - return nullsFirst ? -1 : 1; + return o2 == null ? 0 : -1; } + if (o2 == null) { - return nullsFirst ? 1 : -1; + return 1; + } + + if (o1 == o2) { + continue; } - if (o1.getClass() != o2.getClass()) { - throw new IllegalArgumentException("Incompatible data types " - + o1.getClass().toString() + " and " + o2.getClass().toString()); + Class type = o1.getClass(); + if (type != o2.getClass()) { + throw new IllegalArgumentException("Incompatible data types " + type + " and " + o2.getClass()); } - if (!(o1 instanceof Comparable)) { - throw new IllegalArgumentException("Uncomparable data type " - + o1.getClass().toString()); + if (!Comparable.class.isAssignableFrom(type)) { + throw new IllegalArgumentException("Uncomparable data type " + type); } + @SuppressWarnings("unchecked") final int cmp = ((Comparable) o1).compareTo(o2); if (cmp != 0) { return cmp; } } - if (v1.length < v2.length) { - return nullsFirst ? -1 : 1; + + if (v1.length < v2.length && v2[v1.length] != null) { + return -1; } - if (v1.length > v2.length) { - return nullsFirst ? 1 : -1; + + if (v2.length < v1.length && v1[v2.length] != null) { + return 1; } return 0; } - private static Limit validated(Limit v) { - if (v == null) { - return NO_LIMIT; + private static boolean isValidRange(Limit from, Limit to) { + // empty range + if (from == null || to == null) { + return false; } - if (v.isUnrestricted()) { - return NO_LIMIT; + // range with one or both unrestricted bounds is always correct + if (from.values == null || to.values == null) { + return true; } - int pos = v.values.length; - while (pos > 0) { - Object o = v.values[pos - 1]; - if (o != null) { - break; - } - pos -= 1; + + int cmp = compareValues(from.values, to.values); + if (cmp < 0) { // from < to is always valid range + return true; + } + if (cmp > 0) { // from > to is always invalid range + return false; } - if (pos == v.values.length) { - return v; + // range from single value valids only both bound are inclusive + return from.inclusive && to.inclusive; + } + + public static Serializable[] readTuple(Value value, YdbTypes types) { + if (!(value instanceof TupleValue)) { + throw new IllegalArgumentException(); } - if (pos < 1) { - return NO_LIMIT; + + TupleValue tv = (TupleValue) value; + int sz = tv.size(); + Serializable[] out = new Serializable[sz]; + + for (int i = 0; i < sz; ++i) { + out[i] = types.convertFromYdb(tv.get(i)); + + if (out[i] == null) { // can reduce tuple until first null + Serializable[] reduced = new Serializable[i]; + System.arraycopy(out, 0, reduced, 0, i); + return reduced; + } } - Serializable[] cleaned = new Serializable[pos]; - System.arraycopy(v.values, 0, cleaned, 0, pos); - return new Limit(cleaned, v.inclusive); + return out; } - public static class Limit implements Serializable { - private static final long serialVersionUID = -3278687786440323269L; + private static class Limit implements Serializable { + private static final Limit UNSTRICTED = new Limit(Optional.empty(), null); + private static final long serialVersionUID = -9050443235398158196L; private final Serializable[] values; private final boolean inclusive; - public Limit(Serializable[] values, boolean inclusive) { + private Limit(Serializable[] values, boolean inclusive) { this.values = values; - this.inclusive = inclusive; - } - - public Serializable[] getValues() { - return values; + this.inclusive = values != null && inclusive; // inf cannot be inclusive } - public boolean isInclusive() { - return inclusive; - } - - public boolean isUnrestricted() { - return values == null || values.length == 0; + private Limit(Optional key, YdbTypes types) { + if (key.isPresent()) { + this.values = readTuple(key.get().getValue(), types); + this.inclusive = key.get().isInclusive(); + } else { + this.values = null; + this.inclusive = false; + } } @Override public int hashCode() { - return 2 * Arrays.hashCode(values) + (this.inclusive ? 1 : 0); + return Objects.hash(inclusive, Arrays.hashCode(values)); } @Override @@ -256,31 +272,58 @@ public boolean equals(Object obj) { if (getClass() != obj.getClass()) { return false; } - final Limit other = (Limit) obj; - if (this.inclusive != other.inclusive) { - return false; + Limit other = (Limit) obj; + return this.inclusive == other.inclusive && Arrays.equals(this.values, other.values); + } + + public Limit leftMerge(Limit other) { + if (values == null) { + return other; + } + if (other.values == null) { + return this; + } + + int cmp = compareValues(values, other.values); + if (cmp < 0) { + return other; } - return Arrays.equals(this.values, other.values); + if (cmp > 0) { + return this; + } + + return new Limit(values, inclusive && other.inclusive); } - @Override - public String toString() { - return "{" + "value=" + Arrays.toString(values) + ", inclusive=" + inclusive + '}'; + public Limit rightMerge(Limit other) { + if (values == null) { + return other; + } + if (other.values == null) { + return this; + } + + int cmp = compareValues(values, other.values); + if (cmp < 0) { + return this; + } + if (cmp > 0) { + return other; + } + + return new Limit(values, inclusive && other.inclusive); } - public int compareTo(Limit t, boolean nullsFirst) { - int cmp = compare(this.values, t.values, nullsFirst); - if (cmp == 0) { - if (this.inclusive == t.inclusive) { - return 0; - } - if (this.inclusive) { - return nullsFirst ? -1 : 1; + public TupleValue writeTuple(YdbTypes types, FieldInfo[] columns) { + final List> l = new ArrayList<>(values.length); + for (int i = 0; i < values.length; ++i) { + Value v = types.convertToYdb(values[i], columns[i].getType()); + if (!v.getType().getKind().equals(Type.Kind.OPTIONAL)) { + v = v.makeOptional(); } - return nullsFirst ? 1 : -1; + l.add(v); } - return cmp; + return TupleValue.of(l); } } - } diff --git a/connector/src/main/java/tech/ydb/spark/connector/impl/YdbExecutor.java b/connector/src/main/java/tech/ydb/spark/connector/impl/YdbExecutor.java index 6f94066..d58c2c9 100644 --- a/connector/src/main/java/tech/ydb/spark/connector/impl/YdbExecutor.java +++ b/connector/src/main/java/tech/ydb/spark/connector/impl/YdbExecutor.java @@ -78,6 +78,10 @@ public CompletableFuture executeDataQuery(String query, Params params) { ); } + public CompletableFuture executeSchemeQuery(String query) { + return retryCtx.supplyStatus(s -> s.executeSchemeQuery(query)); + } + public boolean truncateTable(String tablePath) { final YdbTruncateTable action = new YdbTruncateTable(tablePath); retryCtx.supplyStatus(session -> action.run(session)).join().expectSuccess(); diff --git a/connector/src/main/java/tech/ydb/spark/connector/impl/YdbScanReadTable.java b/connector/src/main/java/tech/ydb/spark/connector/impl/YdbScanReadTable.java index fa8d9fc..5a50d2c 100644 --- a/connector/src/main/java/tech/ydb/spark/connector/impl/YdbScanReadTable.java +++ b/connector/src/main/java/tech/ydb/spark/connector/impl/YdbScanReadTable.java @@ -1,8 +1,6 @@ package tech.ydb.spark.connector.impl; -import java.io.Serializable; import java.time.Duration; -import java.util.ArrayList; import java.util.List; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CompletableFuture; @@ -27,8 +25,6 @@ import tech.ydb.table.result.ResultSetReader; import tech.ydb.table.settings.ReadTableSettings; import tech.ydb.table.values.TupleValue; -import tech.ydb.table.values.Type; -import tech.ydb.table.values.Value; /** * YDB table or index scan implementation through the ReadTable call. @@ -44,7 +40,7 @@ public class YdbScanReadTable implements AutoCloseable { private final List outColumns; private final GrpcReadStream stream; - private final CompletableFuture readStatus; + private final CompletableFuture readStatus; private final ArrayBlockingQueue queue; private volatile QueueItem currentItem = null; @@ -55,8 +51,8 @@ public YdbScanReadTable(YdbTable table, YdbScanOptions options, KeysRange keysRa this.queue = new ArrayBlockingQueue<>(options.getScanQueueDepth()); FieldInfo[] keys = table.getKeyColumns(); - logger.debug("Configuring scan for table {}, range {}, columns {}", tablePath, keysRange, keys); ReadTableSettings.Builder rtsb = ReadTableSettings.newBuilder(); + rtsb.orderedRead(true); scala.collection.Iterator sfit = options.getReadSchema().toIterator(); if (sfit.isEmpty()) { // In case no fields are required, add the first field of the primary key. @@ -67,32 +63,30 @@ public YdbScanReadTable(YdbTable table, YdbScanOptions options, KeysRange keysRa } } - final KeysRange.Limit realLeft = keysRange.getFrom(); - final KeysRange.Limit realRight = keysRange.getTo(); - if (!realLeft.isUnrestricted()) { - TupleValue tv = makeRange(types, keys, realLeft.getValues()); - if (realLeft.isInclusive()) { + if (keysRange.hasFromValue()) { + TupleValue tv = keysRange.readFromValue(types, keys); + if (keysRange.includesFromValue()) { rtsb.fromKeyInclusive(tv); } else { rtsb.fromKeyExclusive(tv); } - logger.debug("fromKey: {} -> {}", realLeft, tv); } - if (!realRight.isUnrestricted()) { - TupleValue tv = makeRange(types, keys, realRight.getValues()); - if (realRight.isInclusive()) { + if (keysRange.hasToValue()) { + TupleValue tv = keysRange.readToValue(types, keys); + if (keysRange.includesToValue()) { rtsb.toKeyInclusive(tv); } else { rtsb.toKeyExclusive(tv); } - logger.debug("toKey: {} -> {}", realRight, tv); } if (options.getRowLimit() > 0) { - logger.debug("Setting row limit to {}", options.getRowLimit()); rtsb.rowLimit(options.getRowLimit()); } + logger.debug("Configuring scan for table {} with range {} and limit {}, columns {}", + tablePath, keysRange, options.getRowLimit(), keys); + // TODO: add setting for the maximum scan duration. rtsb.withRequestTimeout(Duration.ofHours(8)); @@ -113,17 +107,6 @@ public YdbScanReadTable(YdbTable table, YdbScanOptions options, KeysRange keysRa session.getValue().close(); } - private static TupleValue makeRange(YdbTypes types, FieldInfo[] keys, Serializable[] values) { - final List> l = new ArrayList<>(values.length); - for (int i = 0; i < values.length; ++i) { - Value v = types.convertToYdb(values[i], keys[i].getType()); - if (!v.getType().getKind().equals(Type.Kind.OPTIONAL)) { - v = v.makeOptional(); - } - l.add(v); - } - return TupleValue.of(l); - } private void onNextPart(ReadTablePart part) { QueueItem nextItem = new QueueItem(part.getResultSetReader()); try { diff --git a/connector/src/main/java/tech/ydb/spark/connector/read/YdbReaderFactory.java b/connector/src/main/java/tech/ydb/spark/connector/read/YdbReaderFactory.java index aa13588..b8ae00f 100644 --- a/connector/src/main/java/tech/ydb/spark/connector/read/YdbReaderFactory.java +++ b/connector/src/main/java/tech/ydb/spark/connector/read/YdbReaderFactory.java @@ -6,8 +6,6 @@ import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.connector.read.PartitionReaderFactory; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import tech.ydb.spark.connector.YdbTable; import tech.ydb.spark.connector.impl.YdbScanReadTable; @@ -17,7 +15,6 @@ * @author Aleksandr Gorshenin */ public class YdbReaderFactory implements PartitionReaderFactory { - private static final Logger logger = LoggerFactory.getLogger(YdbReaderFactory.class); private static final long serialVersionUID = 6815846949510430713L; private final YdbTable table; @@ -49,9 +46,7 @@ public LazyReader(YdbTable table, YdbScanOptions options, InputPartition partiti @Override public boolean next() throws IOException { if (scan == null) { - logger.debug("Preparing scan for table {} at partition {}", table.getTablePath(), partition); scan = new YdbScanReadTable(table, options, partition.getRange()); - logger.debug("Scan prepared, ready to fetch..."); } return scan.next(); } @@ -64,7 +59,6 @@ public InternalRow get() { @Override public void close() throws IOException { if (scan != null) { - logger.debug("Closing the scan."); scan.close(); } scan = null; diff --git a/connector/src/main/java/tech/ydb/spark/connector/write/YdbWriterFactory.java b/connector/src/main/java/tech/ydb/spark/connector/write/YdbWriterFactory.java index ece705e..43a162e 100644 --- a/connector/src/main/java/tech/ydb/spark/connector/write/YdbWriterFactory.java +++ b/connector/src/main/java/tech/ydb/spark/connector/write/YdbWriterFactory.java @@ -46,7 +46,7 @@ public YdbWriterFactory(YdbTable table, LogicalWriteInfo logical, PhysicalWriteI this.table = table; this.types = new YdbTypes(logical.options()); this.method = OperationOption.INGEST_METHOD.readEnum(logical.options(), IngestMethod.BULK_UPSERT); - this.maxBatchSize = OperationOption.BATCH_SIZE.readInt(logical.options(), 1000); + this.maxBatchSize = OperationOption.BATCH_SIZE.readInt(logical.options(), 10000); this.autoPkName = OperationOption.AUTO_PK.read(logical.options(), OperationOption.DEFAULT_AUTO_PK); this.schema = logical.schema(); } diff --git a/connector/src/test/java/tech/ydb/spark/connector/DataFramesTest.java b/connector/src/test/java/tech/ydb/spark/connector/DataFramesTest.java new file mode 100644 index 0000000..783c654 --- /dev/null +++ b/connector/src/test/java/tech/ydb/spark/connector/DataFramesTest.java @@ -0,0 +1,145 @@ +package tech.ydb.spark.connector; + +import java.util.Collections; + +import org.apache.spark.SparkConf; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; + +import tech.ydb.spark.connector.YdbContext; +import tech.ydb.spark.connector.impl.YdbExecutor; +import tech.ydb.table.values.ListType; +import tech.ydb.table.values.ListValue; +import tech.ydb.table.values.PrimitiveType; +import tech.ydb.table.values.PrimitiveValue; +import tech.ydb.table.values.StructType; +import tech.ydb.test.junit4.YdbHelperRule; + +/** + * + * @author Aleksandr Gorshenin + */ +public class DataFramesTest { + @ClassRule + public static final YdbHelperRule YDB = new YdbHelperRule(); + + private static String ydbURL; + private static YdbContext ctx; + private static SparkSession spark; + + @BeforeClass + public static void prepare() { + StringBuilder url = new StringBuilder() + .append(YDB.useTls() ? "grpcs://" : "grpc://") + .append(YDB.endpoint()) + .append(YDB.database()); + + if (YDB.authToken() != null) { + url.append("?").append("token=").append(YDB.authToken()); + } + + ydbURL = url.toString(); + ctx = new YdbContext(Collections.singletonMap("url", ydbURL)); + + prepareTables(ctx.getExecutor()); + + SparkConf conf = new SparkConf() + .setMaster("local[4]") + .setAppName("ydb-spark-dataframes-test") + .set("spark.ui.enabled", "false"); + + spark = SparkSession.builder() + .config(conf) + .getOrCreate(); + } + + @AfterClass + public static void close() { + if (spark != null) { + spark.close(); + } + if (ctx != null) { + cleanTables(ctx.getExecutor()); + ctx.close(); + } + ctx.close(); + } + + private static void prepareTables(YdbExecutor executor) { + executor.makeDirectory(executor.extractPath("df_test_dir")); + executor.executeSchemeQuery("CREATE TABLE df_test_table (" + + " id Int32 NOT NULL," + + " value Text," + + " PRIMARY KEY(id) " + + ")").join().expectSuccess("cannot create test table"); + executor.executeSchemeQuery("CREATE TABLE `df_test_dir/splitted_table` (" + + " id Int32 NOT NULL," + + " value Text," + + " PRIMARY KEY(id) " + + ") WITH (" + + " AUTO_PARTITIONING_MIN_PARTITIONS_COUNT = 7, " + + " PARTITION_AT_KEYS = (11, 22, 33, 44, 55, 66) " + + ")").join().expectSuccess("cannot create test table"); + + StructType struct = StructType.of("id", PrimitiveType.Int32, "value", PrimitiveType.Text); + ListValue initValues = ListType.of(struct).newValueOwn( + struct.newValue("id", PrimitiveValue.newInt32(1), "value", PrimitiveValue.newText("v1")), + struct.newValue("id", PrimitiveValue.newInt32(2), "value", PrimitiveValue.newText("v2")), + struct.newValue("id", PrimitiveValue.newInt32(10), "value", PrimitiveValue.newText("v3")), + struct.newValue("id", PrimitiveValue.newInt32(11), "value", PrimitiveValue.newText("v4")), + struct.newValue("id", PrimitiveValue.newInt32(12), "value", PrimitiveValue.newText("v5")), + struct.newValue("id", PrimitiveValue.newInt32(50), "value", PrimitiveValue.newText("v6")), + struct.newValue("id", PrimitiveValue.newInt32(51), "value", PrimitiveValue.newText("v7")), + struct.newValue("id", PrimitiveValue.newInt32(65), "value", PrimitiveValue.newText("v8")), + struct.newValue("id", PrimitiveValue.newInt32(66), "value", PrimitiveValue.newText("v9")), + struct.newValue("id", PrimitiveValue.newInt32(67), "value", PrimitiveValue.newText("v10")) + ); + + executor.executeBulkUpsert(executor.extractPath("df_test_table"), initValues).join() + .expectSuccess("cannot insert data to df_test_dir"); + executor.executeBulkUpsert(executor.extractPath("df_test_dir/splitted_table"), initValues).join() + .expectSuccess("cannot insert data to df_test_dir/splitted_table"); + } + + private static void cleanTables(YdbExecutor executor) { + executor.executeSchemeQuery("DROP TABLE `df_test_dir/splitted_table`;").join(); + executor.executeSchemeQuery("DROP TABLE df_test_table;").join(); + executor.removeDirectory(executor.extractPath("df_test_dir")); + } + + @Test + public void readTableByOptionTest() { + long count1 = spark.read().format("ydb") + .option("url", ydbURL) + .option("dbtable", "df_test_table") + .load() + .count(); + Assert.assertEquals(10, count1); + + long count2 = spark.read().format("ydb") + .option("url", ydbURL) + .option("dbtable", "df_test_dir/splitted_table") + .load() + .count(); + Assert.assertEquals(10, count2); + } + + @Test + public void readTableByNameTest() { + long count1 = spark.read().format("ydb") + .option("url", ydbURL) + .load("df_test_table") + .count(); + Assert.assertEquals(10, count1); + + long count2 = spark.read().format("ydb") + .option("url", ydbURL) + .load("df_test_dir/splitted_table") + .count(); + Assert.assertEquals(10, count2); + } +} diff --git a/connector/src/test/java/tech/ydb/spark/connector/integration/IntegrationTest.java b/connector/src/test/java/tech/ydb/spark/connector/IntegrationTest.java similarity index 99% rename from connector/src/test/java/tech/ydb/spark/connector/integration/IntegrationTest.java rename to connector/src/test/java/tech/ydb/spark/connector/IntegrationTest.java index 9df8819..a997111 100644 --- a/connector/src/test/java/tech/ydb/spark/connector/integration/IntegrationTest.java +++ b/connector/src/test/java/tech/ydb/spark/connector/IntegrationTest.java @@ -1,4 +1,4 @@ -package tech.ydb.spark.connector.integration; +package tech.ydb.spark.connector; import java.util.ArrayList; import java.util.HashMap; @@ -65,7 +65,6 @@ public static void prepare() { retryCtx = SessionRetryContext.create(tableClient).build(); spark = SparkSession.builder() - .config(conf) .config(conf) .getOrCreate(); } diff --git a/connector/src/test/java/tech/ydb/spark/connector/YdbKeyRangeTest.java b/connector/src/test/java/tech/ydb/spark/connector/YdbKeyRangeTest.java deleted file mode 100644 index ff64124..0000000 --- a/connector/src/test/java/tech/ydb/spark/connector/YdbKeyRangeTest.java +++ /dev/null @@ -1,93 +0,0 @@ -package tech.ydb.spark.connector; - -import java.io.Serializable; - -import org.junit.Assert; -import org.junit.Test; - -import tech.ydb.spark.connector.common.KeysRange; - -/** - * - * @author mzinal - */ -public class YdbKeyRangeTest { - - private KeysRange.Limit makeExclusive(Serializable... vals) { - return new KeysRange.Limit(vals, false); - } - - private KeysRange.Limit makeInclusive(Serializable... vals) { - return new KeysRange.Limit(vals, true); - } - - @Test - public void testCompare() { - KeysRange.Limit x1; - KeysRange.Limit x2; - - x1 = makeExclusive("A", 10, 1L); - x2 = makeExclusive("A", 20, 1L); - Assert.assertEquals(-1, x1.compareTo(x2, true)); - Assert.assertEquals(-1, x1.compareTo(x2, false)); - Assert.assertEquals(1, x2.compareTo(x1, true)); - Assert.assertEquals(1, x2.compareTo(x1, false)); - - x2 = makeExclusive("A", 10, 1L); - Assert.assertEquals(0, x1.compareTo(x2, true)); - - x2 = makeExclusive("A", 10); - Assert.assertEquals(1, x1.compareTo(x2, true)); - Assert.assertEquals(-1, x1.compareTo(x2, false)); - Assert.assertEquals(-1, x2.compareTo(x1, true)); - Assert.assertEquals(1, x2.compareTo(x1, false)); - - x1 = makeInclusive("A", 10, 1L); - x2 = makeInclusive("A", 20, 1L); - Assert.assertEquals(-1, x1.compareTo(x2, true)); - Assert.assertEquals(-1, x1.compareTo(x2, false)); - Assert.assertEquals(1, x2.compareTo(x1, true)); - Assert.assertEquals(1, x2.compareTo(x1, false)); - - x2 = makeInclusive("A", 10, 1L); - Assert.assertEquals(0, x1.compareTo(x2, true)); - - x2 = makeInclusive("A", 10); - Assert.assertEquals(1, x1.compareTo(x2, true)); - Assert.assertEquals(-1, x1.compareTo(x2, false)); - Assert.assertEquals(-1, x2.compareTo(x1, true)); - Assert.assertEquals(1, x2.compareTo(x1, false)); - } - - @Test - public void testEmpty() { - KeysRange.Limit x1; - KeysRange.Limit x2; - - x1 = makeInclusive("A", 10, 1L); - x2 = makeInclusive("A", 10, 1L); - Assert.assertEquals(false, new KeysRange(x1, x2).isEmpty()); - - x1 = makeInclusive("A", 10, 1L); - x2 = makeExclusive("A", 10, 1L); - Assert.assertEquals(true, new KeysRange(x1, x2).isEmpty()); - - x1 = makeExclusive("A", 10, 1L); - x2 = makeExclusive("A", 10, 2L); - Assert.assertEquals(false, new KeysRange(x1, x2).isEmpty()); - - Assert.assertEquals(true, - new KeysRange(new Serializable[] {31000000L}, true, new Serializable[] {31000000L}, false).isEmpty()); - } - - @Test - public void testIntersect() { - KeysRange r1; - KeysRange r2, ro; - r1 = new KeysRange(new Serializable[] {31000000L}, true, new Serializable[] {32000000L}, false); - r2 = new KeysRange(new Serializable[] {46000000L}, true, new Serializable[] {46250000L}, false); - ro = r2.intersect(r1); - Assert.assertEquals(true, ro.isEmpty()); - } - -} diff --git a/connector/src/test/java/tech/ydb/spark/connector/common/KeysRangeTest.java b/connector/src/test/java/tech/ydb/spark/connector/common/KeysRangeTest.java new file mode 100644 index 0000000..85206a6 --- /dev/null +++ b/connector/src/test/java/tech/ydb/spark/connector/common/KeysRangeTest.java @@ -0,0 +1,167 @@ +package tech.ydb.spark.connector.common; + +import java.io.Serializable; + +import org.junit.Assert; +import org.junit.Test; + + + +/** + * + * @author Aleksandr Gorshenin + */ +public class KeysRangeTest { + + @Test + public void equalsTest() { + Serializable[] v1 = new Serializable[] {"A", 10, -1L}; + Serializable[] v2 = new Serializable[] {"A", 10, 1L}; + Serializable[] v3 = new Serializable[] {"A", 10, 1L}; + + KeysRange r1 = new KeysRange(v1, true, v2, true); + KeysRange r2 = new KeysRange(v1, true, v3, true); + KeysRange r3 = new KeysRange(v1, true, v2, false); + KeysRange r4 = new KeysRange(v2, true, v1, false); + + Assert.assertEquals(r1.hashCode(), r2.hashCode()); + Assert.assertNotEquals(r1.hashCode(), r3.hashCode()); + Assert.assertNotEquals(r1.hashCode(), KeysRange.EMPTY.hashCode()); + Assert.assertNotEquals(r1.hashCode(), KeysRange.UNRESTRICTED.hashCode()); + Assert.assertEquals(r4.hashCode(), KeysRange.EMPTY.hashCode()); + Assert.assertNotEquals(r4.hashCode(), KeysRange.UNRESTRICTED.hashCode()); + + Assert.assertEquals(r1, r2); + Assert.assertNotEquals(r1, r3); + Assert.assertNotEquals(r1, KeysRange.EMPTY); + Assert.assertNotEquals(r1, KeysRange.UNRESTRICTED); + Assert.assertEquals(r4, KeysRange.EMPTY); + Assert.assertNotEquals(r4, KeysRange.UNRESTRICTED); + + Assert.assertNotEquals(r1, null); + Assert.assertNotEquals(r1, "String"); + } + + @Test + public void compareValuesTest() { + Serializable[] v1 = new Serializable[] {"A", 10, 1L}; + Serializable[] v2 = new Serializable[] {"A", 10, 1L}; + Serializable[] v3 = new Serializable[] {"A", 10, -1L}; + Serializable[] v4 = new Serializable[] {"A", 10}; + Serializable[] v5 = new Serializable[] {"A", 10, null}; + Serializable[] v6 = new Serializable[] {"A", 10, null, 10}; + + Assert.assertEquals(0, KeysRange.compareValues(null, null)); + Assert.assertEquals(1, KeysRange.compareValues(v1, null)); + Assert.assertEquals(-1, KeysRange.compareValues(null, v1)); + + Assert.assertEquals(0, KeysRange.compareValues(v1, v1)); + Assert.assertEquals(0, KeysRange.compareValues(v1, v2)); + + Assert.assertEquals(1, KeysRange.compareValues(v1, v3)); + Assert.assertEquals(-1, KeysRange.compareValues(v3, v1)); + + Assert.assertEquals(-1, KeysRange.compareValues(v4, v3)); + Assert.assertEquals(1, KeysRange.compareValues(v3, v4)); + Assert.assertEquals(-1, KeysRange.compareValues(v5, v3)); + Assert.assertEquals(1, KeysRange.compareValues(v3, v5)); + + Assert.assertEquals(0, KeysRange.compareValues(v4, v5)); + Assert.assertEquals(0, KeysRange.compareValues(v5, v6)); + } + + @Test + public void compareIncompatibleValuesTest() { + Serializable[] v1 = new Serializable[] {"A"}; + Serializable[] v2 = new Serializable[] {10}; + + Assert.assertEquals("Incompatible data types class java.lang.String and class java.lang.Integer", + Assert.assertThrows( + IllegalArgumentException.class, + () -> KeysRange.compareValues(v1, v2) + ).getMessage()); + } + + @Test + public void compareUncoparableValuesTest() { + Serializable[] v1 = new Serializable[] {new Object[] {10}}; + Serializable[] v2 = new Serializable[] {new Object[] {11}}; + + Assert.assertEquals("Uncomparable data type class [Ljava.lang.Object;", + Assert.assertThrows( + IllegalArgumentException.class, + () -> KeysRange.compareValues(v1, v2) + ).getMessage()); + } + + @Test + public void emptyTest() { + Serializable[] v1 = new Serializable[] {"A", 10, -1L}; + Serializable[] v2 = new Serializable[] {"A", 10, 1L}; + + Assert.assertFalse(KeysRange.UNRESTRICTED.isEmpty()); + Assert.assertFalse(new KeysRange(null, true, null, true).isEmpty()); + Assert.assertFalse(new KeysRange(v1, true, null, true).isEmpty()); + Assert.assertFalse(new KeysRange(null, true, v1, true).isEmpty()); + Assert.assertFalse(new KeysRange(v1, true, v2, true).isEmpty()); + Assert.assertFalse(new KeysRange(v1, true, v1, true).isEmpty()); + + Assert.assertTrue(KeysRange.EMPTY.isEmpty()); + Assert.assertTrue(new KeysRange(v2, true, v1, true).isEmpty()); + Assert.assertTrue(new KeysRange(v1, false, v1, true).isEmpty()); + Assert.assertTrue(new KeysRange(v1, true, v1, false).isEmpty()); + } + + @Test + public void toStringTest() { + Serializable[] v1 = new Serializable[] {"A"}; + Serializable[] v2 = new Serializable[] {"B"}; + Serializable[] v3 = new Serializable[] {"A", 10, -1L}; + Serializable[] v4 = new Serializable[] {"A", 10, 1L}; + + Assert.assertEquals("(-Inf - +Inf)", KeysRange.UNRESTRICTED.toString()); + Assert.assertEquals("(None)", KeysRange.EMPTY.toString()); + Assert.assertEquals("(-Inf - A]", new KeysRange(null, false, v1, true).toString()); + Assert.assertEquals("(-Inf - (A,10,-1))", new KeysRange(null, true, v3, false).toString()); + Assert.assertEquals("(B - +Inf)", new KeysRange(v2, false, null, true).toString()); + Assert.assertEquals("[(A,10,1) - +Inf)", new KeysRange(v4, true, null, false).toString()); + } + + @Test + public void intersectTest() { + Serializable[] v1 = new Serializable[] {"A", 9}; + Serializable[] v2 = new Serializable[] {"A", 10}; + Serializable[] v3 = new Serializable[] {"A", 10, 1L}; + Serializable[] v4 = new Serializable[] {"B"}; + + KeysRange r1 = new KeysRange(v1, true, v2, true); // [(A,9) - (A,10)] + KeysRange r2 = new KeysRange(v1, false, v3, true); // ((A,9) - (A,10,1)] + KeysRange r3 = new KeysRange(v3, false, v4, false); // ((A,10,1) - B) + KeysRange r4 = new KeysRange(v3, true, null, false); // [(A,10,1) - +Inf) + KeysRange r5 = new KeysRange(null, false, v3, false); // (-Inf - (A,10,1)) + + Assert.assertEquals(KeysRange.EMPTY, r1.intersect(r3)); + Assert.assertEquals(KeysRange.EMPTY, r3.intersect(r1)); + Assert.assertEquals(KeysRange.EMPTY, r2.intersect(r3)); + + Assert.assertEquals("((A,9) - (A,10)]", r1.intersect(r2).toString()); + Assert.assertEquals("((A,9) - (A,10)]", r2.intersect(r1).toString()); + Assert.assertEquals("((A,10,1) - B)", r3.intersect(r4).toString()); + Assert.assertEquals("((A,10,1) - B)", r4.intersect(r3).toString()); + Assert.assertEquals("((A,9) - (A,10,1))", r2.intersect(r5).toString()); + Assert.assertEquals("((A,9) - (A,10,1))", r5.intersect(r2).toString()); + + Assert.assertEquals("[(A,10,1) - (A,10,1)]", r2.intersect(r4).toString()); + Assert.assertEquals("[(A,10,1) - (A,10,1)]", r4.intersect(r2).toString()); + Assert.assertEquals(KeysRange.EMPTY, r3.intersect(r2)); + Assert.assertEquals(KeysRange.EMPTY, r4.intersect(r5)); + + Assert.assertEquals(KeysRange.EMPTY, KeysRange.EMPTY.intersect(KeysRange.UNRESTRICTED)); + Assert.assertEquals(KeysRange.EMPTY, KeysRange.UNRESTRICTED.intersect(KeysRange.EMPTY)); + Assert.assertEquals(KeysRange.EMPTY, KeysRange.EMPTY.intersect(r1)); + Assert.assertEquals(KeysRange.EMPTY, r1.intersect(KeysRange.EMPTY)); + + Assert.assertEquals(r1, r1.intersect(KeysRange.UNRESTRICTED)); + Assert.assertEquals(r2, KeysRange.UNRESTRICTED.intersect(r2)); + } +} diff --git a/connector/src/test/resources/log4j2.xml b/connector/src/test/resources/log4j2.xml index d67bcbd..7fd9bb1 100644 --- a/connector/src/test/resources/log4j2.xml +++ b/connector/src/test/resources/log4j2.xml @@ -35,6 +35,10 @@ + + + + diff --git a/pom.xml b/pom.xml index 6a3a286..5b1bc92 100644 --- a/pom.xml +++ b/pom.xml @@ -120,7 +120,12 @@ org.apache.maven.plugins maven-surefire-plugin - 3.5.2 + 3.5.3 + + + true + + org.apache.maven.plugins