diff --git a/CHANGELOG.md b/CHANGELOG.md index daf5568fa7a77..2165fcb4700fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [AdmissionControl] Added changes for AdmissionControl Interceptor and AdmissionControlService for RateLimiting ([#9286](https://github.com/opensearch-project/OpenSearch/pull/9286)) - GHA to verify checklist items completion in PR descriptions ([#10800](https://github.com/opensearch-project/OpenSearch/pull/10800)) - [Remote cluster state] Restore cluster state version during remote state auto restore ([#10853](https://github.com/opensearch-project/OpenSearch/pull/10853)) +- Add back half_float BKD based sort query optimization ([#11024](https://github.com/opensearch-project/OpenSearch/pull/11024)) ### Dependencies - Bump `log4j-core` from 2.18.0 to 2.19.0 diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/indices.sort/10_basic.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/indices.sort/10_basic.yml index b9089689b0cf1..3b7ea15164e9f 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/indices.sort/10_basic.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/indices.sort/10_basic.yml @@ -156,3 +156,23 @@ query: {"range": { "rank": { "from": 0 } } } track_total_hits: false size: 3 + +--- +"Index Sort half float": + - do: + catch: bad_request + indices.create: + index: test + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + index.sort.field: rank + mappings: + properties: + rank: + type: half_float + + # This should failed with 400 as half_float is not supported for index sort + - match: { status: 400 } + - match: { error.type: illegal_argument_exception } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search/260_sort_mixed.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search/260_sort_mixed.yml index ba2b18eb3b6d0..a04dc308b2a06 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/search/260_sort_mixed.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search/260_sort_mixed.yml @@ -20,6 +20,7 @@ properties: counter: type: double + - do: bulk: refresh: true @@ -119,3 +120,87 @@ - match: { status: 400 } - match: { error.type: search_phase_execution_exception } - match: { error.caused_by.reason: "Can't do sort across indices, as a field has [unsigned_long] type in one index, and different type in another index!" } + +--- +"search across indices with mixed long and double and float numeric types": + - skip: + version: " - 2.10.99" + reason: half float was broken before 2.11 + + - do: + indices.create: + index: test_1 + body: + mappings: + properties: + counter: + type: long + + - do: + indices.create: + index: test_2 + body: + mappings: + properties: + counter: + type: double + + - do: + indices.create: + index: test_3 + body: + mappings: + properties: + counter: + type: half_float + + - do: + bulk: + refresh: true + body: + - index: + _index: test_1 + - counter: 223372036854775800 + - index: + _index: test_2 + - counter: 1223372036854775800.23 + - index: + _index: test_2 + - counter: 184.4 + - index: + _index: test_3 + - counter: 187.4 + - index: + _index: test_3 + - counter: 194.4 + + - do: + search: + index: test_* + rest_total_hits_as_int: true + body: + sort: [{ counter: desc }] + - match: { hits.total: 5 } + - length: { hits.hits: 5 } + - match: { hits.hits.0._index: test_2 } + - match: { hits.hits.0._source.counter: 1223372036854775800.23 } + - match: { hits.hits.0.sort.0: 1223372036854775800.23 } + - match: { hits.hits.1._index: test_1 } + - match: { hits.hits.1._source.counter: 223372036854775800 } + - match: { hits.hits.1.sort.0: 223372036854775800 } + - match: { hits.hits.2._index: test_3 } + - match: { hits.hits.2._source.counter: 194.4 } + + - do: + search: + index: test_* + rest_total_hits_as_int: true + body: + sort: [{ counter: asc }] + - match: { hits.total: 5 } + - length: { hits.hits: 5 } + - match: { hits.hits.0._index: test_2 } + - match: { hits.hits.0._source.counter: 184.4 } + - match: { hits.hits.0.sort.0: 184.4 } + - match: { hits.hits.1._index: test_3 } + - match: { hits.hits.1._source.counter: 187.4 } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search/90_search_after.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search/90_search_after.yml index 55e1566656faf..1563daba9de6d 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/search/90_search_after.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search/90_search_after.yml @@ -320,3 +320,130 @@ - length: { hits.hits: 1 } - match: { hits.hits.0._index: test } - match: { hits.hits.0._source.population: null } + +--- +"half float": + - skip: + version: " - 2.10.99" + reason: half_float was broken for 2.10 and earlier + + - do: + indices.create: + index: test + body: + mappings: + properties: + population: + type: half_float + - do: + bulk: + refresh: true + index: test + body: | + {"index":{}} + {"population": 184.4} + {"index":{}} + {"population": 194.4} + {"index":{}} + {"population": 144.4} + {"index":{}} + {"population": 174.4} + {"index":{}} + {"population": 164.4} + + - do: + search: + index: test + rest_total_hits_as_int: true + body: + size: 3 + sort: [ { population: desc } ] + - match: { hits.total: 5 } + - length: { hits.hits: 3 } + - match: { hits.hits.0._index: test } + - match: { hits.hits.0._source.population: 194.4 } + - match: { hits.hits.1._index: test } + - match: { hits.hits.1._source.population: 184.4 } + - match: { hits.hits.2._index: test } + - match: { hits.hits.2._source.population: 174.4 } + + - do: + search: + index: test + rest_total_hits_as_int: true + body: + size: 3 + sort: [ { population: asc } ] + - match: { hits.total: 5 } + - length: { hits.hits: 3 } + - match: { hits.hits.0._index: test } + - match: { hits.hits.0._source.population: 144.4 } + - match: { hits.hits.1._index: test } + - match: { hits.hits.1._source.population: 164.4 } + - match: { hits.hits.2._index: test } + - match: { hits.hits.2._source.population: 174.4 } + + # search_after with the asc sort + - do: + search: + index: test + rest_total_hits_as_int: true + body: + size: 1 + sort: [ { population: asc } ] + search_after: [ 184.375 ] # this is rounded sort value in sort result + - match: { hits.total: 5 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._index: test } + - match: { hits.hits.0._source.population: 194.4 } + + # search_after with the desc sort + - do: + search: + index: test + rest_total_hits_as_int: true + body: + size: 1 + sort: [ { population: desc } ] + search_after: [ 164.375 ] # this is rounded sort value in sort result + - match: { hits.total: 5 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._index: test } + - match: { hits.hits.0._source.population: 144.4 } + + # search_after with the asc sort with missing + - do: + bulk: + refresh: true + index: test + body: | + {"index":{}} + {"population": null} + - do: + search: + index: test + rest_total_hits_as_int: true + body: + "size": 5 + "sort": [ { "population": { "order": "asc", "missing": "_last" } } ] + search_after: [ 200 ] # making it out of min/max so only missing value hit is qualified + + - match: { hits.total: 6 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._index: test } + - match: { hits.hits.0._source.population: null } + + # search_after with the desc sort with missing + - do: + search: + index: test + rest_total_hits_as_int: true + body: + "size": 5 + "sort": [ { "population": { "order": "desc", "missing": "_last" } } ] + search_after: [ 100 ] # making it out of min/max so only missing value hit is qualified + + - match: { hits.total: 6 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._index: test } + - match: { hits.hits.0._source.population: null } diff --git a/server/src/internalClusterTest/java/org/opensearch/search/sort/FieldSortIT.java b/server/src/internalClusterTest/java/org/opensearch/search/sort/FieldSortIT.java index bee242b933dfd..d4980a64a3977 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/sort/FieldSortIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/sort/FieldSortIT.java @@ -605,6 +605,9 @@ public void testSimpleSorts() throws Exception { .startObject("float_value") .field("type", "float") .endObject() + .startObject("half_float_value") + .field("type", "half_float") + .endObject() .startObject("double_value") .field("type", "double") .endObject() @@ -628,6 +631,7 @@ public void testSimpleSorts() throws Exception { .field("long_value", i) .field("unsigned_long_value", UNSIGNED_LONG_BASE.add(BigInteger.valueOf(10000 * i))) .field("float_value", 0.1 * i) + .field("half_float_value", 0.1 * i) .field("double_value", 0.1 * i) .endObject() ); @@ -794,6 +798,28 @@ public void testSimpleSorts() throws Exception { assertThat(searchResponse.toString(), not(containsString("error"))); + // HALF_FLOAT + size = 1 + random.nextInt(10); + searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("half_float_value", SortOrder.ASC).get(); + + assertHitCount(searchResponse, 10L); + assertThat(searchResponse.getHits().getHits().length, equalTo(size)); + for (int i = 0; i < size; i++) { + assertThat(searchResponse.getHits().getAt(i).getId(), equalTo(Integer.toString(i))); + } + + assertThat(searchResponse.toString(), not(containsString("error"))); + size = 1 + random.nextInt(10); + searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("half_float_value", SortOrder.DESC).get(); + + assertHitCount(searchResponse, 10); + assertThat(searchResponse.getHits().getHits().length, equalTo(size)); + for (int i = 0; i < size; i++) { + assertThat(searchResponse.getHits().getAt(i).getId(), equalTo(Integer.toString(9 - i))); + } + + assertThat(searchResponse.toString(), not(containsString("error"))); + // DOUBLE size = 1 + random.nextInt(10); searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("double_value", SortOrder.ASC).get(); @@ -1330,6 +1356,9 @@ public void testSortMVField() throws Exception { .startObject("float_values") .field("type", "float") .endObject() + .startObject("half_float_values") + .field("type", "float") + .endObject() .startObject("double_values") .field("type", "double") .endObject() @@ -1351,6 +1380,7 @@ public void testSortMVField() throws Exception { .array("short_values", 1, 5, 10, 8) .array("byte_values", 1, 5, 10, 8) .array("float_values", 1f, 5f, 10f, 8f) + .array("half_float_values", 1f, 5f, 10f, 8f) .array("double_values", 1d, 5d, 10d, 8d) .array("string_values", "01", "05", "10", "08") .endObject() @@ -1365,6 +1395,7 @@ public void testSortMVField() throws Exception { .array("short_values", 11, 15, 20, 7) .array("byte_values", 11, 15, 20, 7) .array("float_values", 11f, 15f, 20f, 7f) + .array("half_float_values", 11f, 15f, 20f, 7f) .array("double_values", 11d, 15d, 20d, 7d) .array("string_values", "11", "15", "20", "07") .endObject() @@ -1379,6 +1410,7 @@ public void testSortMVField() throws Exception { .array("short_values", 2, 1, 3, -4) .array("byte_values", 2, 1, 3, -4) .array("float_values", 2f, 1f, 3f, -4f) + .array("half_float_values", 2f, 1f, 3f, -4f) .array("double_values", 2d, 1d, 3d, -4d) .array("string_values", "02", "01", "03", "!4") .endObject() @@ -1585,6 +1617,34 @@ public void testSortMVField() throws Exception { assertThat(searchResponse.getHits().getAt(2).getId(), equalTo(Integer.toString(3))); assertThat(((Number) searchResponse.getHits().getAt(2).getSortValues()[0]).floatValue(), equalTo(3f)); + searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(10).addSort("half_float_values", SortOrder.ASC).get(); + + assertThat(searchResponse.getHits().getTotalHits().value, equalTo(3L)); + assertThat(searchResponse.getHits().getHits().length, equalTo(3)); + + assertThat(searchResponse.getHits().getAt(0).getId(), equalTo(Integer.toString(3))); + assertThat(((Number) searchResponse.getHits().getAt(0).getSortValues()[0]).floatValue(), equalTo(-4f)); + + assertThat(searchResponse.getHits().getAt(1).getId(), equalTo(Integer.toString(1))); + assertThat(((Number) searchResponse.getHits().getAt(1).getSortValues()[0]).floatValue(), equalTo(1f)); + + assertThat(searchResponse.getHits().getAt(2).getId(), equalTo(Integer.toString(2))); + assertThat(((Number) searchResponse.getHits().getAt(2).getSortValues()[0]).floatValue(), equalTo(7f)); + + searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(10).addSort("half_float_values", SortOrder.DESC).get(); + + assertThat(searchResponse.getHits().getTotalHits().value, equalTo(3L)); + assertThat(searchResponse.getHits().getHits().length, equalTo(3)); + + assertThat(searchResponse.getHits().getAt(0).getId(), equalTo(Integer.toString(2))); + assertThat(((Number) searchResponse.getHits().getAt(0).getSortValues()[0]).floatValue(), equalTo(20f)); + + assertThat(searchResponse.getHits().getAt(1).getId(), equalTo(Integer.toString(1))); + assertThat(((Number) searchResponse.getHits().getAt(1).getSortValues()[0]).floatValue(), equalTo(10f)); + + assertThat(searchResponse.getHits().getAt(2).getId(), equalTo(Integer.toString(3))); + assertThat(((Number) searchResponse.getHits().getAt(2).getSortValues()[0]).floatValue(), equalTo(3f)); + searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(10).addSort("double_values", SortOrder.ASC).get(); assertThat(searchResponse.getHits().getTotalHits().value, equalTo(3L)); diff --git a/server/src/main/java/org/opensearch/index/fielddata/IndexNumericFieldData.java b/server/src/main/java/org/opensearch/index/fielddata/IndexNumericFieldData.java index 6fc074fe0de95..b0ff944d014de 100644 --- a/server/src/main/java/org/opensearch/index/fielddata/IndexNumericFieldData.java +++ b/server/src/main/java/org/opensearch/index/fielddata/IndexNumericFieldData.java @@ -42,6 +42,7 @@ import org.opensearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested; import org.opensearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource; import org.opensearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource; +import org.opensearch.index.fielddata.fieldcomparator.HalfFloatValuesComparatorSource; import org.opensearch.index.fielddata.fieldcomparator.IntValuesComparatorSource; import org.opensearch.index.fielddata.fieldcomparator.LongValuesComparatorSource; import org.opensearch.index.fielddata.fieldcomparator.UnsignedLongValuesComparatorSource; @@ -220,6 +221,8 @@ private XFieldComparatorSource comparatorSource( final XFieldComparatorSource source; switch (targetNumericType) { case HALF_FLOAT: + source = new HalfFloatValuesComparatorSource(this, missingValue, sortMode, nested); + break; case FLOAT: source = new FloatValuesComparatorSource(this, missingValue, sortMode, nested); break; @@ -242,7 +245,7 @@ private XFieldComparatorSource comparatorSource( assert !targetNumericType.isFloatingPoint(); source = new IntValuesComparatorSource(this, missingValue, sortMode, nested); } - if (targetNumericType != getNumericType() || getNumericType() == NumericType.HALF_FLOAT) { + if (targetNumericType != getNumericType()) { source.disableSkipping(); // disable skipping logic for cast of sort field } return source; diff --git a/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/HalfFloatValuesComparatorSource.java b/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/HalfFloatValuesComparatorSource.java new file mode 100644 index 0000000000000..7e3936be1d8a5 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/fielddata/fieldcomparator/HalfFloatValuesComparatorSource.java @@ -0,0 +1,75 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.fielddata.fieldcomparator; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.LeafFieldComparator; +import org.apache.lucene.util.BitSet; +import org.opensearch.index.fielddata.FieldData; +import org.opensearch.index.fielddata.IndexNumericFieldData; +import org.opensearch.index.fielddata.NumericDoubleValues; +import org.opensearch.index.fielddata.SortedNumericDoubleValues; +import org.opensearch.index.search.comparators.HalfFloatComparator; +import org.opensearch.search.MultiValueMode; + +import java.io.IOException; + +/** + * Comparator source for half_float values. + * + * @opensearch.internal + */ +public class HalfFloatValuesComparatorSource extends FloatValuesComparatorSource { + private final IndexNumericFieldData indexFieldData; + + public HalfFloatValuesComparatorSource( + IndexNumericFieldData indexFieldData, + Object missingValue, + MultiValueMode sortMode, + Nested nested + ) { + super(indexFieldData, missingValue, sortMode, nested); + this.indexFieldData = indexFieldData; + } + + @Override + public FieldComparator newComparator(String fieldname, int numHits, boolean enableSkipping, boolean reversed) { + assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName()); + + final float fMissingValue = (Float) missingObject(missingValue, reversed); + // NOTE: it's important to pass null as a missing value in the constructor so that + // the comparator doesn't check docsWithField since we replace missing values in select() + return new HalfFloatComparator(numHits, fieldname, null, reversed, enableSkipping && this.enableSkipping) { + @Override + public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException { + return new HalfFloatLeafComparator(context) { + @Override + protected NumericDocValues getNumericDocValues(LeafReaderContext context, String field) throws IOException { + return HalfFloatValuesComparatorSource.this.getNumericDocValues(context, fMissingValue).getRawFloatValues(); + } + }; + } + }; + } + + private NumericDoubleValues getNumericDocValues(LeafReaderContext context, float missingValue) throws IOException { + final SortedNumericDoubleValues values = indexFieldData.load(context).getDoubleValues(); + if (nested == null) { + return FieldData.replaceMissing(sortMode.select(values), missingValue); + } else { + final BitSet rootDocs = nested.rootDocs(context); + final DocIdSetIterator innerDocs = nested.innerDocs(context); + final int maxChildren = nested.getNestedSort() != null ? nested.getNestedSort().getMaxChildren() : Integer.MAX_VALUE; + return sortMode.select(values, missingValue, rootDocs, innerDocs, context.reader().maxDoc(), maxChildren); + } + } +} diff --git a/server/src/main/java/org/opensearch/index/search/comparators/HalfFloatComparator.java b/server/src/main/java/org/opensearch/index/search/comparators/HalfFloatComparator.java new file mode 100644 index 0000000000000..6244fa647b042 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/search/comparators/HalfFloatComparator.java @@ -0,0 +1,111 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.search.comparators; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.sandbox.document.HalfFloatPoint; +import org.apache.lucene.search.LeafFieldComparator; +import org.apache.lucene.search.comparators.NumericComparator; + +import java.io.IOException; + +/** + * The comparator for half_float numeric type. + * Comparator based on {@link Float#compare} for {@code numHits}. This comparator provides a + * skipping functionality – an iterator that can skip over non-competitive documents. + */ +public class HalfFloatComparator extends NumericComparator { + private final float[] values; + protected float topValue; + protected float bottom; + + public HalfFloatComparator(int numHits, String field, Float missingValue, boolean reverse, boolean enableSkipping) { + super(field, missingValue != null ? missingValue : 0.0f, reverse, enableSkipping, HalfFloatPoint.BYTES); + values = new float[numHits]; + } + + @Override + public int compare(int slot1, int slot2) { + return Float.compare(values[slot1], values[slot2]); + } + + @Override + public void setTopValue(Float value) { + super.setTopValue(value); + topValue = value; + } + + @Override + public Float value(int slot) { + return Float.valueOf(values[slot]); + } + + @Override + public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException { + return new HalfFloatLeafComparator(context); + } + + /** Leaf comparator for {@link HalfFloatComparator} that provides skipping functionality */ + public class HalfFloatLeafComparator extends NumericLeafComparator { + + public HalfFloatLeafComparator(LeafReaderContext context) throws IOException { + super(context); + } + + private float getValueForDoc(int doc) throws IOException { + if (docValues.advanceExact(doc)) { + return Float.intBitsToFloat((int) docValues.longValue()); + } else { + return missingValue; + } + } + + @Override + public void setBottom(int slot) throws IOException { + bottom = values[slot]; + super.setBottom(slot); + } + + @Override + public int compareBottom(int doc) throws IOException { + return Float.compare(bottom, getValueForDoc(doc)); + } + + @Override + public int compareTop(int doc) throws IOException { + return Float.compare(topValue, getValueForDoc(doc)); + } + + @Override + public void copy(int slot, int doc) throws IOException { + values[slot] = getValueForDoc(doc); + super.copy(slot, doc); + } + + @Override + protected int compareMissingValueWithBottomValue() { + return Float.compare(missingValue, bottom); + } + + @Override + protected int compareMissingValueWithTopValue() { + return Float.compare(missingValue, topValue); + } + + @Override + protected void encodeBottom(byte[] packedValue) { + HalfFloatPoint.encodeDimension(bottom, packedValue, 0); + } + + @Override + protected void encodeTop(byte[] packedValue) { + HalfFloatPoint.encodeDimension(topValue, packedValue, 0); + } + } +}