diff --git a/src/it/java/io/weaviate/integration/AggregationITest.java b/src/it/java/io/weaviate/integration/AggregationITest.java index 586cac1ea..e29da810f 100644 --- a/src/it/java/io/weaviate/integration/AggregationITest.java +++ b/src/it/java/io/weaviate/integration/AggregationITest.java @@ -16,9 +16,9 @@ import io.weaviate.client6.v1.api.collections.Property; import io.weaviate.client6.v1.api.collections.Vectorizers; import io.weaviate.client6.v1.api.collections.Vectors; +import io.weaviate.client6.v1.api.collections.aggregate.Aggregate; import io.weaviate.client6.v1.api.collections.aggregate.AggregateResponseGroup; import io.weaviate.client6.v1.api.collections.aggregate.AggregateResponseGrouped; -import io.weaviate.client6.v1.api.collections.aggregate.Aggregation; import io.weaviate.client6.v1.api.collections.aggregate.GroupBy; import io.weaviate.client6.v1.api.collections.aggregate.GroupedBy; import io.weaviate.client6.v1.api.collections.aggregate.IntegerAggregation; @@ -57,7 +57,7 @@ public void testOverAll() { var result = things.aggregate.overAll( with -> with .metrics( - Aggregation.integer("price", + Aggregate.integer("price", calculate -> calculate.median().max().count())) .includeTotalCount(true)); @@ -77,7 +77,7 @@ public void testOverAll_groupBy_category() { var result = things.aggregate.overAll( with -> with .metrics( - Aggregation.integer("price", + Aggregate.integer("price", calculate -> calculate.min().max().count())) .includeTotalCount(true), GroupBy.property("category")); @@ -115,7 +115,7 @@ public void testNearVector() { near -> near.limit(5), with -> with .metrics( - Aggregation.integer("price", + Aggregate.integer("price", calculate -> calculate.min().max().count())) .objectLimit(4) .includeTotalCount(true)); @@ -135,7 +135,7 @@ public void testNearVector_groupBy_category() { near -> near.distance(2f), with -> with .metrics( - Aggregation.integer("price", + Aggregate.integer("price", calculate -> calculate.min().max().median())) .objectLimit(9) .includeTotalCount(true), diff --git a/src/it/java/io/weaviate/integration/DataITest.java b/src/it/java/io/weaviate/integration/DataITest.java index 22071e321..cac817c17 100644 --- a/src/it/java/io/weaviate/integration/DataITest.java +++ b/src/it/java/io/weaviate/integration/DataITest.java @@ -1,7 +1,10 @@ package io.weaviate.integration; import java.io.IOException; +import java.time.OffsetDateTime; +import java.util.List; import java.util.Map; +import java.util.UUID; import org.assertj.core.api.Assertions; import org.assertj.core.api.InstanceOfAssertFactories; @@ -407,4 +410,61 @@ public void testDuplicateUuid() throws IOException { // Act things.data.insert(Map.of(), thing -> thing.uuid(thing_1.uuid())); } + + @Test + public void testDataTypes() throws IOException { + // Arrange + var nsDataTypes = ns("DataTypes"); + + // BLOB type is omitted because a base64-encoded image + // isn't doing the failure message any favours. + // It's tested in other test cases above. + client.collections.create( + nsDataTypes, c -> c + .properties( + Property.text("prop_text"), + Property.integer("prop_integer"), + Property.number("prop_number"), + Property.bool("prop_bool"), + Property.date("prop_date"), + Property.uuid("prop_uuid"), + Property.integerArray("prop_integer_array"), + Property.numberArray("prop_number_array"), + Property.boolArray("prop_bool_array"), + Property.dateArray("prop_date_array"), + Property.uuidArray("prop_uuid_array"), + Property.textArray("prop_text_array"))); + + var types = client.collections.use(nsDataTypes); + + var now = OffsetDateTime.now(); + var uuid = UUID.randomUUID(); + + Map want = Map.ofEntries( + Map.entry("prop_text", "Hello, World!"), + Map.entry("prop_integer", 1L), + Map.entry("prop_number", 1D), + Map.entry("prop_bool", true), + Map.entry("prop_date", now), + Map.entry("prop_uuid", uuid), + Map.entry("prop_integer_array", List.of(1L, 2L, 3L)), + Map.entry("prop_number_array", List.of(1D, 2D, 3D)), + Map.entry("prop_bool_array", List.of(true, false)), + Map.entry("prop_date_array", List.of(now, now)), + Map.entry("prop_uuid_array", List.of(uuid, uuid)), + Map.entry("prop_text_array", List.of("a", "b", "c"))); + var returnProperties = want.keySet().toArray(String[]::new); + + // Act + var object = types.data.insert(want); + var got = types.query.byId(object.uuid(), + q -> q.returnProperties(returnProperties)); + + // Assert + Assertions.assertThat(got).get() + .extracting(WeaviateObject::properties) + .asInstanceOf(InstanceOfAssertFactories.map(String.class, Object.class)) + .containsAllEntriesOf(want); + + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/DataType.java b/src/main/java/io/weaviate/client6/v1/api/collections/DataType.java index c114f0ab5..32858bca6 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/DataType.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/DataType.java @@ -6,8 +6,20 @@ public interface DataType { public static final String TEXT = "text"; + public static final String TEXT_ARRAY = "text[]"; public static final String INT = "int"; + public static final String INT_ARRAY = "int[]"; + public static final String NUMBER = "number"; + public static final String NUMBER_ARRAY = "number[]"; + public static final String BOOL = "boolean"; + public static final String BOOL_ARRAY = "boolean[]"; public static final String BLOB = "blob"; + public static final String DATE = "date"; + public static final String DATE_ARRAY = "date[]"; + public static final String UUID = "uuid"; + public static final String UUID_ARRAY = "uuid[]"; - public static final Set KNOWN_TYPES = ImmutableSet.of(TEXT, INT, BLOB); + public static final Set KNOWN_TYPES = ImmutableSet.of( + TEXT, INT, BLOB, BOOL, DATE, UUID, NUMBER, + TEXT_ARRAY, INT_ARRAY, NUMBER_ARRAY, BOOL_ARRAY, DATE_ARRAY, UUID_ARRAY); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Property.java b/src/main/java/io/weaviate/client6/v1/api/collections/Property.java index 08e638032..13cde9536 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Property.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Property.java @@ -24,7 +24,15 @@ public static Property text(String name) { } public static Property text(String name, Function> fn) { - return fn.apply(new Builder(name, DataType.TEXT)).build(); + return newProperty(name, DataType.TEXT, fn); + } + + public static Property textArray(String name) { + return textArray(name, ObjectBuilder.identity()); + } + + public static Property textArray(String name, Function> fn) { + return newProperty(name, DataType.TEXT_ARRAY, fn); } public static Property integer(String name) { @@ -32,7 +40,15 @@ public static Property integer(String name) { } public static Property integer(String name, Function> fn) { - return fn.apply(new Builder(name, DataType.INT)).build(); + return newProperty(name, DataType.INT, fn); + } + + public static Property integerArray(String name) { + return integerArray(name, ObjectBuilder.identity()); + } + + public static Property integerArray(String name, Function> fn) { + return newProperty(name, DataType.INT_ARRAY, fn); } public static Property blob(String name) { @@ -40,7 +56,75 @@ public static Property blob(String name) { } public static Property blob(String name, Function> fn) { - return fn.apply(new Builder(name, DataType.BLOB)).build(); + return newProperty(name, DataType.BLOB, fn); + } + + public static Property bool(String name) { + return bool(name, ObjectBuilder.identity()); + } + + public static Property bool(String name, Function> fn) { + return newProperty(name, DataType.BOOL, fn); + } + + public static Property boolArray(String name) { + return boolArray(name, ObjectBuilder.identity()); + } + + public static Property boolArray(String name, Function> fn) { + return newProperty(name, DataType.BOOL_ARRAY, fn); + } + + public static Property date(String name) { + return date(name, ObjectBuilder.identity()); + } + + public static Property date(String name, Function> fn) { + return newProperty(name, DataType.DATE, fn); + } + + public static Property dateArray(String name) { + return dateArray(name, ObjectBuilder.identity()); + } + + public static Property dateArray(String name, Function> fn) { + return newProperty(name, DataType.DATE_ARRAY, fn); + } + + public static Property uuid(String name) { + return uuid(name, ObjectBuilder.identity()); + } + + public static Property uuid(String name, Function> fn) { + return newProperty(name, DataType.UUID, fn); + } + + public static Property uuidArray(String name) { + return uuidArray(name, ObjectBuilder.identity()); + } + + public static Property uuidArray(String name, Function> fn) { + return newProperty(name, DataType.UUID_ARRAY, fn); + } + + public static Property number(String name) { + return number(name, ObjectBuilder.identity()); + } + + public static Property number(String name, Function> fn) { + return newProperty(name, DataType.NUMBER, fn); + } + + public static Property numberArray(String name) { + return numberArray(name, ObjectBuilder.identity()); + } + + public static Property numberArray(String name, Function> fn) { + return newProperty(name, DataType.NUMBER_ARRAY, fn); + } + + private static Property newProperty(String name, String dataType, Function> fn) { + return fn.apply(new Builder(name, dataType)).build(); } public static ReferenceProperty reference(String name, String... collections) { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/Aggregate.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/Aggregate.java new file mode 100644 index 000000000..a8de946b5 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/Aggregate.java @@ -0,0 +1,36 @@ +package io.weaviate.client6.v1.api.collections.aggregate; + +import java.util.function.Function; + +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public final class Aggregate { + /** Prevent public initialization. */ + private Aggregate() { + } + + public static final PropertyAggregation text(String property, + Function> fn) { + return TextAggregation.of(property, fn); + } + + public static final PropertyAggregation integer(String property, + Function> fn) { + return IntegerAggregation.of(property, fn); + } + + public static final PropertyAggregation bool(String property, + Function> fn) { + return BooleanAggregation.of(property, fn); + } + + public static final PropertyAggregation date(String property, + Function> fn) { + return DateAggregation.of(property, fn); + } + + public static final PropertyAggregation number(String property, + Function> fn) { + return NumberAggregation.of(property, fn); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateRequest.java index c20ab920e..75d3046c9 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateRequest.java @@ -4,6 +4,7 @@ import java.util.HashMap; import java.util.Map; +import io.weaviate.client6.v1.internal.DateUtil; import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.internal.grpc.Rpc; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; @@ -62,9 +63,21 @@ static Rpc(property, groupBy.getInt()); + groupedBy = new GroupedBy<>(property, groupBy.getInt()); } else if (groupBy.hasText()) { - groupedBy = new GroupedBy(property, groupBy.getText()); + groupedBy = new GroupedBy<>(property, groupBy.getText()); + } else if (groupBy.hasBoolean()) { + groupedBy = new GroupedBy<>(property, groupBy.getBoolean()); + } else if (groupBy.hasNumber()) { + groupedBy = new GroupedBy<>(property, groupBy.getNumber()); + } else if (groupBy.hasTexts()) { + groupedBy = new GroupedBy<>(property, groupBy.getTexts().getValuesList().toArray(String[]::new)); + } else if (groupBy.hasInts()) { + groupedBy = new GroupedBy<>(property, groupBy.getInts().getValuesList().toArray(Long[]::new)); + } else if (groupBy.hasNumbers()) { + groupedBy = new GroupedBy<>(property, groupBy.getNumbers().getValuesList().toArray(Double[]::new)); + } else if (groupBy.hasBooleans()) { + groupedBy = new GroupedBy<>(property, groupBy.getBooleans().getValuesList().toArray(Boolean[]::new)); } else { assert false : "(aggregate) branch not covered"; } @@ -77,6 +90,7 @@ static Rpc rpc.method(), () -> rpc.methodAsync()); + } private static Map unmarshalAggregation(WeaviateProtoAggregate.AggregateReply.Aggregations result) { @@ -107,7 +121,32 @@ private static Map unmarshalAggregation(WeaviateProtoAggregate.A value = new TextAggregation.Values( metric.hasCount() ? metric.getCount() : null, topOccurrences); - + } else if (aggregation.hasBoolean()) { + var metric = aggregation.getBoolean(); + value = new BooleanAggregation.Values( + metric.hasCount() ? metric.getCount() : null, + metric.hasPercentageFalse() ? Float.valueOf((float) metric.getPercentageFalse()) : null, + metric.hasPercentageTrue() ? Float.valueOf((float) metric.getPercentageTrue()) : null, + metric.hasTotalFalse() ? metric.getTotalFalse() : null, + metric.hasTotalTrue() ? metric.getTotalTrue() : null); + } else if (aggregation.hasDate()) { + var metric = aggregation.getDate(); + value = new DateAggregation.Values( + metric.hasCount() ? metric.getCount() : null, + metric.hasMinimum() ? DateUtil.fromISO8601(metric.getMinimum()) : null, + metric.hasMaximum() ? DateUtil.fromISO8601(metric.getMaximum()) : null, + metric.hasMedian() ? DateUtil.fromISO8601(metric.getMedian()) : null, + metric.hasMode() ? DateUtil.fromISO8601(metric.getMode()) : null); + } else if (aggregation.hasNumber()) { + var metric = aggregation.getNumber(); + value = new NumberAggregation.Values( + metric.hasCount() ? metric.getCount() : null, + metric.hasMinimum() ? metric.getMinimum() : null, + metric.hasMaximum() ? metric.getMaximum() : null, + metric.hasMean() ? metric.getMean() : null, + metric.hasMedian() ? metric.getMedian() : null, + metric.hasMode() ? metric.getMode() : null, + metric.hasSum() ? metric.getSum() : null); } else { assert false : "branch not covered"; } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateResponse.java index 87b94db81..fb4671338 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateResponse.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateResponse.java @@ -9,7 +9,7 @@ public boolean isText(String name) { } public TextAggregation.Values text(String name) { - checkPropertyType(name, this::isText, "Text"); + checkPropertyType(name, this::isText, "TEXT"); return (TextAggregation.Values) this.properties.get(name); } @@ -18,10 +18,37 @@ public boolean isInteger(String name) { } public IntegerAggregation.Values integer(String name) { - checkPropertyType(name, this::isInteger, "Integer"); + checkPropertyType(name, this::isInteger, "INTEGER"); return (IntegerAggregation.Values) this.properties.get(name); } + public boolean isBool(String name) { + return properties.get(name) instanceof BooleanAggregation.Values; + } + + public BooleanAggregation.Values bool(String name) { + checkPropertyType(name, this::isBool, "BOOLEAN"); + return (BooleanAggregation.Values) this.properties.get(name); + } + + public boolean isDate(String name) { + return properties.get(name) instanceof DateAggregation.Values; + } + + public DateAggregation.Values date(String name) { + checkPropertyType(name, this::isDate, "DATE"); + return (DateAggregation.Values) this.properties.get(name); + } + + public boolean isNumber(String name) { + return properties.get(name) instanceof NumberAggregation.Values; + } + + public NumberAggregation.Values number(String name) { + checkPropertyType(name, this::isNumber, "NUMBER"); + return (NumberAggregation.Values) this.properties.get(name); + } + private void checkPropertyType(String name, Function check, String expected) { if (!check.apply(name)) { throw new IllegalStateException(name + "is not a " + expected + " property"); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateResponseGroup.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateResponseGroup.java index 7f28a84c3..8105b8830 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateResponseGroup.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateResponseGroup.java @@ -10,7 +10,7 @@ public boolean isText(String name) { } public TextAggregation.Values text(String name) { - checkPropertyType(name, this::isText, "Text"); + checkPropertyType(name, this::isText, "TEXT"); return (TextAggregation.Values) this.properties.get(name); } @@ -19,10 +19,37 @@ public boolean isInteger(String name) { } public IntegerAggregation.Values integer(String name) { - checkPropertyType(name, this::isInteger, "Integer"); + checkPropertyType(name, this::isInteger, "INTEGER"); return (IntegerAggregation.Values) this.properties.get(name); } + public boolean isBool(String name) { + return properties.get(name) instanceof BooleanAggregation.Values; + } + + public BooleanAggregation.Values bool(String name) { + checkPropertyType(name, this::isBool, "BOOLEAN"); + return (BooleanAggregation.Values) this.properties.get(name); + } + + public boolean isDate(String name) { + return properties.get(name) instanceof DateAggregation.Values; + } + + public DateAggregation.Values date(String name) { + checkPropertyType(name, this::isDate, "DATE"); + return (DateAggregation.Values) this.properties.get(name); + } + + public boolean isNumber(String name) { + return properties.get(name) instanceof NumberAggregation.Values; + } + + public NumberAggregation.Values number(String name) { + checkPropertyType(name, this::isNumber, "NUMBER"); + return (NumberAggregation.Values) this.properties.get(name); + } + private void checkPropertyType(String name, Function check, String expected) { if (!check.apply(name)) { throw new IllegalStateException(name + "is not a " + expected + " property"); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/Aggregation.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/Aggregation.java index 5246e7074..26d8809f5 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/Aggregation.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/Aggregation.java @@ -63,16 +63,6 @@ public Aggregation build() { } } - public static final PropertyAggregation text(String property, - Function> fn) { - return TextAggregation.of(property, fn); - } - - public static final PropertyAggregation integer(String property, - Function> fn) { - return IntegerAggregation.of(property, fn); - } - public void appendTo(WeaviateProtoAggregate.AggregateRequest.Builder req) { if (filter != null) { filter.appendTo(req); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/BooleanAggregation.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/BooleanAggregation.java new file mode 100644 index 000000000..223aad546 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/BooleanAggregation.java @@ -0,0 +1,68 @@ +package io.weaviate.client6.v1.api.collections.aggregate; + +import java.util.Set; +import java.util.function.Function; + +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate; + +public class BooleanAggregation + extends AbstractPropertyAggregation { + + public BooleanAggregation(String property, + Set> metrics) { + super(property, metrics); + } + + public static BooleanAggregation of(String property, Function> fn) { + return fn.apply(new Builder(property)).build(); + } + + public BooleanAggregation(Builder builder) { + this(builder.property, builder.metrics); + } + + public static class Builder extends + AbstractPropertyAggregation.Builder { + + public Builder(String property) { + super(property); + } + + public final Builder count() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Boolean.Builder::setCount); + } + + public final Builder percentageFalse() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Boolean.Builder::setPercentageFalse); + } + + public final Builder percentageTrue() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Boolean.Builder::setPercentageTrue); + } + + public final Builder totalFalse() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Boolean.Builder::setTotalFalse); + } + + public final Builder totalTrue() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Boolean.Builder::setTotalTrue); + } + + @Override + public final BooleanAggregation build() { + return new BooleanAggregation(this); + } + } + + public record Values(Long count, Float percentageFalse, Float percentageTrue, Long totalFalse, Long totalTrue) { + } + + @Override + public void appendTo(WeaviateProtoAggregate.AggregateRequest.Aggregation.Builder req) { + super.appendTo(req); + var bool = WeaviateProtoAggregate.AggregateRequest.Aggregation.Boolean.newBuilder(); + appendMetrics(bool); + req.setBoolean(bool); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/DateAggregation.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/DateAggregation.java new file mode 100644 index 000000000..3dc8c4ac0 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/DateAggregation.java @@ -0,0 +1,69 @@ +package io.weaviate.client6.v1.api.collections.aggregate; + +import java.time.OffsetDateTime; +import java.util.Set; +import java.util.function.Function; + +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate; + +public class DateAggregation + extends AbstractPropertyAggregation { + + public DateAggregation(String property, + Set> metrics) { + super(property, metrics); + } + + public static DateAggregation of(String property, Function> fn) { + return fn.apply(new Builder(property)).build(); + } + + public DateAggregation(Builder builder) { + this(builder.property, builder.metrics); + } + + public static class Builder extends + AbstractPropertyAggregation.Builder { + + public Builder(String property) { + super(property); + } + + public final Builder count() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Date.Builder::setCount); + } + + public Builder min() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Date.Builder::setMinimum); + } + + public Builder max() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Date.Builder::setMaximum); + } + + public Builder median() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Date.Builder::setMedian); + } + + public Builder mode() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Date.Builder::setMode); + } + + @Override + public final DateAggregation build() { + return new DateAggregation(this); + } + } + + public record Values(Long count, OffsetDateTime min, OffsetDateTime max, OffsetDateTime median, OffsetDateTime mode) { + } + + @Override + public void appendTo(WeaviateProtoAggregate.AggregateRequest.Aggregation.Builder req) { + super.appendTo(req); + var date = WeaviateProtoAggregate.AggregateRequest.Aggregation.Date.newBuilder(); + appendMetrics(date); + req.setDate(date); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/GroupedBy.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/GroupedBy.java index f853780c7..d17a0d218 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/GroupedBy.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/GroupedBy.java @@ -1,5 +1,7 @@ package io.weaviate.client6.v1.api.collections.aggregate; +import java.util.Arrays; +import java.util.List; import java.util.function.Supplier; public record GroupedBy(String property, T value) { @@ -8,7 +10,7 @@ public boolean isText() { } public String text() { - checkPropertyType(this::isText, "Text"); + checkPropertyType(this::isText, "TEXT"); return (String) value; } @@ -17,10 +19,68 @@ public boolean isInteger() { } public Long integer() { - checkPropertyType(this::isInteger, "Long"); + checkPropertyType(this::isInteger, "INTEGER"); return (Long) value; } + public boolean isBool() { + return value instanceof Boolean; + } + + public Boolean bool() { + checkPropertyType(this::isBool, "BOOLEAN"); + return (Boolean) value; + } + + public boolean isNumber() { + return value instanceof Double; + } + + public Double number() { + checkPropertyType(this::isNumber, "NUMBER"); + return (Double) value; + } + + public boolean isTextArray() { + return value instanceof String[]; + } + + @SuppressWarnings("unchecked") + public List textArray() { + checkPropertyType(this::isTextArray, "TEXT[]"); + return (List) Arrays.asList(value); + } + + public boolean isBoolArray() { + return value instanceof Boolean[]; + } + + @SuppressWarnings("unchecked") + public List boolArray() { + checkPropertyType(this::isBoolArray, "BOOLEAN[]"); + return (List) Arrays.asList(value); + } + + public boolean isIntegerArray() { + return value instanceof Long[]; + } + + @SuppressWarnings("unchecked") + public List integerArray() { + checkPropertyType(this::isIntegerArray, "INTEGER[]"); + return (List) Arrays.asList(value); + } + + public boolean isNumberArray() { + return value instanceof Double[]; + } + + @SuppressWarnings("unchecked") + public List numberArray() { + checkPropertyType(this::isNumberArray, "NUMBER[]"); + return (List) Arrays.asList(value); + } + private void checkPropertyType(Supplier check, String expected) { if (!check.get()) { throw new IllegalStateException(property + "is not a " + expected + " property"); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/IntegerAggregation.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/IntegerAggregation.java index 5f8f1db00..9d674e4b3 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/IntegerAggregation.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/IntegerAggregation.java @@ -10,8 +10,7 @@ public class IntegerAggregation extends AbstractPropertyAggregation { public IntegerAggregation(String property, - Set> metrics, - Integer topOccurrencesCutoff) { + Set> metrics) { super(property, metrics); } @@ -20,12 +19,11 @@ public static IntegerAggregation of(String property, Function { - private Integer topOccurrencesCutoff; public Builder(String property) { super(property); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/NumberAggregation.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/NumberAggregation.java new file mode 100644 index 000000000..55a4779ec --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/NumberAggregation.java @@ -0,0 +1,76 @@ +package io.weaviate.client6.v1.api.collections.aggregate; + +import java.util.Set; +import java.util.function.Function; + +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate; + +public class NumberAggregation + extends AbstractPropertyAggregation { + + public NumberAggregation(String property, + Set> metrics) { + super(property, metrics); + } + + public static NumberAggregation of(String property, Function> fn) { + return fn.apply(new Builder(property)).build(); + } + + public NumberAggregation(Builder builder) { + this(builder.property, builder.metrics); + } + + public static class Builder extends + AbstractPropertyAggregation.Builder { + + public Builder(String property) { + super(property); + } + + public final Builder count() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Number.Builder::setCount); + } + + public Builder min() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Number.Builder::setMinimum); + } + + public Builder max() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Number.Builder::setMaximum); + } + + public Builder mean() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Number.Builder::setMean); + } + + public Builder median() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Number.Builder::setMedian); + } + + public Builder mode() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Number.Builder::setMode); + } + + public Builder sum() { + return addMetric(WeaviateProtoAggregate.AggregateRequest.Aggregation.Number.Builder::setSum); + } + + @Override + public final NumberAggregation build() { + return new NumberAggregation(this); + } + } + + public record Values(Long count, Double min, Double max, Double mean, Double median, Double mode, Double sum) { + } + + @Override + public void appendTo(WeaviateProtoAggregate.AggregateRequest.Aggregation.Builder req) { + super.appendTo(req); + var number = WeaviateProtoAggregate.AggregateRequest.Aggregation.Number.newBuilder(); + appendMetrics(number); + req.setNumber(number); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java index a14be5794..80918fd43 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java @@ -1,10 +1,10 @@ package io.weaviate.client6.v1.api.collections.query; -import java.time.OffsetDateTime; import java.util.ArrayList; -import java.util.Date; +import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.UUID; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -13,6 +13,7 @@ import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.internal.DateUtil; import io.weaviate.client6.v1.internal.grpc.ByteStringUtil; import io.weaviate.client6.v1.internal.grpc.Rpc; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; @@ -179,12 +180,13 @@ private static WeaviateObject unmarshalWithRefere var vectors = new Vectors.Builder(); for (final var vector : metadataResult.getVectorsList()) { var vectorName = vector.getName(); + var vbytes = vector.getVectorBytes(); switch (vector.getType()) { case VECTOR_TYPE_SINGLE_FP32: - vectors.vector(vectorName, ByteStringUtil.decodeVectorSingle(vector.getVectorBytes())); + vectors.vector(vectorName, ByteStringUtil.decodeVectorSingle(vbytes)); break; case VECTOR_TYPE_MULTI_FP32: - vectors.vector(vectorName, ByteStringUtil.decodeVectorMulti(vector.getVectorBytes())); + vectors.vector(vectorName, ByteStringUtil.decodeVectorMulti(vbytes)); break; default: continue; @@ -213,12 +215,38 @@ private static void setProperty(String property, WeaviateProtoProperties.Val } else if (value.hasIntValue()) { builder.setInteger(property, value.getIntValue()); } else if (value.hasNumberValue()) { - builder.setNumber(property, value.getNumberValue()); + builder.setDouble(property, value.getNumberValue()); } else if (value.hasBlobValue()) { builder.setBlob(property, value.getBlobValue()); } else if (value.hasDateValue()) { - OffsetDateTime offsetDateTime = OffsetDateTime.parse(value.getDateValue()); - builder.setDate(property, Date.from(offsetDateTime.toInstant())); + builder.setOffsetDateTime(property, DateUtil.fromISO8601(value.getDateValue())); + } else if (value.hasUuidValue()) { + builder.setUuid(property, UUID.fromString(value.getUuidValue())); + } else if (value.hasListValue()) { + var list = value.getListValue(); + if (list.hasTextValues()) { + builder.setTextArray(property, list.getTextValues().getValuesList()); + } else if (list.hasIntValues()) { + var ints = Arrays.stream( + ByteStringUtil.decodeIntValues(list.getIntValues().getValues())) + .boxed().toList(); + builder.setLongArray(property, ints); + } else if (list.hasNumberValues()) { + var numbers = Arrays.stream( + ByteStringUtil.decodeNumberValues(list.getNumberValues().getValues())) + .boxed().toList(); + builder.setDoubleArray(property, numbers); + } else if (list.hasUuidValues()) { + var uuids = list.getUuidValues().getValuesList().stream() + .map(UUID::fromString).toList(); + builder.setUuidArray(property, uuids); + } else if (list.hasBoolValues()) { + builder.setBooleanArray(property, list.getBoolValues().getValuesList()); + } else if (list.hasDateValues()) { + var dates = list.getDateValues().getValuesList().stream() + .map(DateUtil::fromISO8601).toList(); + builder.setOffsetDateTimeArray(property, dates); + } } else { assert false : "(query) branch not covered"; } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/Where.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/Where.java index 09c460acb..bde53a695 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/Where.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/Where.java @@ -1,11 +1,9 @@ package io.weaviate.client6.v1.api.collections.query; +import java.time.OffsetDateTime; import java.util.Arrays; -import java.util.Date; import java.util.List; -import org.apache.commons.lang3.time.DateFormatUtils; - import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase.Filters; @@ -125,7 +123,7 @@ public Where eq(String... values) { return new Where(Operator.EQUAL, left, new TextArrayOperand(values)); } - public Where eq(Boolean value) { + public Where eq(boolean value) { return new Where(Operator.EQUAL, left, new BooleanOperand(value)); } @@ -133,27 +131,35 @@ public Where eq(Boolean... values) { return new Where(Operator.EQUAL, left, new BooleanArrayOperand(values)); } - public Where eq(Integer value) { + public Where eq(long value) { + return new Where(Operator.EQUAL, left, new IntegerOperand(value)); + } + + public Where eq(int value) { return new Where(Operator.EQUAL, left, new IntegerOperand(value)); } - public Where eq(Integer... values) { + public Where eq(Long... values) { return new Where(Operator.EQUAL, left, new IntegerArrayOperand(values)); } - public Where eq(Number value) { - return new Where(Operator.EQUAL, left, new NumberOperand(value.doubleValue())); + public Where eq(double value) { + return new Where(Operator.EQUAL, left, new NumberOperand(value)); + } + + public Where eq(float value) { + return new Where(Operator.EQUAL, left, new NumberOperand(value)); } - public Where eq(Number... values) { + public Where eq(Double... values) { return new Where(Operator.EQUAL, left, new NumberArrayOperand(values)); } - public Where eq(Date value) { + public Where eq(OffsetDateTime value) { return new Where(Operator.EQUAL, left, new DateOperand(value)); } - public Where eq(Date... values) { + public Where eq(OffsetDateTime... values) { return new Where(Operator.EQUAL, left, new DateArrayOperand(values)); } @@ -171,7 +177,7 @@ public Where ne(String... values) { return new Where(Operator.NOT_EQUAL, left, new TextArrayOperand(values)); } - public Where ne(Boolean value) { + public Where ne(boolean value) { return new Where(Operator.NOT_EQUAL, left, new BooleanOperand(value)); } @@ -179,27 +185,35 @@ public Where ne(Boolean... values) { return new Where(Operator.NOT_EQUAL, left, new BooleanArrayOperand(values)); } - public Where ne(Integer value) { + public Where ne(long value) { return new Where(Operator.NOT_EQUAL, left, new IntegerOperand(value)); } - public Where ne(Integer... values) { + public Where ne(int value) { + return new Where(Operator.NOT_EQUAL, left, new IntegerOperand(value)); + } + + public Where ne(Long... values) { return new Where(Operator.NOT_EQUAL, left, new IntegerArrayOperand(values)); } - public Where ne(Number value) { - return new Where(Operator.NOT_EQUAL, left, new NumberOperand(value.doubleValue())); + public Where ne(double value) { + return new Where(Operator.NOT_EQUAL, left, new NumberOperand(value)); } - public Where ne(Number... values) { + public Where ne(float value) { + return new Where(Operator.NOT_EQUAL, left, new NumberOperand(value)); + } + + public Where ne(Double... values) { return new Where(Operator.NOT_EQUAL, left, new NumberArrayOperand(values)); } - public Where ne(Date value) { + public Where ne(OffsetDateTime value) { return new Where(Operator.NOT_EQUAL, left, new DateOperand(value)); } - public Where ne(Date... values) { + public Where ne(OffsetDateTime... values) { return new Where(Operator.NOT_EQUAL, left, new DateArrayOperand(values)); } @@ -217,27 +231,35 @@ public Where lt(String... values) { return new Where(Operator.LESS_THAN, left, new TextArrayOperand(values)); } - public Where lt(Integer value) { + public Where lt(long value) { + return new Where(Operator.LESS_THAN, left, new IntegerOperand(value)); + } + + public Where lt(int value) { return new Where(Operator.LESS_THAN, left, new IntegerOperand(value)); } - public Where lt(Integer... values) { + public Where lt(Long... values) { return new Where(Operator.LESS_THAN, left, new IntegerArrayOperand(values)); } - public Where lt(Number value) { - return new Where(Operator.LESS_THAN, left, new NumberOperand(value.doubleValue())); + public Where lt(double value) { + return new Where(Operator.LESS_THAN, left, new NumberOperand(value)); + } + + public Where lt(float value) { + return new Where(Operator.LESS_THAN, left, new NumberOperand(value)); } - public Where lt(Number... values) { + public Where lt(Double... values) { return new Where(Operator.LESS_THAN, left, new NumberArrayOperand(values)); } - public Where lt(Date value) { + public Where lt(OffsetDateTime value) { return new Where(Operator.LESS_THAN, left, new DateOperand(value)); } - public Where lt(Date... values) { + public Where lt(OffsetDateTime... values) { return new Where(Operator.LESS_THAN, left, new DateArrayOperand(values)); } @@ -255,27 +277,35 @@ public Where lte(String... values) { return new Where(Operator.LESS_THAN_EQUAL, left, new TextArrayOperand(values)); } - public Where lte(Integer value) { + public Where lte(long value) { + return new Where(Operator.LESS_THAN_EQUAL, left, new IntegerOperand(value)); + } + + public Where lte(int value) { return new Where(Operator.LESS_THAN_EQUAL, left, new IntegerOperand(value)); } - public Where lte(Integer... values) { + public Where lte(Long... values) { return new Where(Operator.LESS_THAN_EQUAL, left, new IntegerArrayOperand(values)); } - public Where lte(Number value) { - return new Where(Operator.LESS_THAN_EQUAL, left, new NumberOperand(value.doubleValue())); + public Where lte(double value) { + return new Where(Operator.LESS_THAN_EQUAL, left, new NumberOperand(value)); + } + + public Where lte(float value) { + return new Where(Operator.LESS_THAN_EQUAL, left, new NumberOperand(value)); } - public Where lte(Number... values) { + public Where lte(Double... values) { return new Where(Operator.LESS_THAN_EQUAL, left, new NumberArrayOperand(values)); } - public Where lte(Date value) { + public Where lte(OffsetDateTime value) { return new Where(Operator.LESS_THAN_EQUAL, left, new DateOperand(value)); } - public Where lte(Date... values) { + public Where lte(OffsetDateTime... values) { return new Where(Operator.LESS_THAN_EQUAL, left, new DateArrayOperand(values)); } @@ -293,27 +323,35 @@ public Where gt(String... values) { return new Where(Operator.GREATER_THAN, left, new TextArrayOperand(values)); } - public Where gt(Integer value) { + public Where gt(long value) { + return new Where(Operator.GREATER_THAN, left, new IntegerOperand(value)); + } + + public Where gt(int value) { return new Where(Operator.GREATER_THAN, left, new IntegerOperand(value)); } - public Where gt(Integer... values) { + public Where gt(Long... values) { return new Where(Operator.GREATER_THAN, left, new IntegerArrayOperand(values)); } - public Where gt(Number value) { - return new Where(Operator.GREATER_THAN, left, new NumberOperand(value.doubleValue())); + public Where gt(double value) { + return new Where(Operator.GREATER_THAN, left, new NumberOperand(value)); } - public Where gt(Number... values) { + public Where gt(float value) { + return new Where(Operator.GREATER_THAN, left, new NumberOperand(value)); + } + + public Where gt(Double... values) { return new Where(Operator.GREATER_THAN, left, new NumberArrayOperand(values)); } - public Where gt(Date value) { + public Where gt(OffsetDateTime value) { return new Where(Operator.GREATER_THAN, left, new DateOperand(value)); } - public Where gt(Date... values) { + public Where gt(OffsetDateTime... values) { return new Where(Operator.GREATER_THAN, left, new DateArrayOperand(values)); } @@ -331,27 +369,35 @@ public Where gte(String... values) { return new Where(Operator.GREATER_THAN_EQUAL, left, new TextArrayOperand(values)); } - public Where gte(Integer value) { + public Where gte(long value) { + return new Where(Operator.GREATER_THAN_EQUAL, left, new IntegerOperand(value)); + } + + public Where gte(int value) { return new Where(Operator.GREATER_THAN_EQUAL, left, new IntegerOperand(value)); } - public Where gte(Integer... values) { + public Where gte(Long... values) { return new Where(Operator.GREATER_THAN_EQUAL, left, new IntegerArrayOperand(values)); } - public Where gte(Number value) { - return new Where(Operator.GREATER_THAN_EQUAL, left, new NumberOperand(value.doubleValue())); + public Where gte(double value) { + return new Where(Operator.GREATER_THAN_EQUAL, left, new NumberOperand(value)); + } + + public Where gte(float value) { + return new Where(Operator.GREATER_THAN_EQUAL, left, new NumberOperand(value)); } - public Where gte(Number... values) { + public Where gte(Double... values) { return new Where(Operator.GREATER_THAN_EQUAL, left, new NumberArrayOperand(values)); } - public Where gte(Date value) { + public Where gte(OffsetDateTime value) { return new Where(Operator.GREATER_THAN_EQUAL, left, new DateOperand(value)); } - public Where gte(Date... values) { + public Where gte(OffsetDateTime... values) { return new Where(Operator.GREATER_THAN_EQUAL, left, new DateArrayOperand(values)); } @@ -379,15 +425,15 @@ public Where containsAny(Boolean... values) { return new Where(Operator.CONTAINS_ANY, left, new BooleanArrayOperand(values)); } - public Where containsAny(Integer... values) { + public Where containsAny(Long... values) { return new Where(Operator.CONTAINS_ANY, left, new IntegerArrayOperand(values)); } - public Where containsAny(Number... values) { + public Where containsAny(Double... values) { return new Where(Operator.CONTAINS_ANY, left, new NumberArrayOperand(values)); } - public Where containsAny(Date... values) { + public Where containsAny(OffsetDateTime... values) { return new Where(Operator.CONTAINS_ANY, left, new DateArrayOperand(values)); } @@ -405,15 +451,15 @@ public Where containsAll(Boolean... values) { return new Where(Operator.CONTAINS_ALL, left, new BooleanArrayOperand(values)); } - public Where containsAll(Integer... values) { + public Where containsAll(Long... values) { return new Where(Operator.CONTAINS_ALL, left, new IntegerArrayOperand(values)); } - public Where containsAll(Number... values) { + public Where containsAll(Double... values) { return new Where(Operator.CONTAINS_ALL, left, new NumberArrayOperand(values)); } - public Where containsAll(Date... values) { + public Where containsAll(OffsetDateTime... values) { return new Where(Operator.CONTAINS_ALL, left, new DateArrayOperand(values)); } @@ -452,26 +498,30 @@ public void appendTo(WeaviateProtoBase.Filters.Builder where) { @SuppressWarnings("unchecked") static WhereOperand fromObject(Object value) { - if (value instanceof String) { - return new TextOperand((String) value); - } else if (value instanceof Boolean) { - return new BooleanOperand((Boolean) value); - } else if (value instanceof Integer) { - return new IntegerOperand((Integer) value); - } else if (value instanceof Number) { - return new NumberOperand((Number) value); - } else if (value instanceof Date) { - return new DateOperand((Date) value); - } else if (value instanceof String[]) { - return new TextArrayOperand((String[]) value); - } else if (value instanceof Boolean[]) { - return new BooleanArrayOperand((Boolean[]) value); - } else if (value instanceof Integer[]) { - return new IntegerArrayOperand((Integer[]) value); - } else if (value instanceof Number[]) { - return new NumberArrayOperand((Number[]) value); - } else if (value instanceof Date[]) { - return new DateArrayOperand((Date[]) value); + if (value instanceof String str) { + return new TextOperand(str); + } else if (value instanceof Boolean bool) { + return new BooleanOperand(bool); + } else if (value instanceof Long l) { + return new IntegerOperand(l); + } else if (value instanceof Integer i) { + return new IntegerOperand(i); + } else if (value instanceof Double dbl) { + return new NumberOperand(dbl); + } else if (value instanceof Float f) { + return new NumberOperand(f); + } else if (value instanceof OffsetDateTime date) { + return new DateOperand(date); + } else if (value instanceof String[] strarr) { + return new TextArrayOperand(strarr); + } else if (value instanceof Boolean[] boolarr) { + return new BooleanArrayOperand(boolarr); + } else if (value instanceof Long[] lngarr) { + return new IntegerArrayOperand(lngarr); + } else if (value instanceof Double[] dblarr) { + return new NumberArrayOperand(dblarr); + } else if (value instanceof OffsetDateTime[] datearr) { + return new DateArrayOperand(datearr); } else if (value instanceof List) { if (((List) value).isEmpty()) { throw new IllegalArgumentException( @@ -483,16 +533,16 @@ static WhereOperand fromObject(Object value) { return new TextArrayOperand((List) value); } else if (first instanceof Boolean) { return new BooleanArrayOperand((List) value); - } else if (first instanceof Integer) { - return new IntegerArrayOperand((List) value); - } else if (first instanceof Number) { - return new NumberArrayOperand((List) value); - } else if (first instanceof Date) { - return new DateArrayOperand((List) value); + } else if (first instanceof Long) { + return new IntegerArrayOperand((List) value); + } else if (first instanceof Double) { + return new NumberArrayOperand((List) value); + } else if (first instanceof OffsetDateTime) { + return new DateArrayOperand((List) value); } } throw new IllegalArgumentException( - "value must be either of String, Boolean, Date, Integer, Number, Array/List of these types"); + "value must be either of String, Boolean, OffsetDateTime, Long, Double, or Array/List of these types"); } private static class PathOperand implements WhereOperand { @@ -563,9 +613,9 @@ public String toString() { } private static class BooleanOperand implements WhereOperand { - private final Boolean value; + private final boolean value; - private BooleanOperand(Boolean value) { + private BooleanOperand(boolean value) { this.value = value; } @@ -576,7 +626,7 @@ public void appendTo(WeaviateProtoBase.Filters.Builder where) { @Override public String toString() { - return value.toString(); + return Boolean.toString(value); } } @@ -604,9 +654,13 @@ public String toString() { } private static class IntegerOperand implements WhereOperand { - private final Integer value; + private final long value; + + private IntegerOperand(long value) { + this.value = value; + } - private IntegerOperand(Integer value) { + private IntegerOperand(int value) { this.value = value; } @@ -617,29 +671,25 @@ public void appendTo(WeaviateProtoBase.Filters.Builder where) { @Override public String toString() { - return value.toString(); + return Long.toString(value); } } private static class IntegerArrayOperand implements WhereOperand { - private final List values; + private final List values; - private IntegerArrayOperand(List values) { + private IntegerArrayOperand(List values) { this.values = values; } @SafeVarargs - private IntegerArrayOperand(Integer... values) { + private IntegerArrayOperand(Long... values) { this(Arrays.asList(values)); } - private List toLongs() { - return values.stream().map(Integer::longValue).toList(); - } - @Override public void appendTo(WeaviateProtoBase.Filters.Builder where) { - where.setValueIntArray(WeaviateProtoBase.IntArray.newBuilder().addAllValues(toLongs())); + where.setValueIntArray(WeaviateProtoBase.IntArray.newBuilder().addAllValues(values)); } @Override @@ -649,42 +699,42 @@ public String toString() { } private static class NumberOperand implements WhereOperand { - private final Number value; + private final double value; + + private NumberOperand(double value) { + this.value = value; + } - private NumberOperand(Number value) { + private NumberOperand(float value) { this.value = value; } @Override public void appendTo(WeaviateProtoBase.Filters.Builder where) { - where.setValueNumber(value.doubleValue()); + where.setValueNumber(value); } @Override public String toString() { - return value.toString(); + return Double.toString(value); } } private static class NumberArrayOperand implements WhereOperand { - private final List values; + private final List values; - private NumberArrayOperand(List values) { + private NumberArrayOperand(List values) { this.values = values; } @SafeVarargs - private NumberArrayOperand(Number... values) { + private NumberArrayOperand(Double... values) { this(Arrays.asList(values)); } - private List toDoubles() { - return values.stream().map(Number::doubleValue).toList(); - } - @Override public void appendTo(WeaviateProtoBase.Filters.Builder where) { - where.setValueNumberArray(WeaviateProtoBase.NumberArray.newBuilder().addAllValues(toDoubles())); + where.setValueNumberArray(WeaviateProtoBase.NumberArray.newBuilder().addAllValues(values)); } @Override @@ -694,41 +744,37 @@ public String toString() { } private static class DateOperand implements WhereOperand { - private final Date value; + private final OffsetDateTime value; - private DateOperand(Date value) { + private DateOperand(OffsetDateTime value) { this.value = value; } - private static String format(Date date) { - return DateFormatUtils.format(date, "yyyy-MM-dd'T'HH:mm:ssZZZZZ"); - } - @Override public void appendTo(WeaviateProtoBase.Filters.Builder where) { - where.setValueText(format(value)); + where.setValueText(value.toString()); } @Override public String toString() { - return format(value); + return value.toString(); } } private static class DateArrayOperand implements WhereOperand { - private final List values; + private final List values; - private DateArrayOperand(List values) { + private DateArrayOperand(List values) { this.values = values; } @SafeVarargs - private DateArrayOperand(Date... values) { + private DateArrayOperand(OffsetDateTime... values) { this(Arrays.asList(values)); } private List formatted() { - return values.stream().map(date -> DateOperand.format(date)).toList(); + return values.stream().map(OffsetDateTime::toString).toList(); } @Override diff --git a/src/main/java/io/weaviate/client6/v1/internal/DateUtil.java b/src/main/java/io/weaviate/client6/v1/internal/DateUtil.java new file mode 100644 index 000000000..b103e05e7 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/DateUtil.java @@ -0,0 +1,48 @@ +package io.weaviate.client6.v1.internal; + +import java.io.IOException; +import java.time.OffsetDateTime; + +import com.google.gson.Gson; +import com.google.gson.TypeAdapter; +import com.google.gson.TypeAdapterFactory; +import com.google.gson.reflect.TypeToken; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonWriter; + +public final class DateUtil { + /** Prevent public initialization. */ + private DateUtil() { + } + + /** Convert ISO8601-formatted time string to {@link OffsetDateTime}. */ + public static OffsetDateTime fromISO8601(String iso8601) { + return OffsetDateTime.parse(iso8601); + } + + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { + INSTANCE; + + @SuppressWarnings("unchecked") + @Override + public TypeAdapter create(Gson gson, TypeToken type) { + if (type.getRawType() != OffsetDateTime.class) { + return null; + } + + return (TypeAdapter) new TypeAdapter() { + + @Override + public void write(JsonWriter out, OffsetDateTime value) throws IOException { + out.value(value.toString()); + } + + @Override + public OffsetDateTime read(JsonReader in) throws IOException { + return OffsetDateTime.parse(in.nextString()); + } + + }.nullSafe(); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtil.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtil.java index 1d45bed0f..6bee6de0a 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtil.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtil.java @@ -9,6 +9,10 @@ import com.google.protobuf.ByteString; public class ByteStringUtil { + /** Prevent public initialization. */ + private ByteStringUtil() { + } + private static final ByteOrder BYTE_ORDER = ByteOrder.LITTLE_ENDIAN; /** Decode ByteString to UUID. */ @@ -112,4 +116,36 @@ public static float[][] decodeVectorMulti(ByteString bs) { } return vectors; } + + /** + * Decode ByteString to {@code long[]}. + * + * @throws IllegalArgumentException if ByteString size is not + * a multiple of {@link Long#BYTES}. + */ + public static long[] decodeIntValues(ByteString bs) { + if (bs.size() % Long.BYTES != 0) { + throw new IllegalArgumentException( + "ByteString size " + bs.size() + " is not a multiple of " + String.valueOf(Long.BYTES) + " (Long.BYTES)"); + } + long[] vector = new long[bs.size() / Long.BYTES]; + bs.asReadOnlyByteBuffer().order(BYTE_ORDER).asLongBuffer().get(vector); + return vector; + } + + /** + * Decode ByteString to {@code double[]}. + * + * @throws IllegalArgumentException if ByteString size is not + * a multiple of {@link Double#BYTES}. + */ + public static double[] decodeNumberValues(ByteString bs) { + if (bs.size() % Double.BYTES != 0) { + throw new IllegalArgumentException( + "ByteString size " + bs.size() + " is not a multiple of " + String.valueOf(Double.BYTES) + " (Double.BYTES)"); + } + double[] vector = new double[bs.size() / Double.BYTES]; + bs.asReadOnlyByteBuffer().order(BYTE_ORDER).asDoubleBuffer().get(vector); + return vector; + } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java index 960d40279..a56fce609 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java +++ b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java @@ -27,6 +27,8 @@ public final class JSON { io.weaviate.client6.v1.api.collections.Reranker.CustomTypeAdapterFactory.INSTANCE); gsonBuilder.registerTypeAdapterFactory( io.weaviate.client6.v1.api.collections.Generative.CustomTypeAdapterFactory.INSTANCE); + gsonBuilder.registerTypeAdapterFactory( + io.weaviate.client6.v1.internal.DateUtil.CustomTypeAdapterFactory.INSTANCE); // TypeAdapters ----------------------------------------------------------- gsonBuilder.registerTypeAdapter( diff --git a/src/main/java/io/weaviate/client6/v1/internal/orm/MapBuilder.java b/src/main/java/io/weaviate/client6/v1/internal/orm/MapBuilder.java index c45bd57a4..0e2c94c99 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/orm/MapBuilder.java +++ b/src/main/java/io/weaviate/client6/v1/internal/orm/MapBuilder.java @@ -1,8 +1,10 @@ package io.weaviate.client6.v1.internal.orm; -import java.util.Date; +import java.time.OffsetDateTime; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.UUID; public class MapBuilder implements PropertiesBuilder> { private final Map properties = new HashMap<>(); @@ -28,7 +30,7 @@ public void setInteger(String property, Long value) { } @Override - public void setNumber(String property, Number value) { + public void setDouble(String property, Double value) { properties.put(property, value); } @@ -38,7 +40,42 @@ public void setBlob(String property, String value) { } @Override - public void setDate(String property, Date value) { + public void setOffsetDateTime(String property, OffsetDateTime value) { + properties.put(property, value); + } + + @Override + public void setUuid(String property, UUID value) { + properties.put(property, value); + } + + @Override + public void setTextArray(String property, List value) { + properties.put(property, value); + } + + @Override + public void setLongArray(String property, List value) { + properties.put(property, value); + } + + @Override + public void setDoubleArray(String property, List value) { + properties.put(property, value); + } + + @Override + public void setUuidArray(String property, List value) { + properties.put(property, value); + } + + @Override + public void setBooleanArray(String property, List value) { + properties.put(property, value); + } + + @Override + public void setOffsetDateTimeArray(String property, List value) { properties.put(property, value); } diff --git a/src/main/java/io/weaviate/client6/v1/internal/orm/PropertiesBuilder.java b/src/main/java/io/weaviate/client6/v1/internal/orm/PropertiesBuilder.java index 0d88d385c..dd46f87c9 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/orm/PropertiesBuilder.java +++ b/src/main/java/io/weaviate/client6/v1/internal/orm/PropertiesBuilder.java @@ -1,6 +1,8 @@ package io.weaviate.client6.v1.internal.orm; -import java.util.Date; +import java.time.OffsetDateTime; +import java.util.List; +import java.util.UUID; public interface PropertiesBuilder { void setNull(String property); @@ -11,11 +13,25 @@ public interface PropertiesBuilder { void setInteger(String property, Long value); - void setNumber(String property, Number value); + void setDouble(String property, Double value); void setBlob(String property, String value); - void setDate(String property, Date value); + void setOffsetDateTime(String property, OffsetDateTime value); + + void setUuid(String property, UUID value); + + void setTextArray(String property, List value); + + void setLongArray(String property, List value); + + void setDoubleArray(String property, List value); + + void setUuidArray(String property, List value); + + void setBooleanArray(String property, List value); + + void setOffsetDateTimeArray(String property, List value); T build(); } diff --git a/src/test/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtilTest.java b/src/test/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtilTest.java index f9c6d1f71..382dd9212 100644 --- a/src/test/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtilTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtilTest.java @@ -3,6 +3,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import org.assertj.core.api.Assertions; import org.junit.Test; import com.google.protobuf.ByteString; @@ -71,11 +72,27 @@ public void test_decodeVector_2d_empty() { @Test public void test_decodeVector_2d_dim_zero() { - byte[] bytes = new byte[] { 0, 0 }; + byte[] bytes = { 0, 0 }; float[][] got = ByteStringUtil.decodeVectorMulti(ByteString.copyFrom(bytes)); assertEquals(0, got.length); } + @Test + public void test_decodeIntValues() { + byte[] bytes = { 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0 }; + long[] want = { 1, 2, 3 }; + long[] got = ByteStringUtil.decodeIntValues(ByteString.copyFrom(bytes)); + assertArrayEquals(want, got); + } + + @Test + public void test_decodeNumberValues() { + byte[] bytes = { 0, 0, 0, 0, 0, 0, -16, 63, 0, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 8, 64 }; + double[] want = { 1, 2, 3 }; + double[] got = ByteStringUtil.decodeNumberValues(ByteString.copyFrom(bytes)); + Assertions.assertThat(got).isEqualTo(want); + } + @Test(expected = IllegalArgumentException.class) public void test_decodeVector_1d_illegal() { byte[] bytes = new byte[Float.BYTES - 1]; // must be a multiple of Float.BYTES @@ -93,4 +110,16 @@ public void test_decodeVector_2d_illegal() { ByteStringUtil.decodeVectorMulti(ByteString.copyFrom(bytes)); } + + @Test(expected = IllegalArgumentException.class) + public void test_decodeIntValues_illegal() { + byte[] bytes = new byte[Long.BYTES - 1]; // must be a multiple of Long.BYTES + ByteStringUtil.decodeIntValues(ByteString.copyFrom(bytes)); + } + + @Test(expected = IllegalArgumentException.class) + public void test_decodeNumberValues_illegal() { + byte[] bytes = new byte[Double.BYTES - 1]; // must be a multiple of Double.BYTES + ByteStringUtil.decodeNumberValues(ByteString.copyFrom(bytes)); + } }