From 564549af8fe73e270f1341bbfc14b40cc9da4a1c Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 5 Jun 2024 11:43:07 +0100 Subject: [PATCH 01/30] Expose the bitset filter cache in the MappingParserContext (#109298) Add the bitset filter cache in the MappingParserContext --- .../index/mapper/MapperServiceFactory.java | 11 ++++++++++ .../search/QueryParserHelperBenchmark.java | 1 + .../metadata/IndexMetadataVerifier.java | 3 +++ .../org/elasticsearch/index/IndexModule.java | 3 +++ .../org/elasticsearch/index/IndexService.java | 3 ++- .../index/mapper/MapperService.java | 8 +++++++- .../index/mapper/MappingParserContext.java | 17 +++++++++++++--- .../elasticsearch/index/codec/CodecTests.java | 11 ++++++++++ .../index/mapper/MappingParserTests.java | 13 +++++++++++- .../index/mapper/ParametrizedMapperTests.java | 5 ++++- .../index/mapper/TypeParsersTests.java | 10 ++++++++-- .../query/SearchExecutionContextTests.java | 5 ++++- .../elasticsearch/index/MapperTestUtils.java | 11 ++++++++++ .../index/engine/TranslogHandler.java | 3 +++ .../index/mapper/MapperServiceTestCase.java | 9 +++++++++ .../mapper/TestDocumentParserContext.java | 5 ++++- .../aggregations/AggregatorTestCase.java | 5 ++++- .../test/AbstractBuilderTestCase.java | 20 ++++++++----------- 18 files changed, 119 insertions(+), 24 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/index/mapper/MapperServiceFactory.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/index/mapper/MapperServiceFactory.java index 70e9fe424e77b..68b31481e17f3 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/index/mapper/MapperServiceFactory.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/index/mapper/MapperServiceFactory.java @@ -9,6 +9,7 @@ package org.elasticsearch.benchmark.index.mapper; import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.lucene.util.Accountable; import org.elasticsearch.TransportVersion; import org.elasticsearch.cluster.ClusterModule; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -21,10 +22,12 @@ import org.elasticsearch.index.analysis.IndexAnalyzers; import org.elasticsearch.index.analysis.LowercaseNormalizer; import org.elasticsearch.index.analysis.NamedAnalyzer; +import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.mapper.MapperMetrics; import org.elasticsearch.index.mapper.MapperRegistry; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.ProvidedIdFieldMapper; +import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.similarity.SimilarityService; import org.elasticsearch.indices.IndicesModule; import org.elasticsearch.script.Script; @@ -52,6 +55,13 @@ public static MapperService create(String mappings) { MapperRegistry mapperRegistry = new IndicesModule(Collections.emptyList()).getMapperRegistry(); SimilarityService similarityService = new SimilarityService(indexSettings, null, Map.of()); + BitsetFilterCache bitsetFilterCache = new BitsetFilterCache(indexSettings, new BitsetFilterCache.Listener() { + @Override + public void onCache(ShardId shardId, Accountable accountable) {} + + @Override + public void onRemoval(ShardId shardId, Accountable accountable) {} + }); MapperService mapperService = new MapperService( () -> TransportVersion.current(), indexSettings, @@ -73,6 +83,7 @@ public T compile(Script script, ScriptContext scriptContext) { throw new UnsupportedOperationException(); } }, + bitsetFilterCache::getBitSetProducer, MapperMetrics.NOOP ); diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/QueryParserHelperBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/QueryParserHelperBenchmark.java index 14f6fe6501a73..cff15d9c36d34 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/QueryParserHelperBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/QueryParserHelperBenchmark.java @@ -189,6 +189,7 @@ public T compile(Script script, ScriptContext scriptContext) { throw new UnsupportedOperationException(); } }, + query -> { throw new UnsupportedOperationException(); }, MapperMetrics.NOOP ); diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadataVerifier.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadataVerifier.java index 0124f23a1156d..e774d7e4d552d 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadataVerifier.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadataVerifier.java @@ -187,6 +187,9 @@ protected TokenStreamComponents createComponents(String fieldName) { () -> null, indexSettings.getMode().idFieldMapperWithoutFieldData(), scriptService, + query -> { + throw new UnsupportedOperationException("IndexMetadataVerifier"); + }, mapperMetrics ) ) { diff --git a/server/src/main/java/org/elasticsearch/index/IndexModule.java b/server/src/main/java/org/elasticsearch/index/IndexModule.java index ff8db4bacef8c..fa2a9f0f35259 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexModule.java +++ b/server/src/main/java/org/elasticsearch/index/IndexModule.java @@ -652,6 +652,9 @@ public MapperService newIndexMapperService( }, indexSettings.getMode().idFieldMapperWithoutFieldData(), scriptService, + query -> { + throw new UnsupportedOperationException("no index query shard context available"); + }, mapperMetrics ); } diff --git a/server/src/main/java/org/elasticsearch/index/IndexService.java b/server/src/main/java/org/elasticsearch/index/IndexService.java index 1712f824a132c..0605e36b2ea4b 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexService.java +++ b/server/src/main/java/org/elasticsearch/index/IndexService.java @@ -212,6 +212,7 @@ public IndexService( this.indexAnalyzers = indexAnalyzers; if (needsMapperService(indexSettings, indexCreationContext)) { assert indexAnalyzers != null; + this.bitsetFilterCache = new BitsetFilterCache(indexSettings, new BitsetCacheListener(this)); this.mapperService = new MapperService( clusterService, indexSettings, @@ -223,6 +224,7 @@ public IndexService( () -> newSearchExecutionContext(0, 0, null, System::currentTimeMillis, null, emptyMap()), idFieldMapper, scriptService, + bitsetFilterCache::getBitSetProducer, mapperMetrics ); this.indexFieldData = new IndexFieldDataService(indexSettings, indicesFieldDataCache, circuitBreakerService); @@ -238,7 +240,6 @@ public IndexService( this.indexSortSupplier = () -> null; } indexFieldData.setListener(new FieldDataCacheListener(this)); - this.bitsetFilterCache = new BitsetFilterCache(indexSettings, new BitsetCacheListener(this)); this.warmer = new IndexWarmer(threadPool, indexFieldData, bitsetFilterCache.createListener(threadPool)); this.indexCache = new IndexCache(queryCache, bitsetFilterCache); } else { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperService.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperService.java index d3665c3b978bd..e5dc95ddbc2a0 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperService.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperService.java @@ -8,6 +8,8 @@ package org.elasticsearch.index.mapper; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; import org.elasticsearch.TransportVersion; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.MappingMetadata; @@ -167,6 +169,7 @@ public MapperService( Supplier searchExecutionContextSupplier, IdFieldMapper idFieldMapper, ScriptCompiler scriptCompiler, + Function bitSetProducer, MapperMetrics mapperMetrics ) { this( @@ -179,6 +182,7 @@ public MapperService( searchExecutionContextSupplier, idFieldMapper, scriptCompiler, + bitSetProducer, mapperMetrics ); } @@ -194,6 +198,7 @@ public MapperService( Supplier searchExecutionContextSupplier, IdFieldMapper idFieldMapper, ScriptCompiler scriptCompiler, + Function bitSetProducer, MapperMetrics mapperMetrics ) { super(indexSettings); @@ -210,7 +215,8 @@ public MapperService( scriptCompiler, indexAnalyzers, indexSettings, - idFieldMapper + idFieldMapper, + bitSetProducer ); this.documentParser = new DocumentParser(parserConfiguration, this.mappingParserContextSupplier.get()); Map metadataMapperParsers = mapperRegistry.getMetadataMapperParsers( diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingParserContext.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingParserContext.java index 88df87859ccc2..3f614d4346fd4 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingParserContext.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingParserContext.java @@ -8,6 +8,8 @@ package org.elasticsearch.index.mapper; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; import org.elasticsearch.TransportVersion; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.time.DateFormatter; @@ -37,6 +39,7 @@ public class MappingParserContext { private final IndexAnalyzers indexAnalyzers; private final IndexSettings indexSettings; private final IdFieldMapper idFieldMapper; + private final Function bitSetProducer; private final long mappingObjectDepthLimit; private long mappingObjectDepth = 0; @@ -50,7 +53,8 @@ public MappingParserContext( ScriptCompiler scriptCompiler, IndexAnalyzers indexAnalyzers, IndexSettings indexSettings, - IdFieldMapper idFieldMapper + IdFieldMapper idFieldMapper, + Function bitSetProducer ) { this.similarityLookupService = similarityLookupService; this.typeParsers = typeParsers; @@ -63,6 +67,7 @@ public MappingParserContext( this.indexSettings = indexSettings; this.idFieldMapper = idFieldMapper; this.mappingObjectDepthLimit = indexSettings.getMappingDepthLimit(); + this.bitSetProducer = bitSetProducer; } public IndexAnalyzers getIndexAnalyzers() { @@ -132,6 +137,10 @@ public ScriptCompiler scriptCompiler() { return scriptCompiler; } + public BitSetProducer bitSetProducer(Query query) { + return bitSetProducer.apply(query); + } + void incrementMappingObjectDepth() throws MapperParsingException { mappingObjectDepth++; if (mappingObjectDepth > mappingObjectDepthLimit) { @@ -159,7 +168,8 @@ private static class MultiFieldParserContext extends MappingParserContext { in.scriptCompiler, in.indexAnalyzers, in.indexSettings, - in.idFieldMapper + in.idFieldMapper, + in.bitSetProducer ); } @@ -188,7 +198,8 @@ private static class DynamicTemplateParserContext extends MappingParserContext { in.scriptCompiler, in.indexAnalyzers, in.indexSettings, - in.idFieldMapper + in.idFieldMapper, + in.bitSetProducer ); this.dateFormatter = dateFormatter; } diff --git a/server/src/test/java/org/elasticsearch/index/codec/CodecTests.java b/server/src/test/java/org/elasticsearch/index/codec/CodecTests.java index ffb3cc1943bff..3c687f1792d0d 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/CodecTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/CodecTests.java @@ -19,15 +19,18 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase.SuppressCodecs; +import org.apache.lucene.util.Accountable; import org.elasticsearch.TransportVersion; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.env.Environment; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.IndexAnalyzers; +import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.mapper.MapperMetrics; import org.elasticsearch.index.mapper.MapperRegistry; import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.similarity.SimilarityService; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.script.ScriptCompiler; @@ -107,6 +110,13 @@ private CodecService createCodecService() throws IOException { Collections.emptyMap(), MapperPlugin.NOOP_FIELD_FILTER ); + BitsetFilterCache bitsetFilterCache = new BitsetFilterCache(settings, new BitsetFilterCache.Listener() { + @Override + public void onCache(ShardId shardId, Accountable accountable) {} + + @Override + public void onRemoval(ShardId shardId, Accountable accountable) {} + }); MapperService service = new MapperService( () -> TransportVersion.current(), settings, @@ -117,6 +127,7 @@ private CodecService createCodecService() throws IOException { () -> null, settings.getMode().idFieldMapperWithoutFieldData(), ScriptCompiler.NONE, + bitsetFilterCache::getBitSetProducer, MapperMetrics.NOOP ); return new CodecService(service, BigArrays.NON_RECYCLING_INSTANCE); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingParserTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingParserTests.java index abe8e820acae8..aa22a345c5cec 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingParserTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingParserTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.index.mapper; +import org.apache.lucene.util.Accountable; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.bytes.BytesReference; @@ -17,6 +18,8 @@ import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.analysis.IndexAnalyzers; +import org.elasticsearch.index.cache.bitset.BitsetFilterCache; +import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.similarity.SimilarityService; import org.elasticsearch.indices.IndicesModule; import org.elasticsearch.script.ScriptService; @@ -43,6 +46,13 @@ private static MappingParser createMappingParser(Settings settings, IndexVersion IndexAnalyzers indexAnalyzers = createIndexAnalyzers(); SimilarityService similarityService = new SimilarityService(indexSettings, scriptService, Collections.emptyMap()); MapperRegistry mapperRegistry = new IndicesModule(Collections.emptyList()).getMapperRegistry(); + BitsetFilterCache bitsetFilterCache = new BitsetFilterCache(indexSettings, new BitsetFilterCache.Listener() { + @Override + public void onCache(ShardId shardId, Accountable accountable) {} + + @Override + public void onRemoval(ShardId shardId, Accountable accountable) {} + }); Supplier mappingParserContextSupplier = () -> new MappingParserContext( similarityService::getSimilarity, type -> mapperRegistry.getMapperParser(type, indexSettings.getIndexVersionCreated()), @@ -55,7 +65,8 @@ private static MappingParser createMappingParser(Settings settings, IndexVersion scriptService, indexAnalyzers, indexSettings, - indexSettings.getMode().idFieldMapperWithoutFieldData() + indexSettings.getMode().idFieldMapperWithoutFieldData(), + bitsetFilterCache::getBitSetProducer ); Map metadataMapperParsers = mapperRegistry.getMetadataMapperParsers( diff --git a/server/src/test/java/org/elasticsearch/index/mapper/ParametrizedMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/ParametrizedMapperTests.java index b1b7f80ba865f..0ec1997ae652e 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/ParametrizedMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/ParametrizedMapperTests.java @@ -277,7 +277,10 @@ private static TestMapper fromMapping( ScriptCompiler.NONE, mapperService.getIndexAnalyzers(), mapperService.getIndexSettings(), - mapperService.getIndexSettings().getMode().idFieldMapperWithoutFieldData() + mapperService.getIndexSettings().getMode().idFieldMapperWithoutFieldData(), + query -> { + throw new UnsupportedOperationException(); + } ); if (fromDynamicTemplate) { pc = pc.createDynamicTemplateContext(null); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/TypeParsersTests.java b/server/src/test/java/org/elasticsearch/index/mapper/TypeParsersTests.java index 2b704a25e2232..035466d93ab06 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/TypeParsersTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/TypeParsersTests.java @@ -97,7 +97,10 @@ public void testMultiFieldWithinMultiField() throws IOException { ScriptCompiler.NONE, mapperService.getIndexAnalyzers(), mapperService.getIndexSettings(), - ProvidedIdFieldMapper.NO_FIELD_DATA + ProvidedIdFieldMapper.NO_FIELD_DATA, + query -> { + throw new UnsupportedOperationException(); + } ); TextFieldMapper.PARSER.parse("some-field", fieldNode, olderContext); @@ -128,7 +131,10 @@ public void testMultiFieldWithinMultiField() throws IOException { ScriptCompiler.NONE, mapperService.getIndexAnalyzers(), mapperService.getIndexSettings(), - ProvidedIdFieldMapper.NO_FIELD_DATA + ProvidedIdFieldMapper.NO_FIELD_DATA, + query -> { + throw new UnsupportedOperationException(); + } ); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> { diff --git a/server/src/test/java/org/elasticsearch/index/query/SearchExecutionContextTests.java b/server/src/test/java/org/elasticsearch/index/query/SearchExecutionContextTests.java index 6d8a22e7850e4..9cd1df700a618 100644 --- a/server/src/test/java/org/elasticsearch/index/query/SearchExecutionContextTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/SearchExecutionContextTests.java @@ -548,7 +548,10 @@ private static MapperService createMapperService(IndexSettings indexSettings, Ma ScriptCompiler.NONE, indexAnalyzers, indexSettings, - indexSettings.getMode().buildIdFieldMapper(() -> true) + indexSettings.getMode().buildIdFieldMapper(() -> true), + query -> { + throw new UnsupportedOperationException(); + } ) ); when(mapperService.isMultiField(anyString())).then( diff --git a/test/framework/src/main/java/org/elasticsearch/index/MapperTestUtils.java b/test/framework/src/main/java/org/elasticsearch/index/MapperTestUtils.java index 5025299b09b64..913caba615a67 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/MapperTestUtils.java +++ b/test/framework/src/main/java/org/elasticsearch/index/MapperTestUtils.java @@ -8,15 +8,18 @@ package org.elasticsearch.index; +import org.apache.lucene.util.Accountable; import org.elasticsearch.TransportVersion; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.env.Environment; import org.elasticsearch.index.analysis.IndexAnalyzers; +import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.mapper.MapperMetrics; import org.elasticsearch.index.mapper.MapperRegistry; import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.similarity.SimilarityService; import org.elasticsearch.indices.IndicesModule; import org.elasticsearch.script.ScriptCompiler; @@ -58,6 +61,13 @@ public static MapperService newMapperService( IndexSettings indexSettings = IndexSettingsModule.newIndexSettings(indexName, finalSettings); IndexAnalyzers indexAnalyzers = createTestAnalysis(indexSettings, finalSettings).indexAnalyzers; SimilarityService similarityService = new SimilarityService(indexSettings, null, Collections.emptyMap()); + BitsetFilterCache bitsetFilterCache = new BitsetFilterCache(indexSettings, new BitsetFilterCache.Listener() { + @Override + public void onCache(ShardId shardId, Accountable accountable) {} + + @Override + public void onRemoval(ShardId shardId, Accountable accountable) {} + }); return new MapperService( () -> TransportVersion.current(), indexSettings, @@ -68,6 +78,7 @@ public static MapperService newMapperService( () -> null, indexSettings.getMode().idFieldMapperWithoutFieldData(), ScriptCompiler.NONE, + bitsetFilterCache::getBitSetProducer, MapperMetrics.NOOP ); } diff --git a/test/framework/src/main/java/org/elasticsearch/index/engine/TranslogHandler.java b/test/framework/src/main/java/org/elasticsearch/index/engine/TranslogHandler.java index c2da7a561c041..dc626a3228685 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/engine/TranslogHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/index/engine/TranslogHandler.java @@ -55,6 +55,9 @@ public TranslogHandler(NamedXContentRegistry xContentRegistry, IndexSettings ind () -> null, indexSettings.getMode().idFieldMapperWithoutFieldData(), null, + query -> { + throw new UnsupportedOperationException("The bitset filter cache is not available in translog operations"); + }, MapperMetrics.NOOP ); } diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java index 7d13e33be0db7..388d8d6fa6ffd 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java @@ -254,6 +254,14 @@ public MapperService build() { getPlugins().stream().filter(p -> p instanceof MapperPlugin).map(p -> (MapperPlugin) p).collect(toList()) ).getMapperRegistry(); + BitsetFilterCache bitsetFilterCache = new BitsetFilterCache(indexSettings, new BitsetFilterCache.Listener() { + @Override + public void onCache(ShardId shardId, Accountable accountable) {} + + @Override + public void onRemoval(ShardId shardId, Accountable accountable) {} + }); + return new MapperService( () -> TransportVersion.current(), indexSettings, @@ -266,6 +274,7 @@ public MapperService build() { }, indexSettings.getMode().buildIdFieldMapper(idFieldDataEnabled), scriptCompiler, + bitsetFilterCache::getBitSetProducer, mapperMetrics ); } diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/TestDocumentParserContext.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/TestDocumentParserContext.java index d4c238322e28a..5243ef85cdb76 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/TestDocumentParserContext.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/TestDocumentParserContext.java @@ -63,7 +63,10 @@ private TestDocumentParserContext(MappingLookup mappingLookup, SourceToParse sou null, (type, name) -> Lucene.STANDARD_ANALYZER, MapperTestCase.createIndexSettings(IndexVersion.current(), settings), - null + null, + query -> { + throw new UnsupportedOperationException(); + } ), source, mappingLookup.getMapping().getRoot(), diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index 134352a4f8af4..d39a8df80c26d 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -1284,7 +1284,10 @@ private static class MockParserContext extends MappingParserContext { ScriptCompiler.NONE, null, indexSettings, - null + null, + query -> { + throw new UnsupportedOperationException(); + } ); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java index 6ef8d3d8a6a1b..271df2a971fb1 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java @@ -467,6 +467,13 @@ private static class ServiceHolder implements Closeable { IndexAnalyzers indexAnalyzers = analysisModule.getAnalysisRegistry().build(IndexCreationContext.CREATE_INDEX, idxSettings); scriptService = new MockScriptService(Settings.EMPTY, scriptModule.engines, scriptModule.contexts); similarityService = new SimilarityService(idxSettings, null, Collections.emptyMap()); + this.bitsetFilterCache = new BitsetFilterCache(idxSettings, new BitsetFilterCache.Listener() { + @Override + public void onCache(ShardId shardId, Accountable accountable) {} + + @Override + public void onRemoval(ShardId shardId, Accountable accountable) {} + }); MapperRegistry mapperRegistry = indicesModule.getMapperRegistry(); mapperService = new MapperService( clusterService, @@ -478,23 +485,12 @@ private static class ServiceHolder implements Closeable { () -> createShardContext(null), idxSettings.getMode().idFieldMapperWithoutFieldData(), ScriptCompiler.NONE, + bitsetFilterCache::getBitSetProducer, MapperMetrics.NOOP ); IndicesFieldDataCache indicesFieldDataCache = new IndicesFieldDataCache(nodeSettings, new IndexFieldDataCache.Listener() { }); indexFieldDataService = new IndexFieldDataService(idxSettings, indicesFieldDataCache, new NoneCircuitBreakerService()); - bitsetFilterCache = new BitsetFilterCache(idxSettings, new BitsetFilterCache.Listener() { - @Override - public void onCache(ShardId shardId, Accountable accountable) { - - } - - @Override - public void onRemoval(ShardId shardId, Accountable accountable) { - - } - }); - if (registerType) { mapperService.merge( "_doc", From b8b6513fafb38fb34d5a6f5d1a955e6c811949bb Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Wed, 5 Jun 2024 12:48:30 +0200 Subject: [PATCH 02/30] Stop allocating BytesRefBuilder for each StreamOutput (#109251) We can use a thread-local here, no need to allocate a fresh instance for each output when they mostly go unused anyway. --- .../java/org/elasticsearch/common/io/stream/StreamOutput.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java index 33fb000c1bca2..833e7f27852c8 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java @@ -370,11 +370,12 @@ public void writeOptionalText(@Nullable Text text) throws IOException { } } - private final BytesRefBuilder spare = new BytesRefBuilder(); + private static final ThreadLocal spareBytesRefBuilder = ThreadLocal.withInitial(BytesRefBuilder::new); public void writeText(Text text) throws IOException { if (text.hasBytes() == false) { final String string = text.string(); + var spare = spareBytesRefBuilder.get(); spare.copyChars(string); writeInt(spare.length()); write(spare.bytes(), 0, spare.length()); From db130982d94ca9f5952c6c2a5436be048ccfbf38 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Wed, 5 Jun 2024 07:02:37 -0400 Subject: [PATCH 03/30] ESQL: Link to the function guide (#109351) This adds some links to the function guide in the javadoc for the superclass of scalar functions. --- .../expression/function/scalar/EsqlScalarFunction.java | 10 ++++++++++ .../scalar/convert/AbstractConvertFunction.java | 4 ++++ .../scalar/multivalue/AbstractMultivalueFunction.java | 4 ++++ 3 files changed, 18 insertions(+) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java index 797a4c31f0f6c..4f991af54ecff 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java @@ -14,6 +14,16 @@ import java.util.List; +/** + * A {@code ScalarFunction} is a {@code Function} that takes values from some + * operation and converts each to another value. An example would be + * {@code ABS()}, which takes one value at a time, applies a function to the + * value (abs) and returns a new value. + *

+ * We have a guide for writing these in the javadoc for + * {@link org.elasticsearch.xpack.esql.expression.function.scalar}. + *

+ */ public abstract class EsqlScalarFunction extends ScalarFunction implements EvaluatorMapper { protected EsqlScalarFunction(Source source) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/AbstractConvertFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/AbstractConvertFunction.java index 955ce1646813f..f1d0256a1f1c7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/AbstractConvertFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/AbstractConvertFunction.java @@ -38,6 +38,10 @@ /** * Base class for functions that converts a field into a function-specific type. + *

+ * We have a guide for writing these in the javadoc for + * {@link org.elasticsearch.xpack.esql.expression.function.scalar}. + *

*/ public abstract class AbstractConvertFunction extends UnaryScalarFunction { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java index 2ceedd14d6fd8..5aa6dad7b2a5b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java @@ -19,6 +19,10 @@ /** * Base class for functions that reduce multivalued fields into single valued fields. + *

+ * We have a guide for writing these in the javadoc for + * {@link org.elasticsearch.xpack.esql.expression.function.scalar}. + *

*/ public abstract class AbstractMultivalueFunction extends UnaryScalarFunction { protected AbstractMultivalueFunction(Source source, Expression field) { From a36fef011c6167b83d09293d3e9672598f567a01 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Wed, 5 Jun 2024 07:03:22 -0400 Subject: [PATCH 04/30] ESQL: Entirely drop `version` field (#109376) This field is no longer supported. --- .../xpack/esql/action/RequestXContent.java | 2 -- .../esql/action/EsqlQueryRequestTests.java | 33 +------------------ .../rest-api-spec/test/esql/10_basic.yml | 12 +++++++ 3 files changed, 13 insertions(+), 34 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/RequestXContent.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/RequestXContent.java index 9ffd48d9d0c3b..793f453d5ebf5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/RequestXContent.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/RequestXContent.java @@ -46,7 +46,6 @@ final class RequestXContent { PARAM_PARSER.declareString(constructorArg(), TYPE); } - static final ParseField ESQL_VERSION_FIELD = new ParseField("version"); static final ParseField QUERY_FIELD = new ParseField("query"); private static final ParseField COLUMNAR_FIELD = new ParseField("columnar"); private static final ParseField FILTER_FIELD = new ParseField("filter"); @@ -74,7 +73,6 @@ static EsqlQueryRequest parseAsync(XContentParser parser) { } private static void objectParserCommon(ObjectParser parser) { - parser.declareString((str, consumer) -> {}, ESQL_VERSION_FIELD); parser.declareString(EsqlQueryRequest::query, QUERY_FIELD); parser.declareBoolean(EsqlQueryRequest::columnar, COLUMNAR_FIELD); parser.declareObject(EsqlQueryRequest::filter, (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), FILTER_FIELD); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryRequestTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryRequestTests.java index 5fafafe34bd23..6328853eea3c6 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryRequestTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryRequestTests.java @@ -160,41 +160,10 @@ public void testRejectUnknownFields() { }""", "unknown field [asdf]"); } - public void testAnyVersionIsValid() throws IOException { - String validVersionString = randomAlphaOfLength(5); - - String json = String.format(Locale.ROOT, """ - { - "version": "%s", - "query": "ROW x = 1" - } - """, validVersionString); - - EsqlQueryRequest request = parseEsqlQueryRequest(json, randomBoolean()); - assertNull(request.validate()); - - request = parseEsqlQueryRequestAsync(json); - assertNull(request.validate()); - } - - public void testMissingVersionIsValid() throws IOException { - String missingVersion = randomBoolean() ? "" : ", \"version\": \"\""; - String json = String.format(Locale.ROOT, """ - { - "columnar": true, - "query": "row x = 1" - %s - }""", missingVersion); - - EsqlQueryRequest request = parseEsqlQueryRequest(json, randomBoolean()); - assertNull(request.validate()); - } - public void testMissingQueryIsNotValid() throws IOException { String json = """ { - "columnar": true, - "version": "snapshot" + "columnar": true }"""; EsqlQueryRequest request = parseEsqlQueryRequest(json, randomBoolean()); assertNotNull(request.validate()); diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/10_basic.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/10_basic.yml index 52d390e7b288b..ab0261d916630 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/10_basic.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/10_basic.yml @@ -333,3 +333,15 @@ setup: - match: {values.0: ["1",2.0,null,true,123,123]} - match: {values.1: ["1",2.0,null,true,123,123]} - match: {values.2: ["1",2.0,null,true,123,123]} + +--- +version is not allowed: + - requires: + cluster_features: ["gte_v8.14.0"] + reason: version allowed in 8.13.latest + - do: + catch: /unknown field \[version\]/ + esql.query: + body: + query: 'from test' + version: cat From 3ff27a999271f758d81d1db4aa21ec077a4634d1 Mon Sep 17 00:00:00 2001 From: Luigi Dell'Aquila Date: Wed, 5 Jun 2024 13:15:25 +0200 Subject: [PATCH 05/30] ES|QL: Fix MvAppend tests (#109383) --- muted-tests.yml | 3 --- .../function/scalar/multivalue/MvAppendTests.java | 8 ++++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index be4845ccec9e9..249541ff97926 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -53,9 +53,6 @@ tests: - class: "org.elasticsearch.client.RestClientSingleHostIntegTests" issue: "https://github.com/elastic/elasticsearch/issues/102717" method: "testRequestResetAndAbort" -- class: org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAppendTests - issue: https://github.com/elastic/elasticsearch/issues/109373 - method: testEvaluateBlockWithoutNulls - class: "org.elasticsearch.xpack.deprecation.DeprecationHttpIT" issue: "https://github.com/elastic/elasticsearch/issues/108628" method: "testDeprecatedSettingsReturnWarnings" diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvAppendTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvAppendTests.java index 07dab82c50607..6361360652a87 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvAppendTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvAppendTests.java @@ -235,8 +235,8 @@ private static void bytesRefs(List suppliers) { })); suppliers.add(new TestCaseSupplier(List.of(DataType.GEO_SHAPE, DataType.GEO_SHAPE), () -> { - List field1 = randomList(1, 10, () -> new BytesRef(GEO.asWkt(GeometryTestUtils.randomGeometry(randomBoolean())))); - List field2 = randomList(1, 10, () -> new BytesRef(GEO.asWkt(GeometryTestUtils.randomGeometry(randomBoolean())))); + List field1 = randomList(1, 5, () -> new BytesRef(GEO.asWkt(GeometryTestUtils.randomGeometry(randomBoolean())))); + List field2 = randomList(1, 5, () -> new BytesRef(GEO.asWkt(GeometryTestUtils.randomGeometry(randomBoolean())))); var result = new ArrayList<>(field1); result.addAll(field2); return new TestCaseSupplier.TestCase( @@ -251,8 +251,8 @@ private static void bytesRefs(List suppliers) { })); suppliers.add(new TestCaseSupplier(List.of(DataType.CARTESIAN_SHAPE, DataType.CARTESIAN_SHAPE), () -> { - List field1 = randomList(1, 10, () -> new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomGeometry(randomBoolean())))); - List field2 = randomList(1, 10, () -> new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomGeometry(randomBoolean())))); + List field1 = randomList(1, 5, () -> new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomGeometry(randomBoolean())))); + List field2 = randomList(1, 5, () -> new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomGeometry(randomBoolean())))); var result = new ArrayList<>(field1); result.addAll(field2); return new TestCaseSupplier.TestCase( From cd84749d8726b3b2d80d4731be9ba33d002161b2 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 5 Jun 2024 07:38:58 -0400 Subject: [PATCH 06/30] AwaitsFix: https://github.com/elastic/elasticsearch/issues/109391 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 249541ff97926..3d64f87144bd3 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -56,6 +56,9 @@ tests: - class: "org.elasticsearch.xpack.deprecation.DeprecationHttpIT" issue: "https://github.com/elastic/elasticsearch/issues/108628" method: "testDeprecatedSettingsReturnWarnings" +- class: "org.elasticsearch.xpack.inference.InferenceCrudIT" + issue: "https://github.com/elastic/elasticsearch/issues/109391" + method: "testDeleteEndpointWhileReferencedByPipeline" # Examples: # From fdb5058b13abbbc5ce3a5d7c0ab1a7878c72176c Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Wed, 5 Jun 2024 08:25:25 -0400 Subject: [PATCH 07/30] [ML] Inference API rate limit queuing logic refactor (#107706) * Adding new executor * Adding in queuing logic * working tests * Added cleanup task * Update docs/changelog/107706.yaml * Updating yml * deregistering callbacks for settings changes * Cleaning up code * Update docs/changelog/107706.yaml * Fixing rate limit settings bug and only sleeping least amount * Removing debug logging * Removing commented code * Renaming feedback * fixing tests * Updating docs and validation * Fixing source blocks * Adjusting cancel logic * Reformatting ascii * Addressing feedback * adding rate limiting for google embeddings and mistral --------- Co-authored-by: Elastic Machine --- docs/changelog/107706.yaml | 5 + .../inference/put-inference.asciidoc | 289 ++++++---- .../org/elasticsearch/core/TimeValue.java | 7 + .../common/unit/TimeValueTests.java | 7 + .../action/cohere/CohereActionCreator.java | 1 + .../action/cohere/CohereEmbeddingsAction.java | 1 + ...eAiStudioChatCompletionRequestManager.java | 6 +- ...AzureAiStudioEmbeddingsRequestManager.java | 6 +- .../AzureOpenAiCompletionRequestManager.java | 6 +- .../AzureOpenAiEmbeddingsRequestManager.java | 6 +- .../http/sender/BaseRequestManager.java | 7 +- .../CohereCompletionRequestManager.java | 6 +- .../CohereEmbeddingsRequestManager.java | 6 +- .../sender/CohereRerankRequestManager.java | 6 +- .../sender/ExecutableInferenceRequest.java | 3 +- ...oogleAiStudioCompletionRequestManager.java | 6 +- ...oogleAiStudioEmbeddingsRequestManager.java | 6 +- .../http/sender/HttpRequestSender.java | 20 +- .../sender/HuggingFaceRequestManager.java | 14 +- .../http/sender/InferenceRequest.java | 4 +- .../MistralEmbeddingsRequestManager.java | 6 +- .../external/http/sender/NoopTask.java | 52 -- .../OpenAiCompletionRequestManager.java | 6 +- .../OpenAiEmbeddingsRequestManager.java | 6 +- .../http/sender/RequestExecutorService.java | 500 ++++++++++++------ .../RequestExecutorServiceSettings.java | 86 ++- .../external/http/sender/RequestManager.java | 8 +- .../external/http/sender/RequestTask.java | 2 +- .../http/sender/SingleRequestManager.java | 48 -- .../inference/services/SenderService.java | 2 +- .../azureaistudio/AzureAiStudioService.java | 2 +- .../AzureAiStudioServiceSettings.java | 10 +- .../azureopenai/AzureOpenAiService.java | 10 +- .../AzureOpenAiCompletionModel.java | 6 +- .../AzureOpenAiCompletionServiceSettings.java | 19 +- .../AzureOpenAiEmbeddingsServiceSettings.java | 12 +- .../services/cohere/CohereService.java | 15 +- .../cohere/CohereServiceSettings.java | 14 +- .../completion/CohereCompletionModel.java | 6 +- .../CohereCompletionServiceSettings.java | 14 +- .../googleaistudio/GoogleAiStudioService.java | 6 +- .../GoogleAiStudioCompletionModel.java | 6 +- ...ogleAiStudioCompletionServiceSettings.java | 14 +- .../GoogleAiStudioEmbeddingsModel.java | 6 +- ...ogleAiStudioEmbeddingsServiceSettings.java | 14 +- .../huggingface/HuggingFaceBaseService.java | 19 +- .../huggingface/HuggingFaceService.java | 15 +- .../HuggingFaceServiceSettings.java | 13 +- .../elser/HuggingFaceElserModel.java | 6 +- .../elser/HuggingFaceElserService.java | 6 +- .../HuggingFaceElserServiceSettings.java | 14 +- .../HuggingFaceEmbeddingsModel.java | 6 +- .../MistralEmbeddingsServiceSettings.java | 11 +- .../services/openai/OpenAiService.java | 3 +- .../completion/OpenAiChatCompletionModel.java | 6 +- .../OpenAiChatCompletionServiceSettings.java | 14 +- .../OpenAiEmbeddingsServiceSettings.java | 21 +- .../services/settings/RateLimitSettings.java | 15 +- .../AzureAiStudioActionAndCreatorTests.java | 5 +- .../AzureOpenAiActionCreatorTests.java | 19 +- .../AzureOpenAiCompletionActionTests.java | 3 +- .../AzureOpenAiEmbeddingsActionTests.java | 3 +- .../cohere/CohereActionCreatorTests.java | 7 +- .../cohere/CohereCompletionActionTests.java | 6 +- .../cohere/CohereEmbeddingsActionTests.java | 4 +- .../GoogleAiStudioCompletionActionTests.java | 4 +- .../GoogleAiStudioEmbeddingsActionTests.java | 2 +- .../HuggingFaceActionCreatorTests.java | 13 +- .../openai/OpenAiActionCreatorTests.java | 23 +- .../OpenAiChatCompletionActionTests.java | 5 +- .../openai/OpenAiEmbeddingsActionTests.java | 2 +- .../http/sender/BaseRequestManagerTests.java | 122 +++++ .../http/sender/HttpRequestSenderTests.java | 29 +- .../RequestExecutorServiceSettingsTests.java | 12 + .../sender/RequestExecutorServiceTests.java | 345 ++++++++---- ...torTests.java => RequestManagerTests.java} | 38 +- .../sender/SingleRequestManagerTests.java | 27 - .../services/SenderServiceTests.java | 9 +- .../AzureAiStudioServiceTests.java | 5 +- ...dioChatCompletionServiceSettingsTests.java | 3 +- ...iStudioEmbeddingsServiceSettingsTests.java | 2 +- .../azureopenai/AzureOpenAiServiceTests.java | 5 +- ...eOpenAiCompletionServiceSettingsTests.java | 16 +- ...eOpenAiEmbeddingsServiceSettingsTests.java | 4 +- .../services/cohere/CohereServiceTests.java | 5 +- .../CohereCompletionModelTests.java | 4 +- .../CohereCompletionServiceSettingsTests.java | 19 +- .../CohereEmbeddingsServiceSettingsTests.java | 15 - .../CohereRerankServiceSettingsTests.java | 14 - .../GoogleAiStudioServiceTests.java | 5 +- .../GoogleAiStudioCompletionModelTests.java | 4 +- ...iStudioCompletionServiceSettingsTests.java | 18 +- ...iStudioEmbeddingsServiceSettingsTests.java | 21 +- .../HuggingFaceBaseServiceTests.java | 9 +- .../HuggingFaceServiceSettingsTests.java | 38 +- .../HuggingFaceElserServiceSettingsTests.java | 33 +- .../services/mistral/MistralServiceTests.java | 5 +- ...MistralEmbeddingsServiceSettingsTests.java | 12 - .../services/openai/OpenAiServiceTests.java | 5 +- ...nAiChatCompletionServiceSettingsTests.java | 37 +- .../OpenAiEmbeddingsServiceSettingsTests.java | 4 +- .../settings/RateLimitSettingsTests.java | 20 +- 102 files changed, 1487 insertions(+), 925 deletions(-) create mode 100644 docs/changelog/107706.yaml delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/NoopTask.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManager.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/{ExecutableRequestCreatorTests.java => RequestManagerTests.java} (56%) delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManagerTests.java diff --git a/docs/changelog/107706.yaml b/docs/changelog/107706.yaml new file mode 100644 index 0000000000000..76b7f662bf0e0 --- /dev/null +++ b/docs/changelog/107706.yaml @@ -0,0 +1,5 @@ +pr: 107706 +summary: Add rate limiting support for the Inference API +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/reference/inference/put-inference.asciidoc b/docs/reference/inference/put-inference.asciidoc index 354cee3f6a990..f805bc0cc92f7 100644 --- a/docs/reference/inference/put-inference.asciidoc +++ b/docs/reference/inference/put-inference.asciidoc @@ -7,21 +7,17 @@ experimental[] Creates an {infer} endpoint to perform an {infer} task. IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in -{ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure -OpenAI, Google AI Studio or Hugging Face. For built-in models and models -uploaded though Eland, the {infer} APIs offer an alternative way to use and -manage trained models. However, if you do not plan to use the {infer} APIs to -use these models or if you want to use non-NLP models, use the +{ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure OpenAI, Google AI Studio or Hugging Face. +For built-in models and models uploaded though Eland, the {infer} APIs offer an alternative way to use and manage trained models. +However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. - [discrete] [[put-inference-api-request]] ==== {api-request-title} `PUT /_inference//` - [discrete] [[put-inference-api-prereqs]] ==== {api-prereq-title} @@ -29,7 +25,6 @@ use these models or if you want to use non-NLP models, use the * Requires the `manage_inference` <> (the built-in `inference_admin` role grants this privilege) - [discrete] [[put-inference-api-desc]] ==== {api-description-title} @@ -48,25 +43,23 @@ The following services are available through the {infer} API: * Hugging Face * OpenAI - [discrete] [[put-inference-api-path-params]] ==== {api-path-parms-title} - ``:: (Required, string) The unique identifier of the {infer} endpoint. ``:: (Required, string) -The type of the {infer} task that the model will perform. Available task types: +The type of the {infer} task that the model will perform. +Available task types: * `completion`, * `rerank`, * `sparse_embedding`, * `text_embedding`. - [discrete] [[put-inference-api-request-body]] ==== {api-request-body-title} @@ -78,21 +71,18 @@ Available services: * `azureopenai`: specify the `completion` or `text_embedding` task type to use the Azure OpenAI service. * `azureaistudio`: specify the `completion` or `text_embedding` task type to use the Azure AI Studio service. -* `cohere`: specify the `completion`, `text_embedding` or the `rerank` task type to use the -Cohere service. -* `elasticsearch`: specify the `text_embedding` task type to use the E5 -built-in model or text embedding models uploaded by Eland. +* `cohere`: specify the `completion`, `text_embedding` or the `rerank` task type to use the Cohere service. +* `elasticsearch`: specify the `text_embedding` task type to use the E5 built-in model or text embedding models uploaded by Eland. * `elser`: specify the `sparse_embedding` task type to use the ELSER service. * `googleaistudio`: specify the `completion` task to use the Google AI Studio service. -* `hugging_face`: specify the `text_embedding` task type to use the Hugging Face -service. -* `openai`: specify the `completion` or `text_embedding` task type to use the -OpenAI service. +* `hugging_face`: specify the `text_embedding` task type to use the Hugging Face service. +* `openai`: specify the `completion` or `text_embedding` task type to use the OpenAI service. `service_settings`:: (Required, object) -Settings used to install the {infer} model. These settings are specific to the +Settings used to install the {infer} model. +These settings are specific to the `service` you specified. + .`service_settings` for the `azureaistudio` service @@ -104,11 +94,10 @@ Settings used to install the {infer} model. These settings are specific to the A valid API key of your Azure AI Studio model deployment. This key can be found on the overview page for your deployment in the management section of your https://ai.azure.com/[Azure AI Studio] account. -IMPORTANT: You need to provide the API key only once, during the {infer} model -creation. The <> does not retrieve your API key. After -creating the {infer} model, you cannot change the associated API key. If you -want to use a different API key, delete the {infer} model and recreate it with -the same name and the updated API key. +IMPORTANT: You need to provide the API key only once, during the {infer} model creation. +The <> does not retrieve your API key. +After creating the {infer} model, you cannot change the associated API key. +If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key. `target`::: (Required, string) @@ -142,11 +131,13 @@ For "real-time" endpoints which are billed per hour of usage, specify `realtime` By default, the `azureaistudio` service sets the number of requests allowed per minute to `240`. This helps to minimize the number of rate limit errors returned from Azure AI Studio. To modify this, set the `requests_per_minute` setting of this object in your service settings: -``` ++ +[source,text] +---- "rate_limit": { "requests_per_minute": <> } -``` +---- ===== + .`service_settings` for the `azureopenai` service @@ -181,6 +172,22 @@ Your Azure OpenAI deployments can be found though the https://oai.azure.com/[Azu The Azure API version ID to use. We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings[latest supported non-preview version]. +`rate_limit`::: +(Optional, object) +The `azureopenai` service sets a default number of requests allowed per minute depending on the task type. +For `text_embedding` it is set to `1440`. +For `completion` it is set to `120`. +This helps to minimize the number of rate limit errors returned from Azure. +To modify this, set the `requests_per_minute` setting of this object in your service settings: ++ +[source,text] +---- +"rate_limit": { + "requests_per_minute": <> +} +---- ++ +More information about the rate limits for Azure can be found in the https://learn.microsoft.com/en-us/azure/ai-services/openai/quotas-limits[Quota limits docs] and https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/quota?tabs=rest[How to change the quotas]. ===== + .`service_settings` for the `cohere` service @@ -188,24 +195,24 @@ We recommend using the https://learn.microsoft.com/en-us/azure/ai-services/opena ===== `api_key`::: (Required, string) -A valid API key of your Cohere account. You can find your Cohere API keys or you -can create a new one +A valid API key of your Cohere account. +You can find your Cohere API keys or you can create a new one https://dashboard.cohere.com/api-keys[on the API keys settings page]. -IMPORTANT: You need to provide the API key only once, during the {infer} model -creation. The <> does not retrieve your API key. After -creating the {infer} model, you cannot change the associated API key. If you -want to use a different API key, delete the {infer} model and recreate it with -the same name and the updated API key. +IMPORTANT: You need to provide the API key only once, during the {infer} model creation. +The <> does not retrieve your API key. +After creating the {infer} model, you cannot change the associated API key. +If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key. `embedding_type`:: (Optional, string) -Only for `text_embedding`. Specifies the types of embeddings you want to get -back. Defaults to `float`. +Only for `text_embedding`. +Specifies the types of embeddings you want to get back. +Defaults to `float`. Valid values are: - * `byte`: use it for signed int8 embeddings (this is a synonym of `int8`). - * `float`: use it for the default float embeddings. - * `int8`: use it for signed int8 embeddings. +* `byte`: use it for signed int8 embeddings (this is a synonym of `int8`). +* `float`: use it for the default float embeddings. +* `int8`: use it for signed int8 embeddings. `model_id`:: (Optional, string) @@ -214,50 +221,68 @@ To review the available `rerank` models, refer to the https://docs.cohere.com/reference/rerank-1[Cohere docs]. To review the available `text_embedding` models, refer to the -https://docs.cohere.com/reference/embed[Cohere docs]. The default value for +https://docs.cohere.com/reference/embed[Cohere docs]. +The default value for `text_embedding` is `embed-english-v2.0`. + +`rate_limit`::: +(Optional, object) +By default, the `cohere` service sets the number of requests allowed per minute to `10000`. +This value is the same for all task types. +This helps to minimize the number of rate limit errors returned from Cohere. +To modify this, set the `requests_per_minute` setting of this object in your service settings: ++ +[source,text] +---- +"rate_limit": { + "requests_per_minute": <> +} +---- ++ +More information about Cohere's rate limits can be found in https://docs.cohere.com/docs/going-live#production-key-specifications[Cohere's production key docs]. + ===== + .`service_settings` for the `elasticsearch` service [%collapsible%closed] ===== + `model_id`::: (Required, string) -The name of the model to use for the {infer} task. It can be the -ID of either a built-in model (for example, `.multilingual-e5-small` for E5) or -a text embedding model already +The name of the model to use for the {infer} task. +It can be the ID of either a built-in model (for example, `.multilingual-e5-small` for E5) or a text embedding model already {ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland]. `num_allocations`::: (Required, integer) -The number of model allocations to create. `num_allocations` must not exceed the -number of available processors per node divided by the `num_threads`. +The number of model allocations to create. `num_allocations` must not exceed the number of available processors per node divided by the `num_threads`. `num_threads`::: (Required, integer) -The number of threads to use by each model allocation. `num_threads` must not -exceed the number of available processors per node divided by the number of -allocations. Must be a power of 2. Max allowed value is 32. +The number of threads to use by each model allocation. `num_threads` must not exceed the number of available processors per node divided by the number of allocations. +Must be a power of 2. Max allowed value is 32. + ===== + .`service_settings` for the `elser` service [%collapsible%closed] ===== + `num_allocations`::: (Required, integer) -The number of model allocations to create. `num_allocations` must not exceed the -number of available processors per node divided by the `num_threads`. +The number of model allocations to create. `num_allocations` must not exceed the number of available processors per node divided by the `num_threads`. `num_threads`::: (Required, integer) -The number of threads to use by each model allocation. `num_threads` must not -exceed the number of available processors per node divided by the number of -allocations. Must be a power of 2. Max allowed value is 32. +The number of threads to use by each model allocation. `num_threads` must not exceed the number of available processors per node divided by the number of allocations. +Must be a power of 2. Max allowed value is 32. + ===== + .`service_settings` for the `googleiastudio` service [%collapsible%closed] ===== + `api_key`::: (Required, string) A valid API key for the Google Gemini API. @@ -274,76 +299,113 @@ This helps to minimize the number of rate limit errors returned from Google AI S To modify this, set the `requests_per_minute` setting of this object in your service settings: + -- -``` +[source,text] +---- "rate_limit": { "requests_per_minute": <> } -``` +---- -- + ===== + .`service_settings` for the `hugging_face` service [%collapsible%closed] ===== + `api_key`::: (Required, string) -A valid access token of your Hugging Face account. You can find your Hugging -Face access tokens or you can create a new one +A valid access token of your Hugging Face account. +You can find your Hugging Face access tokens or you can create a new one https://huggingface.co/settings/tokens[on the settings page]. -IMPORTANT: You need to provide the API key only once, during the {infer} model -creation. The <> does not retrieve your API key. After -creating the {infer} model, you cannot change the associated API key. If you -want to use a different API key, delete the {infer} model and recreate it with -the same name and the updated API key. +IMPORTANT: You need to provide the API key only once, during the {infer} model creation. +The <> does not retrieve your API key. +After creating the {infer} model, you cannot change the associated API key. +If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key. `url`::: (Required, string) The URL endpoint to use for the requests. + +`rate_limit`::: +(Optional, object) +By default, the `huggingface` service sets the number of requests allowed per minute to `3000`. +This helps to minimize the number of rate limit errors returned from Hugging Face. +To modify this, set the `requests_per_minute` setting of this object in your service settings: ++ +[source,text] +---- +"rate_limit": { + "requests_per_minute": <> +} +---- + ===== + .`service_settings` for the `openai` service [%collapsible%closed] ===== + `api_key`::: (Required, string) -A valid API key of your OpenAI account. You can find your OpenAI API keys in -your OpenAI account under the +A valid API key of your OpenAI account. +You can find your OpenAI API keys in your OpenAI account under the https://platform.openai.com/api-keys[API keys section]. -IMPORTANT: You need to provide the API key only once, during the {infer} model -creation. The <> does not retrieve your API key. After -creating the {infer} model, you cannot change the associated API key. If you -want to use a different API key, delete the {infer} model and recreate it with -the same name and the updated API key. +IMPORTANT: You need to provide the API key only once, during the {infer} model creation. +The <> does not retrieve your API key. +After creating the {infer} model, you cannot change the associated API key. +If you want to use a different API key, delete the {infer} model and recreate it with the same name and the updated API key. `model_id`::: (Required, string) -The name of the model to use for the {infer} task. Refer to the +The name of the model to use for the {infer} task. +Refer to the https://platform.openai.com/docs/guides/embeddings/what-are-embeddings[OpenAI documentation] for the list of available text embedding models. `organization_id`::: (Optional, string) -The unique identifier of your organization. You can find the Organization ID in -your OpenAI account under +The unique identifier of your organization. +You can find the Organization ID in your OpenAI account under https://platform.openai.com/account/organization[**Settings** > **Organizations**]. `url`::: (Optional, string) -The URL endpoint to use for the requests. Can be changed for testing purposes. +The URL endpoint to use for the requests. +Can be changed for testing purposes. Defaults to `https://api.openai.com/v1/embeddings`. +`rate_limit`::: +(Optional, object) +The `openai` service sets a default number of requests allowed per minute depending on the task type. +For `text_embedding` it is set to `3000`. +For `completion` it is set to `500`. +This helps to minimize the number of rate limit errors returned from Azure. +To modify this, set the `requests_per_minute` setting of this object in your service settings: ++ +[source,text] +---- +"rate_limit": { + "requests_per_minute": <> +} +---- ++ +More information about the rate limits for OpenAI can be found in your https://platform.openai.com/account/limits[Account limits]. + ===== `task_settings`:: (Optional, object) -Settings to configure the {infer} task. These settings are specific to the +Settings to configure the {infer} task. +These settings are specific to the `` you specified. + .`task_settings` for the `completion` task type [%collapsible%closed] ===== + `do_sample`::: (Optional, float) For the `azureaistudio` service only. @@ -358,8 +420,8 @@ Defaults to 64. `user`::: (Optional, string) -For `openai` service only. Specifies the user issuing the request, which can be -used for abuse detection. +For `openai` service only. +Specifies the user issuing the request, which can be used for abuse detection. `temperature`::: (Optional, float) @@ -378,45 +440,46 @@ Should not be used if `temperature` is specified. .`task_settings` for the `rerank` task type [%collapsible%closed] ===== + `return_documents`:: (Optional, boolean) -For `cohere` service only. Specify whether to return doc text within the -results. +For `cohere` service only. +Specify whether to return doc text within the results. `top_n`:: (Optional, integer) -The number of most relevant documents to return, defaults to the number of the -documents. +The number of most relevant documents to return, defaults to the number of the documents. + ===== + .`task_settings` for the `text_embedding` task type [%collapsible%closed] ===== + `input_type`::: (Optional, string) -For `cohere` service only. Specifies the type of input passed to the model. +For `cohere` service only. +Specifies the type of input passed to the model. Valid values are: - * `classification`: use it for embeddings passed through a text classifier. - * `clusterning`: use it for the embeddings run through a clustering algorithm. - * `ingest`: use it for storing document embeddings in a vector database. - * `search`: use it for storing embeddings of search queries run against a - vector database to find relevant documents. +* `classification`: use it for embeddings passed through a text classifier. +* `clusterning`: use it for the embeddings run through a clustering algorithm. +* `ingest`: use it for storing document embeddings in a vector database. +* `search`: use it for storing embeddings of search queries run against a vector database to find relevant documents. `truncate`::: (Optional, string) -For `cohere` service only. Specifies how the API handles inputs longer than the -maximum token length. Defaults to `END`. Valid values are: - * `NONE`: when the input exceeds the maximum input token length an error is - returned. - * `START`: when the input exceeds the maximum input token length the start of - the input is discarded. - * `END`: when the input exceeds the maximum input token length the end of - the input is discarded. +For `cohere` service only. +Specifies how the API handles inputs longer than the maximum token length. +Defaults to `END`. +Valid values are: +* `NONE`: when the input exceeds the maximum input token length an error is returned. +* `START`: when the input exceeds the maximum input token length the start of the input is discarded. +* `END`: when the input exceeds the maximum input token length the end of the input is discarded. `user`::: (optional, string) -For `openai`, `azureopenai` and `azureaistudio` services only. Specifies the user issuing the -request, which can be used for abuse detection. +For `openai`, `azureopenai` and `azureaistudio` services only. +Specifies the user issuing the request, which can be used for abuse detection. ===== [discrete] @@ -470,7 +533,6 @@ PUT _inference/completion/azure_ai_studio_completion The list of chat completion models that you can choose from in your deployment can be found in the https://ai.azure.com/explore/models?selectedTask=chat-completion[Azure AI Studio model explorer]. - [discrete] [[inference-example-azureopenai]] ===== Azure OpenAI service @@ -519,7 +581,6 @@ The list of chat completion models that you can choose from in your Azure OpenAI * https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-models[GPT-4 and GPT-4 Turbo models] * https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35[GPT-3.5] - [discrete] [[inference-example-cohere]] ===== Cohere service @@ -565,7 +626,6 @@ PUT _inference/rerank/cohere-rerank For more examples, also review the https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation]. - [discrete] [[inference-example-e5]] ===== E5 via the `elasticsearch` service @@ -586,10 +646,9 @@ PUT _inference/text_embedding/my-e5-model } ------------------------------------------------------------ // TEST[skip:TBD] -<1> The `model_id` must be the ID of one of the built-in E5 models. Valid values -are `.multilingual-e5-small` and `.multilingual-e5-small_linux-x86_64`. For -further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation]. - +<1> The `model_id` must be the ID of one of the built-in E5 models. +Valid values are `.multilingual-e5-small` and `.multilingual-e5-small_linux-x86_64`. +For further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation]. [discrete] [[inference-example-elser]] @@ -597,8 +656,7 @@ further details, refer to the {ml-docs}/ml-nlp-e5.html[E5 model documentation]. The following example shows how to create an {infer} endpoint called `my-elser-model` to perform a `sparse_embedding` task type. -Refer to the {ml-docs}/ml-nlp-elser.html[ELSER model documentation] for more -info. +Refer to the {ml-docs}/ml-nlp-elser.html[ELSER model documentation] for more info. [source,console] ------------------------------------------------------------ @@ -672,16 +730,17 @@ PUT _inference/text_embedding/hugging-face-embeddings } ------------------------------------------------------------ // TEST[skip:TBD] -<1> A valid Hugging Face access token. You can find on the +<1> A valid Hugging Face access token. +You can find on the https://huggingface.co/settings/tokens[settings page of your account]. <2> The {infer} endpoint URL you created on Hugging Face. Create a new {infer} endpoint on -https://ui.endpoints.huggingface.co/[the Hugging Face endpoint page] to get an -endpoint URL. Select the model you want to use on the new endpoint creation page -- for example `intfloat/e5-small-v2` - then select the `Sentence Embeddings` -task under the Advanced configuration section. Create the endpoint. Copy the URL -after the endpoint initialization has been finished. +https://ui.endpoints.huggingface.co/[the Hugging Face endpoint page] to get an endpoint URL. +Select the model you want to use on the new endpoint creation page - for example `intfloat/e5-small-v2` - then select the `Sentence Embeddings` +task under the Advanced configuration section. +Create the endpoint. +Copy the URL after the endpoint initialization has been finished. [discrete] [[inference-example-hugging-face-supported-models]] @@ -695,7 +754,6 @@ The list of recommended models for the Hugging Face service: * https://huggingface.co/intfloat/multilingual-e5-base[multilingual-e5-base] * https://huggingface.co/intfloat/multilingual-e5-small[multilingual-e5-small] - [discrete] [[inference-example-eland]] ===== Models uploaded by Eland via the elasticsearch service @@ -716,11 +774,9 @@ PUT _inference/text_embedding/my-msmarco-minilm-model } ------------------------------------------------------------ // TEST[skip:TBD] -<1> The `model_id` must be the ID of a text embedding model which has already -been +<1> The `model_id` must be the ID of a text embedding model which has already been {ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland]. - [discrete] [[inference-example-openai]] ===== OpenAI service @@ -756,4 +812,3 @@ PUT _inference/completion/openai-completion } ------------------------------------------------------------ // TEST[skip:TBD] - diff --git a/libs/core/src/main/java/org/elasticsearch/core/TimeValue.java b/libs/core/src/main/java/org/elasticsearch/core/TimeValue.java index df7c47943289d..26d93bca6b09a 100644 --- a/libs/core/src/main/java/org/elasticsearch/core/TimeValue.java +++ b/libs/core/src/main/java/org/elasticsearch/core/TimeValue.java @@ -88,6 +88,13 @@ public static TimeValue timeValueDays(long days) { return new TimeValue(days, TimeUnit.DAYS); } + /** + * @return the {@link TimeValue} object that has the least duration. + */ + public static TimeValue min(TimeValue time1, TimeValue time2) { + return time1.compareTo(time2) < 0 ? time1 : time2; + } + /** * @return the unit used for the this time value, see {@link #duration()} */ diff --git a/libs/core/src/test/java/org/elasticsearch/common/unit/TimeValueTests.java b/libs/core/src/test/java/org/elasticsearch/common/unit/TimeValueTests.java index b6481db9b9951..dd2755ac1f9f7 100644 --- a/libs/core/src/test/java/org/elasticsearch/common/unit/TimeValueTests.java +++ b/libs/core/src/test/java/org/elasticsearch/common/unit/TimeValueTests.java @@ -17,6 +17,7 @@ import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.object.HasToString.hasToString; @@ -231,6 +232,12 @@ public void testRejectsNegativeValuesAtCreation() { assertThat(ex.getMessage(), containsString("duration cannot be negative")); } + public void testMin() { + assertThat(TimeValue.min(TimeValue.ZERO, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(0))); + assertThat(TimeValue.min(TimeValue.MAX_VALUE, TimeValue.timeValueNanos(1)), is(TimeValue.timeValueNanos(1))); + assertThat(TimeValue.min(TimeValue.MINUS_ONE, TimeValue.timeValueHours(1)), is(TimeValue.MINUS_ONE)); + } + private TimeUnit randomTimeUnitObject() { return randomFrom( TimeUnit.NANOSECONDS, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java index 140c08ceef80f..81bc90433d34a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java @@ -26,6 +26,7 @@ public class CohereActionCreator implements CohereActionVisitor { private final ServiceComponents serviceComponents; public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) { + // TODO Batching - accept a class that can handle batching this.sender = Objects.requireNonNull(sender); this.serviceComponents = Objects.requireNonNull(serviceComponents); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java index 63e51d99a8cee..b4815f8f0d1bf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java @@ -36,6 +36,7 @@ public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, Thread model.getServiceSettings().getCommonSettings().uri(), "Cohere embeddings" ); + // TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java index deff410aebaa8..002fa71b7fb5d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -37,17 +36,16 @@ public AzureAiStudioChatCompletionRequestManager(AzureAiStudioChatCompletionMode } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, input); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } private static ResponseHandler createCompletionHandler() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java index a2b363151a417..ec5ab2fee6a57 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -41,17 +40,16 @@ public AzureAiStudioEmbeddingsRequestManager(AzureAiStudioEmbeddingsModel model, } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); AzureAiStudioEmbeddingsRequest request = new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } private static ResponseHandler createEmbeddingsHandler() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java index 2811155f6f357..5206d6c2c23cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -43,16 +42,15 @@ public AzureOpenAiCompletionRequestManager(AzureOpenAiCompletionModel model, Thr } @Override - public Runnable create( + public void execute( @Nullable String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java index 06152b50822aa..e0fcee30e5af3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -55,16 +54,15 @@ public AzureOpenAiEmbeddingsRequestManager(AzureOpenAiEmbeddingsModel model, Tru } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java index abca0ce0d049b..a015716b81032 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java @@ -38,11 +38,16 @@ public String inferenceEntityId() { @Override public Object rateLimitGrouping() { - return rateLimitGroup; + // It's possible that two inference endpoints have the same information defining the group but have different + // rate limits then they should be in different groups otherwise whoever initially created the group will set + // the rate and the other inference endpoint's rate will be ignored + return new EndpointGrouping(rateLimitGroup, rateLimitSettings); } @Override public RateLimitSettings rateLimitSettings() { return rateLimitSettings; } + + private record EndpointGrouping(Object group, RateLimitSettings settings) {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java index 255d4a3f3879f..8a4b0e45b93fa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -46,16 +45,15 @@ private CohereCompletionRequestManager(CohereCompletionModel model, ThreadPool t } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { CohereCompletionRequest request = new CohereCompletionRequest(input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java index 0bf1c11285adb..a51910f1d0a67 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -44,16 +43,15 @@ private CohereEmbeddingsRequestManager(CohereEmbeddingsModel model, ThreadPool t } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java index 1778663a194e8..1351eec406569 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -44,16 +43,15 @@ private CohereRerankRequestManager(CohereRerankModel model, ThreadPool threadPoo } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { CohereRerankRequest request = new CohereRerankRequest(query, input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableInferenceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableInferenceRequest.java index 53f30773cbfe3..214eba4ee3485 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableInferenceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableInferenceRequest.java @@ -23,7 +23,6 @@ record ExecutableInferenceRequest( RequestSender requestSender, Logger logger, Request request, - HttpClientContext context, ResponseHandler responseHandler, Supplier hasFinished, ActionListener listener @@ -34,7 +33,7 @@ public void run() { var inferenceEntityId = request.createHttpRequest().inferenceEntityId(); try { - requestSender.send(logger, request, context, hasFinished, responseHandler, listener); + requestSender.send(logger, request, HttpClientContext.create(), hasFinished, responseHandler, listener); } catch (Exception e) { var errorMessage = Strings.format("Failed to send request from inference entity id [%s]", inferenceEntityId); logger.warn(errorMessage, e); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java index eb9baa680446a..2b191b046477b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -42,15 +41,14 @@ public GoogleAiStudioCompletionRequestManager(GoogleAiStudioCompletionModel mode } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java index 15c2825e7d043..6436e0231ab48 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -48,17 +47,16 @@ public GoogleAiStudioEmbeddingsRequestManager(GoogleAiStudioEmbeddingsModel mode } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); GoogleAiStudioEmbeddingsRequest request = new GoogleAiStudioEmbeddingsRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java index 21a758a3db248..d1e309a774ab7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java @@ -15,6 +15,8 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.RequestExecutor; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -39,30 +41,28 @@ public static class Factory { private final ServiceComponents serviceComponents; private final HttpClientManager httpClientManager; private final ClusterService clusterService; - private final SingleRequestManager requestManager; + private final RequestSender requestSender; public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) { this.serviceComponents = Objects.requireNonNull(serviceComponents); this.httpClientManager = Objects.requireNonNull(httpClientManager); this.clusterService = Objects.requireNonNull(clusterService); - var requestSender = new RetryingHttpSender( + requestSender = new RetryingHttpSender( this.httpClientManager.getHttpClient(), serviceComponents.throttlerManager(), new RetrySettings(serviceComponents.settings(), clusterService), serviceComponents.threadPool() ); - requestManager = new SingleRequestManager(requestSender); } - public Sender createSender(String serviceName) { + public Sender createSender() { return new HttpRequestSender( - serviceName, serviceComponents.threadPool(), httpClientManager, clusterService, serviceComponents.settings(), - requestManager + requestSender ); } } @@ -71,26 +71,24 @@ public Sender createSender(String serviceName) { private final ThreadPool threadPool; private final HttpClientManager manager; - private final RequestExecutorService service; + private final RequestExecutor service; private final AtomicBoolean started = new AtomicBoolean(false); private final CountDownLatch startCompleted = new CountDownLatch(1); private HttpRequestSender( - String serviceName, ThreadPool threadPool, HttpClientManager httpClientManager, ClusterService clusterService, Settings settings, - SingleRequestManager requestManager + RequestSender requestSender ) { this.threadPool = Objects.requireNonNull(threadPool); this.manager = Objects.requireNonNull(httpClientManager); service = new RequestExecutorService( - serviceName, threadPool, startCompleted, new RequestExecutorServiceSettings(settings, clusterService), - requestManager + requestSender ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java index 7c09e0c67c1c6..6c8fc446d5243 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -55,26 +54,17 @@ private HuggingFaceRequestManager(HuggingFaceModel model, ResponseHandler respon } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getTokenLimit()); var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest( - requestSender, - logger, - request, - context, - responseHandler, - hasRequestCompletedFunction, - listener - ); + execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); } record RateLimitGrouping(int accountHash) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java index 3c711bb79717c..6199a75a41a7d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java @@ -19,9 +19,9 @@ public interface InferenceRequest { /** - * Returns the creator that handles building an executable request based on the input provided. + * Returns the manager that handles building and executing an inference request. */ - RequestManager getRequestCreator(); + RequestManager getRequestManager(); /** * Returns the query associated with this request. Used for Rerank tasks. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java index f31a633581705..ab6a1bfb31372 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -51,18 +50,17 @@ public MistralEmbeddingsRequestManager(MistralEmbeddingsModel model, Truncator t } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } record RateLimitGrouping(int keyHashCode) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/NoopTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/NoopTask.java deleted file mode 100644 index 0355880b3f714..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/NoopTask.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; - -import java.util.List; -import java.util.function.Supplier; - -class NoopTask implements RejectableTask { - - @Override - public RequestManager getRequestCreator() { - return null; - } - - @Override - public String getQuery() { - return null; - } - - @Override - public List getInput() { - return null; - } - - @Override - public ActionListener getListener() { - return null; - } - - @Override - public boolean hasCompleted() { - return true; - } - - @Override - public Supplier getRequestCompletedFunction() { - return () -> true; - } - - @Override - public void onRejection(Exception e) { - - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index 9c6c216c61272..7bc09fd76736b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -43,17 +42,16 @@ private OpenAiCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPo } @Override - public Runnable create( + public void execute( @Nullable String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } private static ResponseHandler createCompletionHandler() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java index 3a0a8fd64a656..41f91d2b89ee5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -55,17 +54,16 @@ private OpenAiEmbeddingsRequestManager(OpenAiEmbeddingsModel model, Truncator tr } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index d5a13c2e0771d..38d47aec68eb6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -17,21 +16,31 @@ import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue; +import org.elasticsearch.xpack.inference.common.RateLimiter; import org.elasticsearch.xpack.inference.external.http.RequestExecutor; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import java.time.Clock; +import java.time.Instant; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; /** * A service for queuing and executing {@link RequestTask}. This class is useful because the @@ -45,7 +54,18 @@ * {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info. */ class RequestExecutorService implements RequestExecutor { - private static final AdjustableCapacityBlockingQueue.QueueCreator QUEUE_CREATOR = + + /** + * Provides dependency injection mainly for testing + */ + interface Sleeper { + void sleep(TimeValue sleepTime) throws InterruptedException; + } + + // default for tests + static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration()); + // default for tests + static final AdjustableCapacityBlockingQueue.QueueCreator DEFAULT_QUEUE_CREATOR = new AdjustableCapacityBlockingQueue.QueueCreator<>() { @Override public BlockingQueue create(int capacity) { @@ -65,86 +85,116 @@ public BlockingQueue create() { } }; + /** + * Provides dependency injection mainly for testing + */ + interface RateLimiterCreator { + RateLimiter create(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit); + } + + // default for testing + static final RateLimiterCreator DEFAULT_RATE_LIMIT_CREATOR = RateLimiter::new; private static final Logger logger = LogManager.getLogger(RequestExecutorService.class); - private final String serviceName; - private final AdjustableCapacityBlockingQueue queue; - private final AtomicBoolean running = new AtomicBoolean(true); - private final CountDownLatch terminationLatch = new CountDownLatch(1); - private final HttpClientContext httpContext; + private static final TimeValue RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1); + + private final ConcurrentMap rateLimitGroupings = new ConcurrentHashMap<>(); private final ThreadPool threadPool; private final CountDownLatch startupLatch; - private final BlockingQueue controlQueue = new LinkedBlockingQueue<>(); - private final SingleRequestManager requestManager; + private final CountDownLatch terminationLatch = new CountDownLatch(1); + private final RequestSender requestSender; + private final RequestExecutorServiceSettings settings; + private final Clock clock; + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final AdjustableCapacityBlockingQueue.QueueCreator queueCreator; + private final Sleeper sleeper; + private final RateLimiterCreator rateLimiterCreator; + private final AtomicReference cancellableCleanupTask = new AtomicReference<>(); + private final AtomicBoolean started = new AtomicBoolean(false); RequestExecutorService( - String serviceName, ThreadPool threadPool, @Nullable CountDownLatch startupLatch, RequestExecutorServiceSettings settings, - SingleRequestManager requestManager + RequestSender requestSender ) { - this(serviceName, threadPool, QUEUE_CREATOR, startupLatch, settings, requestManager); + this( + threadPool, + DEFAULT_QUEUE_CREATOR, + startupLatch, + settings, + requestSender, + Clock.systemUTC(), + DEFAULT_SLEEPER, + DEFAULT_RATE_LIMIT_CREATOR + ); } - /** - * This constructor should only be used directly for testing. - */ RequestExecutorService( - String serviceName, ThreadPool threadPool, - AdjustableCapacityBlockingQueue.QueueCreator createQueue, + AdjustableCapacityBlockingQueue.QueueCreator queueCreator, @Nullable CountDownLatch startupLatch, RequestExecutorServiceSettings settings, - SingleRequestManager requestManager + RequestSender requestSender, + Clock clock, + Sleeper sleeper, + RateLimiterCreator rateLimiterCreator ) { - this.serviceName = Objects.requireNonNull(serviceName); this.threadPool = Objects.requireNonNull(threadPool); - this.httpContext = HttpClientContext.create(); - this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity()); + this.queueCreator = Objects.requireNonNull(queueCreator); this.startupLatch = startupLatch; - this.requestManager = Objects.requireNonNull(requestManager); + this.requestSender = Objects.requireNonNull(requestSender); + this.settings = Objects.requireNonNull(settings); + this.clock = Objects.requireNonNull(clock); + this.sleeper = Objects.requireNonNull(sleeper); + this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator); + } - Objects.requireNonNull(settings); - settings.registerQueueCapacityCallback(this::onCapacityChange); + public void shutdown() { + if (shutdown.compareAndSet(false, true)) { + if (cancellableCleanupTask.get() != null) { + logger.debug(() -> "Stopping clean up thread"); + cancellableCleanupTask.get().cancel(); + } + } } - private void onCapacityChange(int capacity) { - logger.debug(() -> Strings.format("Setting queue capacity to [%s]", capacity)); + public boolean isShutdown() { + return shutdown.get(); + } - var enqueuedCapacityCommand = controlQueue.offer(() -> updateCapacity(capacity)); - if (enqueuedCapacityCommand == false) { - logger.warn("Failed to change request batching service queue capacity. Control queue was full, please try again later."); - } else { - // ensure that the task execution loop wakes up - queue.offer(new NoopTask()); - } + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return terminationLatch.await(timeout, unit); } - private void updateCapacity(int newCapacity) { - try { - queue.setCapacity(newCapacity); - } catch (Exception e) { - logger.warn( - format("Failed to set the capacity of the task queue to [%s] for request batching service [%s]", newCapacity, serviceName), - e - ); - } + public boolean isTerminated() { + return terminationLatch.getCount() == 0; + } + + public int queueSize() { + return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum(); } /** * Begin servicing tasks. + *

+ * Note: This should only be called once for the life of the object. + *

*/ public void start() { try { + assert started.get() == false : "start() can only be called once"; + started.set(true); + + startCleanupTask(); signalStartInitiated(); - while (running.get()) { + while (isShutdown() == false) { handleTasks(); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); } finally { - running.set(false); + shutdown(); notifyRequestsOfShutdown(); terminationLatch.countDown(); } @@ -156,108 +206,68 @@ private void signalStartInitiated() { } } - /** - * Protects the task retrieval logic from an unexpected exception. - * - * @throws InterruptedException rethrows the exception if it occurred retrieving a task because the thread is likely attempting to - * shut down - */ - private void handleTasks() throws InterruptedException { - try { - RejectableTask task = queue.take(); + private void startCleanupTask() { + assert cancellableCleanupTask.get() == null : "The clean up task can only be set once"; + cancellableCleanupTask.set(startCleanupThread(RATE_LIMIT_GROUP_CLEANUP_INTERVAL)); + } - var command = controlQueue.poll(); - if (command != null) { - command.run(); - } + private Scheduler.Cancellable startCleanupThread(TimeValue interval) { + logger.debug(() -> Strings.format("Clean up task scheduled with interval [%s]", interval)); - // TODO add logic to complete pending items in the queue before shutting down - if (running.get() == false) { - logger.debug(() -> format("Http executor service [%s] exiting", serviceName)); - rejectTaskBecauseOfShutdown(task); - } else { - executeTask(task); - } - } catch (InterruptedException e) { - throw e; - } catch (Exception e) { - logger.warn(format("Http executor service [%s] failed while retrieving task for execution", serviceName), e); - } + return threadPool.scheduleWithFixedDelay(this::removeStaleGroupings, interval, threadPool.executor(UTILITY_THREAD_POOL_NAME)); } - private void executeTask(RejectableTask task) { - try { - requestManager.execute(task, httpContext); - } catch (Exception e) { - logger.warn(format("Http executor service [%s] failed to execute request [%s]", serviceName, task), e); + // default for testing + void removeStaleGroupings() { + var now = Instant.now(clock); + for (var iter = rateLimitGroupings.values().iterator(); iter.hasNext();) { + var endpoint = iter.next(); + + // if the current time is after the last time the endpoint enqueued a request + allowed stale period then we'll remove it + if (now.isAfter(endpoint.timeOfLastEnqueue().plus(settings.getRateLimitGroupStaleDuration()))) { + endpoint.close(); + iter.remove(); + } } } - private synchronized void notifyRequestsOfShutdown() { - assert isShutdown() : "Requests should only be notified if the executor is shutting down"; - - try { - List notExecuted = new ArrayList<>(); - queue.drainTo(notExecuted); - - rejectTasks(notExecuted, this::rejectTaskBecauseOfShutdown); - } catch (Exception e) { - logger.warn(format("Failed to notify tasks of queuing service [%s] shutdown", serviceName)); + private void handleTasks() throws InterruptedException { + var timeToWait = settings.getTaskPollFrequency(); + for (var endpoint : rateLimitGroupings.values()) { + timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait); } - } - private void rejectTaskBecauseOfShutdown(RejectableTask task) { - try { - task.onRejection( - new EsRejectedExecutionException( - format("Failed to send request, queue service [%s] has shutdown prior to executing request", serviceName), - true - ) - ); - } catch (Exception e) { - logger.warn( - format("Failed to notify request [%s] for service [%s] of rejection after queuing service shutdown", task, serviceName) - ); - } + sleeper.sleep(timeToWait); } - private void rejectTasks(List tasks, Consumer rejectionFunction) { - for (var task : tasks) { - rejectionFunction.accept(task); + private void notifyRequestsOfShutdown() { + assert isShutdown() : "Requests should only be notified if the executor is shutting down"; + + for (var endpoint : rateLimitGroupings.values()) { + endpoint.notifyRequestsOfShutdown(); } } - public int queueSize() { - return queue.size(); - } + // default for testing + Integer remainingQueueCapacity(RequestManager requestManager) { + var endpoint = rateLimitGroupings.get(requestManager.rateLimitGrouping()); - @Override - public void shutdown() { - if (running.compareAndSet(true, false)) { - // if this fails because the queue is full, that's ok, we just want to ensure that queue.take() returns - queue.offer(new NoopTask()); + if (endpoint == null) { + return null; } - } - @Override - public boolean isShutdown() { - return running.get() == false; - } - - @Override - public boolean isTerminated() { - return terminationLatch.getCount() == 0; + return endpoint.remainingCapacity(); } - @Override - public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { - return terminationLatch.await(timeout, unit); + // default for testing + int numberOfRateLimitGroups() { + return rateLimitGroupings.size(); } /** * Execute the request at some point in the future. * - * @param requestCreator the http request to send + * @param requestManager the http request to send * @param inferenceInputs the inputs to send in the request * @param timeout the maximum time to wait for this request to complete (failing or succeeding). Once the time elapses, the * listener::onFailure is called with a {@link org.elasticsearch.ElasticsearchTimeoutException}. @@ -265,13 +275,13 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE * @param listener an {@link ActionListener} for the response or failure */ public void execute( - RequestManager requestCreator, + RequestManager requestManager, InferenceInputs inferenceInputs, @Nullable TimeValue timeout, ActionListener listener ) { var task = new RequestTask( - requestCreator, + requestManager, inferenceInputs, timeout, threadPool, @@ -280,38 +290,230 @@ public void execute( ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()) ); - completeExecution(task); + var endpoint = rateLimitGroupings.computeIfAbsent(requestManager.rateLimitGrouping(), key -> { + var endpointHandler = new RateLimitingEndpointHandler( + Integer.toString(requestManager.rateLimitGrouping().hashCode()), + queueCreator, + settings, + requestSender, + clock, + requestManager.rateLimitSettings(), + this::isShutdown, + rateLimiterCreator + ); + + endpointHandler.init(); + return endpointHandler; + }); + + endpoint.enqueue(task); } - private void completeExecution(RequestTask task) { - if (isShutdown()) { - EsRejectedExecutionException rejected = new EsRejectedExecutionException( - format("Failed to enqueue task because the http executor service [%s] has already shutdown", serviceName), - true + /** + * Provides rate limiting functionality for requests. A single {@link RateLimitingEndpointHandler} governs a group of requests. + * This allows many requests to be serialized if they are being sent too fast. If the rate limit has not been met they will be sent + * as soon as a thread is available. + */ + private static class RateLimitingEndpointHandler { + + private static final TimeValue NO_TASKS_AVAILABLE = TimeValue.MAX_VALUE; + private static final TimeValue EXECUTED_A_TASK = TimeValue.ZERO; + private static final Logger logger = LogManager.getLogger(RateLimitingEndpointHandler.class); + private static final int ACCUMULATED_TOKENS_LIMIT = 1; + + private final AdjustableCapacityBlockingQueue queue; + private final Supplier isShutdownMethod; + private final RequestSender requestSender; + private final String id; + private final AtomicReference timeOfLastEnqueue = new AtomicReference<>(); + private final Clock clock; + private final RateLimiter rateLimiter; + private final RequestExecutorServiceSettings requestExecutorServiceSettings; + + RateLimitingEndpointHandler( + String id, + AdjustableCapacityBlockingQueue.QueueCreator createQueue, + RequestExecutorServiceSettings settings, + RequestSender requestSender, + Clock clock, + RateLimitSettings rateLimitSettings, + Supplier isShutdownMethod, + RateLimiterCreator rateLimiterCreator + ) { + this.requestExecutorServiceSettings = Objects.requireNonNull(settings); + this.id = Objects.requireNonNull(id); + this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity()); + this.requestSender = Objects.requireNonNull(requestSender); + this.clock = Objects.requireNonNull(clock); + this.isShutdownMethod = Objects.requireNonNull(isShutdownMethod); + + Objects.requireNonNull(rateLimitSettings); + Objects.requireNonNull(rateLimiterCreator); + rateLimiter = rateLimiterCreator.create( + ACCUMULATED_TOKENS_LIMIT, + rateLimitSettings.requestsPerTimeUnit(), + rateLimitSettings.timeUnit() ); - task.onRejection(rejected); - return; } - boolean added = queue.offer(task); - if (added == false) { - EsRejectedExecutionException rejected = new EsRejectedExecutionException( - format("Failed to execute task because the http executor service [%s] queue is full", serviceName), - false - ); + public void init() { + requestExecutorServiceSettings.registerQueueCapacityCallback(id, this::onCapacityChange); + } - task.onRejection(rejected); - } else if (isShutdown()) { - // It is possible that a shutdown and notification request occurred after we initially checked for shutdown above - // If the task was added after the queue was already drained it could sit there indefinitely. So let's check again if - // we shut down and if so we'll redo the notification - notifyRequestsOfShutdown(); + private void onCapacityChange(int capacity) { + logger.debug(() -> Strings.format("Executor service grouping [%s] setting queue capacity to [%s]", id, capacity)); + + try { + queue.setCapacity(capacity); + } catch (Exception e) { + logger.warn(format("Executor service grouping [%s] failed to set the capacity of the task queue to [%s]", id, capacity), e); + } } - } - // default for testing - int remainingQueueCapacity() { - return queue.remainingCapacity(); + public int queueSize() { + return queue.size(); + } + + public boolean isShutdown() { + return isShutdownMethod.get(); + } + + public Instant timeOfLastEnqueue() { + return timeOfLastEnqueue.get(); + } + + public synchronized TimeValue executeEnqueuedTask() { + try { + return executeEnqueuedTaskInternal(); + } catch (Exception e) { + logger.warn(format("Executor service grouping [%s] failed to execute request", id), e); + // we tried to do some work but failed, so we'll say we did something to try looking for more work + return EXECUTED_A_TASK; + } + } + + private TimeValue executeEnqueuedTaskInternal() { + var timeBeforeAvailableToken = rateLimiter.timeToReserve(1); + if (shouldExecuteImmediately(timeBeforeAvailableToken) == false) { + return timeBeforeAvailableToken; + } + + var task = queue.poll(); + + // TODO Batching - in a situation where no new tasks are queued we'll want to execute any prepared tasks + // So we'll need to check for null and call a helper method executePreparedTasks() + + if (shouldExecuteTask(task) == false) { + return NO_TASKS_AVAILABLE; + } + + // We should never have to wait because we checked above + var reserveRes = rateLimiter.reserve(1); + assert shouldExecuteImmediately(reserveRes) : "Reserving request tokens required a sleep when it should not have"; + + task.getRequestManager() + .execute(task.getQuery(), task.getInput(), requestSender, task.getRequestCompletedFunction(), task.getListener()); + return EXECUTED_A_TASK; + } + + private static boolean shouldExecuteTask(RejectableTask task) { + return task != null && isNoopRequest(task) == false && task.hasCompleted() == false; + } + + private static boolean isNoopRequest(InferenceRequest inferenceRequest) { + return inferenceRequest.getRequestManager() == null + || inferenceRequest.getInput() == null + || inferenceRequest.getListener() == null; + } + + private static boolean shouldExecuteImmediately(TimeValue delay) { + return delay.duration() == 0; + } + + public void enqueue(RequestTask task) { + timeOfLastEnqueue.set(Instant.now(clock)); + + if (isShutdown()) { + EsRejectedExecutionException rejected = new EsRejectedExecutionException( + format( + "Failed to enqueue task for inference id [%s] because the request service [%s] has already shutdown", + task.getRequestManager().inferenceEntityId(), + id + ), + true + ); + + task.onRejection(rejected); + return; + } + + var addedToQueue = queue.offer(task); + + if (addedToQueue == false) { + EsRejectedExecutionException rejected = new EsRejectedExecutionException( + format( + "Failed to execute task for inference id [%s] because the request service [%s] queue is full", + task.getRequestManager().inferenceEntityId(), + id + ), + false + ); + + task.onRejection(rejected); + } else if (isShutdown()) { + notifyRequestsOfShutdown(); + } + } + + public synchronized void notifyRequestsOfShutdown() { + assert isShutdown() : "Requests should only be notified if the executor is shutting down"; + + try { + List notExecuted = new ArrayList<>(); + queue.drainTo(notExecuted); + + rejectTasks(notExecuted); + } catch (Exception e) { + logger.warn(format("Failed to notify tasks of executor service grouping [%s] shutdown", id)); + } + } + + private void rejectTasks(List tasks) { + for (var task : tasks) { + rejectTaskForShutdown(task); + } + } + + private void rejectTaskForShutdown(RejectableTask task) { + try { + task.onRejection( + new EsRejectedExecutionException( + format( + "Failed to send request, request service [%s] for inference id [%s] has shutdown prior to executing request", + id, + task.getRequestManager().inferenceEntityId() + ), + true + ) + ); + } catch (Exception e) { + logger.warn( + format( + "Failed to notify request for inference id [%s] of rejection after executor service grouping [%s] shutdown", + task.getRequestManager().inferenceEntityId(), + id + ) + ); + } + } + + public int remainingCapacity() { + return queue.remainingCapacity(); + } + + public void close() { + requestExecutorServiceSettings.deregisterQueueCapacityCallback(id); + } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettings.java index 86825035f2d05..616ef7a40068b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettings.java @@ -10,9 +10,12 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; -import java.util.ArrayList; +import java.time.Duration; import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.function.Consumer; public class RequestExecutorServiceSettings { @@ -29,37 +32,108 @@ public class RequestExecutorServiceSettings { Setting.Property.Dynamic ); + private static final TimeValue DEFAULT_TASK_POLL_FREQUENCY_TIME = TimeValue.timeValueMillis(50); + /** + * Defines how often all the rate limit groups are polled for tasks. Setting this to very low number could result + * in a busy loop if there are no tasks available to handle. + */ + static final Setting TASK_POLL_FREQUENCY_SETTING = Setting.timeSetting( + "xpack.inference.http.request_executor.task_poll_frequency", + DEFAULT_TASK_POLL_FREQUENCY_TIME, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + private static final TimeValue DEFAULT_RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1); + /** + * Defines how often a thread will check for rate limit groups that are stale. + */ + static final Setting RATE_LIMIT_GROUP_CLEANUP_INTERVAL_SETTING = Setting.timeSetting( + "xpack.inference.http.request_executor.rate_limit_group_cleanup_interval", + DEFAULT_RATE_LIMIT_GROUP_CLEANUP_INTERVAL, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + private static final TimeValue DEFAULT_RATE_LIMIT_GROUP_STALE_DURATION = TimeValue.timeValueDays(10); + /** + * Defines the amount of time it takes to classify a rate limit group as stale. Once it is classified as stale, + * it can be removed when the cleanup thread executes. + */ + static final Setting RATE_LIMIT_GROUP_STALE_DURATION_SETTING = Setting.timeSetting( + "xpack.inference.http.request_executor.rate_limit_group_stale_duration", + DEFAULT_RATE_LIMIT_GROUP_STALE_DURATION, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + public static List> getSettingsDefinitions() { - return List.of(TASK_QUEUE_CAPACITY_SETTING); + return List.of( + TASK_QUEUE_CAPACITY_SETTING, + TASK_POLL_FREQUENCY_SETTING, + RATE_LIMIT_GROUP_CLEANUP_INTERVAL_SETTING, + RATE_LIMIT_GROUP_STALE_DURATION_SETTING + ); } private volatile int queueCapacity; - private final List> queueCapacityCallbacks = new ArrayList>(); + private volatile TimeValue taskPollFrequency; + private volatile Duration rateLimitGroupStaleDuration; + private final ConcurrentMap> queueCapacityCallbacks = new ConcurrentHashMap<>(); public RequestExecutorServiceSettings(Settings settings, ClusterService clusterService) { queueCapacity = TASK_QUEUE_CAPACITY_SETTING.get(settings); + taskPollFrequency = TASK_POLL_FREQUENCY_SETTING.get(settings); + setRateLimitGroupStaleDuration(RATE_LIMIT_GROUP_STALE_DURATION_SETTING.get(settings)); addSettingsUpdateConsumers(clusterService); } private void addSettingsUpdateConsumers(ClusterService clusterService) { clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_QUEUE_CAPACITY_SETTING, this::setQueueCapacity); + clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_POLL_FREQUENCY_SETTING, this::setTaskPollFrequency); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(RATE_LIMIT_GROUP_STALE_DURATION_SETTING, this::setRateLimitGroupStaleDuration); } // default for testing void setQueueCapacity(int queueCapacity) { this.queueCapacity = queueCapacity; - for (var callback : queueCapacityCallbacks) { + for (var callback : queueCapacityCallbacks.values()) { callback.accept(queueCapacity); } } - void registerQueueCapacityCallback(Consumer onChangeCapacityCallback) { - queueCapacityCallbacks.add(onChangeCapacityCallback); + private void setTaskPollFrequency(TimeValue taskPollFrequency) { + this.taskPollFrequency = taskPollFrequency; + } + + private void setRateLimitGroupStaleDuration(TimeValue staleDuration) { + rateLimitGroupStaleDuration = toDuration(staleDuration); + } + + private static Duration toDuration(TimeValue timeValue) { + return Duration.of(timeValue.duration(), timeValue.timeUnit().toChronoUnit()); + } + + void registerQueueCapacityCallback(String id, Consumer onChangeCapacityCallback) { + queueCapacityCallbacks.put(id, onChangeCapacityCallback); + } + + void deregisterQueueCapacityCallback(String id) { + queueCapacityCallbacks.remove(id); } int getQueueCapacity() { return queueCapacity; } + + TimeValue getTaskPollFrequency() { + return taskPollFrequency; + } + + Duration getRateLimitGroupStaleDuration() { + return rateLimitGroupStaleDuration; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java index 7d3cca596f1d0..79ef1b56ad231 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; @@ -21,14 +20,17 @@ * A contract for constructing a {@link Runnable} to handle sending an inference request to a 3rd party service. */ public interface RequestManager extends RateLimitable { - Runnable create( + void execute( @Nullable String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ); + // TODO For batching we'll add 2 new method: prepare(query, input, ...) which will allow the individual + // managers to implement their own batching + // executePreparedRequest() which will execute all prepared requests aka sends the batch + String inferenceEntityId(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java index 738592464232c..7a5f482412289 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java @@ -111,7 +111,7 @@ public void onRejection(Exception e) { } @Override - public RequestManager getRequestCreator() { + public RequestManager getRequestManager() { return requestCreator; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManager.java deleted file mode 100644 index 494c77964080f..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManager.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -import org.apache.http.client.protocol.HttpClientContext; -import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; - -import java.util.Objects; - -/** - * Handles executing a single inference request at a time. - */ -public class SingleRequestManager { - - protected RetryingHttpSender requestSender; - - public SingleRequestManager(RetryingHttpSender requestSender) { - this.requestSender = Objects.requireNonNull(requestSender); - } - - public void execute(InferenceRequest inferenceRequest, HttpClientContext context) { - if (isNoopRequest(inferenceRequest) || inferenceRequest.hasCompleted()) { - return; - } - - inferenceRequest.getRequestCreator() - .create( - inferenceRequest.getQuery(), - inferenceRequest.getInput(), - requestSender, - inferenceRequest.getRequestCompletedFunction(), - context, - inferenceRequest.getListener() - ) - .run(); - } - - private static boolean isNoopRequest(InferenceRequest inferenceRequest) { - return inferenceRequest.getRequestCreator() == null - || inferenceRequest.getInput() == null - || inferenceRequest.getListener() == null; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 24c0ab2cd893e..1c64f505402d8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -31,7 +31,7 @@ public abstract class SenderService implements InferenceService { public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { Objects.requireNonNull(factory); - sender = factory.createSender(name()); + sender = factory.createSender(); this.serviceComponents = Objects.requireNonNull(serviceComponents); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index c488eac422401..f30773962854a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -56,7 +56,7 @@ public class AzureAiStudioService extends SenderService { - private static final String NAME = "azureaistudio"; + static final String NAME = "azureaistudio"; public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceSettings.java index 10c57e19b6403..03034ae70c2b6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceSettings.java @@ -44,7 +44,13 @@ protected static BaseAzureAiStudioCommonFields fromMap( ConfigurationParseContext context ) { String target = extractRequiredString(map, TARGET_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + AzureAiStudioService.NAME, + context + ); AzureAiStudioEndpointType endpointType = extractRequiredEnum( map, ENDPOINT_TYPE_FIELD, @@ -118,13 +124,13 @@ public void writeTo(StreamOutput out) throws IOException { protected void addXContentFields(XContentBuilder builder, Params params) throws IOException { this.addExposedXContentFields(builder, params); - rateLimitSettings.toXContent(builder, params); } protected void addExposedXContentFields(XContentBuilder builder, Params params) throws IOException { builder.field(TARGET_FIELD, this.target); builder.field(PROVIDER_FIELD, this.provider); builder.field(ENDPOINT_TYPE_FIELD, this.endpointType); + rateLimitSettings.toXContent(builder, params); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index e0e48ab20a86b..26bf6f1648d97 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -135,7 +135,15 @@ private static AzureOpenAiModel createModel( ); } case COMPLETION -> { - return new AzureOpenAiCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings); + return new AzureOpenAiCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); } default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java index 05cb663453542..c4146b2ba2d30 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor; import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; @@ -37,13 +38,14 @@ public AzureOpenAiCompletionModel( String service, Map serviceSettings, Map taskSettings, - @Nullable Map secrets + @Nullable Map secrets, + ConfigurationParseContext context ) { this( inferenceEntityId, taskType, service, - AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings), + AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings, context), AzureOpenAiCompletionTaskSettings.fromMap(taskSettings), AzureOpenAiSecretSettings.fromMap(secrets) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettings.java index ba503b2bbdc4b..92dc461d9008c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettings.java @@ -17,7 +17,9 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -55,10 +57,10 @@ public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject */ private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120); - public static AzureOpenAiCompletionServiceSettings fromMap(Map map) { + public static AzureOpenAiCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); - var settings = fromMap(map, validationException); + var settings = fromMap(map, validationException, context); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -69,12 +71,19 @@ public static AzureOpenAiCompletionServiceSettings fromMap(Map m private static AzureOpenAiCompletionServiceSettings.CommonFields fromMap( Map map, - ValidationException validationException + ValidationException validationException, + ConfigurationParseContext context ) { String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException); String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + AzureOpenAiService.NAME, + context + ); return new AzureOpenAiCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings); } @@ -137,7 +146,6 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; @@ -148,6 +156,7 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil builder.field(RESOURCE_NAME, resourceName); builder.field(DEPLOYMENT_ID, deploymentId); builder.field(API_VERSION, apiVersion); + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java index 33bb0fdb07c58..1c426815a83c0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -90,7 +91,13 @@ private static CommonFields fromMap( Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + AzureOpenAiService.NAME, + context + ); Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException); @@ -245,8 +252,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - - rateLimitSettings.toXContent(builder, params); builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); builder.endObject(); @@ -268,6 +273,7 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil if (similarity != null) { builder.field(SIMILARITY, similarity); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 11dbf673ab7bd..4c673026d7efb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -51,6 +51,11 @@ public class CohereService extends SenderService { public static final String NAME = "cohere"; + // TODO Batching - We'll instantiate a batching class within the services that want to support it and pass it through to + // the Cohere*RequestManager via the CohereActionCreator class + // The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated + // on every request + public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); } @@ -131,7 +136,15 @@ private static CohereModel createModel( context ); case RERANK -> new CohereRerankModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); - case COMPLETION -> new CohereCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings); + case COMPLETION -> new CohereCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java index b23f6f188d8c5..d477a8c5a5f55 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java @@ -58,7 +58,13 @@ public static CohereServiceSettings fromMap(Map map, Configurati Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); String oldModelId = extractOptionalString(map, OLD_MODEL_ID_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + CohereService.NAME, + context + ); String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); @@ -173,10 +179,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException { - toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); - - return builder; + return toXContentFragmentOfExposedFields(builder, params); } @Override @@ -196,6 +199,7 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder if (modelId != null) { builder.field(MODEL_ID, modelId); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java index 761081d4d723c..bec4f5a0b5c85 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -30,13 +31,14 @@ public CohereCompletionModel( String service, Map serviceSettings, Map taskSettings, - @Nullable Map secrets + @Nullable Map secrets, + ConfigurationParseContext context ) { this( modelId, taskType, service, - CohereCompletionServiceSettings.fromMap(serviceSettings), + CohereCompletionServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, DefaultSecretSettings.fromMap(secrets) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java index 2a22f6333f1a2..ba9e81b461f9f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java @@ -16,7 +16,9 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.CohereService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -39,12 +41,18 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl // 10K requests per minute private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); - public static CohereCompletionServiceSettings fromMap(Map map) { + public static CohereCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + CohereService.NAME, + context + ); String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); if (validationException.validationErrors().isEmpty() == false) { @@ -94,7 +102,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; @@ -127,6 +134,7 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil if (modelId != null) { builder.field(MODEL_ID, modelId); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index f8720448b0f4f..cfa8566495143 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -108,7 +108,8 @@ private static GoogleAiStudioModel createModel( NAME, serviceSettings, taskSettings, - secretSettings + secretSettings, + context ); case TEXT_EMBEDDING -> new GoogleAiStudioEmbeddingsModel( inferenceEntityId, @@ -116,7 +117,8 @@ private static GoogleAiStudioModel createModel( NAME, serviceSettings, taskSettings, - secretSettings + secretSettings, + context ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java index eafb0c372202c..8fa2ac0148716 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor; import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -37,13 +38,14 @@ public GoogleAiStudioCompletionModel( String service, Map serviceSettings, Map taskSettings, - Map secrets + Map secrets, + ConfigurationParseContext context ) { this( inferenceEntityId, taskType, service, - GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings), + GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, DefaultSecretSettings.fromMap(secrets) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java index f8f343be8eb4c..7c0b812ee213b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java @@ -16,7 +16,9 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -40,11 +42,17 @@ public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObj */ private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360); - public static GoogleAiStudioCompletionServiceSettings fromMap(Map map) { + public static GoogleAiStudioCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + GoogleAiStudioService.NAME, + context + ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -82,7 +90,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; @@ -107,6 +114,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { builder.field(MODEL_ID, modelId); + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java index ad106797de51b..af19e26f3e97a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor; import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -37,13 +38,14 @@ public GoogleAiStudioEmbeddingsModel( String service, Map serviceSettings, Map taskSettings, - Map secrets + Map secrets, + ConfigurationParseContext context ) { this( inferenceEntityId, taskType, service, - GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings), + GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, DefaultSecretSettings.fromMap(secrets) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettings.java index 07d07dc533f06..7608f48d0638d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettings.java @@ -18,7 +18,9 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -47,7 +49,7 @@ public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObj */ private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360); - public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map map) { + public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); @@ -59,7 +61,13 @@ public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - return createModel(inferenceEntityId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(inferenceEntityId, name())); + return createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, name()), + ConfigurationParseContext.PERSISTENT + ); } protected abstract HuggingFaceModel createModel( @@ -105,7 +115,8 @@ protected abstract HuggingFaceModel createModel( TaskType taskType, Map serviceSettings, Map secretSettings, - String failureMessage + String failureMessage, + ConfigurationParseContext context ); @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index d8c383d2b4a67..c0438b3759a65 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; @@ -36,11 +37,19 @@ protected HuggingFaceModel createModel( TaskType taskType, Map serviceSettings, @Nullable Map secretSettings, - String failureMessage + String failureMessage, + ConfigurationParseContext context ) { return switch (taskType) { - case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings); - case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings); + case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + secretSettings, + context + ); + case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java index af2c433663ac4..fc31b1e518dd9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -43,14 +44,20 @@ public class HuggingFaceServiceSettings extends FilteredXContentObject implement // 3000 requests per minute private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); - public static HuggingFaceServiceSettings fromMap(Map map) { + public static HuggingFaceServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); var uri = extractUri(map, URL, validationException); SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + HuggingFaceService.NAME, + context + ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -119,7 +126,6 @@ public HuggingFaceServiceSettings(StreamInput in) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; } @@ -136,6 +142,7 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil if (maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, maxInputTokens); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModel.java index 9010571ea2e55..8132089d8dc99 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModel.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -24,13 +25,14 @@ public HuggingFaceElserModel( TaskType taskType, String service, Map serviceSettings, - @Nullable Map secrets + @Nullable Map secrets, + ConfigurationParseContext context ) { this( inferenceEntityId, taskType, service, - HuggingFaceElserServiceSettings.fromMap(serviceSettings), + HuggingFaceElserServiceSettings.fromMap(serviceSettings, context), DefaultSecretSettings.fromMap(secrets) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 2587b2737e164..d3099e96ee7c1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -14,6 +14,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; @@ -38,10 +39,11 @@ protected HuggingFaceModel createModel( TaskType taskType, Map serviceSettings, @Nullable Map secretSettings, - String failureMessage + String failureMessage, + ConfigurationParseContext context ) { return switch (taskType) { - case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings); + case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java index 1f337de450ef9..8b4bd61649de0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java @@ -15,7 +15,9 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -40,10 +42,16 @@ public class HuggingFaceElserServiceSettings extends FilteredXContentObject // 3000 requests per minute private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); - public static HuggingFaceElserServiceSettings fromMap(Map map) { + public static HuggingFaceElserServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); var uri = extractUri(map, URL, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + HuggingFaceService.NAME, + context + ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -93,7 +101,6 @@ public int maxInputTokens() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; @@ -103,6 +110,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { builder.field(URL, uri.toString()); builder.field(MAX_INPUT_TOKENS, ELSER_TOKEN_LIMIT); + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java index 1cee26558b490..fedd6380d035f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -25,13 +26,14 @@ public HuggingFaceEmbeddingsModel( TaskType taskType, String service, Map serviceSettings, - @Nullable Map secrets + @Nullable Map secrets, + ConfigurationParseContext context ) { this( inferenceEntityId, taskType, service, - HuggingFaceServiceSettings.fromMap(serviceSettings), + HuggingFaceServiceSettings.fromMap(serviceSettings, context), DefaultSecretSettings.fromMap(secrets) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java index d2ea8ccbd18bd..62d06a4e0029c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -59,7 +60,13 @@ public static MistralEmbeddingsServiceSettings fromMap(Map map, ModelConfigurations.SERVICE_SETTINGS, validationException ); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + MistralService.NAME, + context + ); Integer dims = removeAsType(map, DIMENSIONS, Integer.class); if (validationException.validationErrors().isEmpty() == false) { @@ -141,7 +148,6 @@ public void writeTo(StreamOutput out) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); this.toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; } @@ -159,6 +165,7 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil if (this.maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, this.maxInputTokens); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 84dfac8903678..04b6ae94d6b53 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -138,7 +138,8 @@ private static OpenAiModel createModel( NAME, serviceSettings, taskSettings, - secretSettings + secretSettings, + context ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java index b1b670c0911f5..7ca93684bc680 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.openai.OpenAiModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -35,13 +36,14 @@ public OpenAiChatCompletionModel( String service, Map serviceSettings, Map taskSettings, - @Nullable Map secrets + @Nullable Map secrets, + ConfigurationParseContext context ) { this( inferenceEntityId, taskType, service, - OpenAiChatCompletionServiceSettings.fromMap(serviceSettings), + OpenAiChatCompletionServiceSettings.fromMap(serviceSettings, context), OpenAiChatCompletionTaskSettings.fromMap(taskSettings), DefaultSecretSettings.fromMap(secrets) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java index 5105bb59e048f..04f77da1b1463 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettings.java @@ -16,7 +16,9 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -47,7 +49,7 @@ public class OpenAiChatCompletionServiceSettings extends FilteredXContentObject // 500 requests per minute private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500); - public static OpenAiChatCompletionServiceSettings fromMap(Map map) { + public static OpenAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); @@ -58,7 +60,13 @@ public static OpenAiChatCompletionServiceSettings fromMap(Map ma Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + OpenAiService.NAME, + context + ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -142,7 +150,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); builder.endObject(); return builder; @@ -163,6 +170,7 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil if (maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, maxInputTokens); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java index fc479009d3334..080251bf1ba3a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java @@ -20,6 +20,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.openai.OpenAiRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -66,7 +67,7 @@ private static OpenAiEmbeddingsServiceSettings fromPersistentMap(Map map) { ValidationException validationException = new ValidationException(); - var commonFields = fromMap(map, validationException); + var commonFields = fromMap(map, validationException, ConfigurationParseContext.REQUEST); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -89,7 +90,11 @@ private static OpenAiEmbeddingsServiceSettings fromRequestMap(Map map, ValidationException validationException) { + private static CommonFields fromMap( + Map map, + ValidationException validationException, + ConfigurationParseContext context + ) { String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); String organizationId = extractOptionalString(map, ORGANIZATION, ModelConfigurations.SERVICE_SETTINGS, validationException); @@ -98,7 +103,13 @@ private static CommonFields fromMap(Map map, ValidationException Integer dims = removeAsType(map, DIMENSIONS, Integer.class); URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + OpenAiService.NAME, + context + ); return new CommonFields(modelId, uri, organizationId, similarity, maxInputTokens, dims, rateLimitSettings); } @@ -258,7 +269,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); toXContentFragmentOfExposedFields(builder, params); - rateLimitSettings.toXContent(builder, params); if (dimensionsSetByUser != null) { builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); @@ -286,6 +296,7 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil if (maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, maxInputTokens); } + rateLimitSettings.toXContent(builder, params); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java index cfc375a525dd6..f593ca4e0c603 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import java.io.IOException; import java.util.Map; @@ -21,19 +22,29 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; public class RateLimitSettings implements Writeable, ToXContentFragment { - public static final String FIELD_NAME = "rate_limit"; public static final String REQUESTS_PER_MINUTE_FIELD = "requests_per_minute"; private final long requestsPerTimeUnit; private final TimeUnit timeUnit; - public static RateLimitSettings of(Map map, RateLimitSettings defaultValue, ValidationException validationException) { + public static RateLimitSettings of( + Map map, + RateLimitSettings defaultValue, + ValidationException validationException, + String serviceName, + ConfigurationParseContext context + ) { Map settings = removeFromMapOrDefaultEmpty(map, FIELD_NAME); var requestsPerMinute = extractOptionalPositiveLong(settings, REQUESTS_PER_MINUTE_FIELD, FIELD_NAME, validationException); + if (ConfigurationParseContext.isRequestContext(context)) { + throwIfNotEmptyMap(settings, serviceName); + } + return requestsPerMinute == null ? defaultValue : new RateLimitSettings(requestsPerMinute); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java index 88d408d309a7b..8792234102a94 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java @@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; @@ -92,7 +93,7 @@ public void testEmbeddingsRequestAction() throws IOException { TruncatorTests.createTruncator() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson)); @@ -141,7 +142,7 @@ public void testChatCompletionRequestAction() throws IOException { TruncatorTests.createTruncator() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java index 0a2a00143b205..72124a6221254 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java @@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel; @@ -82,7 +83,7 @@ public void shutdown() throws IOException { public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -132,7 +133,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -183,7 +184,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -237,7 +238,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); // note - there is no complete documentation on Azure's error messages @@ -313,7 +314,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); // note - there is no complete documentation on Azure's error messages @@ -389,7 +390,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC public void testExecute_TruncatesInputBeforeSending() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -440,7 +441,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -498,7 +499,7 @@ public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOExcept public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -554,7 +555,7 @@ public void testInfer_AzureOpenAiCompletionModel_FailsFromInvalidResponseFormat( // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); // "choices" missing diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java index 96127841c17a8..7d52616402405 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java @@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreatorTests.getContentOfMessageInRequestMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel; import static org.hamcrest.Matchers.hasSize; @@ -77,7 +78,7 @@ public void shutdown() throws IOException { public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java index 89cc847321796..4cc7b7c0d9cfc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java @@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; @@ -81,7 +82,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index 9b0371ad51f8c..9ec34e7d8e5c5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -73,7 +74,7 @@ public void shutdown() throws IOException { public void testCreate_CohereEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -154,7 +155,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -214,7 +215,7 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java index 12c3d132d1244..0a604980f6c83 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java @@ -77,7 +77,7 @@ public void shutdown() throws IOException { public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -138,7 +138,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -290,7 +290,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java index dbc97fa2e13d8..9cf6de27b93bc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java @@ -81,7 +81,7 @@ public void shutdown() throws IOException { public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -162,7 +162,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java index 09ef5351eb1fc..9dd465e0276f4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java @@ -74,7 +74,7 @@ public void shutdown() throws IOException { public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -206,7 +206,7 @@ public void testExecute_ThrowsException() { public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java index a55b3c5f5030c..7e98b9b31f6ed 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java @@ -79,7 +79,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var input = "input"; var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java index fceea8810f6c2..b3ec565b3146a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java @@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.contains; @@ -75,7 +76,7 @@ public void shutdown() throws IOException { public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -131,7 +132,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx ); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -187,7 +188,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -239,7 +240,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws ); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); // this will fail because the only valid formats are {"embeddings": [[...]]} or [[...]] @@ -292,7 +293,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJsonContentTooLarge = """ @@ -357,7 +358,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc public void testExecute_TruncatesInputBeforeSending() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index 496238eaad0e4..b6d7eb673b7f0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -38,6 +38,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; @@ -74,7 +75,7 @@ public void shutdown() throws IOException { public void testCreate_OpenAiEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -127,7 +128,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException { public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -179,7 +180,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException { public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -238,7 +239,7 @@ public void testCreate_OpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() th ); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -292,7 +293,7 @@ public void testCreate_OpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() th public void testCreate_OpenAiChatCompletionModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -355,7 +356,7 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -417,7 +418,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -486,7 +487,7 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( ); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -552,7 +553,7 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); var contentTooLargeErrorMessage = @@ -635,7 +636,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); var contentTooLargeErrorMessage = @@ -718,7 +719,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC public void testExecute_TruncatesInputBeforeSending() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java index 914ff12db259a..42b062667f770 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java @@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -80,7 +81,7 @@ public void shutdown() throws IOException { public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -234,7 +235,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java index 15b7417912ef5..03c0b4d146b2e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java @@ -79,7 +79,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java new file mode 100644 index 0000000000000..03838896b879d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.mockito.Mockito.mock; + +public class BaseRequestManagerTests extends ESTestCase { + public void testRateLimitGrouping_DifferentObjectReferences_HaveSameGroup() { + int val1 = 1; + int val2 = 1; + + var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) { + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + } + }; + + var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val2, new RateLimitSettings(1)) { + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + } + }; + + assertThat(manager1.rateLimitGrouping(), is(manager2.rateLimitGrouping())); + } + + public void testRateLimitGrouping_DifferentSettings_HaveDifferentGroup() { + int val1 = 1; + + var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) { + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + } + }; + + var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(2)) { + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + } + }; + + assertThat(manager1.rateLimitGrouping(), not(manager2.rateLimitGrouping())); + } + + public void testRateLimitGrouping_DifferentSettingsTimeUnit_HaveDifferentGroup() { + int val1 = 1; + + var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.MILLISECONDS)) { + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + } + }; + + var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.DAYS)) { + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + } + }; + + assertThat(manager1.rateLimitGrouping(), not(manager2.rateLimitGrouping())); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 368745b310884..2b8b5f178b3de 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -79,7 +79,7 @@ public void shutdown() throws IOException, InterruptedException { public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception { var senderFactory = createSenderFactory(clientManager, threadRef); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSender(senderFactory)) { sender.start(); String responseJson = """ @@ -135,11 +135,11 @@ public void testHttpRequestSender_Throws_WhenCallingSendBeforeStart() throws Exc mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( AssertionError.class, - () -> sender.send(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener) + () -> sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener) ); assertThat(thrownException.getMessage(), is("call start() before sending a request")); } @@ -155,17 +155,12 @@ public void testHttpRequestSender_Throws_WhenATimeoutOccurs() throws Exception { mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { assertThat(sender, instanceOf(HttpRequestSender.class)); sender.start(); PlainActionFuture listener = new PlainActionFuture<>(); - sender.send( - ExecutableRequestCreatorTests.createMock(), - new DocumentsOnlyInput(List.of()), - TimeValue.timeValueNanos(1), - listener - ); + sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener); var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT)); @@ -186,16 +181,11 @@ public void testHttpRequestSenderWithTimeout_Throws_WhenATimeoutOccurs() throws mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { sender.start(); PlainActionFuture listener = new PlainActionFuture<>(); - sender.send( - ExecutableRequestCreatorTests.createMock(), - new DocumentsOnlyInput(List.of()), - TimeValue.timeValueNanos(1), - listener - ); + sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener); var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT)); @@ -220,6 +210,7 @@ private static HttpRequestSender.Factory createSenderFactory(HttpClientManager c when(mockThreadPool.executor(anyString())).thenReturn(mockExecutorService); when(mockThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); when(mockThreadPool.schedule(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.ScheduledCancellable.class)); + when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class)); return new HttpRequestSender.Factory( ServiceComponentsTests.createWithEmptySettings(mockThreadPool), @@ -248,7 +239,7 @@ public static HttpRequestSender.Factory createSenderFactory( ); } - public static Sender createSenderWithSingleRequestManager(HttpRequestSender.Factory factory, String serviceName) { - return factory.createSender(serviceName); + public static Sender createSender(HttpRequestSender.Factory factory) { + return factory.createSender(); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettingsTests.java index c0c0bdd49f617..489b502c04110 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettingsTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; import static org.elasticsearch.xpack.inference.Utils.mockClusterService; @@ -18,12 +19,23 @@ public static RequestExecutorServiceSettings createRequestExecutorServiceSetting } public static RequestExecutorServiceSettings createRequestExecutorServiceSettings(@Nullable Integer queueCapacity) { + return createRequestExecutorServiceSettings(queueCapacity, null); + } + + public static RequestExecutorServiceSettings createRequestExecutorServiceSettings( + @Nullable Integer queueCapacity, + @Nullable TimeValue staleDuration + ) { var settingsBuilder = Settings.builder(); if (queueCapacity != null) { settingsBuilder.put(RequestExecutorServiceSettings.TASK_QUEUE_CAPACITY_SETTING.getKey(), queueCapacity); } + if (staleDuration != null) { + settingsBuilder.put(RequestExecutorServiceSettings.RATE_LIMIT_GROUP_STALE_DURATION_SETTING.getKey(), staleDuration); + } + return createRequestExecutorServiceSettings(settingsBuilder.build()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java index ff88ba221d985..9a45e10007643 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java @@ -18,13 +18,19 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.RateLimiter; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; import org.junit.After; import org.junit.Before; import org.mockito.ArgumentCaptor; import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; @@ -42,10 +48,13 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; public class RequestExecutorServiceTests extends ESTestCase { @@ -70,7 +79,7 @@ public void testQueueSize_IsEmpty() { public void testQueueSize_IsOne() { var service = createRequestExecutorServiceWithMocks(); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); assertThat(service.queueSize(), is(1)); } @@ -92,7 +101,20 @@ public void testIsTerminated_IsTrue() throws InterruptedException { assertTrue(service.isTerminated()); } - public void testIsTerminated_AfterStopFromSeparateThread() throws Exception { + public void testCallingStartTwice_ThrowsAssertionException() throws InterruptedException { + var latch = new CountDownLatch(1); + var service = createRequestExecutorService(latch, mock(RetryingHttpSender.class)); + + service.shutdown(); + service.start(); + latch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + + assertTrue(service.isTerminated()); + var exception = expectThrows(AssertionError.class, service::start); + assertThat(exception.getMessage(), is("start() can only be called once")); + } + + public void testIsTerminated_AfterStopFromSeparateThread() { var waitToShutdown = new CountDownLatch(1); var waitToReturnFromSend = new CountDownLatch(1); @@ -127,41 +149,48 @@ public void testIsTerminated_AfterStopFromSeparateThread() throws Exception { assertTrue(service.isTerminated()); } - public void testSend_AfterShutdown_Throws() { + public void testExecute_AfterShutdown_Throws() { var service = createRequestExecutorServiceWithMocks(); service.shutdown(); + var requestManager = RequestManagerTests.createMock("id"); var listener = new PlainActionFuture(); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), - is("Failed to enqueue task because the http executor service [test_service] has already shutdown") + is( + Strings.format( + "Failed to enqueue task for inference id [id] because the request service [%s] has already shutdown", + requestManager.rateLimitGrouping().hashCode() + ) + ) ); assertTrue(thrownException.isExecutorShutdown()); } - public void testSend_Throws_WhenQueueIsFull() { - var service = new RequestExecutorService( - "test_service", - threadPool, - null, - createRequestExecutorServiceSettings(1), - new SingleRequestManager(mock(RetryingHttpSender.class)) - ); + public void testExecute_Throws_WhenQueueIsFull() { + var service = new RequestExecutorService(threadPool, null, createRequestExecutorServiceSettings(1), mock(RetryingHttpSender.class)); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + + var requestManager = RequestManagerTests.createMock("id"); var listener = new PlainActionFuture(); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), - is("Failed to execute task because the http executor service [test_service] queue is full") + is( + Strings.format( + "Failed to execute task for inference id [id] because the request service [%s] queue is full", + requestManager.rateLimitGrouping().hashCode() + ) + ) ); assertFalse(thrownException.isExecutorShutdown()); } @@ -203,16 +232,11 @@ public void testShutdown_AllowsMultipleCalls() { assertTrue(service.isShutdown()); } - public void testSend_CallsOnFailure_WhenRequestTimesOut() { + public void testExecute_CallsOnFailure_WhenRequestTimesOut() { var service = createRequestExecutorServiceWithMocks(); var listener = new PlainActionFuture(); - service.execute( - ExecutableRequestCreatorTests.createMock(), - new DocumentsOnlyInput(List.of()), - TimeValue.timeValueNanos(1), - listener - ); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener); var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT)); @@ -222,7 +246,7 @@ public void testSend_CallsOnFailure_WhenRequestTimesOut() { ); } - public void testSend_PreservesThreadContext() throws InterruptedException, ExecutionException, TimeoutException { + public void testExecute_PreservesThreadContext() throws InterruptedException, ExecutionException, TimeoutException { var headerKey = "not empty"; var headerValue = "value"; @@ -270,7 +294,7 @@ public void onFailure(Exception e) { } }; - service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); Future executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service); @@ -280,11 +304,12 @@ public void onFailure(Exception e) { finishedOnResponse.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); } - public void testSend_NotifiesTasksOfShutdown() { + public void testExecute_NotifiesTasksOfShutdown() { var service = createRequestExecutorServiceWithMocks(); + var requestManager = RequestManagerTests.createMock(mock(RequestSender.class), "id"); var listener = new PlainActionFuture(); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); service.shutdown(); service.start(); @@ -293,47 +318,62 @@ public void testSend_NotifiesTasksOfShutdown() { assertThat( thrownException.getMessage(), - is("Failed to send request, queue service [test_service] has shutdown prior to executing request") + is( + Strings.format( + "Failed to send request, request service [%s] for inference id [id] has shutdown prior to executing request", + requestManager.rateLimitGrouping().hashCode() + ) + ) ); assertTrue(thrownException.isExecutorShutdown()); assertTrue(service.isTerminated()); } - public void testQueueTake_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException { + public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException { @SuppressWarnings("unchecked") BlockingQueue queue = mock(LinkedBlockingQueue.class); + var requestSender = mock(RetryingHttpSender.class); + var service = new RequestExecutorService( - getTestName(), threadPool, mockQueueCreator(queue), null, createRequestExecutorServiceSettingsEmpty(), - new SingleRequestManager(mock(RetryingHttpSender.class)) + requestSender, + Clock.systemUTC(), + RequestExecutorService.DEFAULT_SLEEPER, + RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR ); - when(queue.take()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> { + PlainActionFuture listener = new PlainActionFuture<>(); + var requestManager = RequestManagerTests.createMock(requestSender, "id"); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); + + when(queue.poll()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> { service.shutdown(); return null; }); service.start(); assertTrue(service.isTerminated()); - verify(queue, times(2)).take(); } - public void testQueueTake_ThrowingInterruptedException_TerminatesService() throws Exception { + public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception { @SuppressWarnings("unchecked") BlockingQueue queue = mock(LinkedBlockingQueue.class); - when(queue.take()).thenThrow(new InterruptedException("failed")); + var sleeper = mock(RequestExecutorService.Sleeper.class); + doThrow(new InterruptedException("failed")).when(sleeper).sleep(any()); var service = new RequestExecutorService( - getTestName(), threadPool, mockQueueCreator(queue), null, createRequestExecutorServiceSettingsEmpty(), - new SingleRequestManager(mock(RetryingHttpSender.class)) + mock(RetryingHttpSender.class), + Clock.systemUTC(), + sleeper, + RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR ); Future executorTermination = threadPool.generic().submit(() -> { @@ -347,66 +387,30 @@ public void testQueueTake_ThrowingInterruptedException_TerminatesService() throw executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); assertTrue(service.isTerminated()); - verify(queue, times(1)).take(); - } - - public void testQueueTake_RejectsTask_WhenServiceShutsDown() throws Exception { - var mockTask = mock(RejectableTask.class); - @SuppressWarnings("unchecked") - BlockingQueue queue = mock(LinkedBlockingQueue.class); - - var service = new RequestExecutorService( - "test_service", - threadPool, - mockQueueCreator(queue), - null, - createRequestExecutorServiceSettingsEmpty(), - new SingleRequestManager(mock(RetryingHttpSender.class)) - ); - - doAnswer(invocation -> { - service.shutdown(); - return mockTask; - }).doReturn(new NoopTask()).when(queue).take(); - - service.start(); - - assertTrue(service.isTerminated()); - verify(queue, times(1)).take(); - - ArgumentCaptor argument = ArgumentCaptor.forClass(Exception.class); - verify(mockTask, times(1)).onRejection(argument.capture()); - assertThat(argument.getValue(), instanceOf(EsRejectedExecutionException.class)); - assertThat( - argument.getValue().getMessage(), - is("Failed to send request, queue service [test_service] has shutdown prior to executing request") - ); - - var rejectionException = (EsRejectedExecutionException) argument.getValue(); - assertTrue(rejectionException.isExecutorShutdown()); } public void testChangingCapacity_SetsCapacityToTwo() throws ExecutionException, InterruptedException, TimeoutException { var requestSender = mock(RetryingHttpSender.class); var settings = createRequestExecutorServiceSettings(1); - var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); + var service = new RequestExecutorService(threadPool, null, settings, requestSender); - service.execute( - ExecutableRequestCreatorTests.createMock(requestSender), - new DocumentsOnlyInput(List.of()), - null, - new PlainActionFuture<>() - ); + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); assertThat(service.queueSize(), is(1)); PlainActionFuture listener = new PlainActionFuture<>(); - service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + var requestManager = RequestManagerTests.createMock(requestSender, "id"); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), - is("Failed to execute task because the http executor service [test_service] queue is full") + is( + Strings.format( + "Failed to execute task for inference id [id] because the request service [%s] queue is full", + requestManager.rateLimitGrouping().hashCode() + ) + ) ); settings.setQueueCapacity(2); @@ -426,7 +430,7 @@ public void testChangingCapacity_SetsCapacityToTwo() throws ExecutionException, executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); assertTrue(service.isTerminated()); - assertThat(service.remainingQueueCapacity(), is(2)); + assertThat(service.remainingQueueCapacity(requestManager), is(2)); } public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull() throws ExecutionException, InterruptedException, @@ -434,23 +438,24 @@ public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull( var requestSender = mock(RetryingHttpSender.class); var settings = createRequestExecutorServiceSettings(3); - var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); + var service = new RequestExecutorService(threadPool, null, settings, requestSender); service.execute( - ExecutableRequestCreatorTests.createMock(requestSender), + RequestManagerTests.createMock(requestSender, "id"), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>() ); service.execute( - ExecutableRequestCreatorTests.createMock(requestSender), + RequestManagerTests.createMock(requestSender, "id"), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>() ); PlainActionFuture listener = new PlainActionFuture<>(); - service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + var requestManager = RequestManagerTests.createMock(requestSender, "id"); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); assertThat(service.queueSize(), is(3)); settings.setQueueCapacity(1); @@ -470,7 +475,7 @@ public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull( executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); assertTrue(service.isTerminated()); - assertThat(service.remainingQueueCapacity(), is(1)); + assertThat(service.remainingQueueCapacity(requestManager), is(1)); assertThat(service.queueSize(), is(0)); var thrownException = expectThrows( @@ -479,7 +484,12 @@ public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull( ); assertThat( thrownException.getMessage(), - is("Failed to send request, queue service [test_service] has shutdown prior to executing request") + is( + Strings.format( + "Failed to send request, request service [%s] for inference id [id] has shutdown prior to executing request", + requestManager.rateLimitGrouping().hashCode() + ) + ) ); assertTrue(thrownException.isExecutorShutdown()); } @@ -489,23 +499,24 @@ public void testChangingCapacity_ToZero_SetsQueueCapacityToUnbounded() throws IO var requestSender = mock(RetryingHttpSender.class); var settings = createRequestExecutorServiceSettings(1); - var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); + var service = new RequestExecutorService(threadPool, null, settings, requestSender); + var requestManager = RequestManagerTests.createMock(requestSender); - service.execute( - ExecutableRequestCreatorTests.createMock(requestSender), - new DocumentsOnlyInput(List.of()), - null, - new PlainActionFuture<>() - ); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); assertThat(service.queueSize(), is(1)); PlainActionFuture listener = new PlainActionFuture<>(); - service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(RequestManagerTests.createMock(requestSender, "id"), new DocumentsOnlyInput(List.of()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), - is("Failed to execute task because the http executor service [test_service] queue is full") + is( + Strings.format( + "Failed to execute task for inference id [id] because the request service [%s] queue is full", + requestManager.rateLimitGrouping().hashCode() + ) + ) ); settings.setQueueCapacity(0); @@ -525,7 +536,133 @@ public void testChangingCapacity_ToZero_SetsQueueCapacityToUnbounded() throws IO executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); assertTrue(service.isTerminated()); - assertThat(service.remainingQueueCapacity(), is(Integer.MAX_VALUE)); + assertThat(service.remainingQueueCapacity(requestManager), is(Integer.MAX_VALUE)); + } + + public void testDoesNotExecuteTask_WhenCannotReserveTokens() { + var mockRateLimiter = mock(RateLimiter.class); + RequestExecutorService.RateLimiterCreator rateLimiterCreator = (a, b, c) -> mockRateLimiter; + + var requestSender = mock(RetryingHttpSender.class); + var settings = createRequestExecutorServiceSettings(1); + var service = new RequestExecutorService( + threadPool, + RequestExecutorService.DEFAULT_QUEUE_CREATOR, + null, + settings, + requestSender, + Clock.systemUTC(), + RequestExecutorService.DEFAULT_SLEEPER, + rateLimiterCreator + ); + var requestManager = RequestManagerTests.createMock(requestSender); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); + + doAnswer(invocation -> { + service.shutdown(); + return TimeValue.timeValueDays(1); + }).when(mockRateLimiter).timeToReserve(anyInt()); + + service.start(); + + verifyNoInteractions(requestSender); + } + + public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_AndExecutesTask() { + var mockRateLimiter = mock(RateLimiter.class); + when(mockRateLimiter.reserve(anyInt())).thenReturn(TimeValue.timeValueDays(0)); + + RequestExecutorService.RateLimiterCreator rateLimiterCreator = (a, b, c) -> mockRateLimiter; + + var requestSender = mock(RetryingHttpSender.class); + var settings = createRequestExecutorServiceSettings(1); + var service = new RequestExecutorService( + threadPool, + RequestExecutorService.DEFAULT_QUEUE_CREATOR, + null, + settings, + requestSender, + Clock.systemUTC(), + RequestExecutorService.DEFAULT_SLEEPER, + rateLimiterCreator + ); + var requestManager = RequestManagerTests.createMock(requestSender); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); + + when(mockRateLimiter.timeToReserve(anyInt())).thenReturn(TimeValue.timeValueDays(1)).thenReturn(TimeValue.timeValueDays(0)); + + doAnswer(invocation -> { + service.shutdown(); + return Void.TYPE; + }).when(requestSender).send(any(), any(), any(), any(), any(), any()); + + service.start(); + + verify(requestSender, times(1)).send(any(), any(), any(), any(), any(), any()); + } + + public void testRemovesRateLimitGroup_AfterStaleDuration() { + var now = Instant.now(); + var clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + var requestSender = mock(RetryingHttpSender.class); + var settings = createRequestExecutorServiceSettings(2, TimeValue.timeValueDays(1)); + var service = new RequestExecutorService( + threadPool, + RequestExecutorService.DEFAULT_QUEUE_CREATOR, + null, + settings, + requestSender, + clock, + RequestExecutorService.DEFAULT_SLEEPER, + RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR + ); + var requestManager = RequestManagerTests.createMock(requestSender, "id1"); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); + + assertThat(service.numberOfRateLimitGroups(), is(1)); + // the time is moved to after the stale duration, so now we should remove this grouping + when(clock.instant()).thenReturn(now.plus(Duration.ofDays(2))); + service.removeStaleGroupings(); + assertThat(service.numberOfRateLimitGroups(), is(0)); + + var requestManager2 = RequestManagerTests.createMock(requestSender, "id2"); + service.execute(requestManager2, new DocumentsOnlyInput(List.of()), null, listener); + + assertThat(service.numberOfRateLimitGroups(), is(1)); + } + + public void testStartsCleanupThread() { + var mockThreadPool = mock(ThreadPool.class); + + when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class)); + + var requestSender = mock(RetryingHttpSender.class); + var settings = createRequestExecutorServiceSettings(2, TimeValue.timeValueDays(1)); + var service = new RequestExecutorService( + mockThreadPool, + RequestExecutorService.DEFAULT_QUEUE_CREATOR, + null, + settings, + requestSender, + Clock.systemUTC(), + RequestExecutorService.DEFAULT_SLEEPER, + RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR + ); + + service.shutdown(); + service.start(); + + ArgumentCaptor argument = ArgumentCaptor.forClass(TimeValue.class); + verify(mockThreadPool, times(1)).scheduleWithFixedDelay(any(Runnable.class), argument.capture(), any()); + assertThat(argument.getValue(), is(TimeValue.timeValueDays(1))); } private Future submitShutdownRequest( @@ -552,12 +689,6 @@ private RequestExecutorService createRequestExecutorServiceWithMocks() { } private RequestExecutorService createRequestExecutorService(@Nullable CountDownLatch startupLatch, RetryingHttpSender requestSender) { - return new RequestExecutorService( - "test_service", - threadPool, - startupLatch, - createRequestExecutorServiceSettingsEmpty(), - new SingleRequestManager(requestSender) - ); + return new RequestExecutorService(threadPool, startupLatch, createRequestExecutorServiceSettingsEmpty(), requestSender); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableRequestCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java similarity index 56% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableRequestCreatorTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java index 31297ed432ef5..291de740aca34 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableRequestCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.request.RequestTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyList; @@ -21,34 +22,47 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class ExecutableRequestCreatorTests { +public class RequestManagerTests { public static RequestManager createMock() { - var mockCreator = mock(RequestManager.class); - when(mockCreator.create(any(), anyList(), any(), any(), any(), any())).thenReturn(() -> {}); + return createMock(mock(RequestSender.class)); + } - return mockCreator; + public static RequestManager createMock(String inferenceEntityId) { + return createMock(mock(RequestSender.class), inferenceEntityId); } public static RequestManager createMock(RequestSender requestSender) { - return createMock(requestSender, "id"); + return createMock(requestSender, "id", new RateLimitSettings(1)); + } + + public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId) { + return createMock(requestSender, inferenceEntityId, new RateLimitSettings(1)); } - public static RequestManager createMock(RequestSender requestSender, String modelId) { - var mockCreator = mock(RequestManager.class); + public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId, RateLimitSettings settings) { + var mockManager = mock(RequestManager.class); doAnswer(invocation -> { @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArguments()[5]; - return (Runnable) () -> requestSender.send( + ActionListener listener = (ActionListener) invocation.getArguments()[4]; + requestSender.send( mock(Logger.class), - RequestTests.mockRequest(modelId), + RequestTests.mockRequest(inferenceEntityId), HttpClientContext.create(), () -> false, mock(ResponseHandler.class), listener ); - }).when(mockCreator).create(any(), anyList(), any(), any(), any(), any()); - return mockCreator; + return Void.TYPE; + }).when(mockManager).execute(any(), anyList(), any(), any(), any()); + + // just return something consistent so the hashing works + when(mockManager.rateLimitGrouping()).thenReturn(inferenceEntityId); + + when(mockManager.rateLimitSettings()).thenReturn(settings); + when(mockManager.inferenceEntityId()).thenReturn(inferenceEntityId); + + return mockManager; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManagerTests.java deleted file mode 100644 index 55965bc2354d3..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManagerTests.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -import org.apache.http.client.protocol.HttpClientContext; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; - -public class SingleRequestManagerTests extends ESTestCase { - public void testExecute_DoesNotCallRequestCreatorCreate_WhenInputIsNull() { - var requestCreator = mock(RequestManager.class); - var request = mock(InferenceRequest.class); - when(request.getRequestCreator()).thenReturn(requestCreator); - - new SingleRequestManager(mock(RetryingHttpSender.class)).execute(mock(InferenceRequest.class), HttpClientContext.create()); - verifyNoInteractions(requestCreator); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index ee3403492c423..974b31e73b499 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -33,7 +33,6 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -59,7 +58,7 @@ public void testStart_InitializesTheSender() throws IOException { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { PlainActionFuture listener = new PlainActionFuture<>(); @@ -67,7 +66,7 @@ public void testStart_InitializesTheSender() throws IOException { listener.actionGet(TIMEOUT); verify(sender, times(1)).start(); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); } verify(sender, times(1)).close(); @@ -79,7 +78,7 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { PlainActionFuture listener = new PlainActionFuture<>(); @@ -89,7 +88,7 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep service.start(mock(Model.class), listener); listener.actionGet(TIMEOUT); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(2)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 5869366ac2e22..cacbba82446f1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -76,7 +76,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -819,7 +818,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -841,7 +840,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionServiceSettingsTests.java index 79d6e384d7693..d46a5f190017a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionServiceSettingsTests.java @@ -112,7 +112,8 @@ public void testToFilteredXContent_WritesAllValues() throws IOException { String xContentResult = Strings.toString(builder); assertThat(xContentResult, CoreMatchers.is(""" - {"target":"target_value","provider":"openai","endpoint_type":"token"}""")); + {"target":"target_value","provider":"openai","endpoint_type":"token",""" + """ + "rate_limit":{"requests_per_minute":3}}""")); } public static HashMap createRequestSettingsMap(String target, String provider, String endpointType) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettingsTests.java index 283bfa1490df2..a592dd6e1f956 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettingsTests.java @@ -295,7 +295,7 @@ public void testToFilteredXContent_WritesAllValues_ExceptDimensionsSetByUser() t assertThat(xContentResult, CoreMatchers.is(""" {"target":"target_value","provider":"openai","endpoint_type":"token",""" + """ - "dimensions":1024,"max_input_tokens":512}""")); + "rate_limit":{"requests_per_minute":3},"dimensions":1024,"max_input_tokens":512}""")); } public static HashMap createRequestSettingsMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 9fe8b472b22a5..bb3407056d573 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -594,7 +593,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -616,7 +615,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java index 46e514c8b16c4..797cad8f300ae 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields; import java.io.IOException; @@ -46,7 +47,8 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { AzureOpenAiServiceFields.API_VERSION, apiVersion ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat(serviceSettings, is(new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null))); @@ -63,18 +65,6 @@ public void testToXContent_WritesAllValues() throws IOException { {"resource_name":"resource","deployment_id":"deployment","api_version":"2024","rate_limit":{"requests_per_minute":120}}""")); } - public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException { - var entity = new AzureOpenAiCompletionServiceSettings("resource", "deployment", "2024", null); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = entity.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"resource_name":"resource","deployment_id":"deployment","api_version":"2024"}""")); - } - @Override protected Writeable.Reader instanceReader() { return AzureOpenAiCompletionServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java index f4c6f9b2a4f07..cbb9eea223802 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java @@ -389,7 +389,7 @@ public void testToXContent_WritesAllValues() throws IOException { "dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":3},"dimensions_set_by_user":false}""")); } - public void testToFilteredXContent_WritesAllValues_Except_DimensionsSetByUser_RateLimit() throws IOException { + public void testToFilteredXContent_WritesAllValues_Except_DimensionsSetByUser() throws IOException { var entity = new AzureOpenAiEmbeddingsServiceSettings( "resource", "deployment", @@ -408,7 +408,7 @@ public void testToFilteredXContent_WritesAllValues_Except_DimensionsSetByUser_Ra assertThat(xContentResult, is(""" {"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """ - "dimensions":1024,"max_input_tokens":512}""")); + "dimensions":1024,"max_input_tokens":512,"rate_limit":{"requests_per_minute":1}}""")); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index f06fee4b0b9c4..902d96be29738 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -613,7 +612,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -635,7 +634,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java index aac04e301ece7..b9fc7ee7b9952 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.util.HashMap; @@ -28,7 +29,8 @@ public void testCreateModel_AlwaysWithEmptyTaskSettings() { "service", new HashMap<>(Map.of()), new HashMap<>(Map.of("model", "overridden model")), - null + null, + ConfigurationParseContext.PERSISTENT ); assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java index f4cab3c2b0f1e..ed8bc90d32140 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; @@ -34,7 +35,8 @@ public void testFromMap_WithRateLimitSettingsNull() { var model = "model"; var serviceSettings = CohereCompletionServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, model)) + new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, model)), + ConfigurationParseContext.PERSISTENT ); assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, null))); @@ -55,7 +57,8 @@ public void testFromMap_WithRateLimitSettings() { RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, requestsPerMinute)) ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, new RateLimitSettings(requestsPerMinute)))); @@ -72,18 +75,6 @@ public void testToXContent_WritesAllValues() throws IOException { {"url":"url","model_id":"model","rate_limit":{"requests_per_minute":3}}""")); } - public void testToXContent_WithFilteredObject_WritesAllValues_Except_RateLimit() throws IOException { - var serviceSettings = new CohereCompletionServiceSettings("url", "model", new RateLimitSettings(3)); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = serviceSettings.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"url":"url","model_id":"model"}""")); - } - @Override protected Writeable.Reader instanceReader() { return CohereCompletionServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java index 6f8fe6344b57f..73ebd6c6c0505 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java @@ -331,21 +331,6 @@ public void testToXContent_WritesAllValues() throws IOException { "rate_limit":{"requests_per_minute":3},"embedding_type":"byte"}""")); } - public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException { - var serviceSettings = new CohereEmbeddingsServiceSettings( - new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3)), - CohereEmbeddingType.INT8 - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = serviceSettings.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - assertThat(xContentResult, is(""" - {"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id",""" + """ - "embedding_type":"byte"}""")); - } - @Override protected Writeable.Reader instanceReader() { return CohereEmbeddingsServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java index 4943ddf74fda1..1ce5a9fb12833 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java @@ -51,20 +51,6 @@ public void testToXContent_WritesAllValues() throws IOException { "rate_limit":{"requests_per_minute":3}}""")); } - public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException { - var serviceSettings = new CohereRerankServiceSettings( - new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3)) - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = serviceSettings.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - // TODO we probably shouldn't allow configuring these fields for reranking - assertThat(xContentResult, is(""" - {"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id"}""")); - } - @Override protected Writeable.Reader instanceReader() { return CohereRerankServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 32e912ff8529a..110276e63d077 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.hasSize; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -494,7 +493,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOEx var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -516,7 +515,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOEx is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java index 025317fbe025a..f4c13db78c4bc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.net.URISyntaxException; @@ -28,7 +29,8 @@ public void testCreateModel_AlwaysWithEmptyTaskSettings() { "service", new HashMap<>(Map.of("model_id", "model")), new HashMap<>(Map.of()), - null + null, + ConfigurationParseContext.PERSISTENT ); assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java index 46e6e60af493c..6652af26e09e1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; @@ -31,7 +32,10 @@ public static GoogleAiStudioCompletionServiceSettings createRandom() { public void testFromMap_Request_CreatesSettingsCorrectly() { var model = "some model"; - var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, model))); + var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, model)), + ConfigurationParseContext.PERSISTENT + ); assertThat(serviceSettings, is(new GoogleAiStudioCompletionServiceSettings(model, null))); } @@ -47,18 +51,6 @@ public void testToXContent_WritesAllValues() throws IOException { {"model_id":"model","rate_limit":{"requests_per_minute":360}}""")); } - public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException { - var entity = new GoogleAiStudioCompletionServiceSettings("model", null); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = entity.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"model_id":"model"}""")); - } - @Override protected Writeable.Reader instanceReader() { return GoogleAiStudioCompletionServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettingsTests.java index b5fbd28b476ba..cc195333adfd4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettingsTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; @@ -55,7 +56,8 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { ServiceFields.SIMILARITY, similarity.toString() ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat(serviceSettings, is(new GoogleAiStudioEmbeddingsServiceSettings(model, maxInputTokens, dims, similarity, null))); @@ -80,23 +82,6 @@ public void testToXContent_WritesAllValues() throws IOException { }""")); } - public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException { - var entity = new GoogleAiStudioEmbeddingsServiceSettings("model", 1024, 8, SimilarityMeasure.DOT_PRODUCT, null); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = entity.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" - { - "model_id":"model", - "max_input_tokens": 1024, - "dimensions": 8, - "similarity": "dot_product" - }""")); - } - @Override protected Writeable.Reader instanceReader() { return GoogleAiStudioEmbeddingsServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index 398b21312a03a..fd7e1b48b7e03 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.junit.After; import org.junit.Before; @@ -33,7 +34,6 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.is; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -59,7 +59,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -81,7 +81,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } @@ -111,7 +111,8 @@ protected HuggingFaceModel createModel( TaskType taskType, Map serviceSettings, Map secretSettings, - String failureMessage + String failureMessage, + ConfigurationParseContext context ) { return null; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java index 91b91593adee7..04e9697b08877 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -57,7 +58,10 @@ public void testFromMap() { var dims = 384; var maxInputTokens = 128; { - var serviceSettings = HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url))); + var serviceSettings = HuggingFaceServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, url)), + ConfigurationParseContext.PERSISTENT + ); assertThat(serviceSettings, is(new HuggingFaceServiceSettings(url))); } { @@ -73,7 +77,8 @@ public void testFromMap() { ServiceFields.MAX_INPUT_TOKENS, maxInputTokens ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat( serviceSettings, @@ -95,7 +100,8 @@ public void testFromMap() { RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3)) ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat( serviceSettings, @@ -105,7 +111,10 @@ public void testFromMap() { } public void testFromMap_MissingUrl_ThrowsError() { - var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceServiceSettings.fromMap(new HashMap<>())); + var thrownException = expectThrows( + ValidationException.class, + () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(), ConfigurationParseContext.PERSISTENT) + ); assertThat( thrownException.getMessage(), @@ -118,7 +127,7 @@ public void testFromMap_MissingUrl_ThrowsError() { public void testFromMap_EmptyUrl_ThrowsError() { var thrownException = expectThrows( ValidationException.class, - () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, ""))) + () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")), ConfigurationParseContext.PERSISTENT) ); assertThat( @@ -136,7 +145,7 @@ public void testFromMap_InvalidUrl_ThrowsError() { var url = "https://www.abc^.com"; var thrownException = expectThrows( ValidationException.class, - () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url))) + () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)), ConfigurationParseContext.PERSISTENT) ); assertThat( @@ -152,7 +161,10 @@ public void testFromMap_InvalidSimilarity_ThrowsError() { var similarity = "by_size"; var thrownException = expectThrows( ValidationException.class, - () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.SIMILARITY, similarity))) + () -> HuggingFaceServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.SIMILARITY, similarity)), + ConfigurationParseContext.PERSISTENT + ) ); assertThat( @@ -175,18 +187,6 @@ public void testToXContent_WritesAllValues() throws IOException { {"url":"url","rate_limit":{"requests_per_minute":3}}""")); } - public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException { - var serviceSettings = new HuggingFaceServiceSettings(ServiceUtils.createUri("url"), null, null, null, new RateLimitSettings(3)); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = serviceSettings.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = org.elasticsearch.common.Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"url":"url"}""")); - } - @Override protected Writeable.Reader instanceReader() { return HuggingFaceServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java index 57f9c59b65e12..2a44429687fb3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -32,7 +33,10 @@ public static HuggingFaceElserServiceSettings createRandom() { public void testFromMap() { var url = "https://www.abc.com"; - var serviceSettings = HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url))); + var serviceSettings = HuggingFaceElserServiceSettings.fromMap( + new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)), + ConfigurationParseContext.PERSISTENT + ); assertThat(new HuggingFaceElserServiceSettings(url), is(serviceSettings)); } @@ -40,7 +44,10 @@ public void testFromMap() { public void testFromMap_EmptyUrl_ThrowsError() { var thrownException = expectThrows( ValidationException.class, - () -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, ""))) + () -> HuggingFaceElserServiceSettings.fromMap( + new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, "")), + ConfigurationParseContext.PERSISTENT + ) ); assertThat( @@ -55,7 +62,10 @@ public void testFromMap_EmptyUrl_ThrowsError() { } public void testFromMap_MissingUrl_ThrowsError() { - var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>())); + var thrownException = expectThrows( + ValidationException.class, + () -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(), ConfigurationParseContext.PERSISTENT) + ); assertThat( thrownException.getMessage(), @@ -72,7 +82,10 @@ public void testFromMap_InvalidUrl_ThrowsError() { var url = "https://www.abc^.com"; var thrownException = expectThrows( ValidationException.class, - () -> HuggingFaceElserServiceSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url))) + () -> HuggingFaceElserServiceSettings.fromMap( + new HashMap<>(Map.of(HuggingFaceElserServiceSettings.URL, url)), + ConfigurationParseContext.PERSISTENT + ) ); assertThat( @@ -98,18 +111,6 @@ public void testToXContent_WritesAllValues() throws IOException { {"url":"url","max_input_tokens":512,"rate_limit":{"requests_per_minute":3}}""")); } - public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException { - var serviceSettings = new HuggingFaceElserServiceSettings(ServiceUtils.createUri("url"), new RateLimitSettings(3)); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = serviceSettings.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = org.elasticsearch.common.Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"url":"url","max_input_tokens":512}""")); - } - @Override protected Writeable.Reader instanceReader() { return HuggingFaceElserServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 3ead273e78110..624b24e611340 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -67,7 +67,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -393,7 +392,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -415,7 +414,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java index 13f43a5f31ad3..076986acdcee6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java @@ -98,18 +98,6 @@ public void testToXContent_WritesAllValues() throws IOException { "rate_limit":{"requests_per_minute":3}}""")); } - public void testToFilteredXContent_WritesFilteredValues() throws IOException { - var entity = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3)); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = entity.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, CoreMatchers.is(""" - {"model":"model_name","dimensions":1024,"max_input_tokens":512}""")); - } - public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException { var outputBuffer = new BytesStreamOutput(); var settings = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index cbac29c452772..41995235565df 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -72,7 +72,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -675,7 +674,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -697,7 +696,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java index 186ca89426418..051a9bc6d9bef 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields; @@ -48,7 +49,8 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { ServiceFields.MAX_INPUT_TOKENS, maxInputTokens ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat( @@ -77,7 +79,8 @@ public void testFromMap_Request_CreatesSettingsCorrectly_WithRateLimit() { RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, rateLimit)) ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertThat( @@ -101,7 +104,8 @@ public void testFromMap_MissingUrl_DoesNotThrowException() { ServiceFields.MAX_INPUT_TOKENS, maxInputTokens ) - ) + ), + ConfigurationParseContext.PERSISTENT ); assertNull(serviceSettings.uri()); @@ -113,7 +117,10 @@ public void testFromMap_MissingUrl_DoesNotThrowException() { public void testFromMap_EmptyUrl_ThrowsError() { var thrownException = expectThrows( ValidationException.class, - () -> OpenAiChatCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "", ServiceFields.MODEL_ID, "model"))) + () -> OpenAiChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, "", ServiceFields.MODEL_ID, "model")), + ConfigurationParseContext.PERSISTENT + ) ); assertThat( @@ -132,7 +139,8 @@ public void testFromMap_MissingOrganization_DoesNotThrowException() { var maxInputTokens = 8192; var serviceSettings = OpenAiChatCompletionServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId, ServiceFields.MAX_INPUT_TOKENS, maxInputTokens)) + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId, ServiceFields.MAX_INPUT_TOKENS, maxInputTokens)), + ConfigurationParseContext.PERSISTENT ); assertNull(serviceSettings.uri()); @@ -144,7 +152,8 @@ public void testFromMap_EmptyOrganization_ThrowsError() { var thrownException = expectThrows( ValidationException.class, () -> OpenAiChatCompletionServiceSettings.fromMap( - new HashMap<>(Map.of(OpenAiServiceFields.ORGANIZATION, "", ServiceFields.MODEL_ID, "model")) + new HashMap<>(Map.of(OpenAiServiceFields.ORGANIZATION, "", ServiceFields.MODEL_ID, "model")), + ConfigurationParseContext.PERSISTENT ) ); @@ -164,7 +173,8 @@ public void testFromMap_InvalidUrl_ThrowsError() { var thrownException = expectThrows( ValidationException.class, () -> OpenAiChatCompletionServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, "model")) + new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.MODEL_ID, "model")), + ConfigurationParseContext.PERSISTENT ) ); @@ -213,19 +223,6 @@ public void testToXContent_DoesNotWriteOptionalValues() throws IOException { {"model_id":"model","rate_limit":{"requests_per_minute":500}}""")); } - public void testToXContent_WritesAllValues_Except_RateLimit() throws IOException { - var serviceSettings = new OpenAiChatCompletionServiceSettings("model", "url", "org", 1024, new RateLimitSettings(2)); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - var filteredXContent = serviceSettings.getFilteredXContentObject(); - filteredXContent.toXContent(builder, null); - String xContentResult = org.elasticsearch.common.Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"model_id":"model","url":"url","organization_id":"org",""" + """ - "max_input_tokens":1024}""")); - } - @Override protected Writeable.Reader instanceReader() { return OpenAiChatCompletionServiceSettings::new; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java index 438f895fe48ad..cc0004a2d678c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java @@ -406,7 +406,7 @@ public void testToFilteredXContent_WritesAllValues_ExceptDimensionsSetByUser() t assertThat(xContentResult, is(""" {"model_id":"model","url":"url","organization_id":"org","similarity":"dot_product",""" + """ - "dimensions":1,"max_input_tokens":2}""")); + "dimensions":1,"max_input_tokens":2,"rate_limit":{"requests_per_minute":3000}}""")); } public void testToFilteredXContent_WritesAllValues_WithSpecifiedRateLimit() throws IOException { @@ -428,7 +428,7 @@ public void testToFilteredXContent_WritesAllValues_WithSpecifiedRateLimit() thro assertThat(xContentResult, is(""" {"model_id":"model","url":"url","organization_id":"org","similarity":"dot_product",""" + """ - "dimensions":1,"max_input_tokens":2}""")); + "dimensions":1,"max_input_tokens":2,"rate_limit":{"requests_per_minute":2000}}""")); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java index cdee7c452ff52..7e3bdd6b8e5dc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.settings; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; @@ -14,6 +15,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import java.io.IOException; import java.util.HashMap; @@ -49,7 +51,7 @@ public void testOf() { Map settings = new HashMap<>( Map.of(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100))) ); - var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation); + var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation, "test", ConfigurationParseContext.PERSISTENT); assertThat(res, is(new RateLimitSettings(100))); assertTrue(validation.validationErrors().isEmpty()); @@ -60,7 +62,7 @@ public void testOf_UsesDefaultValue_WhenRateLimit_IsAbsent() { Map settings = new HashMap<>( Map.of("abc", new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100))) ); - var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation); + var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation, "test", ConfigurationParseContext.PERSISTENT); assertThat(res, is(new RateLimitSettings(1))); assertTrue(validation.validationErrors().isEmpty()); @@ -69,12 +71,24 @@ public void testOf_UsesDefaultValue_WhenRateLimit_IsAbsent() { public void testOf_UsesDefaultValue_WhenRequestsPerMinute_IsAbsent() { var validation = new ValidationException(); Map settings = new HashMap<>(Map.of(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of("abc", 100)))); - var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation); + var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation, "test", ConfigurationParseContext.PERSISTENT); assertThat(res, is(new RateLimitSettings(1))); assertTrue(validation.validationErrors().isEmpty()); } + public void testOf_ThrowsException_WithUnknownField_InRequestContext() { + var validation = new ValidationException(); + Map settings = new HashMap<>(Map.of(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of("abc", 100)))); + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> RateLimitSettings.of(settings, new RateLimitSettings(1), validation, "test", ConfigurationParseContext.REQUEST) + ); + + assertThat(exception.getMessage(), is("Model configuration contains settings [{abc=100}] unknown to the [test] service")); + } + public void testToXContent() throws IOException { var settings = new RateLimitSettings(100); From dc13b75656146b59d5e52be7fd78b288237cf7bd Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Wed, 5 Jun 2024 15:35:31 +0200 Subject: [PATCH 08/30] [Inference API] Add `text_embedding` task type to Google AI Studio docs (#109307) --- docs/reference/inference/put-inference.asciidoc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/reference/inference/put-inference.asciidoc b/docs/reference/inference/put-inference.asciidoc index f805bc0cc92f7..e7d66e930e81f 100644 --- a/docs/reference/inference/put-inference.asciidoc +++ b/docs/reference/inference/put-inference.asciidoc @@ -74,7 +74,7 @@ Available services: * `cohere`: specify the `completion`, `text_embedding` or the `rerank` task type to use the Cohere service. * `elasticsearch`: specify the `text_embedding` task type to use the E5 built-in model or text embedding models uploaded by Eland. * `elser`: specify the `sparse_embedding` task type to use the ELSER service. -* `googleaistudio`: specify the `completion` task to use the Google AI Studio service. +* `googleaistudio`: specify the `completion` or `text_embeddig` task to use the Google AI Studio service. * `hugging_face`: specify the `text_embedding` task type to use the Hugging Face service. * `openai`: specify the `completion` or `text_embedding` task type to use the OpenAI service. From 19eabb928dffb232b914926f05cb8eac14a88997 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 5 Jun 2024 15:07:01 +0000 Subject: [PATCH 09/30] Bump versions after 8.14.0 release --- .buildkite/pipelines/intake.yml | 2 +- .buildkite/pipelines/periodic-packaging.yml | 12 ++++++------ .buildkite/pipelines/periodic.yml | 16 ++++++++-------- .ci/bwcVersions | 4 ++-- .ci/snapshotBwcVersions | 3 +-- .../src/main/java/org/elasticsearch/Version.java | 2 +- .../org/elasticsearch/TransportVersions.csv | 1 + .../org/elasticsearch/index/IndexVersions.csv | 1 + 8 files changed, 21 insertions(+), 20 deletions(-) diff --git a/.buildkite/pipelines/intake.yml b/.buildkite/pipelines/intake.yml index 944230377d077..c5b079c39fbc1 100644 --- a/.buildkite/pipelines/intake.yml +++ b/.buildkite/pipelines/intake.yml @@ -56,7 +56,7 @@ steps: timeout_in_minutes: 300 matrix: setup: - BWC_VERSION: ["7.17.22", "8.13.5", "8.14.0", "8.15.0"] + BWC_VERSION: ["7.17.22", "8.14.1", "8.15.0"] agents: provider: gcp image: family/elasticsearch-ubuntu-2004 diff --git a/.buildkite/pipelines/periodic-packaging.yml b/.buildkite/pipelines/periodic-packaging.yml index 5ac361c810627..378a7c5c9c5d2 100644 --- a/.buildkite/pipelines/periodic-packaging.yml +++ b/.buildkite/pipelines/periodic-packaging.yml @@ -529,8 +529,8 @@ steps: env: BWC_VERSION: 8.12.2 - - label: "{{matrix.image}} / 8.13.5 / packaging-tests-upgrade" - command: ./.ci/scripts/packaging-test.sh -Dbwc.checkout.align=true destructiveDistroUpgradeTest.v8.13.5 + - label: "{{matrix.image}} / 8.13.4 / packaging-tests-upgrade" + command: ./.ci/scripts/packaging-test.sh -Dbwc.checkout.align=true destructiveDistroUpgradeTest.v8.13.4 timeout_in_minutes: 300 matrix: setup: @@ -543,10 +543,10 @@ steps: machineType: custom-16-32768 buildDirectory: /dev/shm/bk env: - BWC_VERSION: 8.13.5 + BWC_VERSION: 8.13.4 - - label: "{{matrix.image}} / 8.14.0 / packaging-tests-upgrade" - command: ./.ci/scripts/packaging-test.sh -Dbwc.checkout.align=true destructiveDistroUpgradeTest.v8.14.0 + - label: "{{matrix.image}} / 8.14.1 / packaging-tests-upgrade" + command: ./.ci/scripts/packaging-test.sh -Dbwc.checkout.align=true destructiveDistroUpgradeTest.v8.14.1 timeout_in_minutes: 300 matrix: setup: @@ -559,7 +559,7 @@ steps: machineType: custom-16-32768 buildDirectory: /dev/shm/bk env: - BWC_VERSION: 8.14.0 + BWC_VERSION: 8.14.1 - label: "{{matrix.image}} / 8.15.0 / packaging-tests-upgrade" command: ./.ci/scripts/packaging-test.sh -Dbwc.checkout.align=true destructiveDistroUpgradeTest.v8.15.0 diff --git a/.buildkite/pipelines/periodic.yml b/.buildkite/pipelines/periodic.yml index 7ba46f0f0951c..1726f0f29fa92 100644 --- a/.buildkite/pipelines/periodic.yml +++ b/.buildkite/pipelines/periodic.yml @@ -591,8 +591,8 @@ steps: - signal_reason: agent_stop limit: 3 - - label: 8.13.5 / bwc - command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true v8.13.5#bwcTest + - label: 8.13.4 / bwc + command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true v8.13.4#bwcTest timeout_in_minutes: 300 agents: provider: gcp @@ -601,7 +601,7 @@ steps: buildDirectory: /dev/shm/bk preemptible: true env: - BWC_VERSION: 8.13.5 + BWC_VERSION: 8.13.4 retry: automatic: - exit_status: "-1" @@ -610,8 +610,8 @@ steps: - signal_reason: agent_stop limit: 3 - - label: 8.14.0 / bwc - command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true v8.14.0#bwcTest + - label: 8.14.1 / bwc + command: .ci/scripts/run-gradle.sh -Dbwc.checkout.align=true v8.14.1#bwcTest timeout_in_minutes: 300 agents: provider: gcp @@ -620,7 +620,7 @@ steps: buildDirectory: /dev/shm/bk preemptible: true env: - BWC_VERSION: 8.14.0 + BWC_VERSION: 8.14.1 retry: automatic: - exit_status: "-1" @@ -714,7 +714,7 @@ steps: setup: ES_RUNTIME_JAVA: - openjdk17 - BWC_VERSION: ["7.17.22", "8.13.5", "8.14.0", "8.15.0"] + BWC_VERSION: ["7.17.22", "8.14.1", "8.15.0"] agents: provider: gcp image: family/elasticsearch-ubuntu-2004 @@ -762,7 +762,7 @@ steps: - openjdk21 - openjdk22 - openjdk23 - BWC_VERSION: ["7.17.22", "8.13.5", "8.14.0", "8.15.0"] + BWC_VERSION: ["7.17.22", "8.14.1", "8.15.0"] agents: provider: gcp image: family/elasticsearch-ubuntu-2004 diff --git a/.ci/bwcVersions b/.ci/bwcVersions index b9afdcf23b858..3aa17cc370296 100644 --- a/.ci/bwcVersions +++ b/.ci/bwcVersions @@ -30,6 +30,6 @@ BWC_VERSION: - "8.10.4" - "8.11.4" - "8.12.2" - - "8.13.5" - - "8.14.0" + - "8.13.4" + - "8.14.1" - "8.15.0" diff --git a/.ci/snapshotBwcVersions b/.ci/snapshotBwcVersions index 213e4e93bc81d..f802829f6ec8a 100644 --- a/.ci/snapshotBwcVersions +++ b/.ci/snapshotBwcVersions @@ -1,5 +1,4 @@ BWC_VERSION: - "7.17.22" - - "8.13.5" - - "8.14.0" + - "8.14.1" - "8.15.0" diff --git a/server/src/main/java/org/elasticsearch/Version.java b/server/src/main/java/org/elasticsearch/Version.java index dc161766b7954..06e4a1dd5368d 100644 --- a/server/src/main/java/org/elasticsearch/Version.java +++ b/server/src/main/java/org/elasticsearch/Version.java @@ -174,8 +174,8 @@ public class Version implements VersionId, ToXContentFragment { public static final Version V_8_13_2 = new Version(8_13_02_99); public static final Version V_8_13_3 = new Version(8_13_03_99); public static final Version V_8_13_4 = new Version(8_13_04_99); - public static final Version V_8_13_5 = new Version(8_13_05_99); public static final Version V_8_14_0 = new Version(8_14_00_99); + public static final Version V_8_14_1 = new Version(8_14_01_99); public static final Version V_8_15_0 = new Version(8_15_00_99); public static final Version CURRENT = V_8_15_0; diff --git a/server/src/main/resources/org/elasticsearch/TransportVersions.csv b/server/src/main/resources/org/elasticsearch/TransportVersions.csv index 526f327b91c19..ef0c641bed04f 100644 --- a/server/src/main/resources/org/elasticsearch/TransportVersions.csv +++ b/server/src/main/resources/org/elasticsearch/TransportVersions.csv @@ -120,3 +120,4 @@ 8.13.2,8595000 8.13.3,8595000 8.13.4,8595001 +8.14.0,8636001 diff --git a/server/src/main/resources/org/elasticsearch/index/IndexVersions.csv b/server/src/main/resources/org/elasticsearch/index/IndexVersions.csv index 39f2a701726af..73f60f2e5ea7e 100644 --- a/server/src/main/resources/org/elasticsearch/index/IndexVersions.csv +++ b/server/src/main/resources/org/elasticsearch/index/IndexVersions.csv @@ -120,3 +120,4 @@ 8.13.2,8503000 8.13.3,8503000 8.13.4,8503000 +8.14.0,8505000 From b286921a36cc0f729771987e59ec0112d5597668 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 5 Jun 2024 15:07:52 +0000 Subject: [PATCH 10/30] Prune changelogs after 8.14.0 release --- docs/changelog/103542.yaml | 7 ------- docs/changelog/104711.yaml | 5 ----- docs/changelog/104830.yaml | 5 ----- docs/changelog/104907.yaml | 6 ------ docs/changelog/105063.yaml | 5 ----- docs/changelog/105067.yaml | 5 ----- docs/changelog/105168.yaml | 5 ----- docs/changelog/105360.yaml | 6 ------ docs/changelog/105393.yaml | 5 ----- docs/changelog/105421.yaml | 5 ----- docs/changelog/105439.yaml | 6 ------ docs/changelog/105449.yaml | 6 ------ docs/changelog/105454.yaml | 5 ----- docs/changelog/105470.yaml | 5 ----- docs/changelog/105477.yaml | 6 ------ docs/changelog/105501.yaml | 5 ----- docs/changelog/105517.yaml | 5 ----- docs/changelog/105617.yaml | 5 ----- docs/changelog/105622.yaml | 5 ----- docs/changelog/105629.yaml | 5 ----- docs/changelog/105636.yaml | 5 ----- docs/changelog/105660.yaml | 5 ----- docs/changelog/105670.yaml | 5 ----- docs/changelog/105674.yaml | 6 ------ docs/changelog/105689.yaml | 6 ------ docs/changelog/105693.yaml | 6 ------ docs/changelog/105709.yaml | 5 ----- docs/changelog/105714.yaml | 5 ----- docs/changelog/105717.yaml | 5 ----- docs/changelog/105745.yaml | 6 ------ docs/changelog/105757.yaml | 5 ----- docs/changelog/105768.yaml | 5 ----- docs/changelog/105779.yaml | 5 ----- docs/changelog/105781.yaml | 5 ----- docs/changelog/105791.yaml | 5 ----- docs/changelog/105797.yaml | 5 ----- docs/changelog/105847.yaml | 5 ----- docs/changelog/105860.yaml | 5 ----- docs/changelog/105893.yaml | 5 ----- docs/changelog/105894.yaml | 5 ----- docs/changelog/105985.yaml | 5 ----- docs/changelog/106031.yaml | 13 ------------- docs/changelog/106036.yaml | 5 ----- docs/changelog/106053.yaml | 5 ----- docs/changelog/106063.yaml | 5 ----- docs/changelog/106065.yaml | 6 ------ docs/changelog/106068.yaml | 21 --------------------- docs/changelog/106094.yaml | 5 ----- docs/changelog/106102.yaml | 5 ----- docs/changelog/106133.yaml | 19 ------------------- docs/changelog/106150.yaml | 5 ----- docs/changelog/106171.yaml | 6 ------ docs/changelog/106172.yaml | 5 ----- docs/changelog/106186.yaml | 6 ------ docs/changelog/106189.yaml | 6 ------ docs/changelog/106243.yaml | 5 ----- docs/changelog/106244.yaml | 5 ----- docs/changelog/106259.yaml | 5 ----- docs/changelog/106285.yaml | 5 ----- docs/changelog/106306.yaml | 6 ------ docs/changelog/106315.yaml | 5 ----- docs/changelog/106327.yaml | 5 ----- docs/changelog/106338.yaml | 6 ------ docs/changelog/106361.yaml | 5 ----- docs/changelog/106373.yaml | 5 ----- docs/changelog/106377.yaml | 5 ----- docs/changelog/106378.yaml | 5 ----- docs/changelog/106381.yaml | 5 ----- docs/changelog/106396.yaml | 6 ------ docs/changelog/106413.yaml | 6 ------ docs/changelog/106429.yaml | 5 ----- docs/changelog/106435.yaml | 6 ------ docs/changelog/106472.yaml | 6 ------ docs/changelog/106503.yaml | 5 ----- docs/changelog/106511.yaml | 5 ----- docs/changelog/106514.yaml | 6 ------ docs/changelog/106516.yaml | 5 ----- docs/changelog/106526.yaml | 5 ----- docs/changelog/106531.yaml | 5 ----- docs/changelog/106563.yaml | 5 ----- docs/changelog/106575.yaml | 5 ----- docs/changelog/106579.yaml | 5 ----- docs/changelog/106638.yaml | 5 ----- docs/changelog/106654.yaml | 6 ------ docs/changelog/106685.yaml | 5 ----- docs/changelog/106691.yaml | 6 ------ docs/changelog/106708.yaml | 6 ------ docs/changelog/106714.yaml | 5 ----- docs/changelog/106720.yaml | 5 ----- docs/changelog/106731.yaml | 5 ----- docs/changelog/106745.yaml | 5 ----- docs/changelog/106767.yaml | 5 ----- docs/changelog/106796.yaml | 5 ----- docs/changelog/106808.yaml | 5 ----- docs/changelog/106810.yaml | 5 ----- docs/changelog/106836.yaml | 5 ----- docs/changelog/106840.yaml | 6 ------ docs/changelog/106851.yaml | 5 ----- docs/changelog/106852.yaml | 6 ------ docs/changelog/106860.yaml | 5 ----- docs/changelog/106862.yaml | 5 ----- docs/changelog/106866.yaml | 5 ----- docs/changelog/106889.yaml | 5 ----- docs/changelog/106899.yaml | 6 ------ docs/changelog/106919.yaml | 6 ------ docs/changelog/106934.yaml | 5 ----- docs/changelog/106952.yaml | 5 ----- docs/changelog/106989.yaml | 7 ------- docs/changelog/107007.yaml | 5 ----- docs/changelog/107016.yaml | 5 ----- docs/changelog/107038.yaml | 5 ----- docs/changelog/107041.yaml | 6 ------ docs/changelog/107046.yaml | 6 ------ docs/changelog/107050.yaml | 5 ----- docs/changelog/107107.yaml | 5 ----- docs/changelog/107121.yaml | 6 ------ docs/changelog/107129.yaml | 5 ----- docs/changelog/107158.yaml | 5 ----- docs/changelog/107178.yaml | 5 ----- docs/changelog/107183.yaml | 5 ----- docs/changelog/107196.yaml | 5 ----- docs/changelog/107224.yaml | 6 ------ docs/changelog/107232.yaml | 6 ------ docs/changelog/107242.yaml | 5 ----- docs/changelog/107253.yaml | 5 ----- docs/changelog/107272.yaml | 5 ----- docs/changelog/107287.yaml | 6 ------ docs/changelog/107291.yaml | 6 ------ docs/changelog/107303.yaml | 5 ----- docs/changelog/107312.yaml | 5 ----- docs/changelog/107334.yaml | 5 ----- docs/changelog/107358.yaml | 6 ------ docs/changelog/107370.yaml | 5 ----- docs/changelog/107377.yaml | 13 ------------- docs/changelog/107383.yaml | 6 ------ docs/changelog/107411.yaml | 5 ----- docs/changelog/107414.yaml | 7 ------- docs/changelog/107447.yaml | 5 ----- docs/changelog/107449.yaml | 5 ----- docs/changelog/107467.yaml | 5 ----- docs/changelog/107494.yaml | 6 ------ docs/changelog/107517.yaml | 5 ----- docs/changelog/107533.yaml | 5 ----- docs/changelog/107551.yaml | 5 ----- docs/changelog/107577.yaml | 6 ------ docs/changelog/107578.yaml | 5 ----- docs/changelog/107598.yaml | 5 ----- docs/changelog/107655.yaml | 6 ------ docs/changelog/107678.yaml | 6 ------ docs/changelog/107743.yaml | 5 ----- docs/changelog/107828.yaml | 6 ------ docs/changelog/107865.yaml | 5 ----- docs/changelog/107891.yaml | 6 ------ docs/changelog/107902.yaml | 5 ----- docs/changelog/107969.yaml | 5 ----- docs/changelog/108007.yaml | 5 ----- docs/changelog/108031.yaml | 6 ------ docs/changelog/108041.yaml | 7 ------- docs/changelog/108101.yaml | 6 ------ docs/changelog/108238.yaml | 6 ------ docs/changelog/108257.yaml | 5 ----- docs/changelog/108365.yaml | 5 ----- docs/changelog/108431.yaml | 5 ----- docs/changelog/108518.yaml | 5 ----- docs/changelog/108562.yaml | 6 ------ docs/changelog/108571.yaml | 5 ----- docs/changelog/108600.yaml | 15 --------------- docs/changelog/108654.yaml | 5 ----- docs/changelog/108736.yaml | 5 ----- docs/changelog/108802.yaml | 5 ----- docs/changelog/108834.yaml | 6 ------ docs/changelog/108854.yaml | 5 ----- docs/changelog/108867.yaml | 6 ------ docs/changelog/108900.yaml | 6 ------ docs/changelog/109020.yaml | 6 ------ docs/changelog/109034.yaml | 5 ----- docs/changelog/109048.yaml | 6 ------ docs/changelog/109097.yaml | 6 ------ docs/changelog/109148.yaml | 6 ------ docs/changelog/109173.yaml | 5 ----- docs/changelog/97072.yaml | 5 ----- docs/changelog/97561.yaml | 5 ----- docs/changelog/99048.yaml | 6 ------ 183 files changed, 1033 deletions(-) delete mode 100644 docs/changelog/103542.yaml delete mode 100644 docs/changelog/104711.yaml delete mode 100644 docs/changelog/104830.yaml delete mode 100644 docs/changelog/104907.yaml delete mode 100644 docs/changelog/105063.yaml delete mode 100644 docs/changelog/105067.yaml delete mode 100644 docs/changelog/105168.yaml delete mode 100644 docs/changelog/105360.yaml delete mode 100644 docs/changelog/105393.yaml delete mode 100644 docs/changelog/105421.yaml delete mode 100644 docs/changelog/105439.yaml delete mode 100644 docs/changelog/105449.yaml delete mode 100644 docs/changelog/105454.yaml delete mode 100644 docs/changelog/105470.yaml delete mode 100644 docs/changelog/105477.yaml delete mode 100644 docs/changelog/105501.yaml delete mode 100644 docs/changelog/105517.yaml delete mode 100644 docs/changelog/105617.yaml delete mode 100644 docs/changelog/105622.yaml delete mode 100644 docs/changelog/105629.yaml delete mode 100644 docs/changelog/105636.yaml delete mode 100644 docs/changelog/105660.yaml delete mode 100644 docs/changelog/105670.yaml delete mode 100644 docs/changelog/105674.yaml delete mode 100644 docs/changelog/105689.yaml delete mode 100644 docs/changelog/105693.yaml delete mode 100644 docs/changelog/105709.yaml delete mode 100644 docs/changelog/105714.yaml delete mode 100644 docs/changelog/105717.yaml delete mode 100644 docs/changelog/105745.yaml delete mode 100644 docs/changelog/105757.yaml delete mode 100644 docs/changelog/105768.yaml delete mode 100644 docs/changelog/105779.yaml delete mode 100644 docs/changelog/105781.yaml delete mode 100644 docs/changelog/105791.yaml delete mode 100644 docs/changelog/105797.yaml delete mode 100644 docs/changelog/105847.yaml delete mode 100644 docs/changelog/105860.yaml delete mode 100644 docs/changelog/105893.yaml delete mode 100644 docs/changelog/105894.yaml delete mode 100644 docs/changelog/105985.yaml delete mode 100644 docs/changelog/106031.yaml delete mode 100644 docs/changelog/106036.yaml delete mode 100644 docs/changelog/106053.yaml delete mode 100644 docs/changelog/106063.yaml delete mode 100644 docs/changelog/106065.yaml delete mode 100644 docs/changelog/106068.yaml delete mode 100644 docs/changelog/106094.yaml delete mode 100644 docs/changelog/106102.yaml delete mode 100644 docs/changelog/106133.yaml delete mode 100644 docs/changelog/106150.yaml delete mode 100644 docs/changelog/106171.yaml delete mode 100644 docs/changelog/106172.yaml delete mode 100644 docs/changelog/106186.yaml delete mode 100644 docs/changelog/106189.yaml delete mode 100644 docs/changelog/106243.yaml delete mode 100644 docs/changelog/106244.yaml delete mode 100644 docs/changelog/106259.yaml delete mode 100644 docs/changelog/106285.yaml delete mode 100644 docs/changelog/106306.yaml delete mode 100644 docs/changelog/106315.yaml delete mode 100644 docs/changelog/106327.yaml delete mode 100644 docs/changelog/106338.yaml delete mode 100644 docs/changelog/106361.yaml delete mode 100644 docs/changelog/106373.yaml delete mode 100644 docs/changelog/106377.yaml delete mode 100644 docs/changelog/106378.yaml delete mode 100644 docs/changelog/106381.yaml delete mode 100644 docs/changelog/106396.yaml delete mode 100644 docs/changelog/106413.yaml delete mode 100644 docs/changelog/106429.yaml delete mode 100644 docs/changelog/106435.yaml delete mode 100644 docs/changelog/106472.yaml delete mode 100644 docs/changelog/106503.yaml delete mode 100644 docs/changelog/106511.yaml delete mode 100644 docs/changelog/106514.yaml delete mode 100644 docs/changelog/106516.yaml delete mode 100644 docs/changelog/106526.yaml delete mode 100644 docs/changelog/106531.yaml delete mode 100644 docs/changelog/106563.yaml delete mode 100644 docs/changelog/106575.yaml delete mode 100644 docs/changelog/106579.yaml delete mode 100644 docs/changelog/106638.yaml delete mode 100644 docs/changelog/106654.yaml delete mode 100644 docs/changelog/106685.yaml delete mode 100644 docs/changelog/106691.yaml delete mode 100644 docs/changelog/106708.yaml delete mode 100644 docs/changelog/106714.yaml delete mode 100644 docs/changelog/106720.yaml delete mode 100644 docs/changelog/106731.yaml delete mode 100644 docs/changelog/106745.yaml delete mode 100644 docs/changelog/106767.yaml delete mode 100644 docs/changelog/106796.yaml delete mode 100644 docs/changelog/106808.yaml delete mode 100644 docs/changelog/106810.yaml delete mode 100644 docs/changelog/106836.yaml delete mode 100644 docs/changelog/106840.yaml delete mode 100644 docs/changelog/106851.yaml delete mode 100644 docs/changelog/106852.yaml delete mode 100644 docs/changelog/106860.yaml delete mode 100644 docs/changelog/106862.yaml delete mode 100644 docs/changelog/106866.yaml delete mode 100644 docs/changelog/106889.yaml delete mode 100644 docs/changelog/106899.yaml delete mode 100644 docs/changelog/106919.yaml delete mode 100644 docs/changelog/106934.yaml delete mode 100644 docs/changelog/106952.yaml delete mode 100644 docs/changelog/106989.yaml delete mode 100644 docs/changelog/107007.yaml delete mode 100644 docs/changelog/107016.yaml delete mode 100644 docs/changelog/107038.yaml delete mode 100644 docs/changelog/107041.yaml delete mode 100644 docs/changelog/107046.yaml delete mode 100644 docs/changelog/107050.yaml delete mode 100644 docs/changelog/107107.yaml delete mode 100644 docs/changelog/107121.yaml delete mode 100644 docs/changelog/107129.yaml delete mode 100644 docs/changelog/107158.yaml delete mode 100644 docs/changelog/107178.yaml delete mode 100644 docs/changelog/107183.yaml delete mode 100644 docs/changelog/107196.yaml delete mode 100644 docs/changelog/107224.yaml delete mode 100644 docs/changelog/107232.yaml delete mode 100644 docs/changelog/107242.yaml delete mode 100644 docs/changelog/107253.yaml delete mode 100644 docs/changelog/107272.yaml delete mode 100644 docs/changelog/107287.yaml delete mode 100644 docs/changelog/107291.yaml delete mode 100644 docs/changelog/107303.yaml delete mode 100644 docs/changelog/107312.yaml delete mode 100644 docs/changelog/107334.yaml delete mode 100644 docs/changelog/107358.yaml delete mode 100644 docs/changelog/107370.yaml delete mode 100644 docs/changelog/107377.yaml delete mode 100644 docs/changelog/107383.yaml delete mode 100644 docs/changelog/107411.yaml delete mode 100644 docs/changelog/107414.yaml delete mode 100644 docs/changelog/107447.yaml delete mode 100644 docs/changelog/107449.yaml delete mode 100644 docs/changelog/107467.yaml delete mode 100644 docs/changelog/107494.yaml delete mode 100644 docs/changelog/107517.yaml delete mode 100644 docs/changelog/107533.yaml delete mode 100644 docs/changelog/107551.yaml delete mode 100644 docs/changelog/107577.yaml delete mode 100644 docs/changelog/107578.yaml delete mode 100644 docs/changelog/107598.yaml delete mode 100644 docs/changelog/107655.yaml delete mode 100644 docs/changelog/107678.yaml delete mode 100644 docs/changelog/107743.yaml delete mode 100644 docs/changelog/107828.yaml delete mode 100644 docs/changelog/107865.yaml delete mode 100644 docs/changelog/107891.yaml delete mode 100644 docs/changelog/107902.yaml delete mode 100644 docs/changelog/107969.yaml delete mode 100644 docs/changelog/108007.yaml delete mode 100644 docs/changelog/108031.yaml delete mode 100644 docs/changelog/108041.yaml delete mode 100644 docs/changelog/108101.yaml delete mode 100644 docs/changelog/108238.yaml delete mode 100644 docs/changelog/108257.yaml delete mode 100644 docs/changelog/108365.yaml delete mode 100644 docs/changelog/108431.yaml delete mode 100644 docs/changelog/108518.yaml delete mode 100644 docs/changelog/108562.yaml delete mode 100644 docs/changelog/108571.yaml delete mode 100644 docs/changelog/108600.yaml delete mode 100644 docs/changelog/108654.yaml delete mode 100644 docs/changelog/108736.yaml delete mode 100644 docs/changelog/108802.yaml delete mode 100644 docs/changelog/108834.yaml delete mode 100644 docs/changelog/108854.yaml delete mode 100644 docs/changelog/108867.yaml delete mode 100644 docs/changelog/108900.yaml delete mode 100644 docs/changelog/109020.yaml delete mode 100644 docs/changelog/109034.yaml delete mode 100644 docs/changelog/109048.yaml delete mode 100644 docs/changelog/109097.yaml delete mode 100644 docs/changelog/109148.yaml delete mode 100644 docs/changelog/109173.yaml delete mode 100644 docs/changelog/97072.yaml delete mode 100644 docs/changelog/97561.yaml delete mode 100644 docs/changelog/99048.yaml diff --git a/docs/changelog/103542.yaml b/docs/changelog/103542.yaml deleted file mode 100644 index 74e713eb2f606..0000000000000 --- a/docs/changelog/103542.yaml +++ /dev/null @@ -1,7 +0,0 @@ -pr: 103542 -summary: Flatten object mappings when subobjects is false -area: Mapping -type: feature -issues: - - 99860 - - 103497 diff --git a/docs/changelog/104711.yaml b/docs/changelog/104711.yaml deleted file mode 100644 index f0f9bf7f10e45..0000000000000 --- a/docs/changelog/104711.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 104711 -summary: "Fixing NPE when requesting [_none_] for `stored_fields`" -area: Search -type: bug -issues: [] diff --git a/docs/changelog/104830.yaml b/docs/changelog/104830.yaml deleted file mode 100644 index c056f3d618b75..0000000000000 --- a/docs/changelog/104830.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 104830 -summary: All new `shard_seed` parameter for `random_sampler` agg -area: Aggregations -type: enhancement -issues: [] diff --git a/docs/changelog/104907.yaml b/docs/changelog/104907.yaml deleted file mode 100644 index 0d8592ae29526..0000000000000 --- a/docs/changelog/104907.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 104907 -summary: Support ST_INTERSECTS between geometry column and other geometry or string -area: "ES|QL" -type: enhancement -issues: -- 104874 diff --git a/docs/changelog/105063.yaml b/docs/changelog/105063.yaml deleted file mode 100644 index 668f8ac104493..0000000000000 --- a/docs/changelog/105063.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105063 -summary: Infrastructure for metering the update requests -area: Infra/Metrics -type: enhancement -issues: [] diff --git a/docs/changelog/105067.yaml b/docs/changelog/105067.yaml deleted file mode 100644 index 562e8271f5502..0000000000000 --- a/docs/changelog/105067.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105067 -summary: "ESQL: Use faster field caps" -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/105168.yaml b/docs/changelog/105168.yaml deleted file mode 100644 index 0f3792b832f55..0000000000000 --- a/docs/changelog/105168.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105168 -summary: Add ?master_timeout query parameter to ccr apis -area: CCR -type: bug -issues: [] diff --git a/docs/changelog/105360.yaml b/docs/changelog/105360.yaml deleted file mode 100644 index 41a7ea24e5500..0000000000000 --- a/docs/changelog/105360.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 105360 -summary: Cross-cluster painless/execute actions should check permissions only on target - remote cluster -area: Search -type: bug -issues: [] diff --git a/docs/changelog/105393.yaml b/docs/changelog/105393.yaml deleted file mode 100644 index 4a4cc299b7bd7..0000000000000 --- a/docs/changelog/105393.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105393 -summary: Adding support for hex-encoded byte vectors on knn-search -area: Vector Search -type: feature -issues: [] diff --git a/docs/changelog/105421.yaml b/docs/changelog/105421.yaml deleted file mode 100644 index 2ff9ef008c803..0000000000000 --- a/docs/changelog/105421.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105421 -summary: "ESQL: Add timers to many status results" -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/105439.yaml b/docs/changelog/105439.yaml deleted file mode 100644 index 45bbede469542..0000000000000 --- a/docs/changelog/105439.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 105439 -summary: Support Profile Activate with JWTs with client authn -area: Authentication -type: enhancement -issues: - - 105342 diff --git a/docs/changelog/105449.yaml b/docs/changelog/105449.yaml deleted file mode 100644 index b565d6c782bd9..0000000000000 --- a/docs/changelog/105449.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 105449 -summary: Don't stop checking if the `HealthNode` persistent task is present -area: Health -type: bug -issues: - - 98926 diff --git a/docs/changelog/105454.yaml b/docs/changelog/105454.yaml deleted file mode 100644 index fc814a343c46b..0000000000000 --- a/docs/changelog/105454.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105454 -summary: "ESQL: Sum of constants" -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/105470.yaml b/docs/changelog/105470.yaml deleted file mode 100644 index 56425de6c88e4..0000000000000 --- a/docs/changelog/105470.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105470 -summary: Add retrievers using the parser-only approach -area: Ranking -type: enhancement -issues: [] diff --git a/docs/changelog/105477.yaml b/docs/changelog/105477.yaml deleted file mode 100644 index f994d38a3f671..0000000000000 --- a/docs/changelog/105477.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 105477 -summary: "ESQL: Introduce expression validation phase" -area: ES|QL -type: enhancement -issues: - - 105425 diff --git a/docs/changelog/105501.yaml b/docs/changelog/105501.yaml deleted file mode 100644 index 2e5e375764640..0000000000000 --- a/docs/changelog/105501.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105501 -summary: Support non-keyword dimensions as routing fields in TSDB -area: TSDB -type: enhancement -issues: [] diff --git a/docs/changelog/105517.yaml b/docs/changelog/105517.yaml deleted file mode 100644 index 7cca86d1cff6e..0000000000000 --- a/docs/changelog/105517.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105517 -summary: Upgrade to Netty 4.1.107 -area: Network -type: upgrade -issues: [] diff --git a/docs/changelog/105617.yaml b/docs/changelog/105617.yaml deleted file mode 100644 index 7fd8203336fff..0000000000000 --- a/docs/changelog/105617.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105617 -summary: Fix HTTP corner-case response leaks -area: Network -type: bug -issues: [] diff --git a/docs/changelog/105622.yaml b/docs/changelog/105622.yaml deleted file mode 100644 index 33093f5ffceb5..0000000000000 --- a/docs/changelog/105622.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105622 -summary: Distinguish different snapshot failures by log level -area: Snapshot/Restore -type: enhancement -issues: [] diff --git a/docs/changelog/105629.yaml b/docs/changelog/105629.yaml deleted file mode 100644 index 00fa73a759558..0000000000000 --- a/docs/changelog/105629.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105629 -summary: Show owner `realm_type` for returned API keys -area: Security -type: enhancement -issues: [] diff --git a/docs/changelog/105636.yaml b/docs/changelog/105636.yaml deleted file mode 100644 index 01f27199771d4..0000000000000 --- a/docs/changelog/105636.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105636 -summary: Flip dynamic mapping condition when create tsid -area: TSDB -type: bug -issues: [] diff --git a/docs/changelog/105660.yaml b/docs/changelog/105660.yaml deleted file mode 100644 index 1b30a25417906..0000000000000 --- a/docs/changelog/105660.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105660 -summary: "Text structure endpoints to determine the structure of a list of messages and of an indexed field" -area: Machine Learning -type: feature -issues: [] diff --git a/docs/changelog/105670.yaml b/docs/changelog/105670.yaml deleted file mode 100644 index 234f4b6af5a73..0000000000000 --- a/docs/changelog/105670.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105670 -summary: "Painless: Apply true regex limit factor with FIND and MATCH operation" -area: Infra/Scripting -type: bug -issues: [] diff --git a/docs/changelog/105674.yaml b/docs/changelog/105674.yaml deleted file mode 100644 index 7b8d04f4687a3..0000000000000 --- a/docs/changelog/105674.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 105674 -summary: Health monitor concurrency fixes -area: Health -type: bug -issues: - - 105065 diff --git a/docs/changelog/105689.yaml b/docs/changelog/105689.yaml deleted file mode 100644 index e76281f1b2fc7..0000000000000 --- a/docs/changelog/105689.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 105689 -summary: Fix `uri_parts` processor behaviour for missing extensions -area: Ingest Node -type: bug -issues: - - 105612 diff --git a/docs/changelog/105693.yaml b/docs/changelog/105693.yaml deleted file mode 100644 index 8d14d611e19a3..0000000000000 --- a/docs/changelog/105693.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 105693 -summary: Fix error 500 on invalid `ParentIdQuery` -area: Search -type: bug -issues: - - 105366 diff --git a/docs/changelog/105709.yaml b/docs/changelog/105709.yaml deleted file mode 100644 index 41b6e749d9270..0000000000000 --- a/docs/changelog/105709.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105709 -summary: Apply stricter Document Level Security (DLS) rules for the validate query API with the rewrite parameter. -area: Security -type: bug -issues: [] diff --git a/docs/changelog/105714.yaml b/docs/changelog/105714.yaml deleted file mode 100644 index b6ab5e128c72c..0000000000000 --- a/docs/changelog/105714.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105714 -summary: Apply stricter Document Level Security (DLS) rules for terms aggregations when min_doc_count is set to 0. -area: "Aggregations" -type: bug -issues: [] diff --git a/docs/changelog/105717.yaml b/docs/changelog/105717.yaml deleted file mode 100644 index c75bc4fe65798..0000000000000 --- a/docs/changelog/105717.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105717 -summary: Upgrade jna to 5.12.1 -area: Infra/Core -type: upgrade -issues: [] diff --git a/docs/changelog/105745.yaml b/docs/changelog/105745.yaml deleted file mode 100644 index e9a61f692d94d..0000000000000 --- a/docs/changelog/105745.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 105745 -summary: Fix `noop_update_total` is not being updated when using the `_bulk` -area: CRUD -type: bug -issues: - - 105742 diff --git a/docs/changelog/105757.yaml b/docs/changelog/105757.yaml deleted file mode 100644 index f11aed2b2d96b..0000000000000 --- a/docs/changelog/105757.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105757 -summary: Add pluggable `BuildVersion` in `NodeMetadata` -area: Infra/Core -type: enhancement -issues: [] diff --git a/docs/changelog/105768.yaml b/docs/changelog/105768.yaml deleted file mode 100644 index 49d7f1f15c453..0000000000000 --- a/docs/changelog/105768.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105768 -summary: Add two new OGC functions ST_X and ST_Y -area: "ES|QL" -type: enhancement -issues: [] diff --git a/docs/changelog/105779.yaml b/docs/changelog/105779.yaml deleted file mode 100644 index 3699ca0e2f246..0000000000000 --- a/docs/changelog/105779.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105779 -summary: "[Profiling] Speed up serialization of flamegraph" -area: Application -type: enhancement -issues: [] diff --git a/docs/changelog/105781.yaml b/docs/changelog/105781.yaml deleted file mode 100644 index c3ae7f0035904..0000000000000 --- a/docs/changelog/105781.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105781 -summary: CCS with `minimize_roundtrips` performs incremental merges of each `SearchResponse` -area: Search -type: enhancement -issues: [] diff --git a/docs/changelog/105791.yaml b/docs/changelog/105791.yaml deleted file mode 100644 index f18b5e6b8fdd7..0000000000000 --- a/docs/changelog/105791.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105791 -summary: "Bugfix: Disable eager loading `BitSetFilterCache` on Indexing Nodes" -area: Search -type: bug -issues: [] diff --git a/docs/changelog/105797.yaml b/docs/changelog/105797.yaml deleted file mode 100644 index 7c832e2e5e63c..0000000000000 --- a/docs/changelog/105797.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105797 -summary: Enable retrying on 500 error response from Cohere text embedding API -area: Machine Learning -type: enhancement -issues: [] diff --git a/docs/changelog/105847.yaml b/docs/changelog/105847.yaml deleted file mode 100644 index a731395bc9a81..0000000000000 --- a/docs/changelog/105847.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105847 -summary: (API+) CAT Nodes alias for shard header to match CAT Allocation -area: Stats -type: enhancement -issues: [] diff --git a/docs/changelog/105860.yaml b/docs/changelog/105860.yaml deleted file mode 100644 index 71f3544a02a1f..0000000000000 --- a/docs/changelog/105860.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105860 -summary: "ESQL: Re-enable logical dependency check" -area: ES|QL -type: bug -issues: [] diff --git a/docs/changelog/105893.yaml b/docs/changelog/105893.yaml deleted file mode 100644 index c88736f5dda3d..0000000000000 --- a/docs/changelog/105893.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105893 -summary: Specialize serialization for `ArrayVectors` -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/105894.yaml b/docs/changelog/105894.yaml deleted file mode 100644 index a1a99eaa6259b..0000000000000 --- a/docs/changelog/105894.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105894 -summary: Add allocation stats -area: Allocation -type: enhancement -issues: [] diff --git a/docs/changelog/105985.yaml b/docs/changelog/105985.yaml deleted file mode 100644 index 2f2a8c1394070..0000000000000 --- a/docs/changelog/105985.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 105985 -summary: Wait forever for `IndexTemplateRegistry` asset installation -area: Indices APIs -type: enhancement -issues: [] diff --git a/docs/changelog/106031.yaml b/docs/changelog/106031.yaml deleted file mode 100644 index d0a0303e74164..0000000000000 --- a/docs/changelog/106031.yaml +++ /dev/null @@ -1,13 +0,0 @@ -pr: 106031 -summary: Deprecate allowing `fields` in scenarios where it is ignored -area: Mapping -type: deprecation -issues: [] -deprecation: - title: Deprecate allowing `fields` in scenarios where it is ignored - area: Mapping - details: The following mapped types have always ignored `fields` when using multi-fields. - This deprecation makes this clearer and we will completely disallow `fields` for - these mapped types in the future. - impact: "In the future, `join`, `aggregate_metric_double`, and `constant_keyword`,\ - \ will all disallow supplying `fields` as a parameter in the mapping." diff --git a/docs/changelog/106036.yaml b/docs/changelog/106036.yaml deleted file mode 100644 index 7b129c6c0a7a3..0000000000000 --- a/docs/changelog/106036.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106036 -summary: Add status for enrich operator -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/106053.yaml b/docs/changelog/106053.yaml deleted file mode 100644 index 72cfe0207795d..0000000000000 --- a/docs/changelog/106053.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106053 -summary: Speed up serialization of `BytesRefArray` -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/106063.yaml b/docs/changelog/106063.yaml deleted file mode 100644 index 57c05370a943f..0000000000000 --- a/docs/changelog/106063.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106063 -summary: Consider `ShardRouting` roles when calculating shard copies in shutdown status -area: Infra/Node Lifecycle -type: bug -issues: [] diff --git a/docs/changelog/106065.yaml b/docs/changelog/106065.yaml deleted file mode 100644 index b87f4848fb574..0000000000000 --- a/docs/changelog/106065.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106065 -summary: "ESQL: Values aggregation function" -area: ES|QL -type: feature -issues: - - 103600 diff --git a/docs/changelog/106068.yaml b/docs/changelog/106068.yaml deleted file mode 100644 index 51bcc2bcf98b0..0000000000000 --- a/docs/changelog/106068.yaml +++ /dev/null @@ -1,21 +0,0 @@ -pr: 106068 -summary: Add `modelId` and `modelText` to `KnnVectorQueryBuilder` -area: Search -type: enhancement -issues: [] -highlight: - title: Query phase KNN now supports query_vector_builder - body: |- - It is now possible to pass `model_text` and `model_id` within a `knn` query - in the [query DSL](https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-knn-query.html) to convert a text query into a dense vector and run the - nearest neighbor query on it, instead of requiring the dense vector to be - directly passed (within the `query_vector` parameter). Similar to the - [top-level knn query](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) (executed in the DFS phase), it is possible to supply - a `query_vector_builder` object containing a `text_embedding` object with - `model_text` (the text query to be converted into a dense vector) and - `model_id` (the identifier of a deployed model responsible for transforming - the text query into a dense vector). Note that an embedding model with the - referenced `model_id` needs to be [deployed on a ML node](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html). - in the cluster. - notable: true - diff --git a/docs/changelog/106094.yaml b/docs/changelog/106094.yaml deleted file mode 100644 index 4341164222338..0000000000000 --- a/docs/changelog/106094.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106094 -summary: "ESQL: Support partially folding CASE" -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/106102.yaml b/docs/changelog/106102.yaml deleted file mode 100644 index b7c13514f6715..0000000000000 --- a/docs/changelog/106102.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106102 -summary: Specialize serialization of array blocks -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/106133.yaml b/docs/changelog/106133.yaml deleted file mode 100644 index 6dd7bf6cea086..0000000000000 --- a/docs/changelog/106133.yaml +++ /dev/null @@ -1,19 +0,0 @@ -pr: 106133 -summary: Add a SIMD (Neon) optimised vector distance function for int8 -area: Search -type: enhancement -issues: [] -highlight: - title: A SIMD (Neon) optimised vector distance function for merging int8 Scalar Quantized vectors has been added - body: |- - An optimised int8 vector distance implementation for aarch64 has been added. - This implementation is currently only used during merging. - The vector distance implementation outperforms Lucene's Pamana Vector - implementation for binary comparisons by approx 5x (depending on the number - of dimensions). It does so by means of SIMD (Neon) intrinsics compiled into a - separate native library and link by Panama's FFI. Comparisons are performed on - off-heap mmap'ed vector data. - Macro benchmarks, SO_Dense_Vector with scalar quantization enabled, shows - significant improvements in merge times, approximately 3 times faster. - notable: true - diff --git a/docs/changelog/106150.yaml b/docs/changelog/106150.yaml deleted file mode 100644 index 05bd8b06987c6..0000000000000 --- a/docs/changelog/106150.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106150 -summary: Use correct system index bulk executor -area: CRUD -type: bug -issues: [] diff --git a/docs/changelog/106171.yaml b/docs/changelog/106171.yaml deleted file mode 100644 index 9daf1b9acd994..0000000000000 --- a/docs/changelog/106171.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106171 -summary: Do not log error on node restart when the transform is already failed -area: Transform -type: enhancement -issues: - - 106168 diff --git a/docs/changelog/106172.yaml b/docs/changelog/106172.yaml deleted file mode 100644 index 80d80b9d7f299..0000000000000 --- a/docs/changelog/106172.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106172 -summary: "[Profiling] Allow to override index settings" -area: Application -type: enhancement -issues: [] diff --git a/docs/changelog/106186.yaml b/docs/changelog/106186.yaml deleted file mode 100644 index 097639dd28f1b..0000000000000 --- a/docs/changelog/106186.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106186 -summary: Expand support for ENRICH to full set supported by ES ingest processors -area: ES|QL -type: enhancement -issues: - - 106162 diff --git a/docs/changelog/106189.yaml b/docs/changelog/106189.yaml deleted file mode 100644 index ec485f0e60efb..0000000000000 --- a/docs/changelog/106189.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106189 -summary: Fix numeric sorts in `_cat/nodes` -area: CAT APIs -type: bug -issues: - - 48070 diff --git a/docs/changelog/106243.yaml b/docs/changelog/106243.yaml deleted file mode 100644 index 6b02e3f1699d4..0000000000000 --- a/docs/changelog/106243.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106243 -summary: "[Transform] Auto retry Transform start" -area: "Transform" -type: bug -issues: [] diff --git a/docs/changelog/106244.yaml b/docs/changelog/106244.yaml deleted file mode 100644 index fe03f575b9efb..0000000000000 --- a/docs/changelog/106244.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106244 -summary: Support ES|QL requests through the `NodeClient::execute` -area: ES|QL -type: feature -issues: [] diff --git a/docs/changelog/106259.yaml b/docs/changelog/106259.yaml deleted file mode 100644 index d56b5e5a5e379..0000000000000 --- a/docs/changelog/106259.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106259 -summary: Add data stream lifecycle to kibana reporting template -area: Data streams -type: enhancement -issues: [] diff --git a/docs/changelog/106285.yaml b/docs/changelog/106285.yaml deleted file mode 100644 index 37a7e67fe9395..0000000000000 --- a/docs/changelog/106285.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106285 -summary: Add a check for the same feature being declared regular and historical -area: Infra/Core -type: bug -issues: [] diff --git a/docs/changelog/106306.yaml b/docs/changelog/106306.yaml deleted file mode 100644 index 571fe73c31a3e..0000000000000 --- a/docs/changelog/106306.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 99961 -summary: "added fix for inconsistent text trimming in Unified Highlighter" -area: Highlighting -type: bug -issues: - - 101803 diff --git a/docs/changelog/106315.yaml b/docs/changelog/106315.yaml deleted file mode 100644 index 57c41c8024d20..0000000000000 --- a/docs/changelog/106315.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106315 -summary: Updating the tika version to 2.9.1 in the ingest attachment plugin -area: Ingest Node -type: upgrade -issues: [] diff --git a/docs/changelog/106327.yaml b/docs/changelog/106327.yaml deleted file mode 100644 index 2b4b811ece40b..0000000000000 --- a/docs/changelog/106327.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106327 -summary: Serialize big array vectors -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/106338.yaml b/docs/changelog/106338.yaml deleted file mode 100644 index c05826d87a11f..0000000000000 --- a/docs/changelog/106338.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106338 -summary: Text fields are stored by default in TSDB indices -area: TSDB -type: enhancement -issues: - - 97039 diff --git a/docs/changelog/106361.yaml b/docs/changelog/106361.yaml deleted file mode 100644 index a4cd608279c12..0000000000000 --- a/docs/changelog/106361.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106361 -summary: Add a `PriorityQueue` backed by `BigArrays` -area: Aggregations -type: enhancement -issues: [] diff --git a/docs/changelog/106373.yaml b/docs/changelog/106373.yaml deleted file mode 100644 index e838c7b1a660d..0000000000000 --- a/docs/changelog/106373.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106373 -summary: Serialize big array blocks -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/106377.yaml b/docs/changelog/106377.yaml deleted file mode 100644 index 7f0f18d43b440..0000000000000 --- a/docs/changelog/106377.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106377 -summary: Add transport version for search load autoscaling -area: Search -type: enhancement -issues: [] diff --git a/docs/changelog/106378.yaml b/docs/changelog/106378.yaml deleted file mode 100644 index b54760553d184..0000000000000 --- a/docs/changelog/106378.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106378 -summary: Add Cohere rerank to `_inference` service -area: Machine Learning -type: feature -issues: [] diff --git a/docs/changelog/106381.yaml b/docs/changelog/106381.yaml deleted file mode 100644 index 500f6d5416822..0000000000000 --- a/docs/changelog/106381.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106381 -summary: Dedupe terms in terms queries -area: Mapping -type: bug -issues: [] diff --git a/docs/changelog/106396.yaml b/docs/changelog/106396.yaml deleted file mode 100644 index 7aa06566c75e7..0000000000000 --- a/docs/changelog/106396.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106396 -summary: "Check preTags and postTags params for empty values" -area: Highlighting -type: bug -issues: - - 69009 diff --git a/docs/changelog/106413.yaml b/docs/changelog/106413.yaml deleted file mode 100644 index 8e13a839bc41e..0000000000000 --- a/docs/changelog/106413.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106413 -summary: Consolidate permissions checks -area: Transform -type: bug -issues: - - 105794 diff --git a/docs/changelog/106429.yaml b/docs/changelog/106429.yaml deleted file mode 100644 index 7ac524d13909b..0000000000000 --- a/docs/changelog/106429.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106429 -summary: "ESQL: Regex improvements" -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/106435.yaml b/docs/changelog/106435.yaml deleted file mode 100644 index 5bfe0087a93d3..0000000000000 --- a/docs/changelog/106435.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106435 -summary: "ENRICH support for TEXT fields" -area: ES|QL -type: enhancement -issues: - - 105384 diff --git a/docs/changelog/106472.yaml b/docs/changelog/106472.yaml deleted file mode 100644 index 120286c4cd8c7..0000000000000 --- a/docs/changelog/106472.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106472 -summary: "Fix the position of spike, dip and distribution changes bucket when the\ - \ sibling aggregation includes empty buckets" -area: Machine Learning -type: bug -issues: [] diff --git a/docs/changelog/106503.yaml b/docs/changelog/106503.yaml deleted file mode 100644 index 1b7e78d8ffc27..0000000000000 --- a/docs/changelog/106503.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106503 -summary: "Support ST_CONTAINS and ST_WITHIN" -area: "ES|QL" -type: enhancement -issues: [] diff --git a/docs/changelog/106511.yaml b/docs/changelog/106511.yaml deleted file mode 100644 index bdef7f1aea225..0000000000000 --- a/docs/changelog/106511.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106511 -summary: Wait indefintely for http connections on shutdown by default -area: Infra/Node Lifecycle -type: bug -issues: [] diff --git a/docs/changelog/106514.yaml b/docs/changelog/106514.yaml deleted file mode 100644 index 5b25f40db2742..0000000000000 --- a/docs/changelog/106514.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106514 -summary: Add granular error list to alias action response -area: Indices APIs -type: feature -issues: - - 94478 diff --git a/docs/changelog/106516.yaml b/docs/changelog/106516.yaml deleted file mode 100644 index 905896fb0ef03..0000000000000 --- a/docs/changelog/106516.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106516 -summary: "ESQL: perform a reduction on the data node" -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/106526.yaml b/docs/changelog/106526.yaml deleted file mode 100644 index ac98454b5d8b4..0000000000000 --- a/docs/changelog/106526.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106526 -summary: Enhance search tier GC options -area: Infra/CLI -type: enhancement -issues: [] diff --git a/docs/changelog/106531.yaml b/docs/changelog/106531.yaml deleted file mode 100644 index 631d74185d2d8..0000000000000 --- a/docs/changelog/106531.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106531 -summary: Get and Query API Key with profile uid -area: Security -type: feature -issues: [] diff --git a/docs/changelog/106563.yaml b/docs/changelog/106563.yaml deleted file mode 100644 index 79476f909a04c..0000000000000 --- a/docs/changelog/106563.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106563 -summary: Improve short-circuiting downsample execution -area: TSDB -type: enhancement -issues: [] diff --git a/docs/changelog/106575.yaml b/docs/changelog/106575.yaml deleted file mode 100644 index fb5230a9edb3d..0000000000000 --- a/docs/changelog/106575.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106575 -summary: Unable to retrieve multiple stored field values -area: "Search" -type: bug -issues: [] diff --git a/docs/changelog/106579.yaml b/docs/changelog/106579.yaml deleted file mode 100644 index 104ed3066a6f6..0000000000000 --- a/docs/changelog/106579.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106579 -summary: "ESQL: Allow grouping key inside stats expressions" -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/106638.yaml b/docs/changelog/106638.yaml deleted file mode 100644 index 019800bf03157..0000000000000 --- a/docs/changelog/106638.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106638 -summary: Allow users to get status of own async search tasks -area: Authorization -type: enhancement -issues: [] diff --git a/docs/changelog/106654.yaml b/docs/changelog/106654.yaml deleted file mode 100644 index 3443b68482443..0000000000000 --- a/docs/changelog/106654.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106654 -summary: "ES|QL: Fix usage of IN operator with TEXT fields" -area: ES|QL -type: bug -issues: - - 105379 diff --git a/docs/changelog/106685.yaml b/docs/changelog/106685.yaml deleted file mode 100644 index ed4a16ba0666c..0000000000000 --- a/docs/changelog/106685.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106685 -summary: '`SharedBlobCacheService.maybeFetchRegion` should use `computeCacheFileRegionSize`' -area: Snapshot/Restore -type: bug -issues: [] diff --git a/docs/changelog/106691.yaml b/docs/changelog/106691.yaml deleted file mode 100644 index cbae9796e38c7..0000000000000 --- a/docs/changelog/106691.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106691 -summary: Fix range queries for float/half_float fields when bounds are out of type's - range -area: Search -type: bug -issues: [] diff --git a/docs/changelog/106708.yaml b/docs/changelog/106708.yaml deleted file mode 100644 index b8fdd37e5f03f..0000000000000 --- a/docs/changelog/106708.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106708 -summary: Improve error message when rolling over DS alias -area: Data streams -type: bug -issues: - - 106137 diff --git a/docs/changelog/106714.yaml b/docs/changelog/106714.yaml deleted file mode 100644 index 65b0acd77d764..0000000000000 --- a/docs/changelog/106714.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106714 -summary: Add non-indexed fields to ecs templates -area: Data streams -type: bug -issues: [] diff --git a/docs/changelog/106720.yaml b/docs/changelog/106720.yaml deleted file mode 100644 index 93358ed1d3dff..0000000000000 --- a/docs/changelog/106720.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106720 -summary: "ESQL: Fix treating all fields as MV in COUNT pushdown" -area: ES|QL -type: bug -issues: [] diff --git a/docs/changelog/106731.yaml b/docs/changelog/106731.yaml deleted file mode 100644 index 0d8e16a8f9616..0000000000000 --- a/docs/changelog/106731.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106731 -summary: Fix field caps and field level security -area: Security -type: bug -issues: [] diff --git a/docs/changelog/106745.yaml b/docs/changelog/106745.yaml deleted file mode 100644 index a6cb035bd267a..0000000000000 --- a/docs/changelog/106745.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106745 -summary: Fix `AffixSetting.exists` to include secure settings -area: Infra/Core -type: bug -issues: [] diff --git a/docs/changelog/106767.yaml b/docs/changelog/106767.yaml deleted file mode 100644 index 8541e1b14f275..0000000000000 --- a/docs/changelog/106767.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106767 -summary: Handle pass-through subfields with deep nesting -area: Mapping -type: bug -issues: [] diff --git a/docs/changelog/106796.yaml b/docs/changelog/106796.yaml deleted file mode 100644 index 83eb99dba1603..0000000000000 --- a/docs/changelog/106796.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106796 -summary: Bulk loading enrich fields in ESQL -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/106808.yaml b/docs/changelog/106808.yaml deleted file mode 100644 index 287477fc302fd..0000000000000 --- a/docs/changelog/106808.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106808 -summary: Make OpenAI embeddings parser more flexible -area: Machine Learning -type: bug -issues: [] diff --git a/docs/changelog/106810.yaml b/docs/changelog/106810.yaml deleted file mode 100644 index e93e5cf1e5361..0000000000000 --- a/docs/changelog/106810.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106810 -summary: "ES|QL: Improve support for TEXT fields in functions" -area: ES|QL -type: bug -issues: [] diff --git a/docs/changelog/106836.yaml b/docs/changelog/106836.yaml deleted file mode 100644 index f561f44d9bb2d..0000000000000 --- a/docs/changelog/106836.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106836 -summary: Make int8_hnsw our default index for new dense-vector fields -area: Mapping -type: enhancement -issues: [] diff --git a/docs/changelog/106840.yaml b/docs/changelog/106840.yaml deleted file mode 100644 index 3f6831e4907ca..0000000000000 --- a/docs/changelog/106840.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106840 -summary: Add total size in bytes to doc stats -area: Stats -type: enhancement -issues: - - 97670 diff --git a/docs/changelog/106851.yaml b/docs/changelog/106851.yaml deleted file mode 100644 index 2ada6a6a4e088..0000000000000 --- a/docs/changelog/106851.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106851 -summary: Catching `StackOverflowErrors` from bad regexes in `GsubProcessor` -area: Ingest Node -type: bug -issues: [] diff --git a/docs/changelog/106852.yaml b/docs/changelog/106852.yaml deleted file mode 100644 index 2161b1ea22f30..0000000000000 --- a/docs/changelog/106852.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106852 -summary: Introduce ordinal bytesref block -area: ES|QL -type: enhancement -issues: - - 106387 diff --git a/docs/changelog/106860.yaml b/docs/changelog/106860.yaml deleted file mode 100644 index 376f8753023b9..0000000000000 --- a/docs/changelog/106860.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106860 -summary: "[Profiling] Add TopN Functions API" -area: Application -type: enhancement -issues: [] diff --git a/docs/changelog/106862.yaml b/docs/changelog/106862.yaml deleted file mode 100644 index 3ca2660fc3f73..0000000000000 --- a/docs/changelog/106862.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106862 -summary: Extend support of `allowedFields` to `getMatchingFieldNames` and `getAllFields` -area: "Mapping" -type: bug -issues: [] diff --git a/docs/changelog/106866.yaml b/docs/changelog/106866.yaml deleted file mode 100644 index ffc34e5962850..0000000000000 --- a/docs/changelog/106866.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106866 -summary: Add ES|QL signum function -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/106889.yaml b/docs/changelog/106889.yaml deleted file mode 100644 index 7755081d09036..0000000000000 --- a/docs/changelog/106889.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106889 -summary: Slightly better geoip `databaseType` validation -area: Ingest Node -type: bug -issues: [] diff --git a/docs/changelog/106899.yaml b/docs/changelog/106899.yaml deleted file mode 100644 index a2db24236a47e..0000000000000 --- a/docs/changelog/106899.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106899 -summary: Add ES|QL Locate function -area: ES|QL -type: enhancement -issues: - - 106818 diff --git a/docs/changelog/106919.yaml b/docs/changelog/106919.yaml deleted file mode 100644 index d8288095590de..0000000000000 --- a/docs/changelog/106919.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 106919 -summary: Fix downsample action request serialization -area: Downsampling -type: bug -issues: - - 106917 diff --git a/docs/changelog/106934.yaml b/docs/changelog/106934.yaml deleted file mode 100644 index fbfce3118e8a6..0000000000000 --- a/docs/changelog/106934.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106934 -summary: Adjust array resizing in block builder -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/106952.yaml b/docs/changelog/106952.yaml deleted file mode 100644 index 1b45bf6ca28a2..0000000000000 --- a/docs/changelog/106952.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 106952 -summary: Add Lucene spanish plural stemmer -area: Search -type: enhancement -issues: [] diff --git a/docs/changelog/106989.yaml b/docs/changelog/106989.yaml deleted file mode 100644 index 47df5fe5b47d7..0000000000000 --- a/docs/changelog/106989.yaml +++ /dev/null @@ -1,7 +0,0 @@ -pr: 106989 -summary: Make force-stopping the transform always remove persistent task from cluster - state -area: Transform -type: bug -issues: - - 106811 diff --git a/docs/changelog/107007.yaml b/docs/changelog/107007.yaml deleted file mode 100644 index b2a755171725b..0000000000000 --- a/docs/changelog/107007.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107007 -summary: "ESQL: Support ST_DISJOINT" -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/107016.yaml b/docs/changelog/107016.yaml deleted file mode 100644 index a2e32749a8008..0000000000000 --- a/docs/changelog/107016.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107016 -summary: "ESQL: Enable VALUES agg for datetime" -area: Aggregations -type: bug -issues: [] diff --git a/docs/changelog/107038.yaml b/docs/changelog/107038.yaml deleted file mode 100644 index e00b0d45a8a3a..0000000000000 --- a/docs/changelog/107038.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107038 -summary: Replace `UnsupportedOperationException` with `IllegalArgumentException` for non-existing columns -area: Search -type: bug -issues: [] diff --git a/docs/changelog/107041.yaml b/docs/changelog/107041.yaml deleted file mode 100644 index b8b4f3d7c5690..0000000000000 --- a/docs/changelog/107041.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107041 -summary: '`DocumentParsingObserver` to accept an `indexName` to allow skipping system - indices' -area: Infra/Metrics -type: enhancement -issues: [] diff --git a/docs/changelog/107046.yaml b/docs/changelog/107046.yaml deleted file mode 100644 index 6c1373e09d17c..0000000000000 --- a/docs/changelog/107046.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107046 -summary: "[Security Solution] Add `read` permission for third party agent indices\ - \ for `kibana_system`" -area: Authorization -type: enhancement -issues: [] diff --git a/docs/changelog/107050.yaml b/docs/changelog/107050.yaml deleted file mode 100644 index ecb375967ae44..0000000000000 --- a/docs/changelog/107050.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107050 -summary: Fix support for infinite `?master_timeout` -area: Cluster Coordination -type: bug -issues: [] diff --git a/docs/changelog/107107.yaml b/docs/changelog/107107.yaml deleted file mode 100644 index 5ca611befeb5d..0000000000000 --- a/docs/changelog/107107.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107107 -summary: Increase KDF iteration count in `KeyStoreWrapper` -area: Infra/CLI -type: enhancement -issues: [] diff --git a/docs/changelog/107121.yaml b/docs/changelog/107121.yaml deleted file mode 100644 index d46b1d58e9dfb..0000000000000 --- a/docs/changelog/107121.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107121 -summary: Add a flag to re-enable writes on the final index after an ILM shrink action. -area: ILM+SLM -type: enhancement -issues: - - 106599 diff --git a/docs/changelog/107129.yaml b/docs/changelog/107129.yaml deleted file mode 100644 index 6c9b9094962c1..0000000000000 --- a/docs/changelog/107129.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107129 -summary: Track ongoing search tasks -area: Search -type: enhancement -issues: [] diff --git a/docs/changelog/107158.yaml b/docs/changelog/107158.yaml deleted file mode 100644 index 9589fe7e7264b..0000000000000 --- a/docs/changelog/107158.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107158 -summary: "ESQL: allow sorting by expressions and not only regular fields" -area: ES|QL -type: feature -issues: [] diff --git a/docs/changelog/107178.yaml b/docs/changelog/107178.yaml deleted file mode 100644 index 94a91357d38e6..0000000000000 --- a/docs/changelog/107178.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107178 -summary: "Add support for Azure OpenAI embeddings to inference service" -area: Machine Learning -type: feature -issues: [ ] diff --git a/docs/changelog/107183.yaml b/docs/changelog/107183.yaml deleted file mode 100644 index 226d036456858..0000000000000 --- a/docs/changelog/107183.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107183 -summary: ES|QL fix no-length substring with supplementary (4-byte) character -area: ES|QL -type: bug -issues: [] diff --git a/docs/changelog/107196.yaml b/docs/changelog/107196.yaml deleted file mode 100644 index 9892ccf71856f..0000000000000 --- a/docs/changelog/107196.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107196 -summary: Add metric for calculating index flush time excluding waiting on locks -area: Engine -type: enhancement -issues: [] diff --git a/docs/changelog/107224.yaml b/docs/changelog/107224.yaml deleted file mode 100644 index b0d40c09b758a..0000000000000 --- a/docs/changelog/107224.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107224 -summary: "Enable 'encoder' and 'tags_schema' highlighting settings at field level" -area: Highlighting -type: enhancement -issues: - - 94028 diff --git a/docs/changelog/107232.yaml b/docs/changelog/107232.yaml deleted file mode 100644 index 1422848cb1c91..0000000000000 --- a/docs/changelog/107232.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107232 -summary: Only trigger action once per thread -area: Transform -type: bug -issues: - - 107215 diff --git a/docs/changelog/107242.yaml b/docs/changelog/107242.yaml deleted file mode 100644 index 4a5e9821a1fa9..0000000000000 --- a/docs/changelog/107242.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107242 -summary: Added a timeout parameter to the inference API -area: Machine Learning -type: enhancement -issues: [ ] diff --git a/docs/changelog/107253.yaml b/docs/changelog/107253.yaml deleted file mode 100644 index 6961b59231ea3..0000000000000 --- a/docs/changelog/107253.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107253 -summary: "[Connector API] Support cleaning up sync jobs when deleting a connector" -area: Application -type: feature -issues: [] diff --git a/docs/changelog/107272.yaml b/docs/changelog/107272.yaml deleted file mode 100644 index eb9e0c5e8bab8..0000000000000 --- a/docs/changelog/107272.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107272 -summary: "ESQL: extend BUCKET with spans" -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/107287.yaml b/docs/changelog/107287.yaml deleted file mode 100644 index 791f07fd1c729..0000000000000 --- a/docs/changelog/107287.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107287 -summary: Add support for the 'Anonymous IP' database to the geoip processor -area: Ingest Node -type: enhancement -issues: - - 90789 diff --git a/docs/changelog/107291.yaml b/docs/changelog/107291.yaml deleted file mode 100644 index 3274fb77ef8c8..0000000000000 --- a/docs/changelog/107291.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107291 -summary: Support data streams in enrich policy indices -area: Ingest Node -type: enhancement -issues: - - 98836 diff --git a/docs/changelog/107303.yaml b/docs/changelog/107303.yaml deleted file mode 100644 index 2e04ce6be3627..0000000000000 --- a/docs/changelog/107303.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107303 -summary: Create default word based chunker -area: Machine Learning -type: feature -issues: [] diff --git a/docs/changelog/107312.yaml b/docs/changelog/107312.yaml deleted file mode 100644 index 6ecd4179596e5..0000000000000 --- a/docs/changelog/107312.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107312 -summary: Fix NPE in ML assignment notifier -area: Machine Learning -type: bug -issues: [] diff --git a/docs/changelog/107334.yaml b/docs/changelog/107334.yaml deleted file mode 100644 index d1e8df2fa9c40..0000000000000 --- a/docs/changelog/107334.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107334 -summary: Adding `cache_stats` to geoip stats API -area: Ingest Node -type: enhancement -issues: [] diff --git a/docs/changelog/107358.yaml b/docs/changelog/107358.yaml deleted file mode 100644 index edb6deeffd100..0000000000000 --- a/docs/changelog/107358.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107358 -summary: Check node shutdown before fail -area: Transform -type: enhancement -issues: - - 100891 diff --git a/docs/changelog/107370.yaml b/docs/changelog/107370.yaml deleted file mode 100644 index e7bdeef68cffe..0000000000000 --- a/docs/changelog/107370.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107370 -summary: Fork when handling remote field-caps responses -area: Search -type: bug -issues: [] diff --git a/docs/changelog/107377.yaml b/docs/changelog/107377.yaml deleted file mode 100644 index a97f8b78dcce9..0000000000000 --- a/docs/changelog/107377.yaml +++ /dev/null @@ -1,13 +0,0 @@ -pr: 107377 -summary: Add support for the 'Enterprise' database to the geoip processor -area: Ingest Node -type: enhancement -issues: [] -highlight: - title: "Preview: Support for the 'Anonymous IP' and 'Enterprise' databases in the geoip processor" - body: |- - As a Technical Preview, the {ref}/geoip-processor.html[`geoip`] processor can now use the commercial - https://www.maxmind.com/en/solutions/geoip2-enterprise-product-suite/enterprise-database[GeoIP2 'Enterprise'] - and - https://www.maxmind.com/en/solutions/geoip2-enterprise-product-suite/anonymous-ip-database[GeoIP2 'Anonymous IP'] - databases from MaxMind. diff --git a/docs/changelog/107383.yaml b/docs/changelog/107383.yaml deleted file mode 100644 index 07886ac96180c..0000000000000 --- a/docs/changelog/107383.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107383 -summary: Users with monitor privileges can access async_search/status endpoint - even when setting keep_alive -area: Authorization -type: bug -issues: [] diff --git a/docs/changelog/107411.yaml b/docs/changelog/107411.yaml deleted file mode 100644 index fda040bcdab80..0000000000000 --- a/docs/changelog/107411.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107411 -summary: Invalidating cross cluster API keys requires `manage_security` -area: Security -type: enhancement -issues: [] diff --git a/docs/changelog/107414.yaml b/docs/changelog/107414.yaml deleted file mode 100644 index 60e31f22ca834..0000000000000 --- a/docs/changelog/107414.yaml +++ /dev/null @@ -1,7 +0,0 @@ -pr: 107414 -summary: "ESQL: median, count and `count_distinct` over constants" -area: ES|QL -type: bug -issues: - - 105248 - - 104900 diff --git a/docs/changelog/107447.yaml b/docs/changelog/107447.yaml deleted file mode 100644 index 6ace513013e3e..0000000000000 --- a/docs/changelog/107447.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107447 -summary: "Fix regression in get index settings (human=true) where the version was not displayed in human-readable format" -area: Infra/Core -type: bug -issues: [] diff --git a/docs/changelog/107449.yaml b/docs/changelog/107449.yaml deleted file mode 100644 index 7f0b1bb826e94..0000000000000 --- a/docs/changelog/107449.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107449 -summary: Leverage ordinals in enrich lookup -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/107467.yaml b/docs/changelog/107467.yaml deleted file mode 100644 index e775e5928770d..0000000000000 --- a/docs/changelog/107467.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107467 -summary: "[Connector API] Fix bug with filtering validation toXContent" -area: Application -type: bug -issues: [] diff --git a/docs/changelog/107494.yaml b/docs/changelog/107494.yaml deleted file mode 100644 index 1d71ce284a4a8..0000000000000 --- a/docs/changelog/107494.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107494 -summary: Handle infinity during synthetic source construction for scaled float field -area: Mapping -type: bug -issues: - - 107101 diff --git a/docs/changelog/107517.yaml b/docs/changelog/107517.yaml deleted file mode 100644 index 4d7830699ad49..0000000000000 --- a/docs/changelog/107517.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107517 -summary: Add GET `_inference` for all inference endpoints -area: Machine Learning -type: enhancement -issues: [] diff --git a/docs/changelog/107533.yaml b/docs/changelog/107533.yaml deleted file mode 100644 index da95cfd5b312e..0000000000000 --- a/docs/changelog/107533.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107533 -summary: Add setting for max connections to S3 -area: Snapshot/Restore -type: enhancement -issues: [] diff --git a/docs/changelog/107551.yaml b/docs/changelog/107551.yaml deleted file mode 100644 index 78e64cc526638..0000000000000 --- a/docs/changelog/107551.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107551 -summary: Avoid attempting to load the same empty field twice in fetch phase -area: Search -type: bug -issues: [] diff --git a/docs/changelog/107577.yaml b/docs/changelog/107577.yaml deleted file mode 100644 index a9a3c36a0e04d..0000000000000 --- a/docs/changelog/107577.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107577 -summary: "ESQL: Fix MV_DEDUPE when using data from an index" -area: ES|QL -type: bug -issues: - - 104745 diff --git a/docs/changelog/107578.yaml b/docs/changelog/107578.yaml deleted file mode 100644 index 30746aeee6986..0000000000000 --- a/docs/changelog/107578.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107578 -summary: "ESQL: Allow reusing BUCKET grouping expressions in aggs" -area: ES|QL -type: bug -issues: [] diff --git a/docs/changelog/107598.yaml b/docs/changelog/107598.yaml deleted file mode 100644 index 125bbe759d2ea..0000000000000 --- a/docs/changelog/107598.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107598 -summary: Fix bulk NPE when retrying failure redirect after cluster block -area: Data streams -type: bug -issues: [] diff --git a/docs/changelog/107655.yaml b/docs/changelog/107655.yaml deleted file mode 100644 index 7091224d211f1..0000000000000 --- a/docs/changelog/107655.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107655 -summary: "Use #addWithoutBreaking when adding a negative number of bytes to the circuit\ - \ breaker in `SequenceMatcher`" -area: EQL -type: bug -issues: [] diff --git a/docs/changelog/107678.yaml b/docs/changelog/107678.yaml deleted file mode 100644 index 9be55dd4d6b96..0000000000000 --- a/docs/changelog/107678.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107678 -summary: Validate stats formatting in standard `InternalStats` constructor -area: Aggregations -type: bug -issues: - - 107671 diff --git a/docs/changelog/107743.yaml b/docs/changelog/107743.yaml deleted file mode 100644 index fad45040330d2..0000000000000 --- a/docs/changelog/107743.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107743 -summary: Validate `model_id` is required when using the `learning_to_rank` rescorer -area: Search -type: bug -issues: [] diff --git a/docs/changelog/107828.yaml b/docs/changelog/107828.yaml deleted file mode 100644 index ba0d44029203d..0000000000000 --- a/docs/changelog/107828.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107828 -summary: Update several references to `IndexVersion.toString` to use `toReleaseVersion` -area: Infra/Core -type: bug -issues: - - 107821 diff --git a/docs/changelog/107865.yaml b/docs/changelog/107865.yaml deleted file mode 100644 index f7bb1d869eed5..0000000000000 --- a/docs/changelog/107865.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107865 -summary: Fix docs generation of signatures for variadic functions -area: ES|QL -type: bug -issues: [] diff --git a/docs/changelog/107891.yaml b/docs/changelog/107891.yaml deleted file mode 100644 index deb3fbd2258ff..0000000000000 --- a/docs/changelog/107891.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 107891 -summary: Fix `startOffset` must be non-negative error in XLMRoBERTa tokenizer -area: Machine Learning -type: bug -issues: - - 104626 diff --git a/docs/changelog/107902.yaml b/docs/changelog/107902.yaml deleted file mode 100644 index 6b25f8c12df60..0000000000000 --- a/docs/changelog/107902.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107902 -summary: Update several references to `TransportVersion.toString` to use `toReleaseVersion` -area: Infra/Core -type: bug -issues: [] diff --git a/docs/changelog/107969.yaml b/docs/changelog/107969.yaml deleted file mode 100644 index ed63513d8d57d..0000000000000 --- a/docs/changelog/107969.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 107969 -summary: Disable PIT for remote clusters -area: Transform -type: bug -issues: [] diff --git a/docs/changelog/108007.yaml b/docs/changelog/108007.yaml deleted file mode 100644 index 5d24f8c87597c..0000000000000 --- a/docs/changelog/108007.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 108007 -summary: Allow `typed_keys` for search application Search API -area: Application -type: feature -issues: [] diff --git a/docs/changelog/108031.yaml b/docs/changelog/108031.yaml deleted file mode 100644 index 0d02ddddbd472..0000000000000 --- a/docs/changelog/108031.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 108031 -summary: Fix lingering license warning header -area: License -type: bug -issues: - - 107573 diff --git a/docs/changelog/108041.yaml b/docs/changelog/108041.yaml deleted file mode 100644 index a82e0798dba21..0000000000000 --- a/docs/changelog/108041.yaml +++ /dev/null @@ -1,7 +0,0 @@ -pr: 108041 -summary: Handle parallel calls to `createWeight` when profiling is on -area: Search -type: bug -issues: - - 104131 - - 104235 diff --git a/docs/changelog/108101.yaml b/docs/changelog/108101.yaml deleted file mode 100644 index e935ec1beecd6..0000000000000 --- a/docs/changelog/108101.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 108101 -summary: "ESQL: Fix error message when failing to resolve aggregate groupings" -area: ES|QL -type: bug -issues: - - 108053 diff --git a/docs/changelog/108238.yaml b/docs/changelog/108238.yaml deleted file mode 100644 index 607979c2eb0ac..0000000000000 --- a/docs/changelog/108238.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 108238 -summary: "Nativeaccess: try to load all located libsystemds" -area: Infra/Core -type: bug -issues: - - 107878 diff --git a/docs/changelog/108257.yaml b/docs/changelog/108257.yaml deleted file mode 100644 index ce2c72353af82..0000000000000 --- a/docs/changelog/108257.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 108257 -summary: "ESQL: Log queries at debug level" -area: ES|QL -type: enhancement -issues: [] diff --git a/docs/changelog/108365.yaml b/docs/changelog/108365.yaml deleted file mode 100644 index d94486e2f3ea7..0000000000000 --- a/docs/changelog/108365.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 108365 -summary: "[Bugfix] Connector API - fix status serialisation issue in termquery" -area: Application -type: bug -issues: [] diff --git a/docs/changelog/108431.yaml b/docs/changelog/108431.yaml deleted file mode 100644 index 84607b1b99ac3..0000000000000 --- a/docs/changelog/108431.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 108431 -summary: "ESQL: Disable quoting in FROM command" -area: ES|QL -type: bug -issues: [] diff --git a/docs/changelog/108518.yaml b/docs/changelog/108518.yaml deleted file mode 100644 index aad823ccc89f6..0000000000000 --- a/docs/changelog/108518.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 108518 -summary: Remove leading is_ prefix from Enterprise geoip docs -area: Ingest Node -type: bug -issues: [] diff --git a/docs/changelog/108562.yaml b/docs/changelog/108562.yaml deleted file mode 100644 index 2a0047fe807fd..0000000000000 --- a/docs/changelog/108562.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 108562 -summary: Add `internalClusterTest` for and fix leak in `ExpandSearchPhase` -area: Search -type: bug -issues: - - 108369 diff --git a/docs/changelog/108571.yaml b/docs/changelog/108571.yaml deleted file mode 100644 index b863ac90d9e5f..0000000000000 --- a/docs/changelog/108571.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 108571 -summary: Workaround G1 bug for JDK 22 and 22.0.1 -area: Infra/CLI -type: bug -issues: [] diff --git a/docs/changelog/108600.yaml b/docs/changelog/108600.yaml deleted file mode 100644 index 59177bf34114c..0000000000000 --- a/docs/changelog/108600.yaml +++ /dev/null @@ -1,15 +0,0 @@ -pr: 108600 -summary: "Prevent DLS/FLS if `replication` is assigned" -area: Security -type: breaking -issues: [ ] -breaking: - title: "Prevent DLS/FLS if `replication` is assigned" - area: REST API - details: For cross-cluster API keys, {es} no longer allows specifying document-level security (DLS) - or field-level security (FLS) in the `search` field, if `replication` is also specified. - {es} likewise blocks the use of any existing cross-cluster API keys that meet this condition. - impact: Remove any document-level security (DLS) or field-level security (FLS) definitions from the `search` field - for cross-cluster API keys that also have a `replication` field, or create two separate cross-cluster API keys, - one for search and one for replication. - notable: false diff --git a/docs/changelog/108654.yaml b/docs/changelog/108654.yaml deleted file mode 100644 index 9afae6a19ca80..0000000000000 --- a/docs/changelog/108654.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 108654 -summary: Update bundled JDK to Java 22 (again) -area: Packaging -type: upgrade -issues: [] diff --git a/docs/changelog/108736.yaml b/docs/changelog/108736.yaml deleted file mode 100644 index 41e4084021e00..0000000000000 --- a/docs/changelog/108736.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 108736 -summary: Harden field-caps request dispatcher -area: Search -type: bug -issues: [] diff --git a/docs/changelog/108802.yaml b/docs/changelog/108802.yaml deleted file mode 100644 index 7c28a81a1b353..0000000000000 --- a/docs/changelog/108802.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 108802 -summary: Fix multithreading copies in lib vec -area: Vector Search -type: bug -issues: [] diff --git a/docs/changelog/108834.yaml b/docs/changelog/108834.yaml deleted file mode 100644 index 044056fa9a9da..0000000000000 --- a/docs/changelog/108834.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 108834 -summary: "[ESQL] Mark `date_diff` as requiring all three arguments" -area: ES|QL -type: bug -issues: - - 108383 diff --git a/docs/changelog/108854.yaml b/docs/changelog/108854.yaml deleted file mode 100644 index d6a880830f0d9..0000000000000 --- a/docs/changelog/108854.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 108854 -summary: "[Connector API] Fix bug with parsing *_doc_count nullable fields" -area: Application -type: bug -issues: [] diff --git a/docs/changelog/108867.yaml b/docs/changelog/108867.yaml deleted file mode 100644 index 545349dd84aeb..0000000000000 --- a/docs/changelog/108867.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 108867 -summary: Fix for raw mapping merge of fields named "properties" -area: Mapping -type: bug -issues: - - 108866 diff --git a/docs/changelog/108900.yaml b/docs/changelog/108900.yaml deleted file mode 100644 index 2a182f03ff8ce..0000000000000 --- a/docs/changelog/108900.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 108900 -summary: Treat 404 as empty register in `AzureBlobStore` -area: Snapshot/Restore -type: bug -issues: - - 108504 diff --git a/docs/changelog/109020.yaml b/docs/changelog/109020.yaml deleted file mode 100644 index c3efb1a1409bf..0000000000000 --- a/docs/changelog/109020.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 109020 -summary: Only skip deleting a downsampled index if downsampling is in progress as - part of DSL retention -area: Data streams -type: bug -issues: [] diff --git a/docs/changelog/109034.yaml b/docs/changelog/109034.yaml deleted file mode 100644 index cdf1f6fe28d8d..0000000000000 --- a/docs/changelog/109034.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 109034 -summary: Fix IOOBE in TTest aggregation when using filters -area: Aggregations -type: bug -issues: [] diff --git a/docs/changelog/109048.yaml b/docs/changelog/109048.yaml deleted file mode 100644 index 8bae082404ecd..0000000000000 --- a/docs/changelog/109048.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 109048 -summary: Guard against a null scorer in painless execute -area: Infra/Scripting -type: bug -issues: - - 43541 diff --git a/docs/changelog/109097.yaml b/docs/changelog/109097.yaml deleted file mode 100644 index a7520f4eaa9be..0000000000000 --- a/docs/changelog/109097.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 109097 -summary: "[Connector API] Fix bug with with wrong target index for access control\ - \ sync" -area: Application -type: bug -issues: [] diff --git a/docs/changelog/109148.yaml b/docs/changelog/109148.yaml deleted file mode 100644 index 902da6f1a1db3..0000000000000 --- a/docs/changelog/109148.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 109148 -summary: Fix double-pausing shard snapshot -area: Snapshot/Restore -type: bug -issues: - - 109143 diff --git a/docs/changelog/109173.yaml b/docs/changelog/109173.yaml deleted file mode 100644 index 9f4f73a6f74c8..0000000000000 --- a/docs/changelog/109173.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 109173 -summary: Wrap "Pattern too complex" exception into an `IllegalArgumentException` -area: Mapping -type: bug -issues: [] diff --git a/docs/changelog/97072.yaml b/docs/changelog/97072.yaml deleted file mode 100644 index 686b30952b646..0000000000000 --- a/docs/changelog/97072.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 97072 -summary: Log when update AffixSetting using addAffixMapUpdateConsumer -area: Infra/Logging -type: bug -issues: [] diff --git a/docs/changelog/97561.yaml b/docs/changelog/97561.yaml deleted file mode 100644 index cacefbf7e4ca3..0000000000000 --- a/docs/changelog/97561.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 97561 -summary: Add index forecasts to /_cat/allocation output -area: Allocation -type: enhancement -issues: [] diff --git a/docs/changelog/99048.yaml b/docs/changelog/99048.yaml deleted file mode 100644 index 722c145dae78f..0000000000000 --- a/docs/changelog/99048.yaml +++ /dev/null @@ -1,6 +0,0 @@ -pr: 99048 -summary: String sha512() painless function -area: Infra/Scripting -type: enhancement -issues: - - 97691 From b1c798237879dab3261f4c31f365e16024e2cd1e Mon Sep 17 00:00:00 2001 From: Mark Tozzi Date: Wed, 5 Jun 2024 11:14:03 -0400 Subject: [PATCH 11/30] [ESQL] Migrate Optimizer rules and associated tests (#109216) This PR tries to reconcile the tests for the logical plan optimizer rules. I deleted tests and rules in esql-core that had already been pulled into esql, and pulled in most of the remaining tests, and several of the rules. Already migrated and removed from core: - CombineDisjunction - InferIsNotNull - PropagateEquals Newly migrated rules: - ReplaceRegexMatch - BinaryComparisonSimplification - CombineBinaryComparisons This enables removing the (hopefully) unused operator base classes (Add, Equals, etc) in core, or at least removes one blocker to that. There may be more tests in core that reference these classes and need to be migrated, I haven't exhaustively looked yet. --- .../esql/core/optimizer/OptimizerRules.java | 1049 +---------- .../core/optimizer/OptimizerRulesTests.java | 1657 ----------------- .../esql/optimizer/LogicalPlanOptimizer.java | 25 +- .../xpack/esql/optimizer/OptimizerRules.java | 6 - .../esql/optimizer/OptimizerRulesTests.java | 383 ++++ 5 files changed, 409 insertions(+), 2711 deletions(-) diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRules.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRules.java index 7759b62d3d187..137f440f03b7e 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRules.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRules.java @@ -8,8 +8,6 @@ import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.xpack.esql.core.expression.Alias; -import org.elasticsearch.xpack.esql.core.expression.Attribute; -import org.elasticsearch.xpack.esql.core.expression.AttributeMap; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.Literal; @@ -22,7 +20,6 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryPredicate; import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable; import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; -import org.elasticsearch.xpack.esql.core.expression.predicate.Range; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; @@ -31,27 +28,18 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.Equals; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.GreaterThan; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.GreaterThanOrEqual; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.In; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.LessThan; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.LessThanOrEqual; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.NotEquals; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.NullEquals; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RegexMatch; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.StringPattern; import org.elasticsearch.xpack.esql.core.plan.logical.Filter; import org.elasticsearch.xpack.esql.core.plan.logical.Limit; import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.core.rule.Rule; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.core.util.CollectionUtils; import org.elasticsearch.xpack.esql.core.util.ReflectionUtils; import java.time.ZoneId; import java.util.ArrayList; -import java.util.Iterator; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.LinkedList; @@ -60,7 +48,6 @@ import java.util.Set; import java.util.function.BiFunction; -import static java.util.Collections.emptySet; import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; import static org.elasticsearch.xpack.esql.core.expression.predicate.Predicates.combineAnd; @@ -248,43 +235,6 @@ protected Expression maybeSimplifyNegatable(Expression e) { } } - public static class BinaryComparisonSimplification extends OptimizerExpressionRule { - - public BinaryComparisonSimplification() { - super(TransformDirection.DOWN); - } - - @Override - protected Expression rule(BinaryComparison bc) { - Expression l = bc.left(); - Expression r = bc.right(); - - // true for equality - if (bc instanceof Equals || bc instanceof GreaterThanOrEqual || bc instanceof LessThanOrEqual) { - if (l.nullable() == Nullability.FALSE && r.nullable() == Nullability.FALSE && l.semanticEquals(r)) { - return new Literal(bc.source(), Boolean.TRUE, DataType.BOOLEAN); - } - } - if (bc instanceof NullEquals) { - if (l.semanticEquals(r)) { - return new Literal(bc.source(), Boolean.TRUE, DataType.BOOLEAN); - } - if (Expressions.isNull(r)) { - return new IsNull(bc.source(), l); - } - } - - // false for equality - if (bc instanceof NotEquals || bc instanceof GreaterThan || bc instanceof LessThan) { - if (l.nullable() == Nullability.FALSE && r.nullable() == Nullability.FALSE && l.semanticEquals(r)) { - return new Literal(bc.source(), Boolean.FALSE, DataType.BOOLEAN); - } - } - - return bc; - } - } - public static final class LiteralsOnTheRight extends OptimizerExpressionRule> { public LiteralsOnTheRight() { @@ -297,887 +247,6 @@ public LiteralsOnTheRight() { } } - /** - * Propagate Equals to eliminate conjuncted Ranges or BinaryComparisons. - * When encountering a different Equals, non-containing {@link Range} or {@link BinaryComparison}, the conjunction becomes false. - * When encountering a containing {@link Range}, {@link BinaryComparison} or {@link NotEquals}, these get eliminated by the equality. - * - * Since this rule can eliminate Ranges and BinaryComparisons, it should be applied before {@link CombineBinaryComparisons}. - * - * This rule doesn't perform any promotion of {@link BinaryComparison}s, that is handled by - * {@link CombineBinaryComparisons} on purpose as the resulting Range might be foldable - * (which is picked by the folding rule on the next run). - */ - public static final class PropagateEquals extends OptimizerExpressionRule { - - public PropagateEquals() { - super(TransformDirection.DOWN); - } - - @Override - public Expression rule(BinaryLogic e) { - if (e instanceof And) { - return propagate((And) e); - } else if (e instanceof Or) { - return propagate((Or) e); - } - return e; - } - - // combine conjunction - private static Expression propagate(And and) { - List ranges = new ArrayList<>(); - // Only equalities, not-equalities and inequalities with a foldable .right are extracted separately; - // the others go into the general 'exps'. - List equals = new ArrayList<>(); - List notEquals = new ArrayList<>(); - List inequalities = new ArrayList<>(); - List exps = new ArrayList<>(); - - boolean changed = false; - - for (Expression ex : Predicates.splitAnd(and)) { - if (ex instanceof Range) { - ranges.add((Range) ex); - } else if (ex instanceof Equals || ex instanceof NullEquals) { - BinaryComparison otherEq = (BinaryComparison) ex; - // equals on different values evaluate to FALSE - // ignore date/time fields as equality comparison might actually be a range check - if (otherEq.right().foldable() && DataType.isDateTime(otherEq.left().dataType()) == false) { - for (BinaryComparison eq : equals) { - if (otherEq.left().semanticEquals(eq.left())) { - Integer comp = BinaryComparison.compare(eq.right().fold(), otherEq.right().fold()); - if (comp != null) { - // var cannot be equal to two different values at the same time - if (comp != 0) { - return new Literal(and.source(), Boolean.FALSE, DataType.BOOLEAN); - } - } - } - } - equals.add(otherEq); - } else { - exps.add(otherEq); - } - } else if (ex instanceof GreaterThan - || ex instanceof GreaterThanOrEqual - || ex instanceof LessThan - || ex instanceof LessThanOrEqual) { - BinaryComparison bc = (BinaryComparison) ex; - if (bc.right().foldable()) { - inequalities.add(bc); - } else { - exps.add(ex); - } - } else if (ex instanceof NotEquals otherNotEq) { - if (otherNotEq.right().foldable()) { - notEquals.add(otherNotEq); - } else { - exps.add(ex); - } - } else { - exps.add(ex); - } - } - - // check - for (BinaryComparison eq : equals) { - Object eqValue = eq.right().fold(); - - for (Iterator iterator = ranges.iterator(); iterator.hasNext();) { - Range range = iterator.next(); - - if (range.value().semanticEquals(eq.left())) { - // if equals is outside the interval, evaluate the whole expression to FALSE - if (range.lower().foldable()) { - Integer compare = BinaryComparison.compare(range.lower().fold(), eqValue); - if (compare != null && ( - // eq outside the lower boundary - compare > 0 || - // eq matches the boundary but should not be included - (compare == 0 && range.includeLower() == false))) { - return new Literal(and.source(), Boolean.FALSE, DataType.BOOLEAN); - } - } - if (range.upper().foldable()) { - Integer compare = BinaryComparison.compare(range.upper().fold(), eqValue); - if (compare != null && ( - // eq outside the upper boundary - compare < 0 || - // eq matches the boundary but should not be included - (compare == 0 && range.includeUpper() == false))) { - return new Literal(and.source(), Boolean.FALSE, DataType.BOOLEAN); - } - } - - // it's in the range and thus, remove it - iterator.remove(); - changed = true; - } - } - - // evaluate all NotEquals against the Equal - for (Iterator iter = notEquals.iterator(); iter.hasNext();) { - NotEquals neq = iter.next(); - if (eq.left().semanticEquals(neq.left())) { - Integer comp = BinaryComparison.compare(eqValue, neq.right().fold()); - if (comp != null) { - if (comp == 0) { // clashing and conflicting: a = 1 AND a != 1 - return new Literal(and.source(), Boolean.FALSE, DataType.BOOLEAN); - } else { // clashing and redundant: a = 1 AND a != 2 - iter.remove(); - changed = true; - } - } - } - } - - // evaluate all inequalities against the Equal - for (Iterator iter = inequalities.iterator(); iter.hasNext();) { - BinaryComparison bc = iter.next(); - if (eq.left().semanticEquals(bc.left())) { - Integer compare = BinaryComparison.compare(eqValue, bc.right().fold()); - if (compare != null) { - if (bc instanceof LessThan || bc instanceof LessThanOrEqual) { // a = 2 AND a />= ? - if ((compare == 0 && bc instanceof GreaterThan) || // a = 2 AND a > 2 - compare < 0) { // a = 2 AND a >/>= 3 - return new Literal(and.source(), Boolean.FALSE, DataType.BOOLEAN); - } - } - - iter.remove(); - changed = true; - } - } - } - } - - return changed ? Predicates.combineAnd(CollectionUtils.combine(exps, equals, notEquals, inequalities, ranges)) : and; - } - - // combine disjunction: - // a = 2 OR a > 3 -> nop; a = 2 OR a > 1 -> a > 1 - // a = 2 OR a < 3 -> a < 3; a = 2 OR a < 1 -> nop - // a = 2 OR 3 < a < 5 -> nop; a = 2 OR 1 < a < 3 -> 1 < a < 3; a = 2 OR 0 < a < 1 -> nop - // a = 2 OR a != 2 -> TRUE; a = 2 OR a = 5 -> nop; a = 2 OR a != 5 -> a != 5 - private static Expression propagate(Or or) { - List exps = new ArrayList<>(); - List equals = new ArrayList<>(); // foldable right term Equals - List notEquals = new ArrayList<>(); // foldable right term NotEquals - List ranges = new ArrayList<>(); - List inequalities = new ArrayList<>(); // foldable right term (=limit) BinaryComparision - - // split expressions by type - for (Expression ex : Predicates.splitOr(or)) { - if (ex instanceof Equals eq) { - if (eq.right().foldable()) { - equals.add(eq); - } else { - exps.add(ex); - } - } else if (ex instanceof NotEquals neq) { - if (neq.right().foldable()) { - notEquals.add(neq); - } else { - exps.add(ex); - } - } else if (ex instanceof Range) { - ranges.add((Range) ex); - } else if (ex instanceof BinaryComparison bc) { - if (bc.right().foldable()) { - inequalities.add(bc); - } else { - exps.add(ex); - } - } else { - exps.add(ex); - } - } - - boolean updated = false; // has the expression been modified? - - // evaluate the impact of each Equal over the different types of Expressions - for (Iterator iterEq = equals.iterator(); iterEq.hasNext();) { - Equals eq = iterEq.next(); - Object eqValue = eq.right().fold(); - boolean removeEquals = false; - - // Equals OR NotEquals - for (NotEquals neq : notEquals) { - if (eq.left().semanticEquals(neq.left())) { // a = 2 OR a != ? -> ... - Integer comp = BinaryComparison.compare(eqValue, neq.right().fold()); - if (comp != null) { - if (comp == 0) { // a = 2 OR a != 2 -> TRUE - return TRUE; - } else { // a = 2 OR a != 5 -> a != 5 - removeEquals = true; - break; - } - } - } - } - if (removeEquals) { - iterEq.remove(); - updated = true; - continue; - } - - // Equals OR Range - for (int i = 0; i < ranges.size(); i++) { // might modify list, so use index loop - Range range = ranges.get(i); - if (eq.left().semanticEquals(range.value())) { - Integer lowerComp = range.lower().foldable() ? BinaryComparison.compare(eqValue, range.lower().fold()) : null; - Integer upperComp = range.upper().foldable() ? BinaryComparison.compare(eqValue, range.upper().fold()) : null; - - if (lowerComp != null && lowerComp == 0) { - if (range.includeLower() == false) { // a = 2 OR 2 < a < ? -> 2 <= a < ? - ranges.set( - i, - new Range( - range.source(), - range.value(), - range.lower(), - true, - range.upper(), - range.includeUpper(), - range.zoneId() - ) - ); - } // else : a = 2 OR 2 <= a < ? -> 2 <= a < ? - removeEquals = true; // update range with lower equality instead or simply superfluous - break; - } else if (upperComp != null && upperComp == 0) { - if (range.includeUpper() == false) { // a = 2 OR ? < a < 2 -> ? < a <= 2 - ranges.set( - i, - new Range( - range.source(), - range.value(), - range.lower(), - range.includeLower(), - range.upper(), - true, - range.zoneId() - ) - ); - } // else : a = 2 OR ? < a <= 2 -> ? < a <= 2 - removeEquals = true; // update range with upper equality instead - break; - } else if (lowerComp != null && upperComp != null) { - if (0 < lowerComp && upperComp < 0) { // a = 2 OR 1 < a < 3 - removeEquals = true; // equality is superfluous - break; - } - } - } - } - if (removeEquals) { - iterEq.remove(); - updated = true; - continue; - } - - // Equals OR Inequality - for (int i = 0; i < inequalities.size(); i++) { - BinaryComparison bc = inequalities.get(i); - if (eq.left().semanticEquals(bc.left())) { - Integer comp = BinaryComparison.compare(eqValue, bc.right().fold()); - if (comp != null) { - if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) { - if (comp < 0) { // a = 1 OR a > 2 -> nop - continue; - } else if (comp == 0 && bc instanceof GreaterThan) { // a = 2 OR a > 2 -> a >= 2 - inequalities.set(i, new GreaterThanOrEqual(bc.source(), bc.left(), bc.right(), bc.zoneId())); - } // else (0 < comp || bc instanceof GreaterThanOrEqual) : - // a = 3 OR a > 2 -> a > 2; a = 2 OR a => 2 -> a => 2 - - removeEquals = true; // update range with equality instead or simply superfluous - break; - } else if (bc instanceof LessThan || bc instanceof LessThanOrEqual) { - if (comp > 0) { // a = 2 OR a < 1 -> nop - continue; - } - if (comp == 0 && bc instanceof LessThan) { // a = 2 OR a < 2 -> a <= 2 - inequalities.set(i, new LessThanOrEqual(bc.source(), bc.left(), bc.right(), bc.zoneId())); - } // else (comp < 0 || bc instanceof LessThanOrEqual) : a = 2 OR a < 3 -> a < 3; a = 2 OR a <= 2 -> a <= 2 - removeEquals = true; // update range with equality instead or simply superfluous - break; - } - } - } - } - if (removeEquals) { - iterEq.remove(); - updated = true; - } - } - - return updated ? Predicates.combineOr(CollectionUtils.combine(exps, equals, notEquals, inequalities, ranges)) : or; - } - } - - public static final class CombineBinaryComparisons extends OptimizerExpressionRule { - - public CombineBinaryComparisons() { - super(TransformDirection.DOWN); - } - - @Override - public Expression rule(BinaryLogic e) { - if (e instanceof And) { - return combine((And) e); - } else if (e instanceof Or) { - return combine((Or) e); - } - return e; - } - - // combine conjunction - private static Expression combine(And and) { - List ranges = new ArrayList<>(); - List bcs = new ArrayList<>(); - List exps = new ArrayList<>(); - - boolean changed = false; - - List andExps = Predicates.splitAnd(and); - // Ranges need to show up before BinaryComparisons in list, to allow the latter be optimized away into a Range, if possible. - // NotEquals need to be last in list, to have a complete set of Ranges (ranges) and BinaryComparisons (bcs) and allow these to - // optimize the NotEquals away. - andExps.sort((o1, o2) -> { - if (o1 instanceof Range && o2 instanceof Range) { - return 0; // keep ranges' order - } else if (o1 instanceof Range || o2 instanceof Range) { - return o2 instanceof Range ? 1 : -1; // push Ranges down - } else if (o1 instanceof NotEquals && o2 instanceof NotEquals) { - return 0; // keep NotEquals' order - } else if (o1 instanceof NotEquals || o2 instanceof NotEquals) { - return o1 instanceof NotEquals ? 1 : -1; // push NotEquals up - } else { - return 0; // keep non-Ranges' and non-NotEquals' order - } - }); - for (Expression ex : andExps) { - if (ex instanceof Range r) { - if (findExistingRange(r, ranges, true)) { - changed = true; - } else { - ranges.add(r); - } - } else if (ex instanceof BinaryComparison bc && (ex instanceof Equals || ex instanceof NotEquals) == false) { - - if (bc.right().foldable() && (findConjunctiveComparisonInRange(bc, ranges) || findExistingComparison(bc, bcs, true))) { - changed = true; - } else { - bcs.add(bc); - } - } else if (ex instanceof NotEquals neq) { - if (neq.right().foldable() && notEqualsIsRemovableFromConjunction(neq, ranges, bcs)) { - // the non-equality can simply be dropped: either superfluous or has been merged with an updated range/inequality - changed = true; - } else { // not foldable OR not overlapping - exps.add(ex); - } - } else { - exps.add(ex); - } - } - - // finally try combining any left BinaryComparisons into possible Ranges - // this could be a different rule but it's clearer here wrt the order of comparisons - - for (int i = 0, step = 1; i < bcs.size() - 1; i += step, step = 1) { - BinaryComparison main = bcs.get(i); - - for (int j = i + 1; j < bcs.size(); j++) { - BinaryComparison other = bcs.get(j); - - if (main.left().semanticEquals(other.left())) { - // >/>= AND />= - else if ((other instanceof GreaterThan || other instanceof GreaterThanOrEqual) - && (main instanceof LessThan || main instanceof LessThanOrEqual)) { - bcs.remove(j); - bcs.remove(i); - - ranges.add( - new Range( - and.source(), - main.left(), - other.right(), - other instanceof GreaterThanOrEqual, - main.right(), - main instanceof LessThanOrEqual, - main.zoneId() - ) - ); - - changed = true; - step = 0; - break; - } - } - } - } - - return changed ? Predicates.combineAnd(CollectionUtils.combine(exps, bcs, ranges)) : and; - } - - // combine disjunction - private static Expression combine(Or or) { - List bcs = new ArrayList<>(); - List ranges = new ArrayList<>(); - List exps = new ArrayList<>(); - - boolean changed = false; - - for (Expression ex : Predicates.splitOr(or)) { - if (ex instanceof Range r) { - if (findExistingRange(r, ranges, false)) { - changed = true; - } else { - ranges.add(r); - } - } else if (ex instanceof BinaryComparison bc) { - if (bc.right().foldable() && findExistingComparison(bc, bcs, false)) { - changed = true; - } else { - bcs.add(bc); - } - } else { - exps.add(ex); - } - } - - return changed ? Predicates.combineOr(CollectionUtils.combine(exps, bcs, ranges)) : or; - } - - private static boolean findExistingRange(Range main, List ranges, boolean conjunctive) { - if (main.lower().foldable() == false && main.upper().foldable() == false) { - return false; - } - // NB: the loop modifies the list (hence why the int is used) - for (int i = 0; i < ranges.size(); i++) { - Range other = ranges.get(i); - - if (main.value().semanticEquals(other.value())) { - - // make sure the comparison was done - boolean compared = false; - - boolean lower = false; - boolean upper = false; - // boundary equality (useful to differentiate whether a range is included or not) - // and thus whether it should be preserved or ignored - boolean lowerEq = false; - boolean upperEq = false; - - // evaluate lower - if (main.lower().foldable() && other.lower().foldable()) { - compared = true; - - Integer comp = BinaryComparison.compare(main.lower().fold(), other.lower().fold()); - // values are comparable - if (comp != null) { - // boundary equality - lowerEq = comp == 0 && main.includeLower() == other.includeLower(); - // AND - if (conjunctive) { - // (2 < a < 3) AND (1 < a < 3) -> (2 < a < 3) - lower = comp > 0 || - // (2 < a < 3) AND (2 <= a < 3) -> (2 < a < 3) - (comp == 0 && main.includeLower() == false && other.includeLower()); - } - // OR - else { - // (1 < a < 3) OR (2 < a < 3) -> (1 < a < 3) - lower = comp < 0 || - // (2 <= a < 3) OR (2 < a < 3) -> (2 <= a < 3) - (comp == 0 && main.includeLower() && other.includeLower() == false) || lowerEq; - } - } - } - // evaluate upper - if (main.upper().foldable() && other.upper().foldable()) { - compared = true; - - Integer comp = BinaryComparison.compare(main.upper().fold(), other.upper().fold()); - // values are comparable - if (comp != null) { - // boundary equality - upperEq = comp == 0 && main.includeUpper() == other.includeUpper(); - - // AND - if (conjunctive) { - // (1 < a < 2) AND (1 < a < 3) -> (1 < a < 2) - upper = comp < 0 || - // (1 < a < 2) AND (1 < a <= 2) -> (1 < a < 2) - (comp == 0 && main.includeUpper() == false && other.includeUpper()); - } - // OR - else { - // (1 < a < 3) OR (1 < a < 2) -> (1 < a < 3) - upper = comp > 0 || - // (1 < a <= 3) OR (1 < a < 3) -> (2 < a < 3) - (comp == 0 && main.includeUpper() && other.includeUpper() == false) || upperEq; - } - } - } - - // AND - at least one of lower or upper - if (conjunctive) { - // can tighten range - if (lower || upper) { - ranges.set( - i, - new Range( - main.source(), - main.value(), - lower ? main.lower() : other.lower(), - lower ? main.includeLower() : other.includeLower(), - upper ? main.upper() : other.upper(), - upper ? main.includeUpper() : other.includeUpper(), - main.zoneId() - ) - ); - } - - // range was comparable - return compared; - } - // OR - needs both upper and lower to loosen range - else { - // can loosen range - if (lower && upper) { - ranges.set( - i, - new Range( - main.source(), - main.value(), - main.lower(), - main.includeLower(), - main.upper(), - main.includeUpper(), - main.zoneId() - ) - ); - return true; - } - - // if the range in included, no need to add it - return compared && (((lower && lowerEq == false) || (upper && upperEq == false)) == false); - } - } - } - return false; - } - - private static boolean findConjunctiveComparisonInRange(BinaryComparison main, List ranges) { - Object value = main.right().fold(); - - // NB: the loop modifies the list (hence why the int is used) - for (int i = 0; i < ranges.size(); i++) { - Range other = ranges.get(i); - - if (main.left().semanticEquals(other.value())) { - - if (main instanceof GreaterThan || main instanceof GreaterThanOrEqual) { - if (other.lower().foldable()) { - Integer comp = BinaryComparison.compare(value, other.lower().fold()); - if (comp != null) { - // 2 < a AND (2 <= a < 3) -> 2 < a < 3 - boolean lowerEq = comp == 0 && other.includeLower() && main instanceof GreaterThan; - // 2 < a AND (1 < a < 3) -> 2 < a < 3 - boolean lower = comp > 0 || lowerEq; - - if (lower) { - ranges.set( - i, - new Range( - other.source(), - other.value(), - main.right(), - lowerEq ? false : main instanceof GreaterThanOrEqual, - other.upper(), - other.includeUpper(), - other.zoneId() - ) - ); - } - - // found a match - return true; - } - } - } else if (main instanceof LessThan || main instanceof LessThanOrEqual) { - if (other.upper().foldable()) { - Integer comp = BinaryComparison.compare(value, other.upper().fold()); - if (comp != null) { - // a < 2 AND (1 < a <= 2) -> 1 < a < 2 - boolean upperEq = comp == 0 && other.includeUpper() && main instanceof LessThan; - // a < 2 AND (1 < a < 3) -> 1 < a < 2 - boolean upper = comp < 0 || upperEq; - - if (upper) { - ranges.set( - i, - new Range( - other.source(), - other.value(), - other.lower(), - other.includeLower(), - main.right(), - upperEq ? false : main instanceof LessThanOrEqual, - other.zoneId() - ) - ); - } - - // found a match - return true; - } - } - } - - return false; - } - } - return false; - } - - /** - * Find commonalities between the given comparison in the given list. - * The method can be applied both for conjunctive (AND) or disjunctive purposes (OR). - */ - private static boolean findExistingComparison(BinaryComparison main, List bcs, boolean conjunctive) { - Object value = main.right().fold(); - - // NB: the loop modifies the list (hence why the int is used) - for (int i = 0; i < bcs.size(); i++) { - BinaryComparison other = bcs.get(i); - // skip if cannot evaluate - if (other.right().foldable() == false) { - continue; - } - // if bc is a higher/lower value or gte vs gt, use it instead - if ((other instanceof GreaterThan || other instanceof GreaterThanOrEqual) - && (main instanceof GreaterThan || main instanceof GreaterThanOrEqual)) { - - if (main.left().semanticEquals(other.left())) { - Integer compare = BinaryComparison.compare(value, other.right().fold()); - - if (compare != null) { - // AND - if ((conjunctive && - // a > 3 AND a > 2 -> a > 3 - (compare > 0 || - // a > 2 AND a >= 2 -> a > 2 - (compare == 0 && main instanceof GreaterThan && other instanceof GreaterThanOrEqual))) || - // OR - (conjunctive == false && - // a > 2 OR a > 3 -> a > 2 - (compare < 0 || - // a >= 2 OR a > 2 -> a >= 2 - (compare == 0 && main instanceof GreaterThanOrEqual && other instanceof GreaterThan)))) { - bcs.remove(i); - bcs.add(i, main); - } - // found a match - return true; - } - - return false; - } - } - // if bc is a lower/higher value or lte vs lt, use it instead - else if ((other instanceof LessThan || other instanceof LessThanOrEqual) - && (main instanceof LessThan || main instanceof LessThanOrEqual)) { - - if (main.left().semanticEquals(other.left())) { - Integer compare = BinaryComparison.compare(value, other.right().fold()); - - if (compare != null) { - // AND - if ((conjunctive && - // a < 2 AND a < 3 -> a < 2 - (compare < 0 || - // a < 2 AND a <= 2 -> a < 2 - (compare == 0 && main instanceof LessThan && other instanceof LessThanOrEqual))) || - // OR - (conjunctive == false && - // a < 2 OR a < 3 -> a < 3 - (compare > 0 || - // a <= 2 OR a < 2 -> a <= 2 - (compare == 0 && main instanceof LessThanOrEqual && other instanceof LessThan)))) { - bcs.remove(i); - bcs.add(i, main); - - } - // found a match - return true; - } - - return false; - } - } - } - - return false; - } - - private static boolean notEqualsIsRemovableFromConjunction(NotEquals notEquals, List ranges, List bcs) { - Object neqVal = notEquals.right().fold(); - Integer comp; - - // check on "condition-overlapping" ranges: - // a != 2 AND 3 < a < 5 -> 3 < a < 5; a != 2 AND 0 < a < 1 -> 0 < a < 1 (discard NotEquals) - // a != 2 AND 2 <= a < 3 -> 2 < a < 3; a != 3 AND 2 < a <= 3 -> 2 < a < 3 (discard NotEquals, plus update Range) - // a != 2 AND 1 < a < 3 -> nop (do nothing) - for (int i = 0; i < ranges.size(); i++) { - Range range = ranges.get(i); - - if (notEquals.left().semanticEquals(range.value())) { - comp = range.lower().foldable() ? BinaryComparison.compare(neqVal, range.lower().fold()) : null; - if (comp != null) { - if (comp <= 0) { - if (comp == 0 && range.includeLower()) { // a != 2 AND 2 <= a < ? -> 2 < a < ? - ranges.set( - i, - new Range( - range.source(), - range.value(), - range.lower(), - false, - range.upper(), - range.includeUpper(), - range.zoneId() - ) - ); - } - // else: !.includeLower() : a != 2 AND 2 < a < 3 -> 2 < a < 3; or: - // else: comp < 0 : a != 2 AND 3 < a < ? -> 3 < a < ? - - return true; - } else { // comp > 0 : a != 4 AND 2 < a < ? : can only remove NotEquals if outside the range - comp = range.upper().foldable() ? BinaryComparison.compare(neqVal, range.upper().fold()) : null; - if (comp != null && comp >= 0) { - if (comp == 0 && range.includeUpper()) { // a != 4 AND 2 < a <= 4 -> 2 < a < 4 - ranges.set( - i, - new Range( - range.source(), - range.value(), - range.lower(), - range.includeLower(), - range.upper(), - false, - range.zoneId() - ) - ); - } - // else: !.includeUpper() : a != 4 AND 2 < a < 4 -> 2 < a < 4 - // else: comp > 0 : a != 4 AND 2 < a < 3 -> 2 < a < 3 - - return true; - } - // else: comp < 0 : a != 4 AND 2 < a < 5 -> nop; or: - // else: comp == null : upper bound not comparable -> nop - } - } // else: comp == null : lower bound not comparable: evaluate upper bound, in case non-equality value is ">=" - - comp = range.upper().foldable() ? BinaryComparison.compare(neqVal, range.upper().fold()) : null; - if (comp != null && comp >= 0) { - if (comp == 0 && range.includeUpper()) { // a != 3 AND ?? < a <= 3 -> ?? < a < 3 - ranges.set( - i, - new Range( - range.source(), - range.value(), - range.lower(), - range.includeLower(), - range.upper(), - false, - range.zoneId() - ) - ); - } - // else: !.includeUpper() : a != 3 AND ?? < a < 3 -> ?? < a < 3 - // else: comp > 0 : a != 3 and ?? < a < 2 -> ?? < a < 2 - - return true; - } - // else: comp < 0 : a != 3 AND ?? < a < 4 -> nop, as a decision can't be drawn; or: - // else: comp == null : a != 3 AND ?? < a < ?? -> nop - } - } - - // check on "condition-overlapping" inequalities: - // a != 2 AND a > 3 -> a > 3 (discard NotEquals) - // a != 2 AND a >= 2 -> a > 2 (discard NotEquals plus update inequality) - // a != 2 AND a > 1 -> nop (do nothing) - // - // a != 2 AND a < 3 -> nop - // a != 2 AND a <= 2 -> a < 2 - // a != 2 AND a < 1 -> a < 1 - for (int i = 0; i < bcs.size(); i++) { - BinaryComparison bc = bcs.get(i); - - if (notEquals.left().semanticEquals(bc.left())) { - if (bc instanceof LessThan || bc instanceof LessThanOrEqual) { - comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold()) : null; - if (comp != null) { - if (comp >= 0) { - if (comp == 0 && bc instanceof LessThanOrEqual) { // a != 2 AND a <= 2 -> a < 2 - bcs.set(i, new LessThan(bc.source(), bc.left(), bc.right(), bc.zoneId())); - } // else : comp > 0 (a != 2 AND a a a < 2) - return true; - } // else: comp < 0 : a != 2 AND a nop - } // else: non-comparable, nop - } else if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) { - comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold()) : null; - if (comp != null) { - if (comp <= 0) { - if (comp == 0 && bc instanceof GreaterThanOrEqual) { // a != 2 AND a >= 2 -> a > 2 - bcs.set(i, new GreaterThan(bc.source(), bc.left(), bc.right(), bc.zoneId())); - } // else: comp < 0 (a != 2 AND a >/>= 3 -> a >/>= 3), or == 0 && bc i.of ">" (a != 2 AND a > 2 -> a > 2) - return true; - } // else: comp > 0 : a != 2 AND a >/>= 1 -> nop - } // else: non-comparable, nop - } // else: other non-relevant type - } - } - - return false; - } - - } - /** * Combine disjunctions on the same field into an In expression. * This rule looks for both simple equalities: @@ -1376,33 +445,6 @@ protected LogicalPlan rule(Limit limit) { protected abstract LogicalPlan skipPlan(Limit limit); } - public static class ReplaceRegexMatch extends OptimizerExpressionRule> { - - public ReplaceRegexMatch() { - super(TransformDirection.DOWN); - } - - @Override - protected Expression rule(RegexMatch regexMatch) { - Expression e = regexMatch; - StringPattern pattern = regexMatch.pattern(); - if (pattern.matchesAll()) { - e = new IsNotNull(e.source(), regexMatch.field()); - } else { - String match = pattern.exactMatch(); - if (match != null) { - Literal literal = new Literal(regexMatch.source(), match, DataType.KEYWORD); - e = regexToEquals(regexMatch, literal); - } - } - return e; - } - - protected Expression regexToEquals(RegexMatch regexMatch, Literal literal) { - return new Equals(regexMatch.source(), regexMatch.field(), literal); - } - } - public static class FoldNull extends OptimizerExpressionRule { public FoldNull() { @@ -1410,7 +452,7 @@ public FoldNull() { } @Override - protected Expression rule(Expression e) { + public Expression rule(Expression e) { Expression result = tryReplaceIsNullIsNotNull(e); if (result != e) { return result; @@ -1450,7 +492,7 @@ public PropagateNullable() { } @Override - protected Expression rule(And and) { + public Expression rule(And and) { List splits = Predicates.splitAnd(and); Set nullExpressions = new LinkedHashSet<>(); @@ -1529,93 +571,6 @@ protected Expression nonNullify(Expression exp, Expression nonNullExp) { } } - /** - * Simplify IsNotNull targets by resolving the underlying expression to its root fields with unknown - * nullability. - * e.g. - * (x + 1) / 2 IS NOT NULL --> x IS NOT NULL AND (x+1) / 2 IS NOT NULL - * SUBSTRING(x, 3) > 4 IS NOT NULL --> x IS NOT NULL AND SUBSTRING(x, 3) > 4 IS NOT NULL - * When dealing with multiple fields, a conjunction/disjunction based on the predicate: - * (x + y) / 4 IS NOT NULL --> x IS NOT NULL AND y IS NOT NULL AND (x + y) / 4 IS NOT NULL - * This handles the case of fields nested inside functions or expressions in order to avoid: - * - having to evaluate the whole expression - * - not pushing down the filter due to expression evaluation - * IS NULL cannot be simplified since it leads to a disjunction which prevents the filter to be - * pushed down: - * (x + 1) IS NULL --> x IS NULL OR x + 1 IS NULL - * and x IS NULL cannot be pushed down - *
- * Implementation-wise this rule goes bottom-up, keeping an alias up to date to the current plan - * and then looks for replacing the target. - */ - public static class InferIsNotNull extends Rule { - - @Override - public LogicalPlan apply(LogicalPlan plan) { - // the alias map is shared across the whole plan - AttributeMap aliases = new AttributeMap<>(); - // traverse bottom-up to pick up the aliases as we go - plan = plan.transformUp(p -> inspectPlan(p, aliases)); - return plan; - } - - private LogicalPlan inspectPlan(LogicalPlan plan, AttributeMap aliases) { - // inspect just this plan properties - plan.forEachExpression(Alias.class, a -> aliases.put(a.toAttribute(), a.child())); - // now go about finding isNull/isNotNull - LogicalPlan newPlan = plan.transformExpressionsOnlyUp(IsNotNull.class, inn -> inferNotNullable(inn, aliases)); - return newPlan; - } - - private Expression inferNotNullable(IsNotNull inn, AttributeMap aliases) { - Expression result = inn; - Set refs = resolveExpressionAsRootAttributes(inn.field(), aliases); - // no refs found or could not detect - return the original function - if (refs.size() > 0) { - // add IsNull for the filters along with the initial inn - var innList = CollectionUtils.combine(refs.stream().map(r -> (Expression) new IsNotNull(inn.source(), r)).toList(), inn); - result = Predicates.combineAnd(innList); - } - return result; - } - - /** - * Unroll the expression to its references to get to the root fields - * that really matter for filtering. - */ - protected Set resolveExpressionAsRootAttributes(Expression exp, AttributeMap aliases) { - Set resolvedExpressions = new LinkedHashSet<>(); - boolean changed = doResolve(exp, aliases, resolvedExpressions); - return changed ? resolvedExpressions : emptySet(); - } - - private boolean doResolve(Expression exp, AttributeMap aliases, Set resolvedExpressions) { - boolean changed = false; - // check if the expression can be skipped or is not nullabe - if (skipExpression(exp)) { - resolvedExpressions.add(exp); - } else { - for (Expression e : exp.references()) { - Expression resolved = aliases.resolve(e, e); - // found a root attribute, bail out - if (resolved instanceof Attribute a && resolved == e) { - resolvedExpressions.add(a); - // don't mark things as change if the original expression hasn't been broken down - changed |= resolved != exp; - } else { - // go further - changed |= doResolve(resolved, aliases, resolvedExpressions); - } - } - } - return changed; - } - - protected boolean skipExpression(Expression e) { - return e.nullable() == Nullability.FALSE; - } - } - public static final class SetAsOptimized extends Rule { @Override diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRulesTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRulesTests.java index 84586ed851824..12dbb23a86c59 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRulesTests.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRulesTests.java @@ -8,94 +8,24 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.TestUtils; -import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.Nullability; -import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator; -import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; import org.elasticsearch.xpack.esql.core.expression.predicate.Range; -import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; -import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; -import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; -import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; -import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.Add; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.Div; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.Mod; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.Mul; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.Sub; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.Equals; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.GreaterThan; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.GreaterThanOrEqual; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.In; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.LessThan; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.LessThanOrEqual; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.NotEquals; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.NullEquals; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.Like; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.LikePattern; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardLike; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BinaryComparisonSimplification; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanFunctionEqualsElimination; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.CombineBinaryComparisons; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.ConstantFolding; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.FoldNull; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.LiteralsOnTheRight; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.PropagateEquals; -import org.elasticsearch.xpack.esql.core.plan.logical.EsRelation; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.core.util.StringUtils; -import java.time.ZoneId; import java.util.Collections; import java.util.List; -import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; -import static org.elasticsearch.xpack.esql.core.TestUtils.equalsOf; -import static org.elasticsearch.xpack.esql.core.TestUtils.fieldAttribute; -import static org.elasticsearch.xpack.esql.core.TestUtils.greaterThanOf; -import static org.elasticsearch.xpack.esql.core.TestUtils.greaterThanOrEqualOf; -import static org.elasticsearch.xpack.esql.core.TestUtils.lessThanOf; -import static org.elasticsearch.xpack.esql.core.TestUtils.lessThanOrEqualOf; -import static org.elasticsearch.xpack.esql.core.TestUtils.notEqualsOf; -import static org.elasticsearch.xpack.esql.core.TestUtils.nullEqualsOf; import static org.elasticsearch.xpack.esql.core.TestUtils.of; import static org.elasticsearch.xpack.esql.core.TestUtils.rangeOf; -import static org.elasticsearch.xpack.esql.core.TestUtils.relation; -import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; -import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL; -import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; -import static org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.CombineDisjunctionsToIn; -import static org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.PropagateNullable; -import static org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.ReplaceRegexMatch; -import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; -import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; -import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; -import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; -import static org.hamcrest.Matchers.contains; public class OptimizerRulesTests extends ESTestCase { - private static final Expression DUMMY_EXPRESSION = new DummyBooleanExpression(EMPTY, 0); - - private static final Literal ONE = L(1); - private static final Literal TWO = L(2); - private static final Literal THREE = L(3); - private static final Literal FOUR = L(4); private static final Literal FIVE = L(5); private static final Literal SIX = L(6); @@ -152,198 +82,6 @@ private static FieldAttribute getFieldAttribute() { return TestUtils.getFieldAttribute("a"); } - // - // Constant folding - // - - public void testConstantFolding() { - Expression exp = new Add(EMPTY, TWO, THREE); - - assertTrue(exp.foldable()); - Expression result = new ConstantFolding().rule(exp); - assertTrue(result instanceof Literal); - assertEquals(5, ((Literal) result).value()); - - // check now with an alias - result = new ConstantFolding().rule(new Alias(EMPTY, "a", exp)); - assertEquals("a", Expressions.name(result)); - assertEquals(Alias.class, result.getClass()); - } - - public void testConstantFoldingBinaryComparison() { - assertEquals(FALSE, new ConstantFolding().rule(greaterThanOf(TWO, THREE)).canonical()); - assertEquals(FALSE, new ConstantFolding().rule(greaterThanOrEqualOf(TWO, THREE)).canonical()); - assertEquals(FALSE, new ConstantFolding().rule(equalsOf(TWO, THREE)).canonical()); - assertEquals(FALSE, new ConstantFolding().rule(nullEqualsOf(TWO, THREE)).canonical()); - assertEquals(FALSE, new ConstantFolding().rule(nullEqualsOf(TWO, NULL)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(notEqualsOf(TWO, THREE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(lessThanOrEqualOf(TWO, THREE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(lessThanOf(TWO, THREE)).canonical()); - } - - public void testConstantFoldingBinaryLogic() { - assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, greaterThanOf(TWO, THREE), TRUE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, greaterThanOrEqualOf(TWO, THREE), TRUE)).canonical()); - } - - public void testConstantFoldingBinaryLogic_WithNullHandling() { - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, TRUE)).canonical().nullable()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, TRUE, NULL)).canonical().nullable()); - assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, NULL, FALSE)).canonical()); - assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, FALSE, NULL)).canonical()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, NULL)).canonical().nullable()); - - assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, TRUE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, TRUE, NULL)).canonical()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, FALSE)).canonical().nullable()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, FALSE, NULL)).canonical().nullable()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, NULL)).canonical().nullable()); - } - - public void testConstantFoldingRange() { - assertEquals(true, new ConstantFolding().rule(rangeOf(FIVE, FIVE, true, L(10), false)).fold()); - assertEquals(false, new ConstantFolding().rule(rangeOf(FIVE, FIVE, false, L(10), false)).fold()); - } - - public void testConstantNot() { - assertEquals(FALSE, new ConstantFolding().rule(new Not(EMPTY, TRUE))); - assertEquals(TRUE, new ConstantFolding().rule(new Not(EMPTY, FALSE))); - } - - public void testConstantFoldingLikes() { - assertEquals(TRUE, new ConstantFolding().rule(new Like(EMPTY, of("test_emp"), new LikePattern("test%", (char) 0))).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(new WildcardLike(EMPTY, of("test_emp"), new WildcardPattern("test*"))).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(new RLike(EMPTY, of("test_emp"), new RLikePattern("test.emp"))).canonical()); - } - - public void testArithmeticFolding() { - assertEquals(10, foldOperator(new Add(EMPTY, L(7), THREE))); - assertEquals(4, foldOperator(new Sub(EMPTY, L(7), THREE))); - assertEquals(21, foldOperator(new Mul(EMPTY, L(7), THREE))); - assertEquals(2, foldOperator(new Div(EMPTY, L(7), THREE))); - assertEquals(1, foldOperator(new Mod(EMPTY, L(7), THREE))); - } - - private static Object foldOperator(BinaryOperator b) { - return ((Literal) new ConstantFolding().rule(b)).value(); - } - - // - // Logical simplifications - // - - public void testLiteralsOnTheRight() { - Alias a = new Alias(EMPTY, "a", L(10)); - Expression result = new LiteralsOnTheRight().rule(equalsOf(FIVE, a)); - assertTrue(result instanceof Equals); - Equals eq = (Equals) result; - assertEquals(a, eq.left()); - assertEquals(FIVE, eq.right()); - - a = new Alias(EMPTY, "a", L(10)); - result = new LiteralsOnTheRight().rule(nullEqualsOf(FIVE, a)); - assertTrue(result instanceof NullEquals); - NullEquals nullEquals = (NullEquals) result; - assertEquals(a, nullEquals.left()); - assertEquals(FIVE, nullEquals.right()); - } - - public void testBoolSimplifyOr() { - BooleanSimplification simplification = new BooleanSimplification(); - - assertEquals(TRUE, simplification.rule(new Or(EMPTY, TRUE, TRUE))); - assertEquals(TRUE, simplification.rule(new Or(EMPTY, TRUE, DUMMY_EXPRESSION))); - assertEquals(TRUE, simplification.rule(new Or(EMPTY, DUMMY_EXPRESSION, TRUE))); - - assertEquals(FALSE, simplification.rule(new Or(EMPTY, FALSE, FALSE))); - assertEquals(DUMMY_EXPRESSION, simplification.rule(new Or(EMPTY, FALSE, DUMMY_EXPRESSION))); - assertEquals(DUMMY_EXPRESSION, simplification.rule(new Or(EMPTY, DUMMY_EXPRESSION, FALSE))); - } - - public void testBoolSimplifyAnd() { - BooleanSimplification simplification = new BooleanSimplification(); - - assertEquals(TRUE, simplification.rule(new And(EMPTY, TRUE, TRUE))); - assertEquals(DUMMY_EXPRESSION, simplification.rule(new And(EMPTY, TRUE, DUMMY_EXPRESSION))); - assertEquals(DUMMY_EXPRESSION, simplification.rule(new And(EMPTY, DUMMY_EXPRESSION, TRUE))); - - assertEquals(FALSE, simplification.rule(new And(EMPTY, FALSE, FALSE))); - assertEquals(FALSE, simplification.rule(new And(EMPTY, FALSE, DUMMY_EXPRESSION))); - assertEquals(FALSE, simplification.rule(new And(EMPTY, DUMMY_EXPRESSION, FALSE))); - } - - public void testBoolCommonFactorExtraction() { - BooleanSimplification simplification = new BooleanSimplification(); - - Expression a1 = new DummyBooleanExpression(EMPTY, 1); - Expression a2 = new DummyBooleanExpression(EMPTY, 1); - Expression b = new DummyBooleanExpression(EMPTY, 2); - Expression c = new DummyBooleanExpression(EMPTY, 3); - - Or actual = new Or(EMPTY, new And(EMPTY, a1, b), new And(EMPTY, a2, c)); - And expected = new And(EMPTY, a1, new Or(EMPTY, b, c)); - - assertEquals(expected, simplification.rule(actual)); - } - - public void testBinaryComparisonSimplification() { - assertEquals(TRUE, new BinaryComparisonSimplification().rule(equalsOf(FIVE, FIVE))); - assertEquals(TRUE, new BinaryComparisonSimplification().rule(nullEqualsOf(FIVE, FIVE))); - assertEquals(TRUE, new BinaryComparisonSimplification().rule(nullEqualsOf(NULL, NULL))); - assertEquals(FALSE, new BinaryComparisonSimplification().rule(notEqualsOf(FIVE, FIVE))); - assertEquals(TRUE, new BinaryComparisonSimplification().rule(greaterThanOrEqualOf(FIVE, FIVE))); - assertEquals(TRUE, new BinaryComparisonSimplification().rule(lessThanOrEqualOf(FIVE, FIVE))); - - assertEquals(FALSE, new BinaryComparisonSimplification().rule(greaterThanOf(FIVE, FIVE))); - assertEquals(FALSE, new BinaryComparisonSimplification().rule(lessThanOf(FIVE, FIVE))); - } - - public void testNullEqualsWithNullLiteralBecomesIsNull() { - LiteralsOnTheRight swapLiteralsToRight = new LiteralsOnTheRight(); - BinaryComparisonSimplification bcSimpl = new BinaryComparisonSimplification(); - FieldAttribute fa = getFieldAttribute(); - Source source = new Source(1, 10, "IS_NULL(a)"); - - Expression e = bcSimpl.rule((BinaryComparison) swapLiteralsToRight.rule(new NullEquals(source, fa, NULL, randomZone()))); - assertEquals(IsNull.class, e.getClass()); - IsNull isNull = (IsNull) e; - assertEquals(source, isNull.source()); - - e = bcSimpl.rule((BinaryComparison) swapLiteralsToRight.rule(new NullEquals(source, NULL, fa, randomZone()))); - assertEquals(IsNull.class, e.getClass()); - isNull = (IsNull) e; - assertEquals(source, isNull.source()); - } - - public void testBoolEqualsSimplificationOnExpressions() { - BooleanFunctionEqualsElimination s = new BooleanFunctionEqualsElimination(); - Expression exp = new GreaterThan(EMPTY, getFieldAttribute(), L(0), null); - - assertEquals(exp, s.rule(new Equals(EMPTY, exp, TRUE))); - assertEquals(new Not(EMPTY, exp), s.rule(new Equals(EMPTY, exp, FALSE))); - } - - public void testBoolEqualsSimplificationOnFields() { - BooleanFunctionEqualsElimination s = new BooleanFunctionEqualsElimination(); - - FieldAttribute field = getFieldAttribute(); - - List comparisons = asList( - new Equals(EMPTY, field, TRUE), - new Equals(EMPTY, field, FALSE), - notEqualsOf(field, TRUE), - notEqualsOf(field, FALSE), - new Equals(EMPTY, NULL, TRUE), - new Equals(EMPTY, NULL, FALSE), - notEqualsOf(NULL, TRUE), - notEqualsOf(NULL, FALSE) - ); - - for (BinaryComparison comparison : comparisons) { - assertEquals(comparison, s.rule(comparison)); - } - } - // // Range optimization // @@ -368,1399 +106,4 @@ public void testFoldExcludingRangeWithDifferentTypesToFalse() { // Conjunction - public void testCombineBinaryComparisonsNotComparable() { - FieldAttribute fa = getFieldAttribute(); - LessThanOrEqual lte = lessThanOrEqualOf(fa, SIX); - LessThan lt = lessThanOf(fa, FALSE); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - And and = new And(EMPTY, lte, lt); - Expression exp = rule.rule(and); - assertEquals(exp, and); - } - - // a <= 6 AND a < 5 -> a < 5 - public void testCombineBinaryComparisonsUpper() { - FieldAttribute fa = getFieldAttribute(); - LessThanOrEqual lte = lessThanOrEqualOf(fa, SIX); - LessThan lt = lessThanOf(fa, FIVE); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - - Expression exp = rule.rule(new And(EMPTY, lte, lt)); - assertEquals(LessThan.class, exp.getClass()); - LessThan r = (LessThan) exp; - assertEquals(FIVE, r.right()); - } - - // 6 <= a AND 5 < a -> 6 <= a - public void testCombineBinaryComparisonsLower() { - FieldAttribute fa = getFieldAttribute(); - GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, SIX); - GreaterThan gt = greaterThanOf(fa, FIVE); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - - Expression exp = rule.rule(new And(EMPTY, gte, gt)); - assertEquals(GreaterThanOrEqual.class, exp.getClass()); - GreaterThanOrEqual r = (GreaterThanOrEqual) exp; - assertEquals(SIX, r.right()); - } - - // 5 <= a AND 5 < a -> 5 < a - public void testCombineBinaryComparisonsInclude() { - FieldAttribute fa = getFieldAttribute(); - GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, FIVE); - GreaterThan gt = greaterThanOf(fa, FIVE); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - - Expression exp = rule.rule(new And(EMPTY, gte, gt)); - assertEquals(GreaterThan.class, exp.getClass()); - GreaterThan r = (GreaterThan) exp; - assertEquals(FIVE, r.right()); - } - - // 2 < a AND (2 <= a < 3) -> 2 < a < 3 - public void testCombineBinaryComparisonsAndRangeLower() { - FieldAttribute fa = getFieldAttribute(); - - GreaterThan gt = greaterThanOf(fa, TWO); - Range range = rangeOf(fa, TWO, true, THREE, false); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(new And(EMPTY, gt, range)); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(TWO, r.lower()); - assertFalse(r.includeLower()); - assertEquals(THREE, r.upper()); - assertFalse(r.includeUpper()); - } - - // a < 4 AND (1 < a < 3) -> 1 < a < 3 - public void testCombineBinaryComparisonsAndRangeUpper() { - FieldAttribute fa = getFieldAttribute(); - - LessThan lt = lessThanOf(fa, FOUR); - Range range = rangeOf(fa, ONE, false, THREE, false); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(new And(EMPTY, range, lt)); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(ONE, r.lower()); - assertFalse(r.includeLower()); - assertEquals(THREE, r.upper()); - assertFalse(r.includeUpper()); - } - - // a <= 2 AND (1 < a < 3) -> 1 < a <= 2 - public void testCombineBinaryComparisonsAndRangeUpperEqual() { - FieldAttribute fa = getFieldAttribute(); - - LessThanOrEqual lte = lessThanOrEqualOf(fa, TWO); - Range range = rangeOf(fa, ONE, false, THREE, false); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(new And(EMPTY, lte, range)); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(ONE, r.lower()); - assertFalse(r.includeLower()); - assertEquals(TWO, r.upper()); - assertTrue(r.includeUpper()); - } - - // 3 <= a AND 4 < a AND a <= 7 AND a < 6 -> 4 < a < 6 - public void testCombineMultipleBinaryComparisons() { - FieldAttribute fa = getFieldAttribute(); - GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, THREE); - GreaterThan gt = greaterThanOf(fa, FOUR); - LessThanOrEqual lte = lessThanOrEqualOf(fa, L(7)); - LessThan lt = lessThanOf(fa, SIX); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - - Expression exp = rule.rule(new And(EMPTY, gte, new And(EMPTY, gt, new And(EMPTY, lt, lte)))); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(FOUR, r.lower()); - assertFalse(r.includeLower()); - assertEquals(SIX, r.upper()); - assertFalse(r.includeUpper()); - } - - // 3 <= a AND TRUE AND 4 < a AND a != 5 AND a <= 7 -> 4 < a <= 7 AND a != 5 AND TRUE - public void testCombineMixedMultipleBinaryComparisons() { - FieldAttribute fa = getFieldAttribute(); - GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, THREE); - GreaterThan gt = greaterThanOf(fa, FOUR); - LessThanOrEqual lte = lessThanOrEqualOf(fa, L(7)); - Expression ne = new Not(EMPTY, equalsOf(fa, FIVE)); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - - // TRUE AND a != 5 AND 4 < a <= 7 - Expression exp = rule.rule(new And(EMPTY, gte, new And(EMPTY, TRUE, new And(EMPTY, gt, new And(EMPTY, ne, lte))))); - assertEquals(And.class, exp.getClass()); - And and = ((And) exp); - assertEquals(Range.class, and.right().getClass()); - Range r = (Range) and.right(); - assertEquals(FOUR, r.lower()); - assertFalse(r.includeLower()); - assertEquals(L(7), r.upper()); - assertTrue(r.includeUpper()); - } - - // 1 <= a AND a < 5 -> 1 <= a < 5 - public void testCombineComparisonsIntoRange() { - FieldAttribute fa = getFieldAttribute(); - GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, ONE); - LessThan lt = lessThanOf(fa, FIVE); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(new And(EMPTY, gte, lt)); - assertEquals(Range.class, exp.getClass()); - - Range r = (Range) exp; - assertEquals(ONE, r.lower()); - assertTrue(r.includeLower()); - assertEquals(FIVE, r.upper()); - assertFalse(r.includeUpper()); - } - - // 1 < a AND a < 3 AND 2 < b AND b < 4 AND c < 4 -> (1 < a < 3) AND (2 < b < 4) AND c < 4 - public void testCombineMultipleComparisonsIntoRange() { - FieldAttribute fa = TestUtils.getFieldAttribute("a"); - FieldAttribute fb = TestUtils.getFieldAttribute("b"); - FieldAttribute fc = TestUtils.getFieldAttribute("c"); - - ZoneId zoneId = randomZone(); - GreaterThan agt1 = new GreaterThan(EMPTY, fa, ONE, zoneId); - LessThan alt3 = new LessThan(EMPTY, fa, THREE, zoneId); - GreaterThan bgt2 = new GreaterThan(EMPTY, fb, TWO, zoneId); - LessThan blt4 = new LessThan(EMPTY, fb, FOUR, zoneId); - LessThan clt4 = new LessThan(EMPTY, fc, FOUR, zoneId); - - Expression inputAnd = Predicates.combineAnd(asList(agt1, alt3, bgt2, blt4, clt4)); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression outputAnd = rule.rule((And) inputAnd); - - Range agt1lt3 = new Range(EMPTY, fa, ONE, false, THREE, false, zoneId); - Range bgt2lt4 = new Range(EMPTY, fb, TWO, false, FOUR, false, zoneId); - - // The actual outcome is (c < 4) AND (1 < a < 3) AND (2 < b < 4), due to the way the Expression types are combined in the Optimizer - Expression expectedAnd = Predicates.combineAnd(asList(clt4, agt1lt3, bgt2lt4)); - - assertTrue(outputAnd.semanticEquals(expectedAnd)); - } - - // (2 < a < 3) AND (1 < a < 4) -> (2 < a < 3) - public void testCombineBinaryComparisonsConjunctionOfIncludedRange() { - FieldAttribute fa = getFieldAttribute(); - - Range r1 = rangeOf(fa, TWO, false, THREE, false); - Range r2 = rangeOf(fa, ONE, false, FOUR, false); - - And and = new And(EMPTY, r1, r2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(r1, exp); - } - - // (2 < a < 3) AND a < 2 -> 2 < a < 2 - public void testCombineBinaryComparisonsConjunctionOfNonOverlappingBoundaries() { - FieldAttribute fa = getFieldAttribute(); - - Range r1 = rangeOf(fa, TWO, false, THREE, false); - Range r2 = rangeOf(fa, ONE, false, TWO, false); - - And and = new And(EMPTY, r1, r2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(TWO, r.lower()); - assertFalse(r.includeLower()); - assertEquals(TWO, r.upper()); - assertFalse(r.includeUpper()); - assertEquals(Boolean.FALSE, r.fold()); - } - - // (2 < a < 3) AND (2 < a <= 3) -> 2 < a < 3 - public void testCombineBinaryComparisonsConjunctionOfUpperEqualsOverlappingBoundaries() { - FieldAttribute fa = getFieldAttribute(); - - Range r1 = rangeOf(fa, TWO, false, THREE, false); - Range r2 = rangeOf(fa, TWO, false, THREE, true); - - And and = new And(EMPTY, r1, r2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(r1, exp); - } - - // (2 < a < 3) AND (1 < a < 3) -> 2 < a < 3 - public void testCombineBinaryComparisonsConjunctionOverlappingUpperBoundary() { - FieldAttribute fa = getFieldAttribute(); - - Range r2 = rangeOf(fa, TWO, false, THREE, false); - Range r1 = rangeOf(fa, ONE, false, THREE, false); - - And and = new And(EMPTY, r1, r2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(r2, exp); - } - - // (2 < a <= 3) AND (1 < a < 3) -> 2 < a < 3 - public void testCombineBinaryComparisonsConjunctionWithDifferentUpperLimitInclusion() { - FieldAttribute fa = getFieldAttribute(); - - Range r1 = rangeOf(fa, ONE, false, THREE, false); - Range r2 = rangeOf(fa, TWO, false, THREE, true); - - And and = new And(EMPTY, r1, r2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(TWO, r.lower()); - assertFalse(r.includeLower()); - assertEquals(THREE, r.upper()); - assertFalse(r.includeUpper()); - } - - // (0 < a <= 1) AND (0 <= a < 2) -> 0 < a <= 1 - public void testRangesOverlappingConjunctionNoLowerBoundary() { - FieldAttribute fa = getFieldAttribute(); - - Range r1 = rangeOf(fa, L(0), false, ONE, true); - Range r2 = rangeOf(fa, L(0), true, TWO, false); - - And and = new And(EMPTY, r1, r2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(r1, exp); - } - - // a != 2 AND 3 < a < 5 -> 3 < a < 5 - public void testCombineBinaryComparisonsConjunction_Neq2AndRangeGt3Lt5() { - FieldAttribute fa = getFieldAttribute(); - - NotEquals neq = notEqualsOf(fa, TWO); - Range range = rangeOf(fa, THREE, false, FIVE, false); - And and = new And(EMPTY, range, neq); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(THREE, r.lower()); - assertFalse(r.includeLower()); - assertEquals(FIVE, r.upper()); - assertFalse(r.includeUpper()); - } - - // a != 2 AND 0 < a < 1 -> 0 < a < 1 - public void testCombineBinaryComparisonsConjunction_Neq2AndRangeGt0Lt1() { - FieldAttribute fa = getFieldAttribute(); - - NotEquals neq = notEqualsOf(fa, TWO); - Range range = rangeOf(fa, L(0), false, ONE, false); - And and = new And(EMPTY, neq, range); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(L(0), r.lower()); - assertFalse(r.includeLower()); - assertEquals(ONE, r.upper()); - assertFalse(r.includeUpper()); - } - - // a != 2 AND 2 <= a < 3 -> 2 < a < 3 - public void testCombineBinaryComparisonsConjunction_Neq2AndRangeGte2Lt3() { - FieldAttribute fa = getFieldAttribute(); - - NotEquals neq = notEqualsOf(fa, TWO); - Range range = rangeOf(fa, TWO, true, THREE, false); - And and = new And(EMPTY, neq, range); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(TWO, r.lower()); - assertFalse(r.includeLower()); - assertEquals(THREE, r.upper()); - assertFalse(r.includeUpper()); - } - - // a != 3 AND 2 < a <= 3 -> 2 < a < 3 - public void testCombineBinaryComparisonsConjunction_Neq3AndRangeGt2Lte3() { - FieldAttribute fa = getFieldAttribute(); - - NotEquals neq = notEqualsOf(fa, THREE); - Range range = rangeOf(fa, TWO, false, THREE, true); - And and = new And(EMPTY, neq, range); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(TWO, r.lower()); - assertFalse(r.includeLower()); - assertEquals(THREE, r.upper()); - assertFalse(r.includeUpper()); - } - - // a != 2 AND 1 < a < 3 - public void testCombineBinaryComparisonsConjunction_Neq2AndRangeGt1Lt3() { - FieldAttribute fa = getFieldAttribute(); - - NotEquals neq = notEqualsOf(fa, TWO); - Range range = rangeOf(fa, ONE, false, THREE, false); - And and = new And(EMPTY, neq, range); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(And.class, exp.getClass()); // can't optimize - } - - // a != 2 AND a > 3 -> a > 3 - public void testCombineBinaryComparisonsConjunction_Neq2AndGt3() { - FieldAttribute fa = getFieldAttribute(); - - NotEquals neq = notEqualsOf(fa, TWO); - GreaterThan gt = greaterThanOf(fa, THREE); - And and = new And(EMPTY, neq, gt); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(gt, exp); - } - - // a != 2 AND a >= 2 -> a > 2 - public void testCombineBinaryComparisonsConjunction_Neq2AndGte2() { - FieldAttribute fa = getFieldAttribute(); - - NotEquals neq = notEqualsOf(fa, TWO); - GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, TWO); - And and = new And(EMPTY, neq, gte); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(GreaterThan.class, exp.getClass()); - GreaterThan gt = (GreaterThan) exp; - assertEquals(TWO, gt.right()); - } - - // a != 2 AND a >= 1 -> nop - public void testCombineBinaryComparisonsConjunction_Neq2AndGte1() { - FieldAttribute fa = getFieldAttribute(); - - NotEquals neq = notEqualsOf(fa, TWO); - GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, ONE); - And and = new And(EMPTY, neq, gte); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(And.class, exp.getClass()); // can't optimize - } - - // a != 2 AND a <= 3 -> nop - public void testCombineBinaryComparisonsConjunction_Neq2AndLte3() { - FieldAttribute fa = getFieldAttribute(); - - NotEquals neq = notEqualsOf(fa, TWO); - LessThanOrEqual lte = lessThanOrEqualOf(fa, THREE); - And and = new And(EMPTY, neq, lte); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(and, exp); // can't optimize - } - - // a != 2 AND a <= 2 -> a < 2 - public void testCombineBinaryComparisonsConjunction_Neq2AndLte2() { - FieldAttribute fa = getFieldAttribute(); - - NotEquals neq = notEqualsOf(fa, TWO); - LessThanOrEqual lte = lessThanOrEqualOf(fa, TWO); - And and = new And(EMPTY, neq, lte); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(LessThan.class, exp.getClass()); - LessThan lt = (LessThan) exp; - assertEquals(TWO, lt.right()); - } - - // a != 2 AND a <= 1 -> a <= 1 - public void testCombineBinaryComparisonsConjunction_Neq2AndLte1() { - FieldAttribute fa = getFieldAttribute(); - - NotEquals neq = notEqualsOf(fa, TWO); - LessThanOrEqual lte = lessThanOrEqualOf(fa, ONE); - And and = new And(EMPTY, neq, lte); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals(lte, exp); - } - - // Disjunction - - public void testCombineBinaryComparisonsDisjunctionNotComparable() { - FieldAttribute fa = getFieldAttribute(); - - GreaterThan gt1 = greaterThanOf(fa, ONE); - GreaterThan gt2 = greaterThanOf(fa, FALSE); - - Or or = new Or(EMPTY, gt1, gt2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); - assertEquals(exp, or); - } - - // 2 < a OR 1 < a OR 3 < a -> 1 < a - public void testCombineBinaryComparisonsDisjunctionLowerBound() { - FieldAttribute fa = getFieldAttribute(); - - GreaterThan gt1 = greaterThanOf(fa, ONE); - GreaterThan gt2 = greaterThanOf(fa, TWO); - GreaterThan gt3 = greaterThanOf(fa, THREE); - - Or or = new Or(EMPTY, gt1, new Or(EMPTY, gt2, gt3)); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); - assertEquals(GreaterThan.class, exp.getClass()); - - GreaterThan gt = (GreaterThan) exp; - assertEquals(ONE, gt.right()); - } - - // 2 < a OR 1 < a OR 3 <= a -> 1 < a - public void testCombineBinaryComparisonsDisjunctionIncludeLowerBounds() { - FieldAttribute fa = getFieldAttribute(); - - GreaterThan gt1 = greaterThanOf(fa, ONE); - GreaterThan gt2 = greaterThanOf(fa, TWO); - GreaterThanOrEqual gte3 = greaterThanOrEqualOf(fa, THREE); - - Or or = new Or(EMPTY, new Or(EMPTY, gt1, gt2), gte3); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); - assertEquals(GreaterThan.class, exp.getClass()); - - GreaterThan gt = (GreaterThan) exp; - assertEquals(ONE, gt.right()); - } - - // a < 1 OR a < 2 OR a < 3 -> a < 3 - public void testCombineBinaryComparisonsDisjunctionUpperBound() { - FieldAttribute fa = getFieldAttribute(); - - LessThan lt1 = lessThanOf(fa, ONE); - LessThan lt2 = lessThanOf(fa, TWO); - LessThan lt3 = lessThanOf(fa, THREE); - - Or or = new Or(EMPTY, new Or(EMPTY, lt1, lt2), lt3); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); - assertEquals(LessThan.class, exp.getClass()); - - LessThan lt = (LessThan) exp; - assertEquals(THREE, lt.right()); - } - - // a < 2 OR a <= 2 OR a < 1 -> a <= 2 - public void testCombineBinaryComparisonsDisjunctionIncludeUpperBounds() { - FieldAttribute fa = getFieldAttribute(); - - LessThan lt1 = lessThanOf(fa, ONE); - LessThan lt2 = lessThanOf(fa, TWO); - LessThanOrEqual lte2 = lessThanOrEqualOf(fa, TWO); - - Or or = new Or(EMPTY, lt2, new Or(EMPTY, lte2, lt1)); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); - assertEquals(LessThanOrEqual.class, exp.getClass()); - - LessThanOrEqual lte = (LessThanOrEqual) exp; - assertEquals(TWO, lte.right()); - } - - // a < 2 OR 3 < a OR a < 1 OR 4 < a -> a < 2 OR 3 < a - public void testCombineBinaryComparisonsDisjunctionOfLowerAndUpperBounds() { - FieldAttribute fa = getFieldAttribute(); - - LessThan lt1 = lessThanOf(fa, ONE); - LessThan lt2 = lessThanOf(fa, TWO); - - GreaterThan gt3 = greaterThanOf(fa, THREE); - GreaterThan gt4 = greaterThanOf(fa, FOUR); - - Or or = new Or(EMPTY, new Or(EMPTY, lt2, gt3), new Or(EMPTY, lt1, gt4)); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); - assertEquals(Or.class, exp.getClass()); - - Or ro = (Or) exp; - - assertEquals(LessThan.class, ro.left().getClass()); - LessThan lt = (LessThan) ro.left(); - assertEquals(TWO, lt.right()); - assertEquals(GreaterThan.class, ro.right().getClass()); - GreaterThan gt = (GreaterThan) ro.right(); - assertEquals(THREE, gt.right()); - } - - // (2 < a < 3) OR (1 < a < 4) -> (1 < a < 4) - public void testCombineBinaryComparisonsDisjunctionOfIncludedRangeNotComparable() { - FieldAttribute fa = getFieldAttribute(); - - Range r1 = rangeOf(fa, TWO, false, THREE, false); - Range r2 = rangeOf(fa, ONE, false, FALSE, false); - - Or or = new Or(EMPTY, r1, r2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); - assertEquals(or, exp); - } - - // (2 < a < 3) OR (1 < a < 4) -> (1 < a < 4) - public void testCombineBinaryComparisonsDisjunctionOfIncludedRange() { - FieldAttribute fa = getFieldAttribute(); - - Range r1 = rangeOf(fa, TWO, false, THREE, false); - Range r2 = rangeOf(fa, ONE, false, FOUR, false); - - Or or = new Or(EMPTY, r1, r2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); - assertEquals(Range.class, exp.getClass()); - - Range r = (Range) exp; - assertEquals(ONE, r.lower()); - assertFalse(r.includeLower()); - assertEquals(FOUR, r.upper()); - assertFalse(r.includeUpper()); - } - - // (2 < a < 3) OR (1 < a < 2) -> same - public void testCombineBinaryComparisonsDisjunctionOfNonOverlappingBoundaries() { - FieldAttribute fa = getFieldAttribute(); - - Range r1 = rangeOf(fa, TWO, false, THREE, false); - Range r2 = rangeOf(fa, ONE, false, TWO, false); - - Or or = new Or(EMPTY, r1, r2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); - assertEquals(or, exp); - } - - // (2 < a < 3) OR (2 < a <= 3) -> 2 < a <= 3 - public void testCombineBinaryComparisonsDisjunctionOfUpperEqualsOverlappingBoundaries() { - FieldAttribute fa = getFieldAttribute(); - - Range r1 = rangeOf(fa, TWO, false, THREE, false); - Range r2 = rangeOf(fa, TWO, false, THREE, true); - - Or or = new Or(EMPTY, r1, r2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); - assertEquals(r2, exp); - } - - // (2 < a < 3) OR (1 < a < 3) -> 1 < a < 3 - public void testCombineBinaryComparisonsOverlappingUpperBoundary() { - FieldAttribute fa = getFieldAttribute(); - - Range r2 = rangeOf(fa, TWO, false, THREE, false); - Range r1 = rangeOf(fa, ONE, false, THREE, false); - - Or or = new Or(EMPTY, r1, r2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); - assertEquals(r1, exp); - } - - // (2 < a <= 3) OR (1 < a < 3) -> same (the <= prevents the ranges from being combined) - public void testCombineBinaryComparisonsWithDifferentUpperLimitInclusion() { - FieldAttribute fa = getFieldAttribute(); - - Range r1 = rangeOf(fa, ONE, false, THREE, false); - Range r2 = rangeOf(fa, TWO, false, THREE, true); - - Or or = new Or(EMPTY, r1, r2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); - assertEquals(or, exp); - } - - // (a = 1 AND b = 3 AND c = 4) OR (a = 2 AND b = 3 AND c = 4) -> (b = 3 AND c = 4) AND (a = 1 OR a = 2) - public void testBooleanSimplificationCommonExpressionSubstraction() { - FieldAttribute fa = TestUtils.getFieldAttribute("a"); - FieldAttribute fb = TestUtils.getFieldAttribute("b"); - FieldAttribute fc = TestUtils.getFieldAttribute("c"); - - Expression a1 = equalsOf(fa, ONE); - Expression a2 = equalsOf(fa, TWO); - And common = new And(EMPTY, equalsOf(fb, THREE), equalsOf(fc, FOUR)); - And left = new And(EMPTY, a1, common); - And right = new And(EMPTY, a2, common); - Or or = new Or(EMPTY, left, right); - - Expression exp = new BooleanSimplification().rule(or); - assertEquals(new And(EMPTY, common, new Or(EMPTY, a1, a2)), exp); - } - - // (0 < a <= 1) OR (0 < a < 2) -> 0 < a < 2 - public void testRangesOverlappingNoLowerBoundary() { - FieldAttribute fa = getFieldAttribute(); - - Range r2 = rangeOf(fa, L(0), false, TWO, false); - Range r1 = rangeOf(fa, L(0), false, ONE, true); - - Or or = new Or(EMPTY, r1, r2); - - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); - assertEquals(r2, exp); - } - - public void testBinaryComparisonAndOutOfRangeNotEqualsDifferentFields() { - FieldAttribute doubleOne = fieldAttribute("double", DOUBLE); - FieldAttribute doubleTwo = fieldAttribute("double2", DOUBLE); - FieldAttribute intOne = fieldAttribute("int", INTEGER); - FieldAttribute datetimeOne = fieldAttribute("datetime", INTEGER); - FieldAttribute keywordOne = fieldAttribute("keyword", KEYWORD); - FieldAttribute keywordTwo = fieldAttribute("keyword2", KEYWORD); - - List testCases = asList( - // double > 10 AND integer != -10 - new And(EMPTY, greaterThanOf(doubleOne, L(10)), notEqualsOf(intOne, L(-10))), - // keyword > '5' AND keyword2 != '48' - new And(EMPTY, greaterThanOf(keywordOne, L("5")), notEqualsOf(keywordTwo, L("48"))), - // keyword != '2021' AND datetime <= '2020-12-04T17:48:22.954240Z' - new And(EMPTY, notEqualsOf(keywordOne, L("2021")), lessThanOrEqualOf(datetimeOne, L("2020-12-04T17:48:22.954240Z"))), - // double > 10.1 AND double2 != -10.1 - new And(EMPTY, greaterThanOf(doubleOne, L(10.1d)), notEqualsOf(doubleTwo, L(-10.1d))) - ); - - for (And and : testCases) { - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); - assertEquals("Rule should not have transformed [" + and.nodeString() + "]", and, exp); - } - } - - // Equals & NullEquals - - // 1 <= a < 10 AND a == 1 -> a == 1 - public void testEliminateRangeByEqualsInInterval() { - FieldAttribute fa = getFieldAttribute(); - Equals eq1 = equalsOf(fa, ONE); - Range r = rangeOf(fa, ONE, true, L(10), false); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq1, r)); - assertEquals(eq1, exp); - } - - // 1 <= a < 10 AND a <=> 1 -> a <=> 1 - public void testEliminateRangeByNullEqualsInInterval() { - FieldAttribute fa = getFieldAttribute(); - NullEquals eq1 = nullEqualsOf(fa, ONE); - Range r = rangeOf(fa, ONE, true, L(10), false); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq1, r)); - assertEquals(eq1, exp); - } - - // The following tests should work only to simplify filters and - // not if the expressions are part of a projection - // See: https://github.com/elastic/elasticsearch/issues/35859 - - // a == 1 AND a == 2 -> FALSE - public void testDualEqualsConjunction() { - FieldAttribute fa = getFieldAttribute(); - Equals eq1 = equalsOf(fa, ONE); - Equals eq2 = equalsOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq1, eq2)); - assertEquals(FALSE, exp); - } - - // a <=> 1 AND a <=> 2 -> FALSE - public void testDualNullEqualsConjunction() { - FieldAttribute fa = getFieldAttribute(); - NullEquals eq1 = nullEqualsOf(fa, ONE); - NullEquals eq2 = nullEqualsOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq1, eq2)); - assertEquals(FALSE, exp); - } - - // 1 < a < 10 AND a == 10 -> FALSE - public void testEliminateRangeByEqualsOutsideInterval() { - FieldAttribute fa = getFieldAttribute(); - Equals eq1 = equalsOf(fa, L(10)); - Range r = rangeOf(fa, ONE, false, L(10), false); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq1, r)); - assertEquals(FALSE, exp); - } - - // 1 < a < 10 AND a <=> 10 -> FALSE - public void testEliminateRangeByNullEqualsOutsideInterval() { - FieldAttribute fa = getFieldAttribute(); - NullEquals eq1 = nullEqualsOf(fa, L(10)); - Range r = rangeOf(fa, ONE, false, L(10), false); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq1, r)); - assertEquals(FALSE, exp); - } - - // a != 3 AND a = 3 -> FALSE - public void testPropagateEquals_VarNeq3AndVarEq3() { - FieldAttribute fa = getFieldAttribute(); - NotEquals neq = notEqualsOf(fa, THREE); - Equals eq = equalsOf(fa, THREE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, neq, eq)); - assertEquals(FALSE, exp); - } - - // a != 4 AND a = 3 -> a = 3 - public void testPropagateEquals_VarNeq4AndVarEq3() { - FieldAttribute fa = getFieldAttribute(); - NotEquals neq = notEqualsOf(fa, FOUR); - Equals eq = equalsOf(fa, THREE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, neq, eq)); - assertEquals(Equals.class, exp.getClass()); - assertEquals(eq, exp); - } - - // a = 2 AND a < 2 -> FALSE - public void testPropagateEquals_VarEq2AndVarLt2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - LessThan lt = lessThanOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, lt)); - assertEquals(FALSE, exp); - } - - // a = 2 AND a <= 2 -> a = 2 - public void testPropagateEquals_VarEq2AndVarLte2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - LessThanOrEqual lt = lessThanOrEqualOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, lt)); - assertEquals(eq, exp); - } - - // a = 2 AND a <= 1 -> FALSE - public void testPropagateEquals_VarEq2AndVarLte1() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - LessThanOrEqual lt = lessThanOrEqualOf(fa, ONE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, lt)); - assertEquals(FALSE, exp); - } - - // a = 2 AND a > 2 -> FALSE - public void testPropagateEquals_VarEq2AndVarGt2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - GreaterThan gt = greaterThanOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, gt)); - assertEquals(FALSE, exp); - } - - // a = 2 AND a >= 2 -> a = 2 - public void testPropagateEquals_VarEq2AndVarGte2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, gte)); - assertEquals(eq, exp); - } - - // a = 2 AND a > 3 -> FALSE - public void testPropagateEquals_VarEq2AndVarLt3() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - GreaterThan gt = greaterThanOf(fa, THREE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, gt)); - assertEquals(FALSE, exp); - } - - // a = 2 AND a < 3 AND a > 1 AND a != 4 -> a = 2 - public void testPropagateEquals_VarEq2AndVarLt3AndVarGt1AndVarNeq4() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - LessThan lt = lessThanOf(fa, THREE); - GreaterThan gt = greaterThanOf(fa, ONE); - NotEquals neq = notEqualsOf(fa, FOUR); - - PropagateEquals rule = new PropagateEquals(); - Expression and = Predicates.combineAnd(asList(eq, lt, gt, neq)); - Expression exp = rule.rule((And) and); - assertEquals(eq, exp); - } - - // a = 2 AND 1 < a < 3 AND a > 0 AND a != 4 -> a = 2 - public void testPropagateEquals_VarEq2AndVarRangeGt1Lt3AndVarGt0AndVarNeq4() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - Range range = rangeOf(fa, ONE, false, THREE, false); - GreaterThan gt = greaterThanOf(fa, L(0)); - NotEquals neq = notEqualsOf(fa, FOUR); - - PropagateEquals rule = new PropagateEquals(); - Expression and = Predicates.combineAnd(asList(eq, range, gt, neq)); - Expression exp = rule.rule((And) and); - assertEquals(eq, exp); - } - - // a = 2 OR a > 1 -> a > 1 - public void testPropagateEquals_VarEq2OrVarGt1() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - GreaterThan gt = greaterThanOf(fa, ONE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, gt)); - assertEquals(gt, exp); - } - - // a = 2 OR a > 2 -> a >= 2 - public void testPropagateEquals_VarEq2OrVarGte2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - GreaterThan gt = greaterThanOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, gt)); - assertEquals(GreaterThanOrEqual.class, exp.getClass()); - GreaterThanOrEqual gte = (GreaterThanOrEqual) exp; - assertEquals(TWO, gte.right()); - } - - // a = 2 OR a < 3 -> a < 3 - public void testPropagateEquals_VarEq2OrVarLt3() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - LessThan lt = lessThanOf(fa, THREE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, lt)); - assertEquals(lt, exp); - } - - // a = 3 OR a < 3 -> a <= 3 - public void testPropagateEquals_VarEq3OrVarLt3() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, THREE); - LessThan lt = lessThanOf(fa, THREE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, lt)); - assertEquals(LessThanOrEqual.class, exp.getClass()); - LessThanOrEqual lte = (LessThanOrEqual) exp; - assertEquals(THREE, lte.right()); - } - - // a = 2 OR 1 < a < 3 -> 1 < a < 3 - public void testPropagateEquals_VarEq2OrVarRangeGt1Lt3() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - Range range = rangeOf(fa, ONE, false, THREE, false); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, range)); - assertEquals(range, exp); - } - - // a = 2 OR 2 < a < 3 -> 2 <= a < 3 - public void testPropagateEquals_VarEq2OrVarRangeGt2Lt3() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - Range range = rangeOf(fa, TWO, false, THREE, false); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, range)); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(TWO, r.lower()); - assertTrue(r.includeLower()); - assertEquals(THREE, r.upper()); - assertFalse(r.includeUpper()); - } - - // a = 3 OR 2 < a < 3 -> 2 < a <= 3 - public void testPropagateEquals_VarEq3OrVarRangeGt2Lt3() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, THREE); - Range range = rangeOf(fa, TWO, false, THREE, false); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, range)); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(TWO, r.lower()); - assertFalse(r.includeLower()); - assertEquals(THREE, r.upper()); - assertTrue(r.includeUpper()); - } - - // a = 2 OR a != 2 -> TRUE - public void testPropagateEquals_VarEq2OrVarNeq2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - NotEquals neq = notEqualsOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, neq)); - assertEquals(TRUE, exp); - } - - // a = 2 OR a != 5 -> a != 5 - public void testPropagateEquals_VarEq2OrVarNeq5() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - NotEquals neq = notEqualsOf(fa, FIVE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, neq)); - assertEquals(NotEquals.class, exp.getClass()); - NotEquals ne = (NotEquals) exp; - assertEquals(FIVE, ne.right()); - } - - // a = 2 OR 3 < a < 4 OR a > 2 OR a!= 2 -> TRUE - public void testPropagateEquals_VarEq2OrVarRangeGt3Lt4OrVarGt2OrVarNe2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - Range range = rangeOf(fa, THREE, false, FOUR, false); - GreaterThan gt = greaterThanOf(fa, TWO); - NotEquals neq = notEqualsOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule((Or) Predicates.combineOr(asList(eq, range, neq, gt))); - assertEquals(TRUE, exp); - } - - // a == 1 AND a == 2 -> nop for date/time fields - public void testPropagateEquals_ignoreDateTimeFields() { - FieldAttribute fa = TestUtils.getFieldAttribute("a", DataType.DATETIME); - Equals eq1 = equalsOf(fa, ONE); - Equals eq2 = equalsOf(fa, TWO); - And and = new And(EMPTY, eq1, eq2); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(and); - assertEquals(and, exp); - } - - // - // Like / Regex - // - public void testMatchAllLikeToExist() throws Exception { - for (String s : asList("%", "%%", "%%%")) { - LikePattern pattern = new LikePattern(s, (char) 0); - FieldAttribute fa = getFieldAttribute(); - Like l = new Like(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); - assertEquals(IsNotNull.class, e.getClass()); - IsNotNull inn = (IsNotNull) e; - assertEquals(fa, inn.field()); - } - } - - public void testMatchAllWildcardLikeToExist() throws Exception { - for (String s : asList("*", "**", "***")) { - WildcardPattern pattern = new WildcardPattern(s); - FieldAttribute fa = getFieldAttribute(); - WildcardLike l = new WildcardLike(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); - assertEquals(IsNotNull.class, e.getClass()); - IsNotNull inn = (IsNotNull) e; - assertEquals(fa, inn.field()); - } - } - - public void testMatchAllRLikeToExist() throws Exception { - RLikePattern pattern = new RLikePattern(".*"); - FieldAttribute fa = getFieldAttribute(); - RLike l = new RLike(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); - assertEquals(IsNotNull.class, e.getClass()); - IsNotNull inn = (IsNotNull) e; - assertEquals(fa, inn.field()); - } - - public void testExactMatchLike() throws Exception { - for (String s : asList("ab", "ab0%", "ab0_c")) { - LikePattern pattern = new LikePattern(s, '0'); - FieldAttribute fa = getFieldAttribute(); - Like l = new Like(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); - assertEquals(Equals.class, e.getClass()); - Equals eq = (Equals) e; - assertEquals(fa, eq.left()); - assertEquals(s.replace("0", StringUtils.EMPTY), eq.right().fold()); - } - } - - public void testExactMatchWildcardLike() throws Exception { - String s = "ab"; - WildcardPattern pattern = new WildcardPattern(s); - FieldAttribute fa = getFieldAttribute(); - WildcardLike l = new WildcardLike(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); - assertEquals(Equals.class, e.getClass()); - Equals eq = (Equals) e; - assertEquals(fa, eq.left()); - assertEquals(s, eq.right().fold()); - } - - public void testExactMatchRLike() throws Exception { - RLikePattern pattern = new RLikePattern("abc"); - FieldAttribute fa = getFieldAttribute(); - RLike l = new RLike(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); - assertEquals(Equals.class, e.getClass()); - Equals eq = (Equals) e; - assertEquals(fa, eq.left()); - assertEquals("abc", eq.right().fold()); - } - - // - // CombineDisjunction in Equals - // - public void testTwoEqualsWithOr() throws Exception { - FieldAttribute fa = getFieldAttribute(); - - Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO)); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(In.class, e.getClass()); - In in = (In) e; - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, TWO)); - } - - public void testTwoEqualsWithSameValue() throws Exception { - FieldAttribute fa = getFieldAttribute(); - - Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, ONE)); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(Equals.class, e.getClass()); - Equals eq = (Equals) e; - assertEquals(fa, eq.left()); - assertEquals(ONE, eq.right()); - } - - public void testOneEqualsOneIn() throws Exception { - FieldAttribute fa = getFieldAttribute(); - - Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, singletonList(TWO))); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(In.class, e.getClass()); - In in = (In) e; - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, TWO)); - } - - public void testOneEqualsOneInWithSameValue() throws Exception { - FieldAttribute fa = getFieldAttribute(); - - Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, asList(ONE, TWO))); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(In.class, e.getClass()); - In in = (In) e; - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, TWO)); - } - - public void testSingleValueInToEquals() throws Exception { - FieldAttribute fa = getFieldAttribute(); - - Equals equals = equalsOf(fa, ONE); - Or or = new Or(EMPTY, equals, new In(EMPTY, fa, singletonList(ONE))); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(equals, e); - } - - public void testEqualsBehindAnd() throws Exception { - FieldAttribute fa = getFieldAttribute(); - - And and = new And(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO)); - Filter dummy = new Filter(EMPTY, relation(), and); - LogicalPlan transformed = new CombineDisjunctionsToIn().apply(dummy); - assertSame(dummy, transformed); - assertEquals(and, ((Filter) transformed).condition()); - } - - public void testTwoEqualsDifferentFields() throws Exception { - FieldAttribute fieldOne = TestUtils.getFieldAttribute("ONE"); - FieldAttribute fieldTwo = TestUtils.getFieldAttribute("TWO"); - - Or or = new Or(EMPTY, equalsOf(fieldOne, ONE), equalsOf(fieldTwo, TWO)); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(or, e); - } - - public void testMultipleIn() throws Exception { - FieldAttribute fa = getFieldAttribute(); - - Or firstOr = new Or(EMPTY, new In(EMPTY, fa, singletonList(ONE)), new In(EMPTY, fa, singletonList(TWO))); - Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, singletonList(THREE))); - Expression e = new CombineDisjunctionsToIn().rule(secondOr); - assertEquals(In.class, e.getClass()); - In in = (In) e; - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, TWO, THREE)); - } - - public void testOrWithNonCombinableExpressions() throws Exception { - FieldAttribute fa = getFieldAttribute(); - - Or firstOr = new Or(EMPTY, new In(EMPTY, fa, singletonList(ONE)), lessThanOf(fa, TWO)); - Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, singletonList(THREE))); - Expression e = new CombineDisjunctionsToIn().rule(secondOr); - assertEquals(Or.class, e.getClass()); - Or or = (Or) e; - assertEquals(or.left(), firstOr.right()); - assertEquals(In.class, or.right().getClass()); - In in = (In) or.right(); - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, THREE)); - } - - // Null folding - - public void testNullFoldingIsNull() { - FoldNull foldNull = new FoldNull(); - assertEquals(true, foldNull.rule(new IsNull(EMPTY, NULL)).fold()); - assertEquals(false, foldNull.rule(new IsNull(EMPTY, TRUE)).fold()); - } - - public void testGenericNullableExpression() { - FoldNull rule = new FoldNull(); - // arithmetic - assertNullLiteral(rule.rule(new Add(EMPTY, getFieldAttribute(), NULL))); - // comparison - assertNullLiteral(rule.rule(greaterThanOf(getFieldAttribute(), NULL))); - // regex - assertNullLiteral(rule.rule(new RLike(EMPTY, NULL, new RLikePattern("123")))); - } - - public void testNullFoldingDoesNotApplyOnLogicalExpressions() { - FoldNull rule = new FoldNull(); - - Or or = new Or(EMPTY, NULL, TRUE); - assertEquals(or, rule.rule(or)); - or = new Or(EMPTY, NULL, NULL); - assertEquals(or, rule.rule(or)); - - And and = new And(EMPTY, NULL, TRUE); - assertEquals(and, rule.rule(and)); - and = new And(EMPTY, NULL, NULL); - assertEquals(and, rule.rule(and)); - } - - // - // Propagate nullability (IS NULL / IS NOT NULL) - // - - // a IS NULL AND a IS NOT NULL => false - public void testIsNullAndNotNull() throws Exception { - FieldAttribute fa = getFieldAttribute(); - - And and = new And(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, fa)); - assertEquals(FALSE, new PropagateNullable().rule(and)); - } - - // a IS NULL AND b IS NOT NULL AND c IS NULL AND d IS NOT NULL AND e IS NULL AND a IS NOT NULL => false - public void testIsNullAndNotNullMultiField() throws Exception { - FieldAttribute fa = getFieldAttribute(); - - And andOne = new And(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, getFieldAttribute())); - And andTwo = new And(EMPTY, new IsNull(EMPTY, getFieldAttribute()), new IsNotNull(EMPTY, getFieldAttribute())); - And andThree = new And(EMPTY, new IsNull(EMPTY, getFieldAttribute()), new IsNotNull(EMPTY, fa)); - - And and = new And(EMPTY, andOne, new And(EMPTY, andThree, andTwo)); - - assertEquals(FALSE, new PropagateNullable().rule(and)); - } - - // a IS NULL AND a > 1 => a IS NULL AND false - public void testIsNullAndComparison() throws Exception { - FieldAttribute fa = getFieldAttribute(); - IsNull isNull = new IsNull(EMPTY, fa); - - And and = new And(EMPTY, isNull, greaterThanOf(fa, ONE)); - assertEquals(new And(EMPTY, isNull, nullOf(BOOLEAN)), new PropagateNullable().rule(and)); - } - - // a IS NULL AND b < 1 AND c < 1 AND a < 1 => a IS NULL AND b < 1 AND c < 1 => a IS NULL AND b < 1 AND c < 1 - public void testIsNullAndMultipleComparison() throws Exception { - FieldAttribute fa = getFieldAttribute(); - IsNull isNull = new IsNull(EMPTY, fa); - - And nestedAnd = new And( - EMPTY, - lessThanOf(TestUtils.getFieldAttribute("b"), ONE), - lessThanOf(TestUtils.getFieldAttribute("c"), ONE) - ); - And and = new And(EMPTY, isNull, nestedAnd); - And top = new And(EMPTY, and, lessThanOf(fa, ONE)); - - Expression optimized = new PropagateNullable().rule(top); - Expression expected = new And(EMPTY, and, nullOf(BOOLEAN)); - assertEquals(Predicates.splitAnd(expected), Predicates.splitAnd(optimized)); - } - - // ((a+1)/2) > 1 AND a + 2 AND a IS NULL AND b < 3 => NULL AND NULL AND a IS NULL AND b < 3 - public void testIsNullAndDeeplyNestedExpression() throws Exception { - FieldAttribute fa = getFieldAttribute(); - IsNull isNull = new IsNull(EMPTY, fa); - - Expression nullified = new And( - EMPTY, - greaterThanOf(new Div(EMPTY, new Add(EMPTY, fa, ONE), TWO), ONE), - greaterThanOf(new Add(EMPTY, fa, TWO), ONE) - ); - Expression kept = new And(EMPTY, isNull, lessThanOf(TestUtils.getFieldAttribute("b"), THREE)); - And and = new And(EMPTY, nullified, kept); - - Expression optimized = new PropagateNullable().rule(and); - Expression expected = new And(EMPTY, new And(EMPTY, nullOf(BOOLEAN), nullOf(BOOLEAN)), kept); - - assertEquals(Predicates.splitAnd(expected), Predicates.splitAnd(optimized)); - } - - // a IS NULL OR a IS NOT NULL => no change - // a IS NULL OR a > 1 => no change - public void testIsNullInDisjunction() throws Exception { - FieldAttribute fa = getFieldAttribute(); - - Or or = new Or(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, fa)); - Filter dummy = new Filter(EMPTY, relation(), or); - LogicalPlan transformed = new PropagateNullable().apply(dummy); - assertSame(dummy, transformed); - assertEquals(or, ((Filter) transformed).condition()); - - or = new Or(EMPTY, new IsNull(EMPTY, fa), greaterThanOf(fa, ONE)); - dummy = new Filter(EMPTY, relation(), or); - transformed = new PropagateNullable().apply(dummy); - assertSame(dummy, transformed); - assertEquals(or, ((Filter) transformed).condition()); - } - - // a + 1 AND (a IS NULL OR a > 3) => no change - public void testIsNullDisjunction() throws Exception { - FieldAttribute fa = getFieldAttribute(); - IsNull isNull = new IsNull(EMPTY, fa); - - Or or = new Or(EMPTY, isNull, greaterThanOf(fa, THREE)); - And and = new And(EMPTY, new Add(EMPTY, fa, ONE), or); - - assertEquals(and, new PropagateNullable().rule(and)); - } - - public void testIsNotNullOnIsNullField() { - EsRelation relation = relation(); - var fieldA = TestUtils.getFieldAttribute("a"); - Expression inn = isNotNull(fieldA); - Filter f = new Filter(EMPTY, relation, inn); - - assertEquals(f, new OptimizerRules.InferIsNotNull().apply(f)); - } - - public void testIsNotNullOnOperatorWithOneField() { - EsRelation relation = relation(); - var fieldA = TestUtils.getFieldAttribute("a"); - Expression inn = isNotNull(new Add(EMPTY, fieldA, ONE)); - Filter f = new Filter(EMPTY, relation, inn); - Filter expected = new Filter(EMPTY, relation, new And(EMPTY, isNotNull(fieldA), inn)); - - assertEquals(expected, new OptimizerRules.InferIsNotNull().apply(f)); - } - - public void testIsNotNullOnOperatorWithTwoFields() { - EsRelation relation = relation(); - var fieldA = TestUtils.getFieldAttribute("a"); - var fieldB = TestUtils.getFieldAttribute("b"); - Expression inn = isNotNull(new Add(EMPTY, fieldA, fieldB)); - Filter f = new Filter(EMPTY, relation, inn); - Filter expected = new Filter(EMPTY, relation, new And(EMPTY, new And(EMPTY, isNotNull(fieldA), isNotNull(fieldB)), inn)); - - assertEquals(expected, new OptimizerRules.InferIsNotNull().apply(f)); - } - - public void testIsNotNullOnFunctionWithTwoField() {} - - private IsNotNull isNotNull(Expression field) { - return new IsNotNull(EMPTY, field); - } - - private IsNull isNull(Expression field) { - return new IsNull(EMPTY, field); - } - - private Literal nullOf(DataType dataType) { - return new Literal(Source.EMPTY, null, dataType); - } - - private void assertNullLiteral(Expression expression) { - assertEquals(Literal.class, expression.getClass()); - assertNull(expression.fold()); - } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index 62a4a25445ac7..b5a288408809f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -32,7 +32,9 @@ import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; +import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RegexMatch; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.StringPattern; import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.ConstantFolding; import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.LiteralsOnTheRight; @@ -1293,7 +1295,28 @@ protected LogicalPlan rule(LogicalPlan plan, LogicalOptimizerContext context) { } } - public static class ReplaceRegexMatch extends OptimizerRules.ReplaceRegexMatch { + public static class ReplaceRegexMatch extends org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.OptimizerExpressionRule< + RegexMatch> { + + ReplaceRegexMatch() { + super(org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.TransformDirection.DOWN); + } + + @Override + public Expression rule(RegexMatch regexMatch) { + Expression e = regexMatch; + StringPattern pattern = regexMatch.pattern(); + if (pattern.matchesAll()) { + e = new IsNotNull(e.source(), regexMatch.field()); + } else { + String match = pattern.exactMatch(); + if (match != null) { + Literal literal = new Literal(regexMatch.source(), match, DataType.KEYWORD); + e = regexToEquals(regexMatch, literal); + } + } + return e; + } protected Expression regexToEquals(RegexMatch regexMatch, Literal literal) { return new Equals(regexMatch.source(), regexMatch.field(), literal); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java index fdddbfc837a51..2041c08acbca6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java @@ -294,12 +294,6 @@ protected Expression rule(BinaryComparison bc) { * When encountering a different Equals, non-containing {@link Range} or {@link BinaryComparison}, the conjunction becomes false. * When encountering a containing {@link Range}, {@link BinaryComparison} or {@link NotEquals}, these get eliminated by the equality. * - * Since this rule can eliminate Ranges and BinaryComparisons, it should be applied before - * {@link org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.CombineBinaryComparisons}. - * - * This rule doesn't perform any promotion of {@link BinaryComparison}s, that is handled by - * {@link org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.CombineBinaryComparisons} on purpose as the resulting Range might - * be foldable (which is picked by the folding rule on the next run). */ public static final class PropagateEquals extends org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.OptimizerExpressionRule< BinaryLogic> { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java index 6ca1a638fa0ee..b5400237bfeef 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java @@ -9,19 +9,40 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.TestUtils; +import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.Nullability; +import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator; import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; import org.elasticsearch.xpack.esql.core.expression.predicate.Range; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; +import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; +import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.Like; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.LikePattern; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardLike; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; +import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.ConstantFolding; +import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.FoldNull; +import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.PropagateNullable; import org.elasticsearch.xpack.esql.core.plan.logical.Filter; import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.StringUtils; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mod; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; @@ -29,16 +50,21 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; +import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer.ReplaceRegexMatch; import java.util.List; import static java.util.Arrays.asList; +import static org.elasticsearch.xpack.esql.core.TestUtils.nullEqualsOf; +import static org.elasticsearch.xpack.esql.core.TestUtils.of; import static org.elasticsearch.xpack.esql.core.TestUtils.rangeOf; import static org.elasticsearch.xpack.esql.core.TestUtils.relation; import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL; import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; +import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; import static org.hamcrest.Matchers.contains; public class OptimizerRulesTests extends ESTestCase { @@ -47,6 +73,9 @@ public class OptimizerRulesTests extends ESTestCase { private static final Literal THREE = new Literal(Source.EMPTY, 3, DataType.INTEGER); private static final Literal FOUR = new Literal(Source.EMPTY, 4, DataType.INTEGER); private static final Literal FIVE = new Literal(Source.EMPTY, 5, DataType.INTEGER); + private static final Literal SIX = new Literal(Source.EMPTY, 6, DataType.INTEGER); + private static final Expression DUMMY_EXPRESSION = + new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 0); private static Equals equalsOf(Expression left, Expression right) { return new Equals(EMPTY, left, right, null); @@ -76,6 +105,82 @@ private static FieldAttribute getFieldAttribute() { return TestUtils.getFieldAttribute("a"); } + // + // Constant folding + // + + public void testConstantFolding() { + Expression exp = new Add(EMPTY, TWO, THREE); + + assertTrue(exp.foldable()); + Expression result = new ConstantFolding().rule(exp); + assertTrue(result instanceof Literal); + assertEquals(5, ((Literal) result).value()); + + // check now with an alias + result = new ConstantFolding().rule(new Alias(EMPTY, "a", exp)); + assertEquals("a", Expressions.name(result)); + assertEquals(Alias.class, result.getClass()); + } + + public void testConstantFoldingBinaryComparison() { + assertEquals(FALSE, new ConstantFolding().rule(greaterThanOf(TWO, THREE)).canonical()); + assertEquals(FALSE, new ConstantFolding().rule(greaterThanOrEqualOf(TWO, THREE)).canonical()); + assertEquals(FALSE, new ConstantFolding().rule(equalsOf(TWO, THREE)).canonical()); + assertEquals(FALSE, new ConstantFolding().rule(nullEqualsOf(TWO, THREE)).canonical()); + assertEquals(FALSE, new ConstantFolding().rule(nullEqualsOf(TWO, NULL)).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(notEqualsOf(TWO, THREE)).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(lessThanOrEqualOf(TWO, THREE)).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(lessThanOf(TWO, THREE)).canonical()); + } + + public void testConstantFoldingBinaryLogic() { + assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, greaterThanOf(TWO, THREE), TRUE)).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, greaterThanOrEqualOf(TWO, THREE), TRUE)).canonical()); + } + + public void testConstantFoldingBinaryLogic_WithNullHandling() { + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, TRUE)).canonical().nullable()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, TRUE, NULL)).canonical().nullable()); + assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, NULL, FALSE)).canonical()); + assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, FALSE, NULL)).canonical()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, NULL)).canonical().nullable()); + + assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, TRUE)).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, TRUE, NULL)).canonical()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, FALSE)).canonical().nullable()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, FALSE, NULL)).canonical().nullable()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, NULL)).canonical().nullable()); + } + + public void testConstantFoldingRange() { + assertEquals(true, new ConstantFolding().rule(rangeOf(FIVE, FIVE, true, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold()); + assertEquals(false, new ConstantFolding().rule(rangeOf(FIVE, FIVE, false, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold()); + } + + public void testConstantNot() { + assertEquals(FALSE, new ConstantFolding().rule(new Not(EMPTY, TRUE))); + assertEquals(TRUE, new ConstantFolding().rule(new Not(EMPTY, FALSE))); + } + + public void testConstantFoldingLikes() { + assertEquals(TRUE, new ConstantFolding().rule(new Like(EMPTY, of("test_emp"), new LikePattern("test%", (char) 0))).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(new WildcardLike(EMPTY, of("test_emp"), new WildcardPattern("test*"))).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(new RLike(EMPTY, of("test_emp"), new RLikePattern("test.emp"))).canonical()); + } + + public void testArithmeticFolding() { + assertEquals(10, foldOperator(new Add(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); + assertEquals(4, foldOperator(new Sub(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); + assertEquals(21, foldOperator(new Mul(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); + assertEquals(2, foldOperator(new Div(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); + assertEquals(1, foldOperator(new Mod(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); + } + + private static Object foldOperator(BinaryOperator b) { + return ((Literal) new ConstantFolding().rule(b)).value(); + } + // // CombineDisjunction in Equals // @@ -500,4 +605,282 @@ public void testEliminateRangeByEqualsInInterval() { Expression exp = rule.rule(new And(EMPTY, eq1, r)); assertEquals(eq1, exp); } + // + // Null folding + + public void testNullFoldingIsNull() { + FoldNull foldNull = new FoldNull(); + assertEquals(true, foldNull.rule(new IsNull(EMPTY, NULL)).fold()); + assertEquals(false, foldNull.rule(new IsNull(EMPTY, TRUE)).fold()); + } + + public void testGenericNullableExpression() { + FoldNull rule = new FoldNull(); + // arithmetic + assertNullLiteral(rule.rule(new Add(EMPTY, getFieldAttribute(), NULL))); + // comparison + assertNullLiteral(rule.rule(greaterThanOf(getFieldAttribute(), NULL))); + // regex + assertNullLiteral(rule.rule(new RLike(EMPTY, NULL, new RLikePattern("123")))); + } + + public void testNullFoldingDoesNotApplyOnLogicalExpressions() { + org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.FoldNull rule = + new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.FoldNull(); + + Or or = new Or(EMPTY, NULL, TRUE); + assertEquals(or, rule.rule(or)); + or = new Or(EMPTY, NULL, NULL); + assertEquals(or, rule.rule(or)); + + And and = new And(EMPTY, NULL, TRUE); + assertEquals(and, rule.rule(and)); + and = new And(EMPTY, NULL, NULL); + assertEquals(and, rule.rule(and)); + } + + // + // Propagate nullability (IS NULL / IS NOT NULL) + // + + // a IS NULL AND a IS NOT NULL => false + public void testIsNullAndNotNull() { + FieldAttribute fa = getFieldAttribute(); + + And and = new And(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, fa)); + assertEquals(FALSE, new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.PropagateNullable().rule(and)); + } + + // a IS NULL AND b IS NOT NULL AND c IS NULL AND d IS NOT NULL AND e IS NULL AND a IS NOT NULL => false + public void testIsNullAndNotNullMultiField() { + FieldAttribute fa = getFieldAttribute(); + + And andOne = new And(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, getFieldAttribute())); + And andTwo = new And(EMPTY, new IsNull(EMPTY, getFieldAttribute()), new IsNotNull(EMPTY, getFieldAttribute())); + And andThree = new And(EMPTY, new IsNull(EMPTY, getFieldAttribute()), new IsNotNull(EMPTY, fa)); + + And and = new And(EMPTY, andOne, new And(EMPTY, andThree, andTwo)); + + assertEquals(FALSE, new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.PropagateNullable().rule(and)); + } + + // a IS NULL AND a > 1 => a IS NULL AND false + public void testIsNullAndComparison() { + FieldAttribute fa = getFieldAttribute(); + IsNull isNull = new IsNull(EMPTY, fa); + + And and = new And(EMPTY, isNull, greaterThanOf(fa, ONE)); + assertEquals(new And(EMPTY, isNull, nullOf(BOOLEAN)), new PropagateNullable().rule(and)); + } + + // a IS NULL AND b < 1 AND c < 1 AND a < 1 => a IS NULL AND b < 1 AND c < 1 => a IS NULL AND b < 1 AND c < 1 + public void testIsNullAndMultipleComparison() { + FieldAttribute fa = getFieldAttribute(); + IsNull isNull = new IsNull(EMPTY, fa); + + And nestedAnd = new And( + EMPTY, + lessThanOf(TestUtils.getFieldAttribute("b"), ONE), + lessThanOf(TestUtils.getFieldAttribute("c"), ONE) + ); + And and = new And(EMPTY, isNull, nestedAnd); + And top = new And(EMPTY, and, lessThanOf(fa, ONE)); + + Expression optimized = new PropagateNullable().rule(top); + Expression expected = new And(EMPTY, and, nullOf(BOOLEAN)); + assertEquals(Predicates.splitAnd(expected), Predicates.splitAnd(optimized)); + } + + // ((a+1)/2) > 1 AND a + 2 AND a IS NULL AND b < 3 => NULL AND NULL AND a IS NULL AND b < 3 + public void testIsNullAndDeeplyNestedExpression() { + FieldAttribute fa = getFieldAttribute(); + IsNull isNull = new IsNull(EMPTY, fa); + + Expression nullified = new And( + EMPTY, + greaterThanOf(new Div(EMPTY, new Add(EMPTY, fa, ONE), TWO), ONE), + greaterThanOf(new Add(EMPTY, fa, TWO), ONE) + ); + Expression kept = new And(EMPTY, isNull, lessThanOf(TestUtils.getFieldAttribute("b"), THREE)); + And and = new And(EMPTY, nullified, kept); + + Expression optimized = new PropagateNullable().rule(and); + Expression expected = new And(EMPTY, new And(EMPTY, nullOf(BOOLEAN), nullOf(BOOLEAN)), kept); + + assertEquals(Predicates.splitAnd(expected), Predicates.splitAnd(optimized)); + } + + // a IS NULL OR a IS NOT NULL => no change + // a IS NULL OR a > 1 => no change + public void testIsNullInDisjunction() { + FieldAttribute fa = getFieldAttribute(); + + Or or = new Or(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, fa)); + Filter dummy = new Filter(EMPTY, relation(), or); + LogicalPlan transformed = new PropagateNullable().apply(dummy); + assertSame(dummy, transformed); + assertEquals(or, ((Filter) transformed).condition()); + + or = new Or(EMPTY, new IsNull(EMPTY, fa), greaterThanOf(fa, ONE)); + dummy = new Filter(EMPTY, relation(), or); + transformed = new PropagateNullable().apply(dummy); + assertSame(dummy, transformed); + assertEquals(or, ((Filter) transformed).condition()); + } + + // a + 1 AND (a IS NULL OR a > 3) => no change + public void testIsNullDisjunction() { + FieldAttribute fa = getFieldAttribute(); + IsNull isNull = new IsNull(EMPTY, fa); + + Or or = new Or(EMPTY, isNull, greaterThanOf(fa, THREE)); + And and = new And(EMPTY, new Add(EMPTY, fa, ONE), or); + + assertEquals(and, new PropagateNullable().rule(and)); + } + + // + // Like / Regex + // + public void testMatchAllLikeToExist() { + for (String s : asList("%", "%%", "%%%")) { + LikePattern pattern = new LikePattern(s, (char) 0); + FieldAttribute fa = getFieldAttribute(); + Like l = new Like(EMPTY, fa, pattern); + Expression e = new ReplaceRegexMatch().rule(l); + assertEquals(IsNotNull.class, e.getClass()); + IsNotNull inn = (IsNotNull) e; + assertEquals(fa, inn.field()); + } + } + + public void testMatchAllWildcardLikeToExist() { + for (String s : asList("*", "**", "***")) { + WildcardPattern pattern = new WildcardPattern(s); + FieldAttribute fa = getFieldAttribute(); + WildcardLike l = new WildcardLike(EMPTY, fa, pattern); + Expression e = new ReplaceRegexMatch().rule(l); + assertEquals(IsNotNull.class, e.getClass()); + IsNotNull inn = (IsNotNull) e; + assertEquals(fa, inn.field()); + } + } + + public void testMatchAllRLikeToExist() { + RLikePattern pattern = new RLikePattern(".*"); + FieldAttribute fa = getFieldAttribute(); + RLike l = new RLike(EMPTY, fa, pattern); + Expression e = new ReplaceRegexMatch().rule(l); + assertEquals(IsNotNull.class, e.getClass()); + IsNotNull inn = (IsNotNull) e; + assertEquals(fa, inn.field()); + } + + public void testExactMatchLike() { + for (String s : asList("ab", "ab0%", "ab0_c")) { + LikePattern pattern = new LikePattern(s, '0'); + FieldAttribute fa = getFieldAttribute(); + Like l = new Like(EMPTY, fa, pattern); + Expression e = new ReplaceRegexMatch().rule(l); + assertEquals(Equals.class, e.getClass()); + Equals eq = (Equals) e; + assertEquals(fa, eq.left()); + assertEquals(s.replace("0", StringUtils.EMPTY), eq.right().fold()); + } + } + + public void testExactMatchWildcardLike() { + String s = "ab"; + WildcardPattern pattern = new WildcardPattern(s); + FieldAttribute fa = getFieldAttribute(); + WildcardLike l = new WildcardLike(EMPTY, fa, pattern); + Expression e = new ReplaceRegexMatch().rule(l); + assertEquals(Equals.class, e.getClass()); + Equals eq = (Equals) e; + assertEquals(fa, eq.left()); + assertEquals(s, eq.right().fold()); + } + + public void testExactMatchRLike() { + RLikePattern pattern = new RLikePattern("abc"); + FieldAttribute fa = getFieldAttribute(); + RLike l = new RLike(EMPTY, fa, pattern); + Expression e = new ReplaceRegexMatch().rule(l); + assertEquals(Equals.class, e.getClass()); + Equals eq = (Equals) e; + assertEquals(fa, eq.left()); + assertEquals("abc", eq.right().fold()); + } + + private void assertNullLiteral(Expression expression) { + assertEquals(Literal.class, expression.getClass()); + assertNull(expression.fold()); + } + + private IsNotNull isNotNull(Expression field) { + return new IsNotNull(EMPTY, field); + } + + private IsNull isNull(Expression field) { + return new IsNull(EMPTY, field); + } + + private Literal nullOf(DataType dataType) { + return new Literal(Source.EMPTY, null, dataType); + } + // + // Logical simplifications + // + + public void testLiteralsOnTheRight() { + Alias a = new Alias(EMPTY, "a", new Literal(EMPTY, 10, INTEGER)); + Expression result = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.LiteralsOnTheRight().rule(equalsOf(FIVE, a)); + assertTrue(result instanceof Equals); + Equals eq = (Equals) result; + assertEquals(a, eq.left()); + assertEquals(FIVE, eq.right()); + + // Note: Null Equals test removed here + } + + public void testBoolSimplifyOr() { + org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification simplification = + new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification(); + + assertEquals(TRUE, simplification.rule(new Or(EMPTY, TRUE, TRUE))); + assertEquals(TRUE, simplification.rule(new Or(EMPTY, TRUE, DUMMY_EXPRESSION))); + assertEquals(TRUE, simplification.rule(new Or(EMPTY, DUMMY_EXPRESSION, TRUE))); + + assertEquals(FALSE, simplification.rule(new Or(EMPTY, FALSE, FALSE))); + assertEquals(DUMMY_EXPRESSION, simplification.rule(new Or(EMPTY, FALSE, DUMMY_EXPRESSION))); + assertEquals(DUMMY_EXPRESSION, simplification.rule(new Or(EMPTY, DUMMY_EXPRESSION, FALSE))); + } + + public void testBoolSimplifyAnd() { + org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification simplification = + new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification(); + + assertEquals(TRUE, simplification.rule(new And(EMPTY, TRUE, TRUE))); + assertEquals(DUMMY_EXPRESSION, simplification.rule(new And(EMPTY, TRUE, DUMMY_EXPRESSION))); + assertEquals(DUMMY_EXPRESSION, simplification.rule(new And(EMPTY, DUMMY_EXPRESSION, TRUE))); + + assertEquals(FALSE, simplification.rule(new And(EMPTY, FALSE, FALSE))); + assertEquals(FALSE, simplification.rule(new And(EMPTY, FALSE, DUMMY_EXPRESSION))); + assertEquals(FALSE, simplification.rule(new And(EMPTY, DUMMY_EXPRESSION, FALSE))); + } + + public void testBoolCommonFactorExtraction() { + org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification simplification = + new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification(); + + Expression a1 = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 1); + Expression a2 = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 1); + Expression b = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 2); + Expression c = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 3); + + Or actual = new Or(EMPTY, new And(EMPTY, a1, b), new And(EMPTY, a2, c)); + And expected = new And(EMPTY, a1, new Or(EMPTY, b, c)); + + assertEquals(expected, simplification.rule(actual)); + } } From aab2db57170473536e8b407f7718421bcfbaf0f3 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 5 Jun 2024 16:14:35 +0100 Subject: [PATCH 12/30] [ML] Fix serialising inference delete response (#109384) --- docs/changelog/109384.yaml | 5 +++++ .../action/DeleteInferenceEndpointAction.java | 2 +- .../TransportDeleteInferenceEndpointAction.java | 14 +++++++++----- 3 files changed, 15 insertions(+), 6 deletions(-) create mode 100644 docs/changelog/109384.yaml diff --git a/docs/changelog/109384.yaml b/docs/changelog/109384.yaml new file mode 100644 index 0000000000000..303da23d57d8e --- /dev/null +++ b/docs/changelog/109384.yaml @@ -0,0 +1,5 @@ +pr: 109384 +summary: Fix serialising inference delete response +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java index be4c87195e9d1..19542ef466156 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java @@ -20,7 +20,7 @@ import java.util.Objects; import java.util.Set; -public class DeleteInferenceEndpointAction extends ActionType { +public class DeleteInferenceEndpointAction extends ActionType { public static final DeleteInferenceEndpointAction INSTANCE = new DeleteInferenceEndpointAction(); public static final String NAME = "cluster:admin/xpack/inference/delete"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index 5ebca2bc512a0..07d5e1e618578 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -13,8 +13,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.SubscribableListener; -import org.elasticsearch.action.support.master.AcknowledgedResponse; -import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; @@ -36,7 +35,9 @@ import java.util.Set; -public class TransportDeleteInferenceEndpointAction extends AcknowledgedTransportMasterNodeAction { +public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeAction< + DeleteInferenceEndpointAction.Request, + DeleteInferenceEndpointAction.Response> { private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; @@ -60,6 +61,7 @@ public TransportDeleteInferenceEndpointAction( actionFilters, DeleteInferenceEndpointAction.Request::new, indexNameExpressionResolver, + DeleteInferenceEndpointAction.Response::new, EsExecutors.DIRECT_EXECUTOR_SERVICE ); this.modelRegistry = modelRegistry; @@ -71,7 +73,7 @@ protected void masterOperation( Task task, DeleteInferenceEndpointAction.Request request, ClusterState state, - ActionListener masterListener + ActionListener masterListener ) { SubscribableListener.newForked(modelConfigListener -> { // Get the model from the registry @@ -123,7 +125,9 @@ && endpointIsReferencedInPipelines(state, request.getInferenceEndpointId(), list } }) .addListener( - masterListener.delegateFailure((l3, didDeleteModel) -> masterListener.onResponse(AcknowledgedResponse.of(didDeleteModel))) + masterListener.delegateFailure( + (l3, didDeleteModel) -> masterListener.onResponse(new DeleteInferenceEndpointAction.Response(didDeleteModel, Set.of())) + ) ); } From cccac69344d2d0df2bfda2d67692ec7ba93a87b8 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 5 Jun 2024 16:16:53 +0100 Subject: [PATCH 13/30] [ML] Use the multi node routing action for internal inference services (#109358) The Elser and internal Elasticsearch inference services were calling the wrong inference action that did not distribute the work among all available nodes, rather work was duplicated by all nodes. This bug does not apply when the inference ingest processor is used and is only noticeable when an inference request contains many inputs. --- docs/changelog/109358.yaml | 5 ++ .../org/elasticsearch/TransportVersions.java | 1 + .../core/ml/action/InferModelAction.java | 32 ++++++++- .../action/InferModelActionRequestTests.java | 21 +++++- .../ElasticsearchInternalService.java | 67 +++++++++++++------ .../services/elser/ElserInternalService.java | 37 ++++++---- .../ElasticsearchInternalServiceTests.java | 39 ++++++++--- .../elser/ElserInternalServiceTests.java | 12 ++-- .../TransportInternalInferModelAction.java | 1 + 9 files changed, 163 insertions(+), 52 deletions(-) create mode 100644 docs/changelog/109358.yaml diff --git a/docs/changelog/109358.yaml b/docs/changelog/109358.yaml new file mode 100644 index 0000000000000..af47b4129d874 --- /dev/null +++ b/docs/changelog/109358.yaml @@ -0,0 +1,5 @@ +pr: 109358 +summary: Use the multi node routing action for internal inference services +area: Machine Learning +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 07579161a85c8..e8a33217b937d 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -183,6 +183,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_ENHANCE_DELETE_ENDPOINT = def(8_674_00_0); public static final TransportVersion ML_INFERENCE_GOOGLE_AI_STUDIO_EMBEDDINGS_ADDED = def(8_675_00_0); public static final TransportVersion ADD_MISTRAL_EMBEDDINGS_INFERENCE = def(8_676_00_0); + public static final TransportVersion ML_CHUNK_INFERENCE_OPTION = def(8_677_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index eb5f1d4f086d0..e6b580f62fdd3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -90,6 +90,7 @@ public static Builder parseRequest(String id, XContentParser parser) { private final List textInput; private boolean highPriority; private TrainedModelPrefixStrings.PrefixType prefixType = TrainedModelPrefixStrings.PrefixType.NONE; + private boolean chunked = false; /** * Build a request from a list of documents as maps. @@ -197,6 +198,11 @@ public Request(StreamInput in) throws IOException { } else { prefixType = TrainedModelPrefixStrings.PrefixType.NONE; } + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_CHUNK_INFERENCE_OPTION)) { + chunked = in.readBoolean(); + } else { + chunked = false; + } } public int numberOfDocuments() { @@ -247,6 +253,14 @@ public TrainedModelPrefixStrings.PrefixType getPrefixType() { return prefixType; } + public void setChunked(boolean chunked) { + this.chunked = chunked; + } + + public boolean isChunked() { + return chunked; + } + @Override public ActionRequestValidationException validate() { return null; @@ -271,6 +285,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { out.writeEnum(prefixType); } + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_CHUNK_INFERENCE_OPTION)) { + out.writeBoolean(chunked); + } } @Override @@ -285,7 +302,8 @@ public boolean equals(Object o) { && Objects.equals(objectsToInfer, that.objectsToInfer) && Objects.equals(textInput, that.textInput) && (highPriority == that.highPriority) - && (prefixType == that.prefixType); + && (prefixType == that.prefixType) + && (chunked == that.chunked); } @Override @@ -295,7 +313,17 @@ public Task createTask(long id, String type, String action, TaskId parentTaskId, @Override public int hashCode() { - return Objects.hash(id, objectsToInfer, update, previouslyLicensed, inferenceTimeout, textInput, highPriority, prefixType); + return Objects.hash( + id, + objectsToInfer, + update, + previouslyLicensed, + inferenceTimeout, + textInput, + highPriority, + prefixType, + chunked + ); } public static class Builder { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java index 983e5d43a946d..2e4689de787b3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -73,6 +73,7 @@ protected Request createTestInstance() { if (randomBoolean()) { request.setPrefixType(randomFrom(TrainedModelPrefixStrings.PrefixType.values())); } + request.setChunked(randomBoolean()); return request; } @@ -87,8 +88,9 @@ protected Request mutateInstance(Request instance) { var previouslyLicensed = instance.isPreviouslyLicensed(); var timeout = instance.getInferenceTimeout(); var prefixType = instance.getPrefixType(); + var chunked = instance.isChunked(); - int change = randomIntBetween(0, 7); + int change = randomIntBetween(0, 8); switch (change) { case 0: modelId = modelId + "foo"; @@ -123,6 +125,9 @@ protected Request mutateInstance(Request instance) { prefixType = TrainedModelPrefixStrings.PrefixType.values()[(prefixType.ordinal() + 1) % TrainedModelPrefixStrings.PrefixType .values().length]; break; + case 8: + chunked = chunked == false; + break; default: throw new IllegalStateException(); } @@ -130,6 +135,7 @@ protected Request mutateInstance(Request instance) { var r = new Request(modelId, update, objectsToInfer, textInput, timeout, previouslyLicensed); r.setHighPriority(highPriority); r.setPrefixType(prefixType); + r.setChunked(chunked); return r; } @@ -246,6 +252,19 @@ protected Request mutateInstanceForVersion(Request instance, TransportVersion ve r.setHighPriority(instance.isHighPriority()); r.setPrefixType(TrainedModelPrefixStrings.PrefixType.NONE); return r; + } else if (version.before(TransportVersions.ML_CHUNK_INFERENCE_OPTION)) { + var r = new Request( + instance.getId(), + adjustedUpdate, + instance.getObjectsToInfer(), + instance.getTextInput(), + instance.getInferenceTimeout(), + instance.isPreviouslyLicensed() + ); + r.setHighPriority(instance.isHighPriority()); + r.setPrefixType(instance.getPrefixType()); + r.setChunked(false); // r.setChunked(instance.isChunked()); for the next version + return r; } return instance; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 408e3ec1ccbca..67a45ba8b1295 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -33,13 +33,15 @@ import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; -import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; @@ -239,7 +241,7 @@ public void infer( if (TaskType.TEXT_EMBEDDING.equals(taskType)) { inferTextEmbedding(model, input, inputType, timeout, listener); } else if (TaskType.RERANK.equals(taskType)) { - inferRerank(model, query, input, timeout, taskSettings, listener); + inferRerank(model, query, input, inputType, timeout, taskSettings, listener); } else { throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); } @@ -247,22 +249,26 @@ public void infer( public void inferTextEmbedding( Model model, - List input, + List inputs, InputType inputType, TimeValue timeout, ActionListener listener ) { - var request = InferTrainedModelDeploymentAction.Request.forTextInput( + var request = buildInferenceRequest( model.getConfigurations().getInferenceEntityId(), TextEmbeddingConfigUpdate.EMPTY_INSTANCE, - input, - timeout + inputs, + inputType, + timeout, + false ); client.execute( - InferTrainedModelDeploymentAction.INSTANCE, + InferModelAction.INSTANCE, request, - listener.delegateFailureAndWrap((l, inferenceResult) -> l.onResponse(TextEmbeddingResults.of(inferenceResult.getResults()))) + listener.delegateFailureAndWrap( + (l, inferenceResult) -> l.onResponse(TextEmbeddingResults.of(inferenceResult.getInferenceResults())) + ) ); } @@ -270,16 +276,18 @@ public void inferRerank( Model model, String query, List inputs, + InputType inputType, TimeValue timeout, Map requestTaskSettings, ActionListener listener ) { - var config = new TextSimilarityConfigUpdate(query); - var request = InferTrainedModelDeploymentAction.Request.forTextInput( + var request = buildInferenceRequest( model.getConfigurations().getInferenceEntityId(), - config, + new TextSimilarityConfigUpdate(query), inputs, - timeout + inputType, + timeout, + false ); var modelSettings = (CustomElandRerankTaskSettings) model.getTaskSettings(); @@ -289,10 +297,12 @@ public void inferRerank( Function inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null; client.execute( - InferTrainedModelDeploymentAction.INSTANCE, + InferModelAction.INSTANCE, request, listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getResults(), inputSupplier)) + (l, inferenceResult) -> l.onResponse( + textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier) + ) ) ); } @@ -331,18 +341,21 @@ public void chunkedInfer( ? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span()) : new TokenizationConfigUpdate(null, null); - var request = InferTrainedModelDeploymentAction.Request.forTextInput( + var request = buildInferenceRequest( model.getConfigurations().getInferenceEntityId(), configUpdate, input, - timeout + inputType, + timeout, + true ); - request.setChunkResults(true); client.execute( - InferTrainedModelDeploymentAction.INSTANCE, + InferModelAction.INSTANCE, request, - listener.delegateFailureAndWrap((l, inferenceResult) -> l.onResponse(translateToChunkedResults(inferenceResult.getResults()))) + listener.delegateFailureAndWrap( + (l, inferenceResult) -> l.onResponse(translateToChunkedResults(inferenceResult.getInferenceResults())) + ) ); } @@ -532,4 +545,20 @@ private RankedDocsResults textSimilarityResultsToRankedDocs( return new RankedDocsResults(rankings); } + public static InferModelAction.Request buildInferenceRequest( + String id, + InferenceConfigUpdate update, + List inputs, + InputType inputType, + TimeValue timeout, + boolean chunk + ) { + var request = InferModelAction.Request.forTextInput(id, update, inputs, true, timeout); + request.setPrefixType( + InputType.SEARCH == inputType ? TrainedModelPrefixStrings.PrefixType.SEARCH : TrainedModelPrefixStrings.PrefixType.INGEST + ); + request.setHighPriority(InputType.SEARCH == inputType); + request.setChunked(chunk); + return request; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java index 01829bfee5241..a19e377d59c18 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java @@ -34,7 +34,7 @@ import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; -import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction; @@ -58,6 +58,7 @@ import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.buildInferenceRequest; public class ElserInternalService implements InferenceService { @@ -259,7 +260,7 @@ public void stop(String inferenceEntityId, ActionListener listener) { public void infer( Model model, @Nullable String query, - List input, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -274,16 +275,21 @@ public void infer( return; } - var request = InferTrainedModelDeploymentAction.Request.forTextInput( + var request = buildInferenceRequest( model.getConfigurations().getInferenceEntityId(), TextExpansionConfigUpdate.EMPTY_UPDATE, - input, - timeout + inputs, + inputType, + timeout, + false // chunk ); + client.execute( - InferTrainedModelDeploymentAction.INSTANCE, + InferModelAction.INSTANCE, request, - listener.delegateFailureAndWrap((l, inferenceResult) -> l.onResponse(SparseEmbeddingResults.of(inferenceResult.getResults()))) + listener.delegateFailureAndWrap( + (l, inferenceResult) -> l.onResponse(SparseEmbeddingResults.of(inferenceResult.getInferenceResults())) + ) ); } @@ -303,7 +309,7 @@ public void chunkedInfer( public void chunkedInfer( Model model, @Nullable String query, - List input, + List inputs, Map taskSettings, InputType inputType, @Nullable ChunkingOptions chunkingOptions, @@ -321,18 +327,21 @@ public void chunkedInfer( ? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span()) : new TokenizationConfigUpdate(null, null); - var request = InferTrainedModelDeploymentAction.Request.forTextInput( + var request = buildInferenceRequest( model.getConfigurations().getInferenceEntityId(), configUpdate, - input, - timeout + inputs, + inputType, + timeout, + true // chunk ); - request.setChunkResults(true); client.execute( - InferTrainedModelDeploymentAction.INSTANCE, + InferModelAction.INSTANCE, request, - listener.delegateFailureAndWrap((l, inferenceResult) -> l.onResponse(translateChunkedResults(inferenceResult.getResults()))) + listener.delegateFailureAndWrap( + (l, inferenceResult) -> l.onResponse(translateChunkedResults(inferenceResult.getInferenceResults())) + ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index ea11e9d0343e3..b06f8b0027caf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -30,10 +30,13 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings; @@ -468,21 +471,16 @@ public void testChunkInfer() { mlTrainedModelResults.add(ChunkedTextEmbeddingResultsTests.createRandomResults()); mlTrainedModelResults.add(ChunkedTextEmbeddingResultsTests.createRandomResults()); mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); - var response = new InferTrainedModelDeploymentAction.Response(mlTrainedModelResults); + var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); ThreadPool threadpool = new TestThreadPool("test"); Client client = mock(Client.class); when(client.threadPool()).thenReturn(threadpool); doAnswer(invocationOnMock -> { - var listener = (ActionListener) invocationOnMock.getArguments()[2]; + var listener = (ActionListener) invocationOnMock.getArguments()[2]; listener.onResponse(response); return null; - }).when(client) - .execute( - same(InferTrainedModelDeploymentAction.INSTANCE), - any(InferTrainedModelDeploymentAction.Request.class), - any(ActionListener.class) - ); + }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); var model = new MultilingualE5SmallModel( "foo", @@ -644,6 +642,31 @@ public void testParsePersistedConfig_Rerank() { } } + public void testBuildInferenceRequest() { + var id = randomAlphaOfLength(5); + var inputs = randomList(1, 3, () -> randomAlphaOfLength(4)); + var inputType = randomFrom(InputType.SEARCH, InputType.INGEST); + var timeout = randomTimeValue(); + var chunk = randomBoolean(); + var request = ElasticsearchInternalService.buildInferenceRequest( + id, + TextEmbeddingConfigUpdate.EMPTY_INSTANCE, + inputs, + inputType, + timeout, + chunk + ); + + assertEquals(id, request.getId()); + assertEquals(inputs, request.getTextInput()); + assertEquals( + inputType == InputType.INGEST ? TrainedModelPrefixStrings.PrefixType.INGEST : TrainedModelPrefixStrings.PrefixType.SEARCH, + request.getPrefixType() + ); + assertEquals(timeout, request.getInferenceTimeout()); + assertEquals(chunk, request.isChunked()); + } + private ElasticsearchInternalService createService(Client client) { var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); return new ElasticsearchInternalService(context); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java index dcbb523cceed9..2fdb208a56e1b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; @@ -337,21 +338,16 @@ public void testChunkInfer() { mlTrainedModelResults.add(ChunkedTextExpansionResultsTests.createRandomResults()); mlTrainedModelResults.add(ChunkedTextExpansionResultsTests.createRandomResults()); mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); - var response = new InferTrainedModelDeploymentAction.Response(mlTrainedModelResults); + var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); ThreadPool threadpool = new TestThreadPool("test"); Client client = mock(Client.class); when(client.threadPool()).thenReturn(threadpool); doAnswer(invocationOnMock -> { - var listener = (ActionListener) invocationOnMock.getArguments()[2]; + var listener = (ActionListener) invocationOnMock.getArguments()[2]; listener.onResponse(response); return null; - }).when(client) - .execute( - same(InferTrainedModelDeploymentAction.INSTANCE), - any(InferTrainedModelDeploymentAction.Request.class), - any(ActionListener.class) - ); + }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); var model = new ElserInternalModel( "foo", diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index d54cac9dca496..004d87d643962 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -290,6 +290,7 @@ private void inferAgainstAllocatedModel( deploymentRequest.setPrefixType(request.getPrefixType()); deploymentRequest.setNodes(node.v1()); deploymentRequest.setParentTask(parentTaskId); + deploymentRequest.setChunkResults(request.isChunked()); startPos += node.v2(); From c80a32e76fd99c3f8961be6d35696c1fa8a20b0b Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 5 Jun 2024 16:19:37 +0100 Subject: [PATCH 14/30] [ML] Reset retryable index requests after failures (#109320) Fixes the `autoGeneratedTimestamp should not be set externally` error --- docs/changelog/109320.yaml | 5 ++++ .../persistence/ResultsPersisterService.java | 7 +++-- .../ResultsPersisterServiceTests.java | 29 +++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 docs/changelog/109320.yaml diff --git a/docs/changelog/109320.yaml b/docs/changelog/109320.yaml new file mode 100644 index 0000000000000..84aff5b1d769d --- /dev/null +++ b/docs/changelog/109320.yaml @@ -0,0 +1,5 @@ +pr: 109320 +summary: Reset retryable index requests after failures +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java index 82d19f9d72273..83572b02f754d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java @@ -325,7 +325,7 @@ private static boolean isIrrecoverable(Exception ex) { } @SuppressWarnings("NonAtomicOperationOnVolatileField") - private static class BulkRequestRewriter { + static class BulkRequestRewriter { private volatile BulkRequest bulkRequest; BulkRequestRewriter(BulkRequest initialRequest) { @@ -533,7 +533,7 @@ public void cancel(Exception e) { } } - private static BulkRequest buildNewRequestFromFailures(BulkRequest bulkRequest, BulkResponse bulkResponse) { + static BulkRequest buildNewRequestFromFailures(BulkRequest bulkRequest, BulkResponse bulkResponse) { // If we failed, lets set the bulkRequest to be a collection of the failed requests BulkRequest bulkRequestOfFailures = new BulkRequest(); Set failedDocIds = Arrays.stream(bulkResponse.getItems()) @@ -542,6 +542,9 @@ private static BulkRequest buildNewRequestFromFailures(BulkRequest bulkRequest, .collect(Collectors.toSet()); bulkRequest.requests().forEach(docWriteRequest -> { if (failedDocIds.contains(docWriteRequest.id())) { + if (docWriteRequest instanceof IndexRequest ir) { + ir.reset(); + } bulkRequestOfFailures.add(docWriteRequest); } }); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java index 2acf2e3da3cf6..e109f2995d215 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java @@ -60,6 +60,7 @@ import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.any; @@ -375,6 +376,34 @@ public void testBulkRequestRetriesMsgHandlerIsCalled() { assertThat(lastMessage.get(), containsString("failed to index after [1] attempts. Will attempt again")); } + public void testBuildNewRequestFromFailures_resetsId() { + var bulkRequest = new BulkRequest(); + var indexRequestAutoGeneratedId = new IndexRequest("index-foo"); + indexRequestAutoGeneratedId.autoGenerateId(); + var autoGenId = indexRequestAutoGeneratedId.id(); + var plainIndexRequest = new IndexRequest("index-foo2").id("id-set"); + + bulkRequest.add(indexRequestAutoGeneratedId); + bulkRequest.add(plainIndexRequest); + + var bulkResponse = mock(BulkResponse.class); + + var failed = mock(BulkItemResponse.class); + when(failed.isFailed()).thenReturn(Boolean.TRUE); + when(failed.getId()).thenReturn(autoGenId); + + var sucessful = mock(BulkItemResponse.class); + when(sucessful.isFailed()).thenReturn(Boolean.FALSE); + + when(bulkResponse.getItems()).thenReturn(new BulkItemResponse[] { failed, sucessful }); + + var modifiedRequestForRetry = ResultsPersisterService.buildNewRequestFromFailures(bulkRequest, bulkResponse); + assertThat(modifiedRequestForRetry.requests(), hasSize(1)); // only the failed item is in the new request + assertThat(modifiedRequestForRetry.requests().get(0), instanceOf(IndexRequest.class)); + var ir = (IndexRequest) modifiedRequestForRetry.requests().get(0); + assertEquals(ir.getAutoGeneratedTimestamp(), -1L); // failed request was reset + } + private static Stubber doAnswerWithResponses(Response response1, Response response2) { return doAnswer(withResponse(response1)).doAnswer(withResponse(response2)); } From 80a22ec04678e32e8a501930579a98920cf9df27 Mon Sep 17 00:00:00 2001 From: "Mark J. Hoy" Date: Wed, 5 Jun 2024 11:23:29 -0400 Subject: [PATCH 15/30] [Inference API] Add Docs for Mistral Embedding Support for the Inference API (#109319) * Initial docs for put-inference for Mistral * adds mistral embeddings to tutorial; add changelog * update mistral text and dimensions * fix mistral spelling error * fix azure AI studio; fix Mistral label * fix auto-formatted items * change pipeline button back to azure openai * put proper Azure AI Studio include in * fix missing azure-openai; fix huggingface hidden * fix mistral tab for reindex * re-add Mistral service settings to put inference --- docs/changelog/109194.yaml | 5 ++ .../inference/put-inference.asciidoc | 61 ++++++++++++++++- .../semantic-search-inference.asciidoc | 64 ++++++++---------- .../infer-api-ingest-pipeline-widget.asciidoc | 17 +++++ .../infer-api-ingest-pipeline.asciidoc | 26 ++++++++ .../infer-api-mapping-widget.asciidoc | 20 +++++- .../inference-api/infer-api-mapping.asciidoc | 34 ++++++++++ .../infer-api-reindex-widget.asciidoc | 21 +++++- .../inference-api/infer-api-reindex.asciidoc | 23 +++++++ .../infer-api-requirements-widget.asciidoc | 17 +++++ .../infer-api-requirements.asciidoc | 6 ++ .../infer-api-search-widget.asciidoc | 17 +++++ .../inference-api/infer-api-search.asciidoc | 65 +++++++++++++++++++ .../infer-api-task-widget.asciidoc | 17 +++++ .../inference-api/infer-api-task.asciidoc | 20 ++++++ 15 files changed, 372 insertions(+), 41 deletions(-) create mode 100644 docs/changelog/109194.yaml diff --git a/docs/changelog/109194.yaml b/docs/changelog/109194.yaml new file mode 100644 index 0000000000000..bf50139547f62 --- /dev/null +++ b/docs/changelog/109194.yaml @@ -0,0 +1,5 @@ +pr: 109194 +summary: "[Inference API] Add Mistral Embeddings Support to Inference API" +area: Machine Learning +type: enhancement +issues: [ ] diff --git a/docs/reference/inference/put-inference.asciidoc b/docs/reference/inference/put-inference.asciidoc index e7d66e930e81f..5060e47447f03 100644 --- a/docs/reference/inference/put-inference.asciidoc +++ b/docs/reference/inference/put-inference.asciidoc @@ -7,7 +7,7 @@ experimental[] Creates an {infer} endpoint to perform an {infer} task. IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in -{ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure OpenAI, Google AI Studio or Hugging Face. +{ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Mistral, Azure OpenAI, Google AI Studio or Hugging Face. For built-in models and models uploaded though Eland, the {infer} APIs offer an alternative way to use and manage trained models. However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. @@ -41,6 +41,7 @@ The following services are available through the {infer} API: * ELSER * Google AI Studio * Hugging Face +* Mistral * OpenAI [discrete] @@ -76,6 +77,7 @@ Available services: * `elser`: specify the `sparse_embedding` task type to use the ELSER service. * `googleaistudio`: specify the `completion` or `text_embeddig` task to use the Google AI Studio service. * `hugging_face`: specify the `text_embedding` task type to use the Hugging Face service. +* `mistral`: specify the `text_embedding` task type to use the Mistral service. * `openai`: specify the `completion` or `text_embedding` task type to use the OpenAI service. @@ -341,6 +343,41 @@ To modify this, set the `requests_per_minute` setting of this object in your ser } ---- +===== ++ +.`service_settings` for the `mistral` service +[%collapsible%closed] +===== + +`api_key`::: +(Required, string) +A valid API key for your Mistral account. +You can find your Mistral API keys or you can create a new one +https://console.mistral.ai/api-keys/[on the API Keys page]. + +`model`::: +(Required, string) +The name of the model to use for the {infer} task. +Refer to the https://docs.mistral.ai/getting-started/models/[Mistral models documentation] +for the list of available text embedding models. + +`max_input_tokens`::: +(Optional, integer) +Allows you to specify the maximum number of tokens per input before chunking occurs. + +`rate_limit`::: +(Optional, object) +By default, the `mistral` service sets the number of requests allowed per minute to `240`. +This helps to minimize the number of rate limit errors returned from the Mistral API. +To modify this, set the `requests_per_minute` setting of this object in your service settings: ++ +[source,text] +---- +"rate_limit": { + "requests_per_minute": <> +} +---- + ===== + .`service_settings` for the `openai` service @@ -777,6 +814,28 @@ PUT _inference/text_embedding/my-msmarco-minilm-model <1> The `model_id` must be the ID of a text embedding model which has already been {ml-docs}/ml-nlp-import-model.html#ml-nlp-import-script[uploaded through Eland]. +[discrete] +[[inference-example-mistral]] +===== Mistral Service + +The following example shows how to create an {infer} endpoint called +`mistral-embeddings-test` to perform a `text_embedding` task type. + +[source,console] +------------------------------------------------------------ +PUT _inference/text_embedding/mistral-embeddings-test +{ + "service": "mistral", + "service_settings": { + "api_key": "", + "model": "mistral-embed" <1> + } +} +------------------------------------------------------------ +// TEST[skip:TBD] +<1> The `model` must be the ID of a text embedding model which can be found in the +https://docs.mistral.ai/getting-started/models/[Mistral models documentation] + [discrete] [[inference-example-openai]] ===== OpenAI service diff --git a/docs/reference/search/search-your-data/semantic-search-inference.asciidoc b/docs/reference/search/search-your-data/semantic-search-inference.asciidoc index 89464d46744b2..e53d895882b4e 100644 --- a/docs/reference/search/search-your-data/semantic-search-inference.asciidoc +++ b/docs/reference/search/search-your-data/semantic-search-inference.asciidoc @@ -1,20 +1,22 @@ [[semantic-search-inference]] === Tutorial: semantic search with the {infer} API + ++++ Semantic search with the {infer} API ++++ -The instructions in this tutorial shows you how to use the {infer} API with -various services to perform semantic search on your data. The following examples -use Cohere's `embed-english-v3.0` model, the `all-mpnet-base-v2` model from -HuggingFace, and OpenAI's `text-embedding-ada-002` second generation embedding -model. You can use any Cohere and OpenAI models, they are all supported by the -{infer} API. For a list of supported models available on HuggingFace, refer to +The instructions in this tutorial shows you how to use the {infer} API with various services to perform semantic search on your data. +The following examples use Cohere's `embed-english-v3.0` model, the `all-mpnet-base-v2` model from HuggingFace, and OpenAI's `text-embedding-ada-002` second generation embedding model. +You can use any Cohere and OpenAI models, they are all supported by the +{infer} API. +For a list of supported models available on HuggingFace, refer to <>. -Click the name of the service you want to use on any of the widgets below to -review the corresponding instructions. +Azure based examples use models available through https://ai.azure.com/explore/models?selectedTask=embeddings[Azure AI Studio] +or https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models[Azure OpenAI]. +Mistral examples use the `mistral-embed` model from https://docs.mistral.ai/getting-started/models/[the Mistral API]. +Click the name of the service you want to use on any of the widgets below to review the corresponding instructions. [discrete] [[infer-service-requirements]] @@ -22,7 +24,6 @@ review the corresponding instructions. include::{es-ref-dir}/tab-widgets/inference-api/infer-api-requirements-widget.asciidoc[] - [discrete] [[infer-text-embedding-task]] ==== Create an inference endpoint @@ -31,49 +32,42 @@ Create an {infer} endpoint by using the <>: include::{es-ref-dir}/tab-widgets/inference-api/infer-api-task-widget.asciidoc[] - [discrete] [[infer-service-mappings]] ==== Create the index mapping -The mapping of the destination index - the index that contains the embeddings -that the model will create based on your input text - must be created. The -destination index must have a field with the <> +The mapping of the destination index - the index that contains the embeddings that the model will create based on your input text - must be created. +The destination index must have a field with the <> field type to index the output of the used model. include::{es-ref-dir}/tab-widgets/inference-api/infer-api-mapping-widget.asciidoc[] - [discrete] [[infer-service-inference-ingest-pipeline]] ==== Create an ingest pipeline with an inference processor Create an <> with an -<> and use the model you created above to -infer against the data that is being ingested in the pipeline. +<> and use the model you created above to infer against the data that is being ingested in the pipeline. include::{es-ref-dir}/tab-widgets/inference-api/infer-api-ingest-pipeline-widget.asciidoc[] - [discrete] [[infer-load-data]] ==== Load data -In this step, you load the data that you later use in the {infer} ingest -pipeline to create embeddings from it. +In this step, you load the data that you later use in the {infer} ingest pipeline to create embeddings from it. -Use the `msmarco-passagetest2019-top1000` data set, which is a subset of the MS -MARCO Passage Ranking data set. It consists of 200 queries, each accompanied by -a list of relevant text passages. All unique passages, along with their IDs, -have been extracted from that data set and compiled into a +Use the `msmarco-passagetest2019-top1000` data set, which is a subset of the MS MARCO Passage Ranking data set. +It consists of 200 queries, each accompanied by a list of relevant text passages. +All unique passages, along with their IDs, have been extracted from that data set and compiled into a https://github.com/elastic/stack-docs/blob/main/docs/en/stack/ml/nlp/data/msmarco-passagetest2019-unique.tsv[tsv file]. Download the file and upload it to your cluster using the {kibana-ref}/connect-to-elasticsearch.html#upload-data-kibana[Data Visualizer] -in the {ml-app} UI. Assign the name `id` to the first column and `content` to -the second column. The index name is `test-data`. Once the upload is complete, -you can see an index named `test-data` with 182469 documents. - +in the {ml-app} UI. +Assign the name `id` to the first column and `content` to the second column. +The index name is `test-data`. +Once the upload is complete, you can see an index named `test-data` with 182469 documents. [discrete] [[reindexing-data-infer]] @@ -92,8 +86,7 @@ GET _tasks/ ---- // TEST[skip:TBD] -You can also cancel the reindexing process if you don't want to wait until the -reindexing process is fully complete which might take hours for large data sets: +You can also cancel the reindexing process if you don't want to wait until the reindexing process is fully complete which might take hours for large data sets: [source,console] ---- @@ -106,17 +99,14 @@ POST _tasks//_cancel [[infer-semantic-search]] ==== Semantic search -After the data set has been enriched with the embeddings, you can query the data -using {ref}/knn-search.html#knn-semantic-search[semantic search]. Pass a -`query_vector_builder` to the k-nearest neighbor (kNN) vector search API, and -provide the query text and the model you have used to create the embeddings. +After the data set has been enriched with the embeddings, you can query the data using {ref}/knn-search.html#knn-semantic-search[semantic search]. +Pass a +`query_vector_builder` to the k-nearest neighbor (kNN) vector search API, and provide the query text and the model you have used to create the embeddings. -NOTE: If you cancelled the reindexing process, you run the query only a part of -the data which affects the quality of your results. +NOTE: If you cancelled the reindexing process, you run the query only a part of the data which affects the quality of your results. include::{es-ref-dir}/tab-widgets/inference-api/infer-api-search-widget.asciidoc[] - [discrete] [[infer-interactive-tutorials]] ==== Interactive tutorials @@ -124,4 +114,4 @@ include::{es-ref-dir}/tab-widgets/inference-api/infer-api-search-widget.asciidoc You can also find tutorials in an interactive Colab notebook format using the {es} Python client: * https://colab.research.google.com/github/elastic/elasticsearch-labs/blob/main/notebooks/integrations/cohere/inference-cohere.ipynb[Cohere {infer} tutorial notebook] -* https://colab.research.google.com/github/elastic/elasticsearch-labs/blob/main/notebooks/search/07-inference.ipynb[OpenAI {infer} tutorial notebook] \ No newline at end of file +* https://colab.research.google.com/github/elastic/elasticsearch-labs/blob/main/notebooks/search/07-inference.ipynb[OpenAI {infer} tutorial notebook] diff --git a/docs/reference/tab-widgets/inference-api/infer-api-ingest-pipeline-widget.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-ingest-pipeline-widget.asciidoc index 80f6da2cf602a..c8a42c4d0585a 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-ingest-pipeline-widget.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-ingest-pipeline-widget.asciidoc @@ -31,6 +31,12 @@ id="infer-api-ingest-azure-ai-studio"> Azure AI Studio +
+
diff --git a/docs/reference/tab-widgets/inference-api/infer-api-ingest-pipeline.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-ingest-pipeline.asciidoc index 4f85c69c7605e..a239c79e5a6d1 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-ingest-pipeline.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-ingest-pipeline.asciidoc @@ -138,3 +138,29 @@ PUT _ingest/pipeline/azure_ai_studio_embeddings and the `output_field` that will contain the {infer} results. // end::azure-ai-studio[] + +// tag::mistral[] + +[source,console] +-------------------------------------------------- +PUT _ingest/pipeline/mistral_embeddings +{ + "processors": [ + { + "inference": { + "model_id": "mistral_embeddings", <1> + "input_output": { <2> + "input_field": "content", + "output_field": "content_embedding" + } + } + } + ] +} +-------------------------------------------------- +<1> The name of the inference endpoint you created by using the +<>, it's referred to as `inference_id` in that step. +<2> Configuration object that defines the `input_field` for the {infer} process +and the `output_field` that will contain the {infer} results. + +// end::mistral[] diff --git a/docs/reference/tab-widgets/inference-api/infer-api-mapping-widget.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-mapping-widget.asciidoc index f6aa44a2b60a7..80c7c7ef23ee3 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-mapping-widget.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-mapping-widget.asciidoc @@ -31,6 +31,12 @@ id="infer-api-mapping-azure-ai-studio"> Azure AI Studio +
+ aria-labelledby="infer-api-mapping-hf" + hidden=""> ++++ include::infer-api-mapping.asciidoc[tag=hugging-face] @@ -83,6 +90,17 @@ include::infer-api-mapping.asciidoc[tag=azure-openai] include::infer-api-mapping.asciidoc[tag=azure-ai-studio] +++++ +
+ diff --git a/docs/reference/tab-widgets/inference-api/infer-api-mapping.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-mapping.asciidoc index 8038dea713117..a1bce38a02ad2 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-mapping.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-mapping.asciidoc @@ -173,3 +173,37 @@ the {infer} pipeline configuration in the next step. <6> The field type which is text in this example. // end::azure-ai-studio[] + +// tag::mistral[] + +[source,console] +-------------------------------------------------- +PUT mistral-embeddings +{ + "mappings": { + "properties": { + "content_embedding": { <1> + "type": "dense_vector", <2> + "dims": 1024, <3> + "element_type": "float", + "similarity": "dot_product" <4> + }, + "content": { <5> + "type": "text" <6> + } + } + } +} +-------------------------------------------------- +<1> The name of the field to contain the generated tokens. It must be referenced +in the {infer} pipeline configuration in the next step. +<2> The field to contain the tokens is a `dense_vector` field. +<3> The output dimensions of the model. This value may be found on the https://docs.mistral.ai/getting-started/models/[Mistral model reference]. +<4> For Mistral embeddings, the `dot_product` function should be used to +calculate similarity. +<5> The name of the field from which to create the dense vector representation. +In this example, the name of the field is `content`. It must be referenced in +the {infer} pipeline configuration in the next step. +<6> The field type which is text in this example. + +// end::mistral[] diff --git a/docs/reference/tab-widgets/inference-api/infer-api-reindex-widget.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-reindex-widget.asciidoc index a35ea4e3b0207..4face6a105819 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-reindex-widget.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-reindex-widget.asciidoc @@ -27,10 +27,16 @@ +
+
diff --git a/docs/reference/tab-widgets/inference-api/infer-api-reindex.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-reindex.asciidoc index a862e864fb068..927e47ea4d67c 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-reindex.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-reindex.asciidoc @@ -131,3 +131,26 @@ might affect the throughput of the reindexing process. If this happens, change `size` to `3` or a similar value in magnitude. // end::azure-ai-studio[] + +// tag::mistral[] + +[source,console] +---- +POST _reindex?wait_for_completion=false +{ + "source": { + "index": "test-data", + "size": 50 <1> + }, + "dest": { + "index": "mistral-embeddings", + "pipeline": "mistral_embeddings" + } +} +---- +// TEST[skip:TBD] +<1> The default batch size for reindexing is 1000. Reducing `size` to a smaller +number makes the update of the reindexing process quicker which enables you to +follow the progress closely and detect errors early. + +// end::mistral[] diff --git a/docs/reference/tab-widgets/inference-api/infer-api-requirements-widget.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-requirements-widget.asciidoc index 85b15678d1681..9981eb90d4929 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-requirements-widget.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-requirements-widget.asciidoc @@ -31,6 +31,12 @@ id="infer-api-requirements-azure-ai-studio"> Azure AI Studio +
+
diff --git a/docs/reference/tab-widgets/inference-api/infer-api-requirements.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-requirements.asciidoc index 3ffcc6e4dd2b1..435e53bbc0bc0 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-requirements.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-requirements.asciidoc @@ -33,3 +33,9 @@ You can apply for access to Azure OpenAI by completing the form at https://aka.m * A deployed https://ai.azure.com/explore/models?selectedTask=embeddings[embeddings] or https://ai.azure.com/explore/models?selectedTask=chat-completion[chat completion] model. // end::azure-ai-studio[] + +// tag::mistral[] +* A Mistral Account on https://console.mistral.ai/[La Plateforme] +* An API key generated for your account + +// end::mistral[] diff --git a/docs/reference/tab-widgets/inference-api/infer-api-search-widget.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-search-widget.asciidoc index 17b747e86be4a..6a67b28f91601 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-search-widget.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-search-widget.asciidoc @@ -31,6 +31,12 @@ id="infer-api-search-azure-ai-studio"> Azure AI Studio +
+
diff --git a/docs/reference/tab-widgets/inference-api/infer-api-search.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-search.asciidoc index 4f1a24959de9f..523c2301e75ff 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-search.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-search.asciidoc @@ -340,3 +340,68 @@ query from the `azure-ai-studio-embeddings` index sorted by their proximity to t // NOTCONSOLE // end::azure-ai-studio[] + +// tag::mistral[] + +[source,console] +-------------------------------------------------- +GET mistral-embeddings/_search +{ + "knn": { + "field": "content_embedding", + "query_vector_builder": { + "text_embedding": { + "model_id": "mistral_embeddings", + "model_text": "Calculate fuel cost" + } + }, + "k": 10, + "num_candidates": 100 + }, + "_source": [ + "id", + "content" + ] +} +-------------------------------------------------- +// TEST[skip:TBD] + +As a result, you receive the top 10 documents that are closest in meaning to the +query from the `mistral-embeddings` index sorted by their proximity to the query: + +[source,consol-result] +-------------------------------------------------- +"hits": [ + { + "_index": "mistral-embeddings", + "_id": "DDd5OowBHxQKHyc3TDSC", + "_score": 0.83704096, + "_source": { + "id": 862114, + "body": "How to calculate fuel cost for a road trip. By Tara Baukus Mello • Bankrate.com. Dear Driving for Dollars, My family is considering taking a long road trip to finish off the end of the summer, but I'm a little worried about gas prices and our overall fuel cost.It doesn't seem easy to calculate since we'll be traveling through many states and we are considering several routes.y family is considering taking a long road trip to finish off the end of the summer, but I'm a little worried about gas prices and our overall fuel cost. It doesn't seem easy to calculate since we'll be traveling through many states and we are considering several routes." + } + }, + { + "_index": "mistral-embeddings", + "_id": "ajd5OowBHxQKHyc3TDSC", + "_score": 0.8345704, + "_source": { + "id": 820622, + "body": "Home Heating Calculator. Typically, approximately 50% of the energy consumed in a home annually is for space heating. When deciding on a heating system, many factors will come into play: cost of fuel, installation cost, convenience and life style are all important.This calculator can help you estimate the cost of fuel for different heating appliances.hen deciding on a heating system, many factors will come into play: cost of fuel, installation cost, convenience and life style are all important. This calculator can help you estimate the cost of fuel for different heating appliances." + } + }, + { + "_index": "mistral-embeddings", + "_id": "Djd5OowBHxQKHyc3TDSC", + "_score": 0.8327426, + "_source": { + "id": 8202683, + "body": "Fuel is another important cost. This cost will depend on your boat, how far you travel, and how fast you travel. A 33-foot sailboat traveling at 7 knots should be able to travel 300 miles on 50 gallons of diesel fuel.If you are paying $4 per gallon, the trip would cost you $200.Most boats have much larger gas tanks than cars.uel is another important cost. This cost will depend on your boat, how far you travel, and how fast you travel. A 33-foot sailboat traveling at 7 knots should be able to travel 300 miles on 50 gallons of diesel fuel." + } + }, + (...) + ] +-------------------------------------------------- +// NOTCONSOLE + +// end::mistral[] diff --git a/docs/reference/tab-widgets/inference-api/infer-api-task-widget.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-task-widget.asciidoc index 3bccb140d44f6..1f3ad645d7c29 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-task-widget.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-task-widget.asciidoc @@ -31,6 +31,12 @@ id="infer-api-task-azure-ai-studio"> Azure AI Studio +
+
diff --git a/docs/reference/tab-widgets/inference-api/infer-api-task.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-task.asciidoc index 5692388a18531..18fa3ba541bff 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-task.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-task.asciidoc @@ -157,3 +157,23 @@ Also, when using this model the recommended similarity measure to use in the `dense_vector` field mapping is `dot_product`. // end::azure-ai-studio[] + +// tag::mistral[] + +[source,console] +------------------------------------------------------------ +PUT _inference/text_embedding/mistral_embeddings <1> +{ + "service": "mistral", + "service_settings": { + "api_key": "", <2> + "model": "" <3> + } +} +------------------------------------------------------------ +// TEST[skip:TBD] +<1> The task type is `text_embedding` in the path and the `inference_id` which is the unique identifier of the {infer} endpoint is `mistral_embeddings`. +<2> The API key for accessing the Mistral API. You can find this in your Mistral account's API Keys page. +<3> The Mistral embeddings model name, for example `mistral-embed`. + +// end::mistral[] From ac6c0eecc1f2cb75a6d040306b6373d666f64577 Mon Sep 17 00:00:00 2001 From: Martijn van Groningen Date: Wed, 5 Jun 2024 17:32:14 +0200 Subject: [PATCH 16/30] Ensure synthetic source and dv codec are enabled with logs index mode (attempt 2). (#109382) This was initially muted via #109365, because of a failing newly introduced assert. Original PR #109269 --- .../test/aggregations/time_series.yml | 6 +-- .../rest-api-spec/test/logsdb/10_settings.yml | 4 ++ .../cluster/metadata/IndexMetadata.java | 7 +-- .../index/mapper/SourceFieldMapper.java | 43 +++++++++++++---- .../index/LogsIndexModeTests.java | 31 +++++++----- .../index/codec/PerFieldMapperCodecTests.java | 47 +++++++++++++++---- .../index/mapper/SourceFieldMapperTests.java | 7 +++ .../index/mapper/MapperServiceTestCase.java | 5 ++ 8 files changed, 116 insertions(+), 34 deletions(-) diff --git a/modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/time_series.yml b/modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/time_series.yml index 7800923ab1580..1703d4908a753 100644 --- a/modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/time_series.yml +++ b/modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/time_series.yml @@ -294,11 +294,11 @@ setup: --- "Configure with no synthetic source": - requires: - cluster_features: ["gte_v8.9.0"] - reason: "Error message fix in 8.9" + cluster_features: ["gte_v8.15.0"] + reason: "Error message changed in 8.15.0" - do: - catch: '/Time series indices only support synthetic source./' + catch: '/Indices with with index mode \[time_series\] only support synthetic source/' indices.create: index: tsdb_error body: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/logsdb/10_settings.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/logsdb/10_settings.yml index 95075da20fe5e..128903f4faac8 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/logsdb/10_settings.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/logsdb/10_settings.yml @@ -77,6 +77,10 @@ create logs index: - is_true: test - match: { test.settings.index.mode: "logs" } + - do: + indices.get_mapping: + index: test + - match: { test.mappings._source.mode: synthetic } --- using default timestamp field mapping: diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 681dcb3e314e3..64809c963cb6d 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -2267,8 +2267,9 @@ IndexMetadata build(boolean repair) { } final boolean isSearchableSnapshot = SearchableSnapshotsSettings.isSearchableSnapshotStore(settings); - final String indexMode = settings.get(IndexSettings.MODE.getKey()); - final boolean isTsdb = indexMode != null && IndexMode.TIME_SERIES.getName().equals(indexMode.toLowerCase(Locale.ROOT)); + String indexModeString = settings.get(IndexSettings.MODE.getKey()); + final IndexMode indexMode = indexModeString != null ? IndexMode.fromString(indexModeString.toLowerCase(Locale.ROOT)) : null; + final boolean isTsdb = indexMode == IndexMode.TIME_SERIES; return new IndexMetadata( new Index(index, uuid), version, @@ -2308,7 +2309,7 @@ IndexMetadata build(boolean repair) { AutoExpandReplicas.SETTING.get(settings), isSearchableSnapshot, isSearchableSnapshot && settings.getAsBoolean(SEARCHABLE_SNAPSHOT_PARTIAL_SETTING_KEY, false), - isTsdb ? IndexMode.TIME_SERIES : null, + indexMode, isTsdb ? IndexSettings.TIME_SERIES_START_TIME.get(settings) : null, isTsdb ? IndexSettings.TIME_SERIES_END_TIME.get(settings) : null, SETTING_INDEX_VERSION_COMPATIBILITY.get(settings), diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java index b819ffb0ef6ad..d6a15ff9ec47a 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java @@ -69,6 +69,14 @@ private enum Mode { IndexMode.TIME_SERIES ); + private static final SourceFieldMapper LOGS_DEFAULT = new SourceFieldMapper( + Mode.SYNTHETIC, + Explicit.IMPLICIT_TRUE, + Strings.EMPTY_ARRAY, + Strings.EMPTY_ARRAY, + IndexMode.LOGS + ); + /* * Synthetic source was added as the default for TSDB in v.8.7. The legacy field mapper below * is used in bwc tests and mixed clusters containing time series indexes created in an earlier version. @@ -156,7 +164,8 @@ protected Parameter[] getParameters() { private boolean isDefault() { Mode m = mode.get(); - if (m != null && (((indexMode == IndexMode.TIME_SERIES && m == Mode.SYNTHETIC) == false) || m == Mode.DISABLED)) { + if (m != null + && (((indexMode != null && indexMode.isSyntheticSourceEnabled() && m == Mode.SYNTHETIC) == false) || m == Mode.DISABLED)) { return false; } return enabled.get().value() && includes.getValue().isEmpty() && excludes.getValue().isEmpty(); @@ -165,15 +174,19 @@ private boolean isDefault() { @Override public SourceFieldMapper build() { if (enabled.getValue().explicit()) { - if (indexMode == IndexMode.TIME_SERIES) { - throw new MapperParsingException("Time series indices only support synthetic source"); + if (indexMode != null && indexMode.isSyntheticSourceEnabled()) { + throw new MapperParsingException("Indices with with index mode [" + indexMode + "] only support synthetic source"); } if (mode.get() != null) { throw new MapperParsingException("Cannot set both [mode] and [enabled] parameters"); } } if (isDefault()) { - return indexMode == IndexMode.TIME_SERIES ? TSDB_DEFAULT : DEFAULT; + return switch (indexMode) { + case TIME_SERIES -> TSDB_DEFAULT; + case LOGS -> LOGS_DEFAULT; + default -> DEFAULT; + }; } if (supportsNonDefaultParameterValues == false) { List disallowed = new ArrayList<>(); @@ -212,10 +225,21 @@ public SourceFieldMapper build() { } - public static final TypeParser PARSER = new ConfigurableTypeParser( - c -> c.getIndexSettings().getMode() == IndexMode.TIME_SERIES - ? c.getIndexSettings().getIndexVersionCreated().onOrAfter(IndexVersions.V_8_7_0) ? TSDB_DEFAULT : TSDB_LEGACY_DEFAULT - : DEFAULT, + public static final TypeParser PARSER = new ConfigurableTypeParser(c -> { + var indexMode = c.getIndexSettings().getMode(); + if (indexMode.isSyntheticSourceEnabled()) { + if (indexMode == IndexMode.TIME_SERIES) { + if (c.getIndexSettings().getIndexVersionCreated().onOrAfter(IndexVersions.V_8_7_0)) { + return TSDB_DEFAULT; + } else { + return TSDB_LEGACY_DEFAULT; + } + } else if (indexMode == IndexMode.LOGS) { + return LOGS_DEFAULT; + } + } + return DEFAULT; + }, c -> new Builder( c.getIndexSettings().getMode(), c.getSettings(), @@ -323,6 +347,9 @@ public void preParse(DocumentParserContext context) throws IOException { final BytesReference adaptedSource = applyFilters(originalSource, contentType); if (adaptedSource != null) { + assert context.indexSettings().getIndexVersionCreated().before(IndexVersions.V_8_7_0) + || indexMode == null + || indexMode.isSyntheticSourceEnabled() == false; final BytesRef ref = adaptedSource.toBytesRef(); context.doc().add(new StoredField(fieldType().name(), ref.bytes, ref.offset, ref.length)); } diff --git a/server/src/test/java/org/elasticsearch/index/LogsIndexModeTests.java b/server/src/test/java/org/elasticsearch/index/LogsIndexModeTests.java index fd73a8c9f8f52..caddc7d5ea5af 100644 --- a/server/src/test/java/org/elasticsearch/index/LogsIndexModeTests.java +++ b/server/src/test/java/org/elasticsearch/index/LogsIndexModeTests.java @@ -10,12 +10,13 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.mapper.MapperServiceTestCase; -import org.hamcrest.Matchers; +import org.elasticsearch.test.ESTestCase; -public class LogsIndexModeTests extends MapperServiceTestCase { +import static org.hamcrest.Matchers.equalTo; + +public class LogsIndexModeTests extends ESTestCase { public void testLogsIndexModeSetting() { - assertThat(IndexSettings.MODE.get(buildSettings()), Matchers.equalTo(IndexMode.LOGS)); + assertThat(IndexSettings.MODE.get(buildSettings()), equalTo(IndexMode.LOGS)); } public void testSortField() { @@ -24,8 +25,10 @@ public void testSortField() { .put(IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey(), "agent_id") .build(); final IndexMetadata metadata = IndexSettingsTests.newIndexMeta("test", sortSettings); + assertThat(metadata.getIndexMode(), equalTo(IndexMode.LOGS)); final IndexSettings settings = new IndexSettings(metadata, Settings.EMPTY); - assertThat("agent_id", Matchers.equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey()))); + assertThat(settings.getMode(), equalTo(IndexMode.LOGS)); + assertThat("agent_id", equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey()))); } public void testSortMode() { @@ -35,9 +38,11 @@ public void testSortMode() { .put(IndexSortConfig.INDEX_SORT_MODE_SETTING.getKey(), "max") .build(); final IndexMetadata metadata = IndexSettingsTests.newIndexMeta("test", sortSettings); + assertThat(metadata.getIndexMode(), equalTo(IndexMode.LOGS)); final IndexSettings settings = new IndexSettings(metadata, Settings.EMPTY); - assertThat("agent_id", Matchers.equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey()))); - assertThat("max", Matchers.equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_MODE_SETTING.getKey()))); + assertThat(settings.getMode(), equalTo(IndexMode.LOGS)); + assertThat("agent_id", equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey()))); + assertThat("max", equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_MODE_SETTING.getKey()))); } public void testSortOrder() { @@ -47,9 +52,11 @@ public void testSortOrder() { .put(IndexSortConfig.INDEX_SORT_ORDER_SETTING.getKey(), "desc") .build(); final IndexMetadata metadata = IndexSettingsTests.newIndexMeta("test", sortSettings); + assertThat(metadata.getIndexMode(), equalTo(IndexMode.LOGS)); final IndexSettings settings = new IndexSettings(metadata, Settings.EMPTY); - assertThat("agent_id", Matchers.equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey()))); - assertThat("desc", Matchers.equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_ORDER_SETTING.getKey()))); + assertThat(settings.getMode(), equalTo(IndexMode.LOGS)); + assertThat("agent_id", equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey()))); + assertThat("desc", equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_ORDER_SETTING.getKey()))); } public void testSortMissing() { @@ -59,9 +66,11 @@ public void testSortMissing() { .put(IndexSortConfig.INDEX_SORT_MISSING_SETTING.getKey(), "_last") .build(); final IndexMetadata metadata = IndexSettingsTests.newIndexMeta("test", sortSettings); + assertThat(metadata.getIndexMode(), equalTo(IndexMode.LOGS)); final IndexSettings settings = new IndexSettings(metadata, Settings.EMPTY); - assertThat("agent_id", Matchers.equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey()))); - assertThat("_last", Matchers.equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_MISSING_SETTING.getKey()))); + assertThat(settings.getMode(), equalTo(IndexMode.LOGS)); + assertThat("agent_id", equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey()))); + assertThat("_last", equalTo(getIndexSetting(settings, IndexSortConfig.INDEX_SORT_MISSING_SETTING.getKey()))); } private Settings buildSettings() { diff --git a/server/src/test/java/org/elasticsearch/index/codec/PerFieldMapperCodecTests.java b/server/src/test/java/org/elasticsearch/index/codec/PerFieldMapperCodecTests.java index 74657842488b5..525fa31673494 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/PerFieldMapperCodecTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/PerFieldMapperCodecTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.MapperTestUtils; import org.elasticsearch.index.codec.bloomfilter.ES87BloomFilterPostingsFormat; @@ -61,6 +62,28 @@ public class PerFieldMapperCodecTests extends ESTestCase { } """; + private static final String MAPPING_3 = """ + { + "_data_stream_timestamp": { + "enabled": true + }, + "properties": { + "@timestamp": { + "type": "date" + }, + "hostname": { + "type": "keyword" + }, + "response_size": { + "type": "long" + }, + "message": { + "type": "text" + } + } + } + """; + public void testUseBloomFilter() throws IOException { PerFieldFormatSupplier perFieldMapperCodec = createFormatSupplier(false, randomBoolean(), false); assertThat(perFieldMapperCodec.useBloomFilter("_id"), is(true)); @@ -103,13 +126,13 @@ public void testDoNotUseES87TSDBEncodingForTimestampFieldNonTimeSeriesIndex() th } public void testEnableES87TSDBCodec() throws IOException { - PerFieldFormatSupplier perFieldMapperCodec = createFormatSupplier(true, true, MAPPING_1); + PerFieldFormatSupplier perFieldMapperCodec = createFormatSupplier(true, IndexMode.TIME_SERIES, MAPPING_1); assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("gauge")), is(true)); assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("@timestamp")), is(true)); } public void testDisableES87TSDBCodec() throws IOException { - PerFieldFormatSupplier perFieldMapperCodec = createFormatSupplier(false, true, MAPPING_1); + PerFieldFormatSupplier perFieldMapperCodec = createFormatSupplier(false, IndexMode.TIME_SERIES, MAPPING_1); assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("gauge")), is(false)); assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("@timestamp")), is(false)); } @@ -144,31 +167,37 @@ private PerFieldFormatSupplier createFormatSupplier(boolean timestampField, bool } public void testUseES87TSDBEncodingSettingDisabled() throws IOException { - PerFieldFormatSupplier perFieldMapperCodec = createFormatSupplier(false, true, MAPPING_2); + PerFieldFormatSupplier perFieldMapperCodec = createFormatSupplier(false, IndexMode.TIME_SERIES, MAPPING_2); assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("@timestamp")), is(false)); assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("counter")), is(false)); assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("gauge")), is(false)); } public void testUseTimeSeriesModeDisabledCodecDisabled() throws IOException { - PerFieldFormatSupplier perFieldMapperCodec = createFormatSupplier(true, false, MAPPING_2); + PerFieldFormatSupplier perFieldMapperCodec = createFormatSupplier(true, IndexMode.STANDARD, MAPPING_2); assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("@timestamp")), is(false)); assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("counter")), is(false)); assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("gauge")), is(false)); } public void testUseTimeSeriesModeAndCodecEnabled() throws IOException { - PerFieldFormatSupplier perFieldMapperCodec = createFormatSupplier(true, true, MAPPING_2); + PerFieldFormatSupplier perFieldMapperCodec = createFormatSupplier(true, IndexMode.TIME_SERIES, MAPPING_2); assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("@timestamp")), is(true)); assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("counter")), is(true)); assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("gauge")), is(true)); } - private PerFieldFormatSupplier createFormatSupplier(boolean enableES87TSDBCodec, boolean timeSeries, String mapping) - throws IOException { + public void testLogsIndexMode() throws IOException { + PerFieldFormatSupplier perFieldMapperCodec = createFormatSupplier(true, IndexMode.LOGS, MAPPING_3); + assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("@timestamp")), is(true)); + assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("hostname")), is(true)); + assertThat((perFieldMapperCodec.useTSDBDocValuesFormat("response_size")), is(true)); + } + + private PerFieldFormatSupplier createFormatSupplier(boolean enableES87TSDBCodec, IndexMode mode, String mapping) throws IOException { Settings.Builder settings = Settings.builder(); - if (timeSeries) { - settings.put(IndexSettings.MODE.getKey(), "time_series"); + settings.put(IndexSettings.MODE.getKey(), mode); + if (mode == IndexMode.TIME_SERIES) { settings.put(IndexMetadata.INDEX_ROUTING_PATH.getKey(), "field"); } settings.put(IndexSettings.TIME_SERIES_ES87TSDB_CODEC_ENABLED_SETTING.getKey(), enableES87TSDBCodec); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java index 802a18645eab6..d0350c1d92a83 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java @@ -243,6 +243,13 @@ public void testSyntheticSourceInTimeSeries() throws IOException { assertEquals("{\"_source\":{\"mode\":\"synthetic\"}}", mapper.sourceMapper().toString()); } + public void testSyntheticSourceWithLogsIndexMode() throws IOException { + XContentBuilder mapping = fieldMapping(b -> { b.field("type", "keyword"); }); + DocumentMapper mapper = createLogsModeDocumentMapper(mapping); + assertTrue(mapper.sourceMapper().isSynthetic()); + assertEquals("{\"_source\":{\"mode\":\"synthetic\"}}", mapper.sourceMapper().toString()); + } + public void testSupportsNonDefaultParameterValues() throws IOException { Settings settings = Settings.builder().put(SourceFieldMapper.LOSSY_PARAMETERS_ALLOWED_SETTING_NAME, false).build(); { diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java index 388d8d6fa6ffd..dfd4a59e2c3a1 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java @@ -148,6 +148,11 @@ protected final DocumentMapper createTimeSeriesModeDocumentMapper(XContentBuilde return createMapperService(settings, mappings).documentMapper(); } + protected final DocumentMapper createLogsModeDocumentMapper(XContentBuilder mappings) throws IOException { + Settings settings = Settings.builder().put(IndexSettings.MODE.getKey(), "logs").build(); + return createMapperService(settings, mappings).documentMapper(); + } + protected final DocumentMapper createDocumentMapper(IndexVersion version, XContentBuilder mappings) throws IOException { return createMapperService(version, mappings).documentMapper(); } From 17c6230b9f314c6abeba88162981cfd458abbd09 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Thu, 6 Jun 2024 01:32:17 +1000 Subject: [PATCH 17/30] Remove 8.13 from branches.json --- branches.json | 3 --- 1 file changed, 3 deletions(-) diff --git a/branches.json b/branches.json index daf6d249f7268..2794b545facc6 100644 --- a/branches.json +++ b/branches.json @@ -7,9 +7,6 @@ { "branch": "8.14" }, - { - "branch": "8.13" - }, { "branch": "7.17" } From 307cac564802960a61a5626f1a40f59301d82e46 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Wed, 5 Jun 2024 14:08:44 -0400 Subject: [PATCH 18/30] ESQL: Move `NamedExpression` serialization (#109380) This moves the serialization for the remaining `NamedExpression` subclass into the class itself, and switches all direct serialization of `NamedExpression`s to `readNamedWriteable` and friends. All other `NamedExpression` subclasses extend from `Attribute` who's serialization was moved ealier. They are already registered under the "category class" for `Attribute`. This also registers them as `NamedExpression`s. --- .../xpack/esql/core/expression/Alias.java | 50 ++++++- .../xpack/esql/core/expression/Attribute.java | 3 +- .../esql/core/expression/NamedExpression.java | 13 +- .../esql/core/expression/UnresolvedAlias.java | 78 ----------- .../esql/core/expression/UnresolvedStar.java | 12 ++ .../xpack/esql/core/util/PlanStreamInput.java | 8 ++ .../esql/core/util/PlanStreamOutput.java | 27 ++++ .../esql/enrich/EnrichLookupService.java | 7 +- .../esql/expression/NamedExpressions.java | 1 - .../expression/UnresolvedNamePattern.java | 12 ++ .../function/UnsupportedAttribute.java | 6 + .../xpack/esql/io/stream/PlanNamedTypes.java | 122 +++++++----------- .../xpack/esql/io/stream/PlanStreamInput.java | 6 +- .../esql/io/stream/PlanStreamOutput.java | 8 +- .../xpack/esql/plugin/EsqlPlugin.java | 3 + .../xpack/esql/SerializationTestUtils.java | 3 + .../xpack/esql/expression/AliasTests.java | 86 ++++++++++++ .../function/ReferenceAttributeTests.java | 8 +- .../esql/io/stream/PlanNamedTypesTests.java | 11 -- 19 files changed, 276 insertions(+), 188 deletions(-) delete mode 100644 x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/UnresolvedAlias.java create mode 100644 x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamOutput.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/AliasTests.java diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Alias.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Alias.java index 58203c8a0072e..d9f99b6d92318 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Alias.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Alias.java @@ -6,11 +6,18 @@ */ package org.elasticsearch.xpack.esql.core.expression; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; +import org.elasticsearch.xpack.esql.core.util.PlanStreamOutput; +import java.io.IOException; import java.util.List; +import java.util.Objects; import static java.util.Collections.singletonList; @@ -22,7 +29,8 @@ * And in {@code SELECT col AS x} "col" is a named expression that gets renamed to "x" through an alias. * */ -public class Alias extends NamedExpression { +public final class Alias extends NamedExpression { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(NamedExpression.class, "Alias", Alias::new); private final Expression child; private final String qualifier; @@ -51,6 +59,32 @@ public Alias(Source source, String name, String qualifier, Expression child, Nam this.qualifier = qualifier; } + public Alias(StreamInput in) throws IOException { + this( + Source.readFrom((StreamInput & PlanStreamInput) in), + in.readString(), + in.readOptionalString(), + ((PlanStreamInput) in).readExpression(), + NameId.readFrom((StreamInput & PlanStreamInput) in), + in.readBoolean() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeString(name()); + out.writeOptionalString(qualifier()); + ((PlanStreamOutput) out).writeExpression(child()); + id().writeTo(out); + out.writeBoolean(synthetic()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, Alias::new, name(), qualifier, child, id(), synthetic()); @@ -113,4 +147,18 @@ public String nodeString() { public static Expression unwrap(Expression e) { return e instanceof Alias as ? as.child() : e; } + + @Override + public boolean equals(Object obj) { + if (super.equals(obj) == false) { + return false; + } + Alias other = (Alias) obj; + return Objects.equals(qualifier, other.qualifier); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), qualifier); + } } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java index 5326825ec1105..e89f39294a28b 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java @@ -6,7 +6,6 @@ */ package org.elasticsearch.xpack.esql.core.expression; -import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.core.Tuple; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -30,7 +29,7 @@ * is a named expression (an {@code Alias} will be created automatically for it). * The rest are not as they are not part of the projection and thus are not part of the derived table. */ -public abstract class Attribute extends NamedExpression implements NamedWriteable { +public abstract class Attribute extends NamedExpression { public static List getNamedWriteables() { // TODO add UnsupportedAttribute when these are moved to the same project return List.of(FieldAttribute.ENTRY, MetadataAttribute.ENTRY, ReferenceAttribute.ENTRY); diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/NamedExpression.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/NamedExpression.java index 4a3666c8b8aa7..e3e9a60180da7 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/NamedExpression.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/NamedExpression.java @@ -6,8 +6,11 @@ */ package org.elasticsearch.xpack.esql.core.expression; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.xpack.esql.core.tree.Source; +import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -15,7 +18,15 @@ * An expression that has a name. Named expressions can be used as a result * (by converting to an attribute). */ -public abstract class NamedExpression extends Expression { +public abstract class NamedExpression extends Expression implements NamedWriteable { + public static List getNamedWriteables() { + List entries = new ArrayList<>(); + for (NamedWriteableRegistry.Entry e : Attribute.getNamedWriteables()) { + entries.add(new NamedWriteableRegistry.Entry(NamedExpression.class, e.name, in -> (NamedExpression) e.reader.read(in))); + } + entries.add(Alias.ENTRY); + return entries; + } private final String name; private final NameId id; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/UnresolvedAlias.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/UnresolvedAlias.java deleted file mode 100644 index a4b0d06f54b83..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/UnresolvedAlias.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.expression; - -import org.elasticsearch.xpack.esql.core.capabilities.UnresolvedException; -import org.elasticsearch.xpack.esql.core.tree.NodeInfo; -import org.elasticsearch.xpack.esql.core.tree.Source; - -import java.util.List; -import java.util.Objects; - -import static java.util.Collections.singletonList; - -public class UnresolvedAlias extends UnresolvedNamedExpression { - - private final Expression child; - - public UnresolvedAlias(Source source, Expression child) { - super(source, singletonList(child)); - this.child = child; - } - - @Override - protected NodeInfo info() { - return NodeInfo.create(this, UnresolvedAlias::new, child); - } - - @Override - public Expression replaceChildren(List newChildren) { - return new UnresolvedAlias(source(), newChildren.get(0)); - } - - public Expression child() { - return child; - } - - @Override - public String unresolvedMessage() { - return "Unknown alias [" + name() + "]"; - } - - @Override - public Nullability nullable() { - throw new UnresolvedException("nullable", this); - } - - @Override - public int hashCode() { - return Objects.hash(child); - } - - @Override - public boolean equals(Object obj) { - /* - * Intentionally not calling the superclass - * equals because it uses id which we always - * mutate when we make a clone. - */ - if (obj == null || obj.getClass() != getClass()) { - return false; - } - return Objects.equals(child, ((UnresolvedAlias) obj).child); - } - - @Override - public String toString() { - return child + " AS ?"; - } - - @Override - public String nodeString() { - return toString(); - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/UnresolvedStar.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/UnresolvedStar.java index 198016c710ce3..f3b52cfcccf90 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/UnresolvedStar.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/UnresolvedStar.java @@ -6,10 +6,12 @@ */ package org.elasticsearch.xpack.esql.core.expression; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.capabilities.UnresolvedException; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import java.io.IOException; import java.util.List; import java.util.Objects; @@ -25,6 +27,16 @@ public UnresolvedStar(Source source, UnresolvedAttribute qualifier) { this.qualifier = qualifier; } + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("doesn't escape the node"); + } + + @Override + public String getWriteableName() { + throw new UnsupportedOperationException("doesn't escape the node"); + } + @Override protected NodeInfo info() { return NodeInfo.create(this, UnresolvedStar::new, qualifier); diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamInput.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamInput.java index df8fac06dd478..485084bac60b3 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamInput.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamInput.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.esql.core.util; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -33,4 +35,10 @@ public interface PlanStreamInput { * the same result. */ NameId mapNameId(long id) throws IOException; + + /** + * Read an {@link Expression} from the stream. This will soon be replaced with + * {@link StreamInput#readNamedWriteable}. + */ + Expression readExpression() throws IOException; } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamOutput.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamOutput.java new file mode 100644 index 0000000000000..6a3d8fb77316c --- /dev/null +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamOutput.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.core.util; + +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.esql.core.expression.Expression; + +import java.io.IOException; + +/** + * Interface for streams that can serialize plan components. This exists so + * ESQL proper can expose streaming capability to ESQL-core. If the world is kind + * and just we'll remove this when we flatten everything from ESQL-core into + * ESQL proper. + */ +public interface PlanStreamOutput { + /** + * Write an {@link Expression} to the stream. This will soon be replaced with + * {@link StreamOutput#writeNamedWriteable}. + */ + void writeExpression(Expression expression) throws IOException; +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java index 69d988c958169..05b78c8b5f309 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java @@ -96,9 +96,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import static org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanReader.readerFromPlanReader; -import static org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanWriter.writerFromPlanWriter; - /** * {@link EnrichLookupService} performs enrich lookup for a given input page. The lookup process consists of three stages: * - Stage 1: Finding matching document IDs for the input page. This stage is done by the {@link EnrichQuerySourceOperator} or its variants. @@ -460,7 +457,7 @@ private static class LookupRequest extends TransportRequest implements IndicesRe } this.toRelease = inputPage; PlanStreamInput planIn = new PlanStreamInput(in, PlanNameRegistry.INSTANCE, in.namedWriteableRegistry(), null); - this.extractFields = planIn.readCollectionAsList(readerFromPlanReader(PlanStreamInput::readNamedExpression)); + this.extractFields = planIn.readNamedWriteableCollectionAsList(NamedExpression.class); } @Override @@ -475,7 +472,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(matchField); out.writeWriteable(inputPage); PlanStreamOutput planOut = new PlanStreamOutput(out, PlanNameRegistry.INSTANCE, null); - planOut.writeCollection(extractFields, writerFromPlanWriter(PlanStreamOutput::writeNamedExpression)); + planOut.writeNamedWriteableCollection(extractFields); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/NamedExpressions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/NamedExpressions.java index cb6aaf879f3cb..d0c8adfd3c858 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/NamedExpressions.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/NamedExpressions.java @@ -17,7 +17,6 @@ import java.util.Map; public class NamedExpressions { - /** * Calculates the actual output of a command given the new attributes plus the existing inputs that are emitted as outputs * @param fields the fields added by the command diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/UnresolvedNamePattern.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/UnresolvedNamePattern.java index 7df28f0648318..98282b5dec0eb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/UnresolvedNamePattern.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/UnresolvedNamePattern.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.expression; import org.apache.lucene.util.automaton.CharacterRunAutomaton; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.capabilities.UnresolvedException; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Nullability; @@ -16,6 +17,7 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.util.CollectionUtils; +import java.io.IOException; import java.util.List; import java.util.Objects; @@ -42,6 +44,16 @@ public UnresolvedNamePattern(Source source, CharacterRunAutomaton automaton, Str this.name = name; } + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("doesn't escape the node"); + } + + @Override + public String getWriteableName() { + throw new UnsupportedOperationException("doesn't escape the node"); + } + public boolean match(String string) { return automaton.run(string); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java index fe6db916f7a0d..79dcc6a3d3920 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.NameId; +import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -36,6 +37,11 @@ public final class UnsupportedAttribute extends FieldAttribute implements Unreso "UnsupportedAttribute", UnsupportedAttribute::new ); + public static final NamedWriteableRegistry.Entry NAMED_EXPRESSION_ENTRY = new NamedWriteableRegistry.Entry( + NamedExpression.class, + ENTRY.name, + UnsupportedAttribute::new + ); private final String message; private final boolean hasCustomMessage; // TODO remove me and just use message != null? diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java index fe2c704743f5c..624c9f5c65ca5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java @@ -27,7 +27,6 @@ import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; -import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.Order; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; @@ -299,12 +298,12 @@ public static List namedTypeEntries() { of(LogicalPlan.class, Project.class, PlanNamedTypes::writeProject, PlanNamedTypes::readProject), of(LogicalPlan.class, TopN.class, PlanNamedTypes::writeTopN, PlanNamedTypes::readTopN), // Attributes - of(NamedExpression.class, FieldAttribute.class, (o, a) -> a.writeTo(o), FieldAttribute::new), - of(NamedExpression.class, ReferenceAttribute.class, (o, a) -> a.writeTo(o), ReferenceAttribute::new), - of(NamedExpression.class, MetadataAttribute.class, (o, a) -> a.writeTo(o), MetadataAttribute::new), - of(NamedExpression.class, UnsupportedAttribute.class, (o, a) -> a.writeTo(o), UnsupportedAttribute::new), + of(Expression.class, FieldAttribute.class, (o, a) -> a.writeTo(o), FieldAttribute::new), + of(Expression.class, ReferenceAttribute.class, (o, a) -> a.writeTo(o), ReferenceAttribute::new), + of(Expression.class, MetadataAttribute.class, (o, a) -> a.writeTo(o), MetadataAttribute::new), + of(Expression.class, UnsupportedAttribute.class, (o, a) -> a.writeTo(o), UnsupportedAttribute::new), // NamedExpressions - of(NamedExpression.class, Alias.class, PlanNamedTypes::writeAlias, PlanNamedTypes::readAlias), + of(Expression.class, Alias.class, (o, a) -> a.writeTo(o), Alias::new), // BinaryComparison of(EsqlBinaryComparison.class, Equals.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), of(EsqlBinaryComparison.class, NotEquals.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), @@ -454,7 +453,7 @@ static AggregateExec readAggregateExec(PlanStreamInput in) throws IOException { Source.readFrom(in), in.readPhysicalPlanNode(), in.readCollectionAsList(readerFromPlanReader(PlanStreamInput::readExpression)), - readNamedExpressions(in), + in.readNamedWriteableCollectionAsList(NamedExpression.class), in.readEnum(AggregateExec.Mode.class), in.readOptionalVInt() ); @@ -464,7 +463,7 @@ static void writeAggregateExec(PlanStreamOutput out, AggregateExec aggregateExec Source.EMPTY.writeTo(out); out.writePhysicalPlanNode(aggregateExec.child()); out.writeCollection(aggregateExec.groupings(), writerFromPlanWriter(PlanStreamOutput::writeExpression)); - writeNamedExpressions(out, aggregateExec.aggregates()); + out.writeNamedWriteableCollection(aggregateExec.aggregates()); out.writeEnum(aggregateExec.getMode()); out.writeOptionalVInt(aggregateExec.estimatedRowSize()); } @@ -547,19 +546,19 @@ static void writeIndexMode(StreamOutput out, IndexMode indexMode) throws IOExcep } static EvalExec readEvalExec(PlanStreamInput in) throws IOException { - return new EvalExec(Source.readFrom(in), in.readPhysicalPlanNode(), readAliases(in)); + return new EvalExec(Source.readFrom(in), in.readPhysicalPlanNode(), in.readCollectionAsList(Alias::new)); } static void writeEvalExec(PlanStreamOutput out, EvalExec evalExec) throws IOException { Source.EMPTY.writeTo(out); out.writePhysicalPlanNode(evalExec.child()); - writeAliases(out, evalExec.fields()); + out.writeCollection(evalExec.fields()); } static EnrichExec readEnrichExec(PlanStreamInput in) throws IOException { final Source source = Source.readFrom(in); final PhysicalPlan child = in.readPhysicalPlanNode(); - final NamedExpression matchField = in.readNamedExpression(); + final NamedExpression matchField = in.readNamedWriteable(NamedExpression.class); final String policyName = in.readString(); final String matchType = (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_EXTENDED_ENRICH_TYPES)) ? in.readString() @@ -587,14 +586,14 @@ static EnrichExec readEnrichExec(PlanStreamInput in) throws IOException { policyName, policyMatchField, concreteIndices, - readNamedExpressions(in) + in.readNamedWriteableCollectionAsList(NamedExpression.class) ); } static void writeEnrichExec(PlanStreamOutput out, EnrichExec enrich) throws IOException { Source.EMPTY.writeTo(out); out.writePhysicalPlanNode(enrich.child()); - out.writeNamedExpression(enrich.matchField()); + out.writeNamedWriteable(enrich.matchField()); out.writeString(enrich.policyName()); if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_EXTENDED_ENRICH_TYPES)) { out.writeString(enrich.matchType()); @@ -611,7 +610,7 @@ static void writeEnrichExec(PlanStreamOutput out, EnrichExec enrich) throws IOEx throw new IllegalStateException("expected a single concrete enrich index; got " + enrich.concreteIndices()); } } - writeNamedExpressions(out, enrich.enrichFields()); + out.writeNamedWriteableCollection(enrich.enrichFields()); } static ExchangeExec readExchangeExec(PlanStreamInput in) throws IOException { @@ -728,7 +727,7 @@ static MvExpandExec readMvExpandExec(PlanStreamInput in) throws IOException { return new MvExpandExec( Source.readFrom(in), in.readPhysicalPlanNode(), - in.readNamedExpression(), + in.readNamedWriteable(NamedExpression.class), in.readNamedWriteable(Attribute.class) ); } @@ -736,7 +735,7 @@ static MvExpandExec readMvExpandExec(PlanStreamInput in) throws IOException { static void writeMvExpandExec(PlanStreamOutput out, MvExpandExec mvExpandExec) throws IOException { Source.EMPTY.writeTo(out); out.writePhysicalPlanNode(mvExpandExec.child()); - out.writeNamedExpression(mvExpandExec.target()); + out.writeNamedWriteable(mvExpandExec.target()); out.writeNamedWriteable(mvExpandExec.expanded()); } @@ -755,23 +754,27 @@ static void writeOrderExec(PlanStreamOutput out, OrderExec orderExec) throws IOE } static ProjectExec readProjectExec(PlanStreamInput in) throws IOException { - return new ProjectExec(Source.readFrom(in), in.readPhysicalPlanNode(), readNamedExpressions(in)); + return new ProjectExec( + Source.readFrom(in), + in.readPhysicalPlanNode(), + in.readNamedWriteableCollectionAsList(NamedExpression.class) + ); } static void writeProjectExec(PlanStreamOutput out, ProjectExec projectExec) throws IOException { Source.EMPTY.writeTo(out); out.writePhysicalPlanNode(projectExec.child()); - writeNamedExpressions(out, projectExec.projections()); + out.writeNamedWriteableCollection(projectExec.projections()); } static RowExec readRowExec(PlanStreamInput in) throws IOException { - return new RowExec(Source.readFrom(in), readAliases(in)); + return new RowExec(Source.readFrom(in), in.readCollectionAsList(Alias::new)); } static void writeRowExec(PlanStreamOutput out, RowExec rowExec) throws IOException { assert rowExec.children().size() == 0; Source.EMPTY.writeTo(out); - writeAliases(out, rowExec.fields()); + out.writeCollection(rowExec.fields()); } @SuppressWarnings("unchecked") @@ -813,7 +816,7 @@ static Aggregate readAggregate(PlanStreamInput in) throws IOException { Source.readFrom(in), in.readLogicalPlanNode(), in.readCollectionAsList(readerFromPlanReader(PlanStreamInput::readExpression)), - readNamedExpressions(in) + in.readNamedWriteableCollectionAsList(NamedExpression.class) ); } @@ -821,7 +824,7 @@ static void writeAggregate(PlanStreamOutput out, Aggregate aggregate) throws IOE Source.EMPTY.writeTo(out); out.writeLogicalPlanNode(aggregate.child()); out.writeCollection(aggregate.groupings(), writerFromPlanWriter(PlanStreamOutput::writeExpression)); - writeNamedExpressions(out, aggregate.aggregates()); + out.writeNamedWriteableCollection(aggregate.aggregates()); } static Dissect readDissect(PlanStreamInput in) throws IOException { @@ -890,13 +893,13 @@ private static void writeEsSourceOptions(PlanStreamOutput out) throws IOExceptio } static Eval readEval(PlanStreamInput in) throws IOException { - return new Eval(Source.readFrom(in), in.readLogicalPlanNode(), readAliases(in)); + return new Eval(Source.readFrom(in), in.readLogicalPlanNode(), in.readCollectionAsList(Alias::new)); } static void writeEval(PlanStreamOutput out, Eval eval) throws IOException { Source.EMPTY.writeTo(out); out.writeLogicalPlanNode(eval.child()); - writeAliases(out, eval.fields()); + out.writeCollection(eval.fields()); } static Enrich readEnrich(PlanStreamInput in) throws IOException { @@ -907,7 +910,7 @@ static Enrich readEnrich(PlanStreamInput in) throws IOException { final Source source = Source.readFrom(in); final LogicalPlan child = in.readLogicalPlanNode(); final Expression policyName = in.readExpression(); - final NamedExpression matchField = in.readNamedExpression(); + final NamedExpression matchField = in.readNamedWriteable(NamedExpression.class); if (in.getTransportVersion().before(TransportVersions.V_8_13_0)) { in.readString(); // discard the old policy name } @@ -922,7 +925,16 @@ static Enrich readEnrich(PlanStreamInput in) throws IOException { } concreteIndices = Map.of(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, Iterables.get(esIndex.concreteIndices(), 0)); } - return new Enrich(source, child, mode, policyName, matchField, policy, concreteIndices, readNamedExpressions(in)); + return new Enrich( + source, + child, + mode, + policyName, + matchField, + policy, + concreteIndices, + in.readNamedWriteableCollectionAsList(NamedExpression.class) + ); } static void writeEnrich(PlanStreamOutput out, Enrich enrich) throws IOException { @@ -933,7 +945,7 @@ static void writeEnrich(PlanStreamOutput out, Enrich enrich) throws IOException Source.EMPTY.writeTo(out); out.writeLogicalPlanNode(enrich.child()); out.writeExpression(enrich.policyName()); - out.writeNamedExpression(enrich.matchField()); + out.writeNamedWriteable(enrich.matchField()); if (out.getTransportVersion().before(TransportVersions.V_8_13_0)) { out.writeString(BytesRefs.toString(enrich.policyName().fold())); // old policy name } @@ -950,17 +962,17 @@ static void writeEnrich(PlanStreamOutput out, Enrich enrich) throws IOException throw new IllegalStateException("expected a single enrich index; got " + concreteIndices); } } - writeNamedExpressions(out, enrich.enrichFields()); + out.writeNamedWriteableCollection(enrich.enrichFields()); } static EsqlProject readEsqlProject(PlanStreamInput in) throws IOException { - return new EsqlProject(Source.readFrom(in), in.readLogicalPlanNode(), readNamedExpressions(in)); + return new EsqlProject(Source.readFrom(in), in.readLogicalPlanNode(), in.readNamedWriteableCollectionAsList(NamedExpression.class)); } static void writeEsqlProject(PlanStreamOutput out, EsqlProject project) throws IOException { Source.EMPTY.writeTo(out); out.writeLogicalPlanNode(project.child()); - writeNamedExpressions(out, project.projections()); + out.writeNamedWriteableCollection(project.projections()); } static Filter readFilter(PlanStreamInput in) throws IOException { @@ -1006,7 +1018,7 @@ static MvExpand readMvExpand(PlanStreamInput in) throws IOException { return new MvExpand( Source.readFrom(in), in.readLogicalPlanNode(), - in.readNamedExpression(), + in.readNamedWriteable(NamedExpression.class), in.readNamedWriteable(Attribute.class) ); } @@ -1014,7 +1026,7 @@ static MvExpand readMvExpand(PlanStreamInput in) throws IOException { static void writeMvExpand(PlanStreamOutput out, MvExpand mvExpand) throws IOException { Source.EMPTY.writeTo(out); out.writeLogicalPlanNode(mvExpand.child()); - out.writeNamedExpression(mvExpand.target()); + out.writeNamedWriteable(mvExpand.target()); out.writeNamedWriteable(mvExpand.expanded()); } @@ -1033,13 +1045,13 @@ static void writeOrderBy(PlanStreamOutput out, OrderBy order) throws IOException } static Project readProject(PlanStreamInput in) throws IOException { - return new Project(Source.readFrom(in), in.readLogicalPlanNode(), readNamedExpressions(in)); + return new Project(Source.readFrom(in), in.readLogicalPlanNode(), in.readNamedWriteableCollectionAsList(NamedExpression.class)); } static void writeProject(PlanStreamOutput out, Project project) throws IOException { Source.EMPTY.writeTo(out); out.writeLogicalPlanNode(project.child()); - writeNamedExpressions(out, project.projections()); + out.writeNamedWriteableCollection(project.projections()); } static TopN readTopN(PlanStreamInput in) throws IOException { @@ -1058,26 +1070,6 @@ static void writeTopN(PlanStreamOutput out, TopN topN) throws IOException { out.writeExpression(topN.limit()); } - // - // -- Attributes - // - - private static List readNamedExpressions(PlanStreamInput in) throws IOException { - return in.readCollectionAsList(readerFromPlanReader(PlanStreamInput::readNamedExpression)); - } - - static void writeNamedExpressions(PlanStreamOutput out, List namedExpressions) throws IOException { - out.writeCollection(namedExpressions, writerFromPlanWriter(PlanStreamOutput::writeNamedExpression)); - } - - private static List readAliases(PlanStreamInput in) throws IOException { - return in.readCollectionAsList(readerFromPlanReader(PlanNamedTypes::readAlias)); - } - - static void writeAliases(PlanStreamOutput out, List aliases) throws IOException { - out.writeCollection(aliases, writerFromPlanWriter(PlanNamedTypes::writeAlias)); - } - // -- BinaryComparison static EsqlBinaryComparison readBinComparison(PlanStreamInput in, String name) throws IOException { @@ -1683,28 +1675,6 @@ static void writeMvConcat(PlanStreamOutput out, MvConcat fn) throws IOException out.writeExpression(fn.right()); } - // -- NamedExpressions - - static Alias readAlias(PlanStreamInput in) throws IOException { - return new Alias( - Source.readFrom(in), - in.readString(), - in.readOptionalString(), - in.readNamed(Expression.class), - NameId.readFrom(in), - in.readBoolean() - ); - } - - static void writeAlias(PlanStreamOutput out, Alias alias) throws IOException { - Source.EMPTY.writeTo(out); - out.writeString(alias.name()); - out.writeOptionalString(alias.qualifier()); - out.writeExpression(alias.child()); - alias.id().writeTo(out); - out.writeBoolean(alias.synthetic()); - } - // -- Expressions (other) static Literal readLiteral(PlanStreamInput in) throws IOException { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java index e7f1fbd6e1460..0b671d6b90c7e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java @@ -25,7 +25,6 @@ import org.elasticsearch.xpack.esql.Column; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.NameId; -import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanNamedReader; import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanReader; @@ -93,14 +92,11 @@ public PhysicalPlan readOptionalPhysicalPlanNode() throws IOException { return readOptionalNamed(PhysicalPlan.class); } + @Override public Expression readExpression() throws IOException { return readNamed(Expression.class); } - public NamedExpression readNamedExpression() throws IOException { - return readNamed(NamedExpression.class); - } - public T readNamed(Class type) throws IOException { String name = readString(); @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java index 05dc7ab919868..f7380588fbd77 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java @@ -20,7 +20,6 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.xpack.esql.Column; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanWriter; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; @@ -35,7 +34,7 @@ * A customized stream output used to serialize ESQL physical plan fragments. Complements stream * output with methods that write plan nodes, Attributes, Expressions, etc. */ -public final class PlanStreamOutput extends StreamOutput { +public final class PlanStreamOutput extends StreamOutput implements org.elasticsearch.xpack.esql.core.util.PlanStreamOutput { /** * Cache of written blocks. We use an {@link IdentityHashMap} for this @@ -94,14 +93,11 @@ public void writeOptionalPhysicalPlanNode(PhysicalPlan physicalPlan) throws IOEx } } + @Override public void writeExpression(Expression expression) throws IOException { writeNamed(Expression.class, expression); } - public void writeNamedExpression(NamedExpression namedExpression) throws IOException { - writeNamed(NamedExpression.class, namedExpression); - } - public void writeOptionalExpression(Expression expression) throws IOException { if (expression == null) { writeBoolean(false); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java index 6059b61031d1e..4fdc0bdab5ade 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java @@ -56,6 +56,7 @@ import org.elasticsearch.xpack.esql.action.RestEsqlGetAsyncResultAction; import org.elasticsearch.xpack.esql.action.RestEsqlQueryAction; import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.index.IndexResolver; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.enrich.EnrichLookupOperator; @@ -195,6 +196,8 @@ public List getNamedWriteables() { entries.addAll(EsField.getNamedWriteables()); entries.addAll(Attribute.getNamedWriteables()); entries.add(UnsupportedAttribute.ENTRY); // TODO combine with above once these are in the same project + entries.addAll(NamedExpression.getNamedWriteables()); + entries.add(UnsupportedAttribute.NAMED_EXPRESSION_ENTRY); return entries; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java index 6ef33b7ae5eb8..a614ff3c621f8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java @@ -25,6 +25,7 @@ import org.elasticsearch.test.EqualsHashCodeTestUtils; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; @@ -118,6 +119,8 @@ public static NamedWriteableRegistry writableRegistry() { entries.addAll(EsField.getNamedWriteables()); entries.addAll(Attribute.getNamedWriteables()); entries.add(UnsupportedAttribute.ENTRY); + entries.addAll(NamedExpression.getNamedWriteables()); + entries.add(UnsupportedAttribute.NAMED_EXPRESSION_ENTRY); return new NamedWriteableRegistry(entries); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/AliasTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/AliasTests.java new file mode 100644 index 0000000000000..ce7aa789f89b1 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/AliasTests.java @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.test.AbstractWireTestCase; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.NameId; +import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.tree.SourceTests; +import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.expression.function.ReferenceAttributeTests; +import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; +import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.in; + +public class AliasTests extends AbstractWireTestCase { + @Override + protected Alias createTestInstance() { + Source source = SourceTests.randomSource(); + String name = randomAlphaOfLength(5); + String qualifier = randomBoolean() ? null : randomAlphaOfLength(3); + // TODO better randomChild + Expression child = ReferenceAttributeTests.randomReferenceAttribute(); + boolean synthetic = randomBoolean(); + return new Alias(source, name, qualifier, child, new NameId(), synthetic); + } + + @Override + protected Alias mutateInstance(Alias instance) throws IOException { + Source source = instance.source(); + String name = instance.name(); + String qualifier = instance.qualifier(); + Expression child = instance.child(); + boolean synthetic = instance.synthetic(); + switch (between(0, 3)) { + case 0 -> name = randomAlphaOfLength(name.length() + 1); + case 1 -> qualifier = randomValueOtherThan(qualifier, () -> randomBoolean() ? null : randomAlphaOfLength(3)); + case 2 -> child = randomValueOtherThan(child, ReferenceAttributeTests::randomReferenceAttribute); + case 3 -> synthetic = false == synthetic; + } + return new Alias(source, name, qualifier, child, instance.id(), synthetic); + } + + @Override + protected Alias copyInstance(Alias instance, TransportVersion version) throws IOException { + return copyInstance( + instance, + getNamedWriteableRegistry(), + (out, v) -> new PlanStreamOutput(out, new PlanNameRegistry(), null).writeNamedWriteable(v), + in -> { + PlanStreamInput pin = new PlanStreamInput(in, new PlanNameRegistry(), in.namedWriteableRegistry(), null); + Alias deser = (Alias) pin.readNamedWriteable(NamedExpression.class); + assertThat(deser.id(), equalTo(pin.mapNameId(Long.parseLong(instance.id().toString())))); + return deser; + }, + version + ); + } + + @Override + protected final NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(NamedExpression.getNamedWriteables()); + entries.addAll(Attribute.getNamedWriteables()); + entries.add(UnsupportedAttribute.ENTRY); + entries.addAll(EsField.getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/ReferenceAttributeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/ReferenceAttributeTests.java index 716d4fa1f5cce..31d1018bacc91 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/ReferenceAttributeTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/ReferenceAttributeTests.java @@ -14,8 +14,7 @@ import org.elasticsearch.xpack.esql.core.type.DataType; public class ReferenceAttributeTests extends AbstractAttributeTestCase { - @Override - protected ReferenceAttribute create() { + public static ReferenceAttribute randomReferenceAttribute() { Source source = Source.EMPTY; String name = randomAlphaOfLength(5); DataType type = randomFrom(DataType.types()); @@ -25,6 +24,11 @@ protected ReferenceAttribute create() { return new ReferenceAttribute(source, name, type, qualifier, nullability, new NameId(), synthetic); } + @Override + protected ReferenceAttribute create() { + return randomReferenceAttribute(); + } + @Override protected ReferenceAttribute mutate(ReferenceAttribute instance) { Source source = instance.source(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java index 55c3811f9a870..2abc70b4ecc88 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java @@ -296,17 +296,6 @@ public void testPowSimple() throws IOException { EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser); } - public void testAliasSimple() throws IOException { - var orig = new Alias(Source.EMPTY, "alias_name", field("a", DataType.LONG)); - BytesStreamOutput bso = new BytesStreamOutput(); - PlanStreamOutput out = new PlanStreamOutput(bso, planNameRegistry, null); - PlanNamedTypes.writeAlias(out, orig); - var in = planStreamInput(bso); - var deser = PlanNamedTypes.readAlias(in); - EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser); - assertThat(deser.id(), equalTo(in.mapNameId(Long.parseLong(orig.id().toString())))); - } - public void testLiteralSimple() throws IOException { var orig = new Literal(Source.EMPTY, 1, DataType.INTEGER); BytesStreamOutput bso = new BytesStreamOutput(); From 21ffdac4a747103751c258063a9e1e3547b18f46 Mon Sep 17 00:00:00 2001 From: Ryan Ernst Date: Wed, 5 Jun 2024 11:13:55 -0700 Subject: [PATCH 19/30] Patch ImmutableCollections for tests (#109271) ImmutableCollections uses a seed, set early during JVM startup, which affects the iteration order of collections. Although we do not want to rely on the iteration order of Map and Set collections, bugs do sometimes occur. In order to reproduce those bugs to fix them, it is important the test seed for Elasticsearch matches the seed used in ImmutableCollections. Unfortunately ImmutableCollections is internal to the JDK, and the seed used is private and final. This commit works around these limitations by creating a patched version of ImmutableCollections which allows access to the seed member. ESTestCase is then able to reflectively set the seed at runtime based on the Elasticsearch seed. Note that this only affects tests. ImmutableCollections remains is unchanged for production code. relates #94946 --- .../internal/ElasticsearchTestBasePlugin.java | 26 +++++++++ .../gradle/internal/MrjarPlugin.java | 1 - settings.gradle | 3 +- .../bootstrap/BootstrapForTesting.java | 1 + .../org/elasticsearch/test/ESTestCase.java | 51 +++++++++++++--- test/immutable-collections-patch/build.gradle | 49 ++++++++++++++++ .../patch/ImmutableCollectionsPatcher.java | 58 +++++++++++++++++++ 7 files changed, 180 insertions(+), 9 deletions(-) create mode 100644 test/immutable-collections-patch/build.gradle create mode 100644 test/immutable-collections-patch/src/main/java/org/elasticsearch/jdk/patch/ImmutableCollectionsPatcher.java diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/ElasticsearchTestBasePlugin.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/ElasticsearchTestBasePlugin.java index ed2dfb577e038..d344b4694a5b5 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/ElasticsearchTestBasePlugin.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/ElasticsearchTestBasePlugin.java @@ -30,6 +30,7 @@ import org.gradle.api.tasks.testing.Test; import java.io.File; +import java.util.List; import java.util.Map; import static org.elasticsearch.gradle.util.FileUtils.mkdirs; @@ -100,6 +101,7 @@ public void execute(Task t) { "-Xmx" + System.getProperty("tests.heap.size", "512m"), "-Xms" + System.getProperty("tests.heap.size", "512m"), "-Djava.security.manager=allow", + "--add-opens=java.base/java.util=ALL-UNNAMED", // TODO: only open these for mockito when it is modularized "--add-opens=java.base/java.security.cert=ALL-UNNAMED", "--add-opens=java.base/java.nio.channels=ALL-UNNAMED", @@ -199,5 +201,29 @@ public void execute(Task t) { } }); }); + configureImmutableCollectionsPatch(project); + } + + private void configureImmutableCollectionsPatch(Project project) { + String patchProject = ":test:immutable-collections-patch"; + if (project.findProject(patchProject) == null) { + return; // build tests may not have this project, just skip + } + String configurationName = "immutableCollectionsPatch"; + FileCollection patchedFileCollection = project.getConfigurations() + .create(configurationName, config -> config.setCanBeConsumed(false)); + var deps = project.getDependencies(); + deps.add(configurationName, deps.project(Map.of("path", patchProject, "configuration", "patch"))); + project.getTasks().withType(Test.class).matching(task -> task.getName().equals("test")).configureEach(test -> { + test.getInputs().files(patchedFileCollection); + test.systemProperty("tests.hackImmutableCollections", "true"); + test.getJvmArgumentProviders() + .add( + () -> List.of( + "--patch-module=java.base=" + patchedFileCollection.getSingleFile() + "/java.base", + "--add-opens=java.base/java.util=ALL-UNNAMED" + ) + ); + }); } } diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/MrjarPlugin.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/MrjarPlugin.java index 16c286bfdd3f2..756d1ea48610b 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/MrjarPlugin.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/MrjarPlugin.java @@ -172,7 +172,6 @@ private void createTestTask(Project project, SourceSet sourceSet, int javaVersio testTask.getJavaLauncher() .set(javaToolchains.launcherFor(spec -> spec.getLanguageVersion().set(JavaLanguageVersion.of(javaVersion)))); } - }); project.getTasks().named("check").configure(checkTask -> checkTask.dependsOn(testTaskProvider)); diff --git a/settings.gradle b/settings.gradle index 6ed340b27da65..ef758a7205cd0 100644 --- a/settings.gradle +++ b/settings.gradle @@ -105,7 +105,8 @@ List projects = [ 'test:test-clusters', 'test:x-content', 'test:yaml-rest-runner', - 'test:metadata-extractor' + 'test:metadata-extractor', + 'test:immutable-collections-patch' ] /** diff --git a/test/framework/src/main/java/org/elasticsearch/bootstrap/BootstrapForTesting.java b/test/framework/src/main/java/org/elasticsearch/bootstrap/BootstrapForTesting.java index 8a3f36ebb1f8a..30623c6bafd6b 100644 --- a/test/framework/src/main/java/org/elasticsearch/bootstrap/BootstrapForTesting.java +++ b/test/framework/src/main/java/org/elasticsearch/bootstrap/BootstrapForTesting.java @@ -73,6 +73,7 @@ public class BootstrapForTesting { // without making things complex??? static { + // make sure java.io.tmpdir exists always (in case code uses it in a static initializer) Path javaTmpDir = PathUtils.get( Objects.requireNonNull(System.getProperty("java.io.tmpdir"), "please set ${java.io.tmpdir} in pom.xml") diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index 42cc719a904cd..6920083f2a1a6 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -79,6 +79,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Booleans; import org.elasticsearch.core.CheckedRunnable; import org.elasticsearch.core.PathUtils; import org.elasticsearch.core.PathUtilsForTesting; @@ -145,6 +146,7 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import java.lang.invoke.MethodHandles; import java.math.BigInteger; import java.net.InetAddress; import java.net.UnknownHostException; @@ -257,8 +259,9 @@ public static void resetPortCounter() { private static final SetOnce WARN_SECURE_RANDOM_FIPS_NOT_DETERMINISTIC = new SetOnce<>(); static { + Random random = initTestSeed(); TEST_WORKER_VM_ID = System.getProperty(TEST_WORKER_SYS_PROPERTY, DEFAULT_TEST_WORKER_ID); - setTestSysProps(); + setTestSysProps(random); // TODO: consolidate logging initialization for tests so it all occurs in logconfigurator LogConfigurator.loadLog4jPlugins(); LogConfigurator.configureESLogging(); @@ -359,8 +362,46 @@ public void append(LogEvent event) { JAVA_ZONE_IDS = ZoneId.getAvailableZoneIds().stream().filter(unsupportedZoneIdsPredicate.negate()).sorted().toList(); } + static Random initTestSeed() { + String inputSeed = System.getProperty("tests.seed"); + long seed; + if (inputSeed == null) { + // when running tests in intellij, we don't have a seed. Setup the seed early here, before getting to RandomizedRunner, + // so that we can use it in ESTestCase static init + seed = System.nanoTime(); + setTestSeed(Long.toHexString(seed)); + } else { + String[] seedParts = inputSeed.split("[\\:]"); + seed = Long.parseUnsignedLong(seedParts[0], 16); + } + + if (Booleans.parseBoolean(System.getProperty("tests.hackImmutableCollections", "false"))) { + forceImmutableCollectionsSeed(seed); + } + + return new Random(seed); + } + + @SuppressForbidden(reason = "set tests.seed for intellij") + static void setTestSeed(String seed) { + System.setProperty("tests.seed", seed); + } + + private static void forceImmutableCollectionsSeed(long seed) { + try { + MethodHandles.Lookup lookup = MethodHandles.lookup(); + Class collectionsClass = Class.forName("java.util.ImmutableCollections"); + var salt32l = lookup.findStaticVarHandle(collectionsClass, "SALT32L", long.class); + var reverse = lookup.findStaticVarHandle(collectionsClass, "REVERSE", boolean.class); + salt32l.set(seed & 0xFFFF_FFFFL); + reverse.set((seed & 1) == 0); + } catch (Exception e) { + throw new AssertionError(e); + } + } + @SuppressForbidden(reason = "force log4j and netty sysprops") - private static void setTestSysProps() { + private static void setTestSysProps(Random random) { System.setProperty("log4j.shutdownHookEnabled", "false"); System.setProperty("log4j2.disable.jmx", "true"); @@ -377,11 +418,7 @@ private static void setTestSysProps() { System.setProperty("es.set.netty.runtime.available.processors", "false"); // sometimes use the java.time date formatters - // we can't use randomBoolean here, the random context isn't set properly - // so read it directly from the test seed in an unfortunately hacky way - String testSeed = System.getProperty("tests.seed", "0"); - boolean firstBit = (Integer.parseInt(testSeed.substring(testSeed.length() - 1), 16) & 1) == 1; - if (firstBit) { + if (random.nextBoolean()) { System.setProperty("es.datetime.java_time_parsers", "true"); } } diff --git a/test/immutable-collections-patch/build.gradle b/test/immutable-collections-patch/build.gradle new file mode 100644 index 0000000000000..2d42215b3e02c --- /dev/null +++ b/test/immutable-collections-patch/build.gradle @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +import org.elasticsearch.gradle.OS +import org.elasticsearch.gradle.VersionProperties +import org.elasticsearch.gradle.internal.info.BuildParams + +apply plugin: 'elasticsearch.java' + +configurations { + patch +} + +dependencies { + implementation 'org.ow2.asm:asm:9.7' + implementation 'org.ow2.asm:asm-tree:9.7' +} + +def outputDir = layout.buildDirectory.dir("jdk-patches") +def generatePatch = tasks.register("generatePatch", JavaExec) +generatePatch.configure { + dependsOn tasks.named("compileJava") + inputs.property("java-home-set", BuildParams.getIsRuntimeJavaHomeSet()) + inputs.property("java-version", BuildParams.runtimeJavaVersion) + outputs.dir(outputDir) + + classpath = sourceSets.main.runtimeClasspath + mainClass = 'org.elasticsearch.jdk.patch.ImmutableCollectionsPatcher' + if (BuildParams.getIsRuntimeJavaHomeSet()) { + executable = "${BuildParams.runtimeJavaHome}/bin/java" + (OS.current() == OS.WINDOWS ? '.exe' : '') + } else { + javaLauncher = javaToolchains.launcherFor { + languageVersion = JavaLanguageVersion.of(BuildParams.runtimeJavaVersion.majorVersion) + vendor = VersionProperties.bundledJdkVendor == "openjdk" ? + JvmVendorSpec.ORACLE : + JvmVendorSpec.matching(VersionProperties.bundledJdkVendor) + } + } + doFirst { + args outputDir.get().getAsFile().toString() + } +} + +artifacts.add("patch", generatePatch); diff --git a/test/immutable-collections-patch/src/main/java/org/elasticsearch/jdk/patch/ImmutableCollectionsPatcher.java b/test/immutable-collections-patch/src/main/java/org/elasticsearch/jdk/patch/ImmutableCollectionsPatcher.java new file mode 100644 index 0000000000000..b98df1b3d2e17 --- /dev/null +++ b/test/immutable-collections-patch/src/main/java/org/elasticsearch/jdk/patch/ImmutableCollectionsPatcher.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.jdk.patch; + +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassVisitor; +import org.objectweb.asm.ClassWriter; +import org.objectweb.asm.FieldVisitor; +import org.objectweb.asm.Opcodes; + +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +/** + * Loads ImmutableCollections.class from the current jdk and writes it out + * as a public class with SALT32L and REVERSE as public, non-final static fields. + * + * By exposing ImmutableCollections, tests run with this patched version can + * hook in the existing test seed to ensure consistent iteration of immutable collections. + * Note that the consistency is for reproducing dependencies on iteration + * order, so that the code can be fixed. + */ +public class ImmutableCollectionsPatcher { + private static final String CLASSFILE = "java.base/java/util/ImmutableCollections.class"; + + public static void main(String[] args) throws Exception { + Path outputDir = Paths.get(args[0]); + byte[] originalClassFile = Files.readAllBytes(Paths.get(URI.create("jrt:/" + CLASSFILE))); + + ClassReader classReader = new ClassReader(originalClassFile); + ClassWriter classWriter = new ClassWriter(classReader, 0); + classReader.accept(new ClassVisitor(Opcodes.ASM9, classWriter) { + @Override + public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { + super.visit(version, Opcodes.ACC_PUBLIC, name, signature, superName, interfaces); + } + + @Override + public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) { + if (name.equals("SALT32L") || name.equals("REVERSE")) { + access = Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC; + } + return super.visitField(access, name, descriptor, signature, value); + } + }, 0); + Path outputFile = outputDir.resolve(CLASSFILE); + Files.createDirectories(outputFile.getParent()); + Files.write(outputFile, classWriter.toByteArray()); + } +} From 02a6c831e17869eb99187924dc1a80f758d2ff70 Mon Sep 17 00:00:00 2001 From: Lorenzo Verardo Date: Wed, 5 Jun 2024 19:51:07 +0100 Subject: [PATCH 20/30] Limit the value in prefix query (#108537) Reuse the setting index.max_regex_length for the max length in a prefix query. Closes #108486 --- docs/changelog/108537.yaml | 6 +++ docs/reference/index-modules.asciidoc | 2 +- .../rest-api-spec/test/search/30_limits.yml | 28 +++++++++++++ .../search/simple/SimpleSearchIT.java | 40 +++++++++++++++++++ .../index/query/PrefixQueryBuilder.java | 15 +++++++ 5 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 docs/changelog/108537.yaml diff --git a/docs/changelog/108537.yaml b/docs/changelog/108537.yaml new file mode 100644 index 0000000000000..1c0228a71d449 --- /dev/null +++ b/docs/changelog/108537.yaml @@ -0,0 +1,6 @@ +pr: 108537 +summary: Limit the value in prefix query +area: Search +type: enhancement +issues: + - 108486 diff --git a/docs/reference/index-modules.asciidoc b/docs/reference/index-modules.asciidoc index e826956440497..40b4ff4bb9dc8 100644 --- a/docs/reference/index-modules.asciidoc +++ b/docs/reference/index-modules.asciidoc @@ -304,7 +304,7 @@ are ignored for this index. [[index-max-regex-length]] `index.max_regex_length`:: - The maximum length of regex that can be used in Regexp Query. + The maximum length of value that can be used in `regexp` or `prefix` query. Defaults to `1000`. [[index-query-default-field]] diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/30_limits.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/30_limits.yml index bea52c22e151f..f14614a820176 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/30_limits.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/30_limits.yml @@ -161,3 +161,31 @@ setup: ]))|\\[([^\\[\\]\\r\\\\]|\\\\.)*\\](?:(?:\\r\\n)?[\\t])*))*\\>(?:(?:\\r\\n)?[ \\t])*)(?:,\\s*( | \".\\[\\]]))|\"(?:[^\\\"\\r\\\\]|\\\\.|(?:(?:\\r\\n)?[\\t]))*\"(?:(?:\\r\\n)?[ \\t])*)(?:\\.(?:( | \\[\"()<>@,;:\\\\\".\\[\\]]))|\"(?:[^\\\"\\r\\\\]|\\\\.|(?:(?:\\r\\n)?[\\t]))*\"(?:(?:\\r\\n)?[\\t/" + +--- +"Prefix length limit": + + - requires: + cluster_features: "gte_v8.15.0" + reason: "Limit for value in prefix query was introduced in 8.15" + + - do: + catch: /The length of prefix \[1110\] used in the Prefix Query request has exceeded the allowed maximum of \[1000\]\. This maximum can be set by changing the \[index.max_regex_length\] index level setting\./ + search: + rest_total_hits_as_int: true + index: test_1 + body: + query: + prefix: + foo: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/simple/SimpleSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/simple/SimpleSearchIT.java index 543f45b58279e..dd8cf5e527055 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/simple/SimpleSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/simple/SimpleSearchIT.java @@ -490,6 +490,46 @@ public void testTooLongRegexInRegexpQuery() throws Exception { ); } + public void testTooLongPrefixInPrefixQuery() throws Exception { + createIndex("idx"); + + // Ensure the field `num` exists in the mapping + client().admin() + .indices() + .preparePutMapping("idx") + .setSource("{\"properties\":{\"num\":{\"type\":\"keyword\"}}}", XContentType.JSON) + .get(); + + // Index a simple document to ensure the field `num` is in the index + indexRandom(true, prepareIndex("idx").setSource("{\"num\":\"test\"}", XContentType.JSON)); + + int defaultMaxRegexLength = IndexSettings.MAX_REGEX_LENGTH_SETTING.get(Settings.EMPTY); + StringBuilder prefix = new StringBuilder(defaultMaxRegexLength); + + while (prefix.length() <= defaultMaxRegexLength) { + prefix.append("a"); + } + + SearchPhaseExecutionException e = expectThrows( + SearchPhaseExecutionException.class, + () -> client().prepareSearch("idx").setQuery(QueryBuilders.prefixQuery("num", prefix.toString())).get() + ); + assertThat( + e.getRootCause().getMessage(), + containsString( + "The length of prefix [" + + prefix.length() + + "] used in the Prefix Query request has exceeded " + + "the allowed maximum of [" + + defaultMaxRegexLength + + "]. " + + "This maximum can be set by changing the [" + + IndexSettings.MAX_REGEX_LENGTH_SETTING.getKey() + + "] index level setting." + ) + ); + } + public void testStrictlyCountRequest() throws Exception { createIndex("test_count_1"); indexRandom( diff --git a/server/src/main/java/org/elasticsearch/index/query/PrefixQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/PrefixQueryBuilder.java index 5042ab358a96c..e64a424e86052 100644 --- a/server/src/main/java/org/elasticsearch/index/query/PrefixQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/PrefixQueryBuilder.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.mapper.ConstantFieldType; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.support.QueryParsers; @@ -209,6 +210,20 @@ protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throw @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { + final int maxAllowedRegexLength = context.getIndexSettings().getMaxRegexLength(); + if (value.length() > maxAllowedRegexLength) { + throw new IllegalArgumentException( + "The length of prefix [" + + value.length() + + "] used in the Prefix Query request has exceeded " + + "the allowed maximum of [" + + maxAllowedRegexLength + + "]. " + + "This maximum can be set by changing the [" + + IndexSettings.MAX_REGEX_LENGTH_SETTING.getKey() + + "] index level setting." + ); + } MultiTermQuery.RewriteMethod method = QueryParsers.parseRewriteMethod(rewrite, null, LoggingDeprecationHandler.INSTANCE); MappedFieldType fieldType = context.getFieldType(fieldName); From f8291f8e8308e8e55177455de3ba9f695e2c8f4b Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Thu, 6 Jun 2024 04:52:38 +1000 Subject: [PATCH 21/30] Forward port release notes for v8.14.0 (#109403) --- .../reference/migration/migrate_8_14.asciidoc | 49 ++- docs/reference/release-notes/8.14.0.asciidoc | 346 +++++++++++++++++- .../release-notes/highlights.asciidoc | 50 ++- 3 files changed, 414 insertions(+), 31 deletions(-) diff --git a/docs/reference/migration/migrate_8_14.asciidoc b/docs/reference/migration/migrate_8_14.asciidoc index bdff8bef246b5..2e6cd439ebed0 100644 --- a/docs/reference/migration/migrate_8_14.asciidoc +++ b/docs/reference/migration/migrate_8_14.asciidoc @@ -21,8 +21,25 @@ and prevent them from operating normally. Before upgrading to 8.14, review these changes and take the described steps to mitigate the impact. + +There are no notable breaking changes in {es} 8.14. +But there are some less critical breaking changes. + [discrete] -[[breaking-changes-8.14-0]] +[[breaking_814_rest_api_changes]] +==== REST API changes + +[[prevent_dls_fls_if_replication_assigned]] +.Prevent DLS/FLS if `replication` is assigned +[%collapsible] +==== +*Details* + +For cross-cluster API keys, {es} no longer allows specifying document-level security (DLS) or field-level security (FLS) in the `search` field, if `replication` is also specified. {es} likewise blocks the use of any existing cross-cluster API keys that meet this condition. + +*Impact* + +Remove any document-level security (DLS) or field-level security (FLS) definitions from the `search` field for cross-cluster API keys that also have a `replication` field, or create two separate cross-cluster API keys, one for search and one for replication. +==== + [discrete] [[breaking_814_dls_changes]] @@ -41,3 +58,33 @@ When Document Level Security (DLS) is applied to the validate query API with the *Impact* + If needed, test workflows with DLS enabled to ensure that the stricter security rules do not impact your application. ==== + + +[discrete] +[[deprecated-8.14]] +=== Deprecations + +The following functionality has been deprecated in {es} 8.14 +and will be removed in a future version. +While this won't have an immediate impact on your applications, +we strongly encourage you to take the described steps to update your code +after upgrading to 8.14. + +To find out if you are using any deprecated functionality, +enable <>. + +[discrete] +[[deprecations_814_mapping]] +==== Mapping deprecations + +[[deprecate_allowing_fields_in_scenarios_where_it_ignored]] +.Deprecate allowing `fields` in scenarios where it is ignored +[%collapsible] +==== +*Details* + +The following mapped types have always ignored `fields` when using multi-fields. This deprecation makes this clearer and we will completely disallow `fields` for these mapped types in the future. + +*Impact* + +In the future, `join`, `aggregate_metric_double`, and `constant_keyword`, will all disallow supplying `fields` as a parameter in the mapping. +==== + diff --git a/docs/reference/release-notes/8.14.0.asciidoc b/docs/reference/release-notes/8.14.0.asciidoc index a203c983927cd..42f2f86a123ed 100644 --- a/docs/reference/release-notes/8.14.0.asciidoc +++ b/docs/reference/release-notes/8.14.0.asciidoc @@ -1,8 +1,350 @@ [[release-notes-8.14.0]] == {es} version 8.14.0 -coming[8.14.0] - Also see <>. +[[breaking-8.14.0]] +[float] +=== Breaking changes + +Security:: +* Prevent DLS/FLS if `replication` is assigned {es-pull}108600[#108600] +* Apply stricter Document Level Security (DLS) rules for the validate query API with the rewrite parameter {es-pull}105709[#105709] +* Apply stricter Document Level Security (DLS) rules for terms aggregations when min_doc_count is set to 0 {es-pull}105714[#105714] + +[[bug-8.14.0]] +[float] +=== Bug fixes + +Aggregations:: +* Cross check livedocs for terms aggs when index access control list is non-null {es-pull}105714[#105714] +* ESQL: Enable VALUES agg for datetime {es-pull}107016[#107016] +* Fix IOOBE in TTest aggregation when using filters {es-pull}109034[#109034] +* Validate stats formatting in standard `InternalStats` constructor {es-pull}107678[#107678] (issue: {es-issue}107671[#107671]) + +Application:: +* [Bugfix] Connector API - fix status serialisation issue in termquery {es-pull}108365[#108365] +* [Connector API] Fix bug with filtering validation toXContent {es-pull}107467[#107467] +* [Connector API] Fix bug with parsing *_doc_count nullable fields {es-pull}108854[#108854] +* [Connector API] Fix bug with with wrong target index for access control sync {es-pull}109097[#109097] + +Authorization:: +* Users with monitor privileges can access async_search/status endpoint even when setting keep_alive {es-pull}107383[#107383] + +CAT APIs:: +* Fix numeric sorts in `_cat/nodes` {es-pull}106189[#106189] (issue: {es-issue}48070[#48070]) + +CCR:: +* Add ?master_timeout query parameter to ccr apis {es-pull}105168[#105168] + +CRUD:: +* Fix `noop_update_total` is not being updated when using the `_bulk` {es-pull}105745[#105745] (issue: {es-issue}105742[#105742]) +* Use correct system index bulk executor {es-pull}106150[#106150] + +Cluster Coordination:: +* Fix support for infinite `?master_timeout` {es-pull}107050[#107050] + +Data streams:: +* Add non-indexed fields to ecs templates {es-pull}106714[#106714] +* Fix bulk NPE when retrying failure redirect after cluster block {es-pull}107598[#107598] +* Improve error message when rolling over DS alias {es-pull}106708[#106708] (issue: {es-issue}106137[#106137]) +* Only skip deleting a downsampled index if downsampling is in progress as part of DSL retention {es-pull}109020[#109020] + +Downsampling:: +* Fix downsample action request serialization {es-pull}106919[#106919] (issue: {es-issue}106917[#106917]) + +EQL:: +* Use #addWithoutBreaking when adding a negative number of bytes to the circuit breaker in `SequenceMatcher` {es-pull}107655[#107655] + +ES|QL:: +* ESQL: Allow reusing BUCKET grouping expressions in aggs {es-pull}107578[#107578] +* ESQL: Disable quoting in FROM command {es-pull}108431[#108431] +* ESQL: Fix MV_DEDUPE when using data from an index {es-pull}107577[#107577] (issue: {es-issue}104745[#104745]) +* ESQL: Fix error message when failing to resolve aggregate groupings {es-pull}108101[#108101] (issue: {es-issue}108053[#108053]) +* ESQL: Fix treating all fields as MV in COUNT pushdown {es-pull}106720[#106720] +* ESQL: Re-enable logical dependency check {es-pull}105860[#105860] +* ESQL: median, count and `count_distinct` over constants {es-pull}107414[#107414] (issues: {es-issue}105248[#105248], {es-issue}104900[#104900]) +* ES|QL fix no-length substring with supplementary (4-byte) character {es-pull}107183[#107183] +* ES|QL: Fix usage of IN operator with TEXT fields {es-pull}106654[#106654] (issue: {es-issue}105379[#105379]) +* ES|QL: Improve support for TEXT fields in functions {es-pull}106810[#106810] +* Fix docs generation of signatures for variadic functions {es-pull}107865[#107865] +* [ESQL] Mark `date_diff` as requiring all three arguments {es-pull}108834[#108834] (issue: {es-issue}108383[#108383]) + +Health:: +* Don't stop checking if the `HealthNode` persistent task is present {es-pull}105449[#105449] (issue: {es-issue}98926[#98926]) +* Health monitor concurrency fixes {es-pull}105674[#105674] (issue: {es-issue}105065[#105065]) + +Highlighting:: +* Check preTags and postTags params for empty values {es-pull}106396[#106396] (issue: {es-issue}69009[#69009]) +* added fix for inconsistent text trimming in Unified Highlighter {es-pull}99961[#99961] (issue: {es-issue}101803[#101803]) + +Infra/CLI:: +* Workaround G1 bug for JDK 22 and 22.0.1 {es-pull}108571[#108571] + +Infra/Core:: +* Add a check for the same feature being declared regular and historical {es-pull}106285[#106285] +* Fix `AffixSetting.exists` to include secure settings {es-pull}106745[#106745] +* Fix regression in get index settings (human=true) where the version was not displayed in human-readable format {es-pull}107447[#107447] +* Nativeaccess: try to load all located libsystemds {es-pull}108238[#108238] (issue: {es-issue}107878[#107878]) +* Update several references to `IndexVersion.toString` to use `toReleaseVersion` {es-pull}107828[#107828] (issue: {es-issue}107821[#107821]) +* Update several references to `TransportVersion.toString` to use `toReleaseVersion` {es-pull}107902[#107902] + +Infra/Logging:: +* Log when update AffixSetting using addAffixMapUpdateConsumer {es-pull}97072[#97072] + +Infra/Node Lifecycle:: +* Consider `ShardRouting` roles when calculating shard copies in shutdown status {es-pull}106063[#106063] +* Wait indefintely for http connections on shutdown by default {es-pull}106511[#106511] + +Infra/Scripting:: +* Guard against a null scorer in painless execute {es-pull}109048[#109048] (issue: {es-issue}43541[#43541]) +* Painless: Apply true regex limit factor with FIND and MATCH operation {es-pull}105670[#105670] + +Ingest Node:: +* Catching `StackOverflowErrors` from bad regexes in `GsubProcessor` {es-pull}106851[#106851] +* Fix `uri_parts` processor behaviour for missing extensions {es-pull}105689[#105689] (issue: {es-issue}105612[#105612]) +* Remove leading is_ prefix from Enterprise geoip docs {es-pull}108518[#108518] +* Slightly better geoip `databaseType` validation {es-pull}106889[#106889] + +License:: +* Fix lingering license warning header {es-pull}108031[#108031] (issue: {es-issue}107573[#107573]) + +Machine Learning:: +* Fix NPE in ML assignment notifier {es-pull}107312[#107312] +* Fix `startOffset` must be non-negative error in XLMRoBERTa tokenizer {es-pull}107891[#107891] (issue: {es-issue}104626[#104626]) +* Fix the position of spike, dip and distribution changes bucket when the sibling aggregation includes empty buckets {es-pull}106472[#106472] +* Make OpenAI embeddings parser more flexible {es-pull}106808[#106808] + +Mapping:: +* Dedupe terms in terms queries {es-pull}106381[#106381] +* Extend support of `allowedFields` to `getMatchingFieldNames` and `getAllFields` {es-pull}106862[#106862] +* Fix for raw mapping merge of fields named "properties" {es-pull}108867[#108867] (issue: {es-issue}108866[#108866]) +* Handle infinity during synthetic source construction for scaled float field {es-pull}107494[#107494] (issue: {es-issue}107101[#107101]) +* Handle pass-through subfields with deep nesting {es-pull}106767[#106767] +* Wrap "Pattern too complex" exception into an `IllegalArgumentException` {es-pull}109173[#109173] + +Network:: +* Fix HTTP corner-case response leaks {es-pull}105617[#105617] + +Search:: +* Add `internalClusterTest` for and fix leak in `ExpandSearchPhase` {es-pull}108562[#108562] (issue: {es-issue}108369[#108369]) +* Avoid attempting to load the same empty field twice in fetch phase {es-pull}107551[#107551] +* Bugfix: Disable eager loading `BitSetFilterCache` on Indexing Nodes {es-pull}105791[#105791] +* Cross-cluster painless/execute actions should check permissions only on target remote cluster {es-pull}105360[#105360] +* Fix error 500 on invalid `ParentIdQuery` {es-pull}105693[#105693] (issue: {es-issue}105366[#105366]) +* Fix range queries for float/half_float fields when bounds are out of type's range {es-pull}106691[#106691] +* Fixing NPE when requesting [_none_] for `stored_fields` {es-pull}104711[#104711] +* Fork when handling remote field-caps responses {es-pull}107370[#107370] +* Handle parallel calls to `createWeight` when profiling is on {es-pull}108041[#108041] (issues: {es-issue}104131[#104131], {es-issue}104235[#104235]) +* Harden field-caps request dispatcher {es-pull}108736[#108736] +* Replace `UnsupportedOperationException` with `IllegalArgumentException` for non-existing columns {es-pull}107038[#107038] +* Unable to retrieve multiple stored field values {es-pull}106575[#106575] +* Validate `model_id` is required when using the `learning_to_rank` rescorer {es-pull}107743[#107743] + +Security:: +* Disable validate when rewrite parameter is sent and the index access control list is non-null {es-pull}105709[#105709] +* Fix field caps and field level security {es-pull}106731[#106731] + +Snapshot/Restore:: +* Fix double-pausing shard snapshot {es-pull}109148[#109148] (issue: {es-issue}109143[#109143]) +* Treat 404 as empty register in `AzureBlobStore` {es-pull}108900[#108900] (issue: {es-issue}108504[#108504]) +* `SharedBlobCacheService.maybeFetchRegion` should use `computeCacheFileRegionSize` {es-pull}106685[#106685] + +TSDB:: +* Flip dynamic mapping condition when create tsid {es-pull}105636[#105636] + +Transform:: +* Consolidate permissions checks {es-pull}106413[#106413] (issue: {es-issue}105794[#105794]) +* Disable PIT for remote clusters {es-pull}107969[#107969] +* Make force-stopping the transform always remove persistent task from cluster state {es-pull}106989[#106989] (issue: {es-issue}106811[#106811]) +* Only trigger action once per thread {es-pull}107232[#107232] (issue: {es-issue}107215[#107215]) +* [Transform] Auto retry Transform start {es-pull}106243[#106243] + +Vector Search:: +* Fix multithreading copies in lib vec {es-pull}108802[#108802] +* [8.14] Fix multithreading copies in lib vec {es-pull}108810[#108810] + +[[deprecation-8.14.0]] +[float] +=== Deprecations + +Mapping:: +* Deprecate allowing `fields` in scenarios where it is ignored {es-pull}106031[#106031] + +[[enhancement-8.14.0]] +[float] +=== Enhancements + +Aggregations:: +* Add a `PriorityQueue` backed by `BigArrays` {es-pull}106361[#106361] +* All new `shard_seed` parameter for `random_sampler` agg {es-pull}104830[#104830] + +Allocation:: +* Add allocation stats {es-pull}105894[#105894] +* Add index forecasts to /_cat/allocation output {es-pull}97561[#97561] + +Application:: +* [Profiling] Add TopN Functions API {es-pull}106860[#106860] +* [Profiling] Allow to override index settings {es-pull}106172[#106172] +* [Profiling] Speed up serialization of flamegraph {es-pull}105779[#105779] + +Authentication:: +* Support Profile Activate with JWTs with client authn {es-pull}105439[#105439] (issue: {es-issue}105342[#105342]) + +Authorization:: +* Allow users to get status of own async search tasks {es-pull}106638[#106638] +* [Security Solution] Add `read` permission for third party agent indices for `kibana_system` {es-pull}107046[#107046] + +Data streams:: +* Add data stream lifecycle to kibana reporting template {es-pull}106259[#106259] + +ES|QL:: +* Add ES|QL Locate function {es-pull}106899[#106899] (issue: {es-issue}106818[#106818]) +* Add ES|QL signum function {es-pull}106866[#106866] +* Add status for enrich operator {es-pull}106036[#106036] +* Add two new OGC functions ST_X and ST_Y {es-pull}105768[#105768] +* Adjust array resizing in block builder {es-pull}106934[#106934] +* Bulk loading enrich fields in ESQL {es-pull}106796[#106796] +* ENRICH support for TEXT fields {es-pull}106435[#106435] (issue: {es-issue}105384[#105384]) +* ESQL: Add timers to many status results {es-pull}105421[#105421] +* ESQL: Allow grouping key inside stats expressions {es-pull}106579[#106579] +* ESQL: Introduce expression validation phase {es-pull}105477[#105477] (issue: {es-issue}105425[#105425]) +* ESQL: Log queries at debug level {es-pull}108257[#108257] +* ESQL: Regex improvements {es-pull}106429[#106429] +* ESQL: Sum of constants {es-pull}105454[#105454] +* ESQL: Support ST_DISJOINT {es-pull}107007[#107007] +* ESQL: Support partially folding CASE {es-pull}106094[#106094] +* ESQL: Use faster field caps {es-pull}105067[#105067] +* ESQL: extend BUCKET with spans {es-pull}107272[#107272] +* ESQL: perform a reduction on the data node {es-pull}106516[#106516] +* Expand support for ENRICH to full set supported by ES ingest processors {es-pull}106186[#106186] (issue: {es-issue}106162[#106162]) +* Introduce ordinal bytesref block {es-pull}106852[#106852] (issue: {es-issue}106387[#106387]) +* Leverage ordinals in enrich lookup {es-pull}107449[#107449] +* Serialize big array blocks {es-pull}106373[#106373] +* Serialize big array vectors {es-pull}106327[#106327] +* Specialize serialization for `ArrayVectors` {es-pull}105893[#105893] +* Specialize serialization of array blocks {es-pull}106102[#106102] +* Speed up serialization of `BytesRefArray` {es-pull}106053[#106053] +* Support ST_CONTAINS and ST_WITHIN {es-pull}106503[#106503] +* Support ST_INTERSECTS between geometry column and other geometry or string {es-pull}104907[#104907] (issue: {es-issue}104874[#104874]) + +Engine:: +* Add metric for calculating index flush time excluding waiting on locks {es-pull}107196[#107196] + +Highlighting:: +* Enable 'encoder' and 'tags_schema' highlighting settings at field level {es-pull}107224[#107224] (issue: {es-issue}94028[#94028]) + +ILM+SLM:: +* Add a flag to re-enable writes on the final index after an ILM shrink action. {es-pull}107121[#107121] (issue: {es-issue}106599[#106599]) + +Indices APIs:: +* Wait forever for `IndexTemplateRegistry` asset installation {es-pull}105985[#105985] + +Infra/CLI:: +* Enhance search tier GC options {es-pull}106526[#106526] +* Increase KDF iteration count in `KeyStoreWrapper` {es-pull}107107[#107107] + +Infra/Core:: +* Add pluggable `BuildVersion` in `NodeMetadata` {es-pull}105757[#105757] + +Infra/Metrics:: +* Infrastructure for metering the update requests {es-pull}105063[#105063] +* `DocumentParsingObserver` to accept an `indexName` to allow skipping system indices {es-pull}107041[#107041] + +Infra/Scripting:: +* String sha512() painless function {es-pull}99048[#99048] (issue: {es-issue}97691[#97691]) + +Ingest Node:: +* Add support for the 'Anonymous IP' database to the geoip processor {es-pull}107287[#107287] (issue: {es-issue}90789[#90789]) +* Add support for the 'Enterprise' database to the geoip processor {es-pull}107377[#107377] +* Adding `cache_stats` to geoip stats API {es-pull}107334[#107334] +* Support data streams in enrich policy indices {es-pull}107291[#107291] (issue: {es-issue}98836[#98836]) + +Machine Learning:: +* Add GET `_inference` for all inference endpoints {es-pull}107517[#107517] +* Added a timeout parameter to the inference API {es-pull}107242[#107242] +* Enable retrying on 500 error response from Cohere text embedding API {es-pull}105797[#105797] + +Mapping:: +* Make int8_hnsw our default index for new dense-vector fields {es-pull}106836[#106836] + +Ranking:: +* Add retrievers using the parser-only approach {es-pull}105470[#105470] + +Search:: +* Add Lucene spanish plural stemmer {es-pull}106952[#106952] +* Add `modelId` and `modelText` to `KnnVectorQueryBuilder` {es-pull}106068[#106068] +* Add a SIMD (Neon) optimised vector distance function for int8 {es-pull}106133[#106133] +* Add transport version for search load autoscaling {es-pull}106377[#106377] +* CCS with `minimize_roundtrips` performs incremental merges of each `SearchResponse` {es-pull}105781[#105781] +* Track ongoing search tasks {es-pull}107129[#107129] + +Security:: +* Invalidating cross cluster API keys requires `manage_security` {es-pull}107411[#107411] +* Show owner `realm_type` for returned API keys {es-pull}105629[#105629] + +Snapshot/Restore:: +* Add setting for max connections to S3 {es-pull}107533[#107533] +* Distinguish different snapshot failures by log level {es-pull}105622[#105622] + +Stats:: +* (API+) CAT Nodes alias for shard header to match CAT Allocation {es-pull}105847[#105847] +* Add total size in bytes to doc stats {es-pull}106840[#106840] (issue: {es-issue}97670[#97670]) + +TSDB:: +* Improve short-circuiting downsample execution {es-pull}106563[#106563] +* Support non-keyword dimensions as routing fields in TSDB {es-pull}105501[#105501] +* Text fields are stored by default in TSDB indices {es-pull}106338[#106338] (issue: {es-issue}97039[#97039]) + +Transform:: +* Check node shutdown before fail {es-pull}107358[#107358] (issue: {es-issue}100891[#100891]) +* Do not log error on node restart when the transform is already failed {es-pull}106171[#106171] (issue: {es-issue}106168[#106168]) + +[[feature-8.14.0]] +[float] +=== New features + +Application:: +* Allow `typed_keys` for search application Search API {es-pull}108007[#108007] +* [Connector API] Support cleaning up sync jobs when deleting a connector {es-pull}107253[#107253] + +ES|QL:: +* ESQL: Values aggregation function {es-pull}106065[#106065] (issue: {es-issue}103600[#103600]) +* ESQL: allow sorting by expressions and not only regular fields {es-pull}107158[#107158] +* Support ES|QL requests through the `NodeClient::execute` {es-pull}106244[#106244] + +Indices APIs:: +* Add granular error list to alias action response {es-pull}106514[#106514] (issue: {es-issue}94478[#94478]) + +Machine Learning:: +* Add Cohere rerank to `_inference` service {es-pull}106378[#106378] +* Add support for Azure OpenAI embeddings to inference service {es-pull}107178[#107178] +* Create default word based chunker {es-pull}107303[#107303] +* Text structure endpoints to determine the structure of a list of messages and of an indexed field {es-pull}105660[#105660] + +Mapping:: +* Flatten object mappings when subobjects is false {es-pull}103542[#103542] (issues: {es-issue}99860[#99860], {es-issue}103497[#103497]) + +Security:: +* Get and Query API Key with profile uid {es-pull}106531[#106531] + +Vector Search:: +* Adding support for hex-encoded byte vectors on knn-search {es-pull}105393[#105393] + +[[upgrade-8.14.0]] +[float] +=== Upgrades + +Infra/Core:: +* Upgrade jna to 5.12.1 {es-pull}105717[#105717] + +Ingest Node:: +* Updating the tika version to 2.9.1 in the ingest attachment plugin {es-pull}106315[#106315] + +Network:: +* Upgrade to Netty 4.1.107 {es-pull}105517[#105517] + +Packaging:: +* Update bundled JDK to Java 22 (again) {es-pull}108654[#108654] + diff --git a/docs/reference/release-notes/highlights.asciidoc b/docs/reference/release-notes/highlights.asciidoc index d39be07b0bf02..e6016fe438e24 100644 --- a/docs/reference/release-notes/highlights.asciidoc +++ b/docs/reference/release-notes/highlights.asciidoc @@ -44,38 +44,32 @@ faster indexing and similar retrieval latencies. {es-pull}103374[#103374] -[discrete] -[[query_phase_knn_supports_query_vector_builder]] -=== Query phase KNN now supports query_vector_builder -It is now possible to pass `model_text` and `model_id` within a `knn` query -in the [query DSL](https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-knn-query.html) to convert a text query into a dense vector and run the -nearest neighbor query on it, instead of requiring the dense vector to be -directly passed (within the `query_vector` parameter). Similar to the -[top-level knn query](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) (executed in the DFS phase), it is possible to supply -a `query_vector_builder` object containing a `text_embedding` object with -`model_text` (the text query to be converted into a dense vector) and -`model_id` (the identifier of a deployed model responsible for transforming -the text query into a dense vector). Note that an embedding model with the -referenced `model_id` needs to be [deployed on a ML node](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html). -in the cluster. +// end::notable-highlights[] -{es-pull}106068[#106068] [discrete] -[[simd_neon_optimised_vector_distance_function_for_merging_int8_scalar_quantized_vectors_has_been_added]] -=== A SIMD (Neon) optimised vector distance function for merging int8 Scalar Quantized vectors has been added -An optimised int8 vector distance implementation for aarch64 has been added. -This implementation is currently only used during merging. -The vector distance implementation outperforms Lucene's Pamana Vector -implementation for binary comparisons by approx 5x (depending on the number -of dimensions). It does so by means of SIMD (Neon) intrinsics compiled into a -separate native library and link by Panama's FFI. Comparisons are performed on -off-heap mmap'ed vector data. -Macro benchmarks, SO_Dense_Vector with scalar quantization enabled, shows -significant improvements in merge times, approximately 3 times faster. +[[new_custom_parser_for_iso_8601_datetimes]] +=== New custom parser for ISO-8601 datetimes +This introduces a new custom parser for ISO-8601 datetimes, for the `iso8601`, `strict_date_optional_time`, and +`strict_date_optional_time_nanos` built-in date formats. This provides a performance improvement over the +default Java date-time parsing. Whilst it maintains much of the same behaviour, +the new parser does not accept nonsensical date-time strings that have multiple fractional seconds fields +or multiple timezone specifiers. If the new parser fails to parse a string, it will then use the previous parser +to parse it. If a large proportion of the input data consists of these invalid strings, this may cause +a small performance degradation. If you wish to force the use of the old parsers regardless, +set the JVM property `es.datetime.java_time_parsers=true` on all ES nodes. -{es-pull}106133[#106133] +{es-pull}106486[#106486] -// end::notable-highlights[] +[discrete] +[[preview_support_for_connection_type_domain_isp_databases_in_geoip_processor]] +=== Preview: Support for the 'Connection Type, 'Domain', and 'ISP' databases in the geoip processor +As a Technical Preview, the {ref}/geoip-processor.html[`geoip`] processor can now use the commercial +https://dev.maxmind.com/geoip/docs/databases/connection-type[GeoIP2 'Connection Type'], +https://dev.maxmind.com/geoip/docs/databases/domain[GeoIP2 'Domain'], +and +https://dev.maxmind.com/geoip/docs/databases/isp[GeoIP2 'ISP'] +databases from MaxMind. +{es-pull}108683[#108683] From b19bab2e00b6911d7791dafa8730d8eb5a8429db Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Wed, 5 Jun 2024 15:22:11 -0400 Subject: [PATCH 22/30] ESQL: Fix a folding bug with MV_ZIP (#109404) In the process of migrating all tests off of the deprecated `AbstractScalarFunctionTestCase` I had to add some extra null tests to a few functions. This discovered a bug in MV_EXPAND where the explicit `MV_EXPAND(null, ["a", "b"])` would give different results then `a = null; MV_EXPAND(a, ["a", "b"])`. This fixes that and completes the migration off of `AbstractScalarFunctionTestCase`. --- .../src/main/resources/string.csv-spec | 35 ++++ .../function/scalar/multivalue/MvZip.java | 7 + .../AbstractScalarFunctionTestCase.java | 196 ------------------ .../function/scalar/math/AbsTests.java | 14 +- .../function/scalar/math/CeilTests.java | 22 +- .../function/scalar/math/LogTests.java | 14 +- .../function/scalar/math/PowTests.java | 14 +- .../scalar/multivalue/MvZipTests.java | 105 ++++++---- 8 files changed, 121 insertions(+), 286 deletions(-) delete mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractScalarFunctionTestCase.java diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec index 063b74584a28b..53d7d1fd0d352 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec @@ -839,6 +839,41 @@ emp_no:integer | full_name:keyword | full_name_2:keyword | job_positions:keyword 10005 | Kyoichi Maliniak | Maliniak,Kyoichi | null | [-2.14,13.07] | [-2.14,13.07] ; +mvZipLiteralNullDelim +required_capability: mv_sort + +FROM employees +| EVAL full_name = mv_zip(first_name, last_name, null) +| KEEP emp_no, full_name +| SORT emp_no +| LIMIT 5; + +emp_no:integer | full_name:keyword +10001 | null +10002 | null +10003 | null +10004 | null +10005 | null +; + +mvZipLiteralLongDelim +required_capability: mv_sort + +FROM employees +| EVAL full_name = mv_zip(first_name, last_name, " words words words ") +| KEEP emp_no, full_name +| SORT emp_no +| LIMIT 5; + +emp_no:integer | full_name:keyword +10001 | Georgi words words words Facello +10002 | Bezalel words words words Simmel +10003 | Parto words words words Bamford +10004 | Chirstian words words words Koblick +10005 | Kyoichi words words words Maliniak +; + + showTextFields from hosts | sort description, card, ip0, ip1 | where host == "beta" | keep host, host_group, description; ignoreOrder:true diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvZip.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvZip.java index 6298057b16013..4f42858cbedba 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvZip.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvZip.java @@ -14,6 +14,7 @@ import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.function.OptionalArgument; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -93,6 +94,12 @@ public boolean foldable() { return mvLeft.foldable() && mvRight.foldable() && (delim == null || delim.foldable()); } + @Override + public Nullability nullable() { + // Nullability.TRUE means if *any* parameter is null we return null. We're only null if the first two are null. + return Nullability.FALSE; + } + @Override public EvalOperator.ExpressionEvaluator.Factory toEvaluator( Function toEvaluator diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractScalarFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractScalarFunctionTestCase.java deleted file mode 100644 index c8fe9e536beea..0000000000000 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractScalarFunctionTestCase.java +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.expression.function.scalar; - -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; -import org.elasticsearch.xpack.esql.core.tree.Location; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; -import org.elasticsearch.xpack.esql.type.EsqlDataTypes; -import org.hamcrest.Matcher; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Locale; -import java.util.Set; -import java.util.stream.Stream; - -import static org.hamcrest.Matchers.equalTo; - -/** - * Base class for function tests. - * @deprecated extends from {@link AbstractFunctionTestCase} instead - * and {@link AbstractFunctionTestCase#errorsForCasesWithoutExamples}. - */ -@Deprecated -public abstract class AbstractScalarFunctionTestCase extends AbstractFunctionTestCase { - /** - * Describe supported arguments. Build each argument with - * {@link #required} or {@link #optional}. - */ - protected abstract List argSpec(); - - /** - * The data type that applying this function to arguments of this type should produce. - */ - protected abstract DataType expectedType(List argTypes); - - /** - * Define a required argument. - */ - protected final ArgumentSpec required(DataType... validTypes) { - return new ArgumentSpec(false, withNullAndSorted(validTypes)); - } - - /** - * Define an optional argument. - */ - protected final ArgumentSpec optional(DataType... validTypes) { - return new ArgumentSpec(true, withNullAndSorted(validTypes)); - } - - private Set withNullAndSorted(DataType[] validTypes) { - Set realValidTypes = new LinkedHashSet<>(); - Arrays.stream(validTypes).sorted(Comparator.comparing(DataType::nameUpper)).forEach(realValidTypes::add); - realValidTypes.add(DataType.NULL); - return realValidTypes; - } - - public Set sortedTypesSet(DataType[] validTypes, DataType... additionalTypes) { - Set mergedSet = new LinkedHashSet<>(); - Stream.concat(Stream.of(validTypes), Stream.of(additionalTypes)) - .sorted(Comparator.comparing(DataType::nameUpper)) - .forEach(mergedSet::add); - return mergedSet; - } - - /** - * All integer types (long, int, short, byte). For passing to {@link #required} or {@link #optional}. - */ - protected static DataType[] integers() { - return DataType.types().stream().filter(DataType::isInteger).toArray(DataType[]::new); - } - - /** - * All rational types (double, float, whatever). For passing to {@link #required} or {@link #optional}. - */ - protected static DataType[] rationals() { - return DataType.types().stream().filter(DataType::isRational).toArray(DataType[]::new); - } - - /** - * All numeric types (integers and rationals.) For passing to {@link #required} or {@link #optional}. - */ - protected static DataType[] numerics() { - return DataType.types().stream().filter(DataType::isNumeric).toArray(DataType[]::new); - } - - protected final DataType[] representableNumerics() { - // TODO numeric should only include representable numbers but that is a change for a followup - return DataType.types().stream().filter(DataType::isNumeric).filter(EsqlDataTypes::isRepresentable).toArray(DataType[]::new); - } - - protected record ArgumentSpec(boolean optional, Set validTypes) {} - - public final void testResolveType() { - List specs = argSpec(); - for (int mutArg = 0; mutArg < specs.size(); mutArg++) { - for (DataType mutArgType : DataType.types()) { - List args = new ArrayList<>(specs.size()); - for (int arg = 0; arg < specs.size(); arg++) { - if (mutArg == arg) { - args.add(new Literal(new Source(Location.EMPTY, "arg" + arg), "", mutArgType)); - } else { - args.add(new Literal(new Source(Location.EMPTY, "arg" + arg), "", specs.get(arg).validTypes.iterator().next())); - } - } - assertResolution(specs, args, mutArg, mutArgType, specs.get(mutArg).validTypes.contains(mutArgType)); - int optionalIdx = specs.size() - 1; - while (optionalIdx > 0 && specs.get(optionalIdx).optional()) { - args.remove(optionalIdx--); - assertResolution( - specs, - args, - mutArg, - mutArgType, - args.size() <= mutArg || specs.get(mutArg).validTypes.contains(mutArgType) - ); - } - } - } - } - - private void assertResolution(List specs, List args, int mutArg, DataType mutArgType, boolean shouldBeValid) { - Expression exp = build(new Source(Location.EMPTY, "exp"), args); - logger.info("checking {} is {}", exp.nodeString(), shouldBeValid ? "valid" : "invalid"); - if (shouldBeValid) { - assertResolveTypeValid(exp, expectedType(args.stream().map(Expression::dataType).toList())); - return; - } - Expression.TypeResolution resolution = exp.typeResolved(); - assertFalse(exp.nodeString(), resolution.resolved()); - assertThat(exp.nodeString(), resolution.message(), badTypeError(specs, mutArg, mutArgType)); - } - - protected Matcher badTypeError(List spec, int badArgPosition, DataType badArgType) { - String ordinal = spec.size() == 1 - ? "" - : TypeResolutions.ParamOrdinal.fromIndex(badArgPosition).name().toLowerCase(Locale.ROOT) + " "; - return equalTo( - ordinal - + "argument of [exp] must be [" - + expectedTypeName(spec.get(badArgPosition).validTypes()) - + "], found value [arg" - + badArgPosition - + "] type [" - + badArgType.typeName() - + "]" - ); - } - - private String expectedTypeName(Set validTypes) { - List withoutNull = validTypes.stream().filter(t -> t != DataType.NULL).toList(); - if (withoutNull.equals(Arrays.asList(strings()))) { - return "string"; - } - if (withoutNull.equals(Arrays.asList(integers())) || withoutNull.equals(List.of(DataType.INTEGER))) { - return "integer"; - } - if (withoutNull.equals(Arrays.asList(rationals()))) { - return "double"; - } - if (withoutNull.equals(Arrays.asList(numerics())) || withoutNull.equals(Arrays.asList(representableNumerics()))) { - return "numeric"; - } - if (withoutNull.equals(List.of(DataType.DATETIME))) { - return "datetime"; - } - if (withoutNull.equals(List.of(DataType.IP))) { - return "ip"; - } - List negations = Stream.concat(Stream.of(numerics()), Stream.of(DataType.DATE_PERIOD, DataType.TIME_DURATION)) - .sorted(Comparator.comparing(DataType::nameUpper)) - .toList(); - if (withoutNull.equals(negations)) { - return "numeric, date_period or time_duration"; - } - if (validTypes.equals(Set.copyOf(Arrays.asList(representableTypes())))) { - return "representable"; - } - if (validTypes.equals(Set.copyOf(Arrays.asList(representableNonSpatialTypes())))) { - return "representableNonSpatial"; - } - throw new IllegalArgumentException("can't guess expected type for " + validTypes); - } -} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/AbsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/AbsTests.java index 5158fb9aad372..63642a01fa117 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/AbsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/AbsTests.java @@ -13,8 +13,8 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; -import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractScalarFunctionTestCase; import java.math.BigInteger; import java.util.ArrayList; @@ -23,7 +23,7 @@ import static org.hamcrest.Matchers.equalTo; -public class AbsTests extends AbstractScalarFunctionTestCase { +public class AbsTests extends AbstractFunctionTestCase { @ParametersFactory public static Iterable parameters() { List suppliers = new ArrayList<>(); @@ -74,14 +74,4 @@ public AbsTests(@Name("TestCase") Supplier testCaseSu protected Expression build(Source source, List args) { return new Abs(source, args.get(0)); } - - @Override - protected List argSpec() { - return List.of(required(numerics())); - } - - @Override - protected DataType expectedType(List argTypes) { - return argTypes.get(0); - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/CeilTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/CeilTests.java index f562ccbf0071b..735113c34ca1b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/CeilTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/CeilTests.java @@ -13,8 +13,8 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; -import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractScalarFunctionTestCase; import java.math.BigInteger; import java.util.ArrayList; @@ -23,7 +23,7 @@ import static org.hamcrest.Matchers.equalTo; -public class CeilTests extends AbstractScalarFunctionTestCase { +public class CeilTests extends AbstractFunctionTestCase { public CeilTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @@ -31,7 +31,7 @@ public CeilTests(@Name("TestCase") Supplier testCaseS @ParametersFactory public static Iterable parameters() { List suppliers = new ArrayList<>(); - suppliers.addAll(List.of(new TestCaseSupplier("large double value", () -> { + suppliers.addAll(List.of(new TestCaseSupplier("large double value", List.of(DataType.DOUBLE), () -> { double arg = 1 / randomDouble(); return new TestCaseSupplier.TestCase( List.of(new TestCaseSupplier.TypedData(arg, DataType.DOUBLE, "arg")), @@ -39,7 +39,7 @@ public static Iterable parameters() { DataType.DOUBLE, equalTo(Math.ceil(arg)) ); - }), new TestCaseSupplier("integer value", () -> { + }), new TestCaseSupplier("integer value", List.of(DataType.INTEGER), () -> { int arg = randomInt(); return new TestCaseSupplier.TestCase( List.of(new TestCaseSupplier.TypedData(arg, DataType.INTEGER, "arg")), @@ -47,7 +47,7 @@ public static Iterable parameters() { DataType.INTEGER, equalTo(arg) ); - }), new TestCaseSupplier("long value", () -> { + }), new TestCaseSupplier("long value", List.of(DataType.LONG), () -> { long arg = randomLong(); return new TestCaseSupplier.TestCase( List.of(new TestCaseSupplier.TypedData(arg, DataType.LONG, "arg")), @@ -66,17 +66,7 @@ public static Iterable parameters() { UNSIGNED_LONG_MAX, List.of() ); - return parameterSuppliersFromTypedData(suppliers); - } - - @Override - protected DataType expectedType(List argTypes) { - return argTypes.get(0); - } - - @Override - protected List argSpec() { - return List.of(required(numerics())); + return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(anyNullIsNull(false, suppliers))); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/LogTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/LogTests.java index a25fc66ab2d73..ce53fdbfc1851 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/LogTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/LogTests.java @@ -13,13 +13,13 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; -import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractScalarFunctionTestCase; import java.util.List; import java.util.function.Supplier; -public class LogTests extends AbstractScalarFunctionTestCase { +public class LogTests extends AbstractFunctionTestCase { public LogTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @@ -194,16 +194,6 @@ public static Iterable parameters() { return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers)); } - @Override - protected DataType expectedType(List argTypes) { - return DataType.DOUBLE; - } - - @Override - protected List argSpec() { - return List.of(optional(numerics()), required(numerics())); - } - @Override protected Expression build(Source source, List args) { return new Log(source, args.get(0), args.size() > 1 ? args.get(1) : null); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java index 855e3070d442f..545e7c14ff2b2 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java @@ -13,13 +13,13 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; -import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractScalarFunctionTestCase; import java.util.List; import java.util.function.Supplier; -public class PowTests extends AbstractScalarFunctionTestCase { +public class PowTests extends AbstractFunctionTestCase { public PowTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @@ -80,16 +80,6 @@ public static Iterable parameters() { return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers)); } - @Override - protected DataType expectedType(List argTypes) { - return DataType.DOUBLE; - } - - @Override - protected List argSpec() { - return List.of(required(numerics()), required(numerics())); - } - @Override protected Expression build(Source source, List args) { return new Pow(source, args.get(0), args.get(1)); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvZipTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvZipTests.java index e465d72555e4e..30fe420f29960 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvZipTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvZipTests.java @@ -15,8 +15,8 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; -import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractScalarFunctionTestCase; import java.util.ArrayList; import java.util.List; @@ -25,59 +25,79 @@ import static java.lang.Math.max; import static org.hamcrest.Matchers.equalTo; -public class MvZipTests extends AbstractScalarFunctionTestCase { +public class MvZipTests extends AbstractFunctionTestCase { public MvZipTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @ParametersFactory public static Iterable parameters() { + // Note that any null is *not* null, so we explicitly test with nulls List suppliers = new ArrayList<>(); - suppliers.add(new TestCaseSupplier(List.of(DataType.KEYWORD, DataType.KEYWORD, DataType.KEYWORD), () -> { - List left = randomList(1, 3, () -> randomLiteral(DataType.KEYWORD).value()); - List right = randomList(1, 3, () -> randomLiteral(DataType.KEYWORD).value()); - String delim = randomAlphaOfLengthBetween(1, 1); + for (DataType leftType : DataType.types()) { + if (leftType != DataType.NULL && DataType.isString(leftType) == false) { + continue; + } + for (DataType rightType : DataType.types()) { + if (rightType != DataType.NULL && DataType.isString(rightType) == false) { + continue; + } + for (DataType delimType : DataType.types()) { + if (delimType != DataType.NULL && DataType.isString(delimType) == false) { + continue; + } + suppliers.add(supplier(leftType, rightType, delimType)); + } + suppliers.add(supplier(leftType, rightType)); + } + } + + return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers)); + } + + private static TestCaseSupplier supplier(DataType leftType, DataType rightType, DataType delimType) { + return new TestCaseSupplier(List.of(leftType, rightType, delimType), () -> { + List left = randomList(leftType); + List right = randomList(rightType); + BytesRef delim = delimType == DataType.NULL ? null : new BytesRef(randomAlphaOfLength(1)); + List expected = calculateExpected(left, right, delim); return new TestCaseSupplier.TestCase( List.of( - new TestCaseSupplier.TypedData(left, DataType.KEYWORD, "mvLeft"), - new TestCaseSupplier.TypedData(right, DataType.KEYWORD, "mvRight"), - new TestCaseSupplier.TypedData(delim, DataType.KEYWORD, "delim") + new TestCaseSupplier.TypedData(left, leftType, "mvLeft"), + new TestCaseSupplier.TypedData(right, rightType, "mvRight"), + new TestCaseSupplier.TypedData(delim, delimType, "delim") ), "MvZipEvaluator[leftField=Attribute[channel=0], rightField=Attribute[channel=1], delim=Attribute[channel=2]]", DataType.KEYWORD, - equalTo(expected.size() == 1 ? expected.iterator().next() : expected) + equalTo(expected == null ? null : expected.size() == 1 ? expected.iterator().next() : expected) ); - })); + }); + } - suppliers.add(new TestCaseSupplier(List.of(DataType.TEXT, DataType.TEXT, DataType.TEXT), () -> { - List left = randomList(1, 10, () -> randomLiteral(DataType.TEXT).value()); - List right = randomList(1, 10, () -> randomLiteral(DataType.TEXT).value()); - String delim = randomAlphaOfLengthBetween(1, 1); - List expected = calculateExpected(left, right, delim); + private static TestCaseSupplier supplier(DataType leftType, DataType rightType) { + return new TestCaseSupplier(List.of(leftType, rightType), () -> { + List left = randomList(leftType); + List right = randomList(rightType); + + List expected = calculateExpected(left, right, new BytesRef(",")); return new TestCaseSupplier.TestCase( List.of( - new TestCaseSupplier.TypedData(left, DataType.TEXT, "mvLeft"), - new TestCaseSupplier.TypedData(right, DataType.TEXT, "mvRight"), - new TestCaseSupplier.TypedData(delim, DataType.TEXT, "delim") + new TestCaseSupplier.TypedData(left, leftType, "mvLeft"), + new TestCaseSupplier.TypedData(right, rightType, "mvRight") ), - "MvZipEvaluator[leftField=Attribute[channel=0], rightField=Attribute[channel=1], delim=Attribute[channel=2]]", + "MvZipEvaluator[leftField=Attribute[channel=0], rightField=Attribute[channel=1], delim=LiteralsEvaluator[lit=,]]", DataType.KEYWORD, - equalTo(expected.size() == 1 ? expected.iterator().next() : expected) + equalTo(expected == null ? null : expected.size() == 1 ? expected.iterator().next() : expected) ); - })); - - return parameterSuppliersFromTypedData(suppliers); + }); } - @Override - protected DataType expectedType(List argTypes) { - return DataType.KEYWORD; - } - - @Override - protected List argSpec() { - return List.of(required(strings()), required(strings()), optional(strings())); + private static List randomList(DataType type) { + if (type == DataType.NULL) { + return null; + } + return randomList(1, 3, () -> new BytesRef(randomAlphaOfLength(5))); } @Override @@ -85,27 +105,36 @@ protected Expression build(Source source, List args) { return new MvZip(source, args.get(0), args.get(1), args.size() > 2 ? args.get(2) : null); } - private static List calculateExpected(List left, List right, String delim) { + private static List calculateExpected(List left, List right, BytesRef delim) { + if (delim == null) { + return null; + } + if (left == null) { + return right; + } + if (right == null) { + return left; + } List expected = new ArrayList<>(max(left.size(), right.size())); int i = 0, j = 0; while (i < left.size() && j < right.size()) { BytesRefBuilder work = new BytesRefBuilder(); - work.append((BytesRef) left.get(i)); - work.append(new BytesRef(delim)); - work.append((BytesRef) right.get(j)); + work.append(left.get(i)); + work.append(delim); + work.append(right.get(j)); expected.add(work.get()); i++; j++; } while (i < left.size()) { BytesRefBuilder work = new BytesRefBuilder(); - work.append((BytesRef) left.get(i)); + work.append(left.get(i)); expected.add(work.get()); i++; } while (j < right.size()) { BytesRefBuilder work = new BytesRefBuilder(); - work.append((BytesRef) right.get(j)); + work.append(right.get(j)); expected.add(work.get()); j++; } From 7ea0d4be5480520962193032e52b354239b0ef3a Mon Sep 17 00:00:00 2001 From: James Baiera Date: Wed, 5 Jun 2024 16:51:22 -0400 Subject: [PATCH 23/30] Add support for failure stores in ILM (#108741) Failure stores on a data stream will inherit the ILM policy of the parent data stream by default. This PR adds logic to ensure failure stores are properly accounted for in data stream related ILM operations. --- .../cluster/metadata/DataStream.java | 34 +++ .../cluster/metadata/DataStreamTests.java | 81 ++++++ .../MetadataDeleteIndexServiceTests.java | 70 +++++ .../ilm/CheckNotDataStreamWriteIndexStep.java | 7 +- .../xpack/core/ilm/DeleteStep.java | 18 +- .../ReplaceDataStreamBackingIndexStep.java | 13 +- .../xpack/core/ilm/RolloverStep.java | 18 +- .../core/ilm/WaitForActiveShardsStep.java | 13 +- .../core/ilm/WaitForRolloverReadyStep.java | 32 ++- .../CheckNoDataStreamWriteIndexStepTests.java | 52 +++- .../xpack/core/ilm/DeleteStepTests.java | 251 +++++++++++++++++- ...eplaceDataStreamBackingIndexStepTests.java | 116 ++++++-- .../xpack/core/ilm/RolloverStepTests.java | 48 +++- .../core/ilm/WaitForActiveShardsTests.java | 46 +++- .../ilm/WaitForRolloverReadyStepTests.java | 70 ++++- 15 files changed, 782 insertions(+), 87 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java index ae01b7c064749..bf1d9462ab89f 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java @@ -636,6 +636,40 @@ public DataStream replaceBackingIndex(Index existingBackingIndex, Index newBacki .build(); } + /** + * Replaces the specified failure store index with a new index and returns a new {@code DataStream} instance with + * the modified backing indices. An {@code IllegalArgumentException} is thrown if the index to be replaced + * is not a failure store index for this data stream or if it is the {@code DataStream}'s failure store write index. + * + * @param existingFailureIndex the failure store index to be replaced + * @param newFailureIndex the new index that will be part of the {@code DataStream} + * @return new {@code DataStream} instance with failure store indices that contain replacement index instead of the specified + * existing index. + */ + public DataStream replaceFailureStoreIndex(Index existingFailureIndex, Index newFailureIndex) { + List currentFailureIndices = new ArrayList<>(failureIndices.indices); + int failureIndexPosition = currentFailureIndices.indexOf(existingFailureIndex); + if (failureIndexPosition == -1) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "index [%s] is not part of data stream [%s] failure store", existingFailureIndex.getName(), name) + ); + } + if (failureIndices.indices.size() == (failureIndexPosition + 1)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "cannot replace failure index [%s] of data stream [%s] because it is the failure store write index", + existingFailureIndex.getName(), + name + ) + ); + } + currentFailureIndices.set(failureIndexPosition, newFailureIndex); + return copy().setFailureIndices(this.failureIndices.copy().setIndices(currentFailureIndices).build()) + .setGeneration(generation + 1) + .build(); + } + /** * Adds the specified index as a backing index and returns a new {@code DataStream} instance with the new combination * of backing indices. diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java index 1c1f6b314fa70..0277855db9c4c 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java @@ -731,6 +731,15 @@ public void testReplaceBackingIndexThrowsExceptionIfIndexNotPartOfDataStream() { expectThrows(IllegalArgumentException.class, () -> original.replaceBackingIndex(standaloneIndex, newBackingIndex)); } + public void testReplaceBackingIndexThrowsExceptionIfIndexPartOfFailureStore() { + DataStream original = createRandomDataStream(); + int indexToReplace = randomIntBetween(1, original.getFailureIndices().getIndices().size() - 1) - 1; + + Index failureIndex = original.getFailureIndices().getIndices().get(indexToReplace); + Index newBackingIndex = new Index("replacement-index", UUIDs.randomBase64UUID(random())); + expectThrows(IllegalArgumentException.class, () -> original.replaceBackingIndex(failureIndex, newBackingIndex)); + } + public void testReplaceBackingIndexThrowsExceptionIfReplacingWriteIndex() { int numBackingIndices = randomIntBetween(2, 32); int writeIndexPosition = numBackingIndices - 1; @@ -761,6 +770,78 @@ public void testReplaceBackingIndexThrowsExceptionIfReplacingWriteIndex() { ); } + public void testReplaceFailureIndex() { + DataStream original = createRandomDataStream(); + int indexToReplace = randomIntBetween(1, original.getFailureIndices().getIndices().size() - 1) - 1; + + Index newFailureIndex = new Index("replacement-index", UUIDs.randomBase64UUID(random())); + DataStream updated = original.replaceFailureStoreIndex( + original.getFailureIndices().getIndices().get(indexToReplace), + newFailureIndex + ); + assertThat(updated.getName(), equalTo(original.getName())); + assertThat(updated.getGeneration(), equalTo(original.getGeneration() + 1)); + assertThat(updated.getFailureIndices().getIndices().size(), equalTo(original.getFailureIndices().getIndices().size())); + assertThat(updated.getFailureIndices().getIndices().get(indexToReplace), equalTo(newFailureIndex)); + + for (int i = 0; i < original.getFailureIndices().getIndices().size(); i++) { + if (i != indexToReplace) { + assertThat(updated.getFailureIndices().getIndices().get(i), equalTo(original.getFailureIndices().getIndices().get(i))); + } + } + } + + public void testReplaceFailureIndexThrowsExceptionIfIndexNotPartOfDataStream() { + DataStream original = createRandomDataStream(); + + Index standaloneIndex = new Index("index-foo", UUIDs.randomBase64UUID(random())); + Index newFailureIndex = new Index("replacement-index", UUIDs.randomBase64UUID(random())); + expectThrows(IllegalArgumentException.class, () -> original.replaceFailureStoreIndex(standaloneIndex, newFailureIndex)); + } + + public void testReplaceFailureIndexThrowsExceptionIfIndexPartOfBackingIndices() { + DataStream original = createRandomDataStream(); + int indexToReplace = randomIntBetween(1, original.getIndices().size() - 1) - 1; + + Index backingIndex = original.getIndices().get(indexToReplace); + Index newFailureIndex = new Index("replacement-index", UUIDs.randomBase64UUID(random())); + expectThrows(IllegalArgumentException.class, () -> original.replaceFailureStoreIndex(backingIndex, newFailureIndex)); + } + + public void testReplaceFailureIndexThrowsExceptionIfReplacingWriteIndex() { + int numFailureIndices = randomIntBetween(2, 32); + int writeIndexPosition = numFailureIndices - 1; + String dataStreamName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + long ts = System.currentTimeMillis(); + + List indices = new ArrayList<>(1); + indices.add(new Index(DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts), UUIDs.randomBase64UUID(random()))); + + List failureIndices = new ArrayList<>(numFailureIndices); + for (int i = 1; i <= numFailureIndices; i++) { + failureIndices.add(new Index(DataStream.getDefaultFailureStoreName(dataStreamName, i, ts), UUIDs.randomBase64UUID(random()))); + } + int generation = randomBoolean() ? numFailureIndices : numFailureIndices + randomIntBetween(1, 5); + DataStream original = newInstance(dataStreamName, indices, generation, null, false, null, failureIndices); + + Index newBackingIndex = new Index("replacement-index", UUIDs.randomBase64UUID(random())); + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> original.replaceFailureStoreIndex(failureIndices.get(writeIndexPosition), newBackingIndex) + ); + assertThat( + e.getMessage(), + equalTo( + String.format( + Locale.ROOT, + "cannot replace failure index [%s] of data stream [%s] because it is the failure store write index", + failureIndices.get(writeIndexPosition).getName(), + dataStreamName + ) + ) + ); + } + public void testSnapshot() { var preSnapshotDataStream = DataStreamTestHelper.randomInstance(); var indicesToRemove = randomSubsetOf(preSnapshotDataStream.getIndices()); diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDeleteIndexServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDeleteIndexServiceTests.java index b7bd54eef2c70..344acb7a8ff40 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDeleteIndexServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDeleteIndexServiceTests.java @@ -274,6 +274,76 @@ public void testDeleteCurrentWriteIndexForDataStream() { ); } + public void testDeleteMultipleFailureIndexForDataStream() { + int numBackingIndices = randomIntBetween(3, 5); + int numBackingIndicesToDelete = randomIntBetween(2, numBackingIndices - 1); + String dataStreamName = randomAlphaOfLength(6).toLowerCase(Locale.ROOT); + long ts = System.currentTimeMillis(); + ClusterState before = DataStreamTestHelper.getClusterStateWithDataStreams( + List.of(new Tuple<>(dataStreamName, numBackingIndices)), + List.of(), + ts, + Settings.EMPTY, + 1, + false, + true + ); + + List indexNumbersToDelete = randomSubsetOf( + numBackingIndicesToDelete, + IntStream.rangeClosed(1, numBackingIndices - 1).boxed().toList() + ); + + Set indicesToDelete = new HashSet<>(); + for (int k : indexNumbersToDelete) { + indicesToDelete.add(before.metadata().index(DataStream.getDefaultFailureStoreName(dataStreamName, k, ts)).getIndex()); + } + ClusterState after = MetadataDeleteIndexService.deleteIndices(before, indicesToDelete, Settings.EMPTY); + + DataStream dataStream = after.metadata().dataStreams().get(dataStreamName); + assertThat(dataStream, notNullValue()); + assertThat(dataStream.getFailureIndices().getIndices().size(), equalTo(numBackingIndices - indexNumbersToDelete.size())); + for (Index i : indicesToDelete) { + assertThat(after.metadata().getIndices().get(i.getName()), nullValue()); + assertFalse(dataStream.getFailureIndices().getIndices().contains(i)); + } + assertThat(after.metadata().getIndices().size(), equalTo((2 * numBackingIndices) - indexNumbersToDelete.size())); + } + + public void testDeleteCurrentWriteFailureIndexForDataStream() { + int numBackingIndices = randomIntBetween(1, 5); + String dataStreamName = randomAlphaOfLength(6).toLowerCase(Locale.ROOT); + long ts = System.currentTimeMillis(); + ClusterState before = DataStreamTestHelper.getClusterStateWithDataStreams( + List.of(new Tuple<>(dataStreamName, numBackingIndices)), + List.of(), + ts, + Settings.EMPTY, + 1, + false, + true + ); + + Index indexToDelete = before.metadata() + .index(DataStream.getDefaultFailureStoreName(dataStreamName, numBackingIndices, ts)) + .getIndex(); + Exception e = expectThrows( + IllegalArgumentException.class, + () -> MetadataDeleteIndexService.deleteIndices(before, Set.of(indexToDelete), Settings.EMPTY) + ); + + assertThat( + e.getMessage(), + containsString( + "index [" + + indexToDelete.getName() + + "] is the failure store write index for data stream [" + + dataStreamName + + "] and cannot be deleted" + ) + ); + } + private ClusterState clusterState(String index) { IndexMetadata indexMetadata = IndexMetadata.builder(index) .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersionUtils.randomVersion(random()))) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/CheckNotDataStreamWriteIndexStep.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/CheckNotDataStreamWriteIndexStep.java index e716a18738bca..28b04bc9614bb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/CheckNotDataStreamWriteIndexStep.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/CheckNotDataStreamWriteIndexStep.java @@ -62,14 +62,15 @@ public Result isConditionMet(Index index, ClusterState clusterState) { assert indexAbstraction != null : "invalid cluster metadata. index [" + indexName + "] was not found"; DataStream dataStream = indexAbstraction.getParentDataStream(); if (dataStream != null) { - assert dataStream.getWriteIndex() != null : dataStream.getName() + " has no write index"; - if (dataStream.getWriteIndex().equals(index)) { + boolean isFailureStoreWriteIndex = index.equals(dataStream.getFailureStoreWriteIndex()); + if (isFailureStoreWriteIndex || dataStream.getWriteIndex().equals(index)) { String errorMessage = String.format( Locale.ROOT, - "index [%s] is the write index for data stream [%s], pausing " + "index [%s] is the%s write index for data stream [%s], pausing " + "ILM execution of lifecycle [%s] until this index is no longer the write index for the data stream via manual or " + "automated rollover", indexName, + isFailureStoreWriteIndex ? " failure store" : "", dataStream.getName(), policyName ); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/DeleteStep.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/DeleteStep.java index ba6b6f9366c61..282f32da28a6b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/DeleteStep.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/DeleteStep.java @@ -17,6 +17,7 @@ import org.elasticsearch.cluster.metadata.IndexAbstraction; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.Index; import java.util.Locale; @@ -40,13 +41,17 @@ public void performDuringNoSnapshot(IndexMetadata indexMetadata, ClusterState cu DataStream dataStream = indexAbstraction.getParentDataStream(); if (dataStream != null) { - assert dataStream.getWriteIndex() != null : dataStream.getName() + " has no write index"; + Index failureStoreWriteIndex = dataStream.getFailureStoreWriteIndex(); + boolean isFailureStoreWriteIndex = failureStoreWriteIndex != null && indexName.equals(failureStoreWriteIndex.getName()); // using index name equality across this if/else branch as the UUID of the index might change via restoring a data stream // with one index from snapshot - if (dataStream.getIndices().size() == 1 && dataStream.getWriteIndex().getName().equals(indexName)) { - // This is the last index in the data stream, the entire stream - // needs to be deleted, because we can't have an empty data stream + if (dataStream.getIndices().size() == 1 + && isFailureStoreWriteIndex == false + && dataStream.getWriteIndex().getName().equals(indexName)) { + // This is the last backing index in the data stream, and it's being deleted because the policy doesn't have a rollover + // phase. The entire stream needs to be deleted, because we can't have an empty list of data stream backing indices. + // We do this even if there are multiple failure store indices because otherwise we would never delete the index. DeleteDataStreamAction.Request deleteReq = new DeleteDataStreamAction.Request(new String[] { dataStream.getName() }); getClient().execute( DeleteDataStreamAction.INSTANCE, @@ -54,13 +59,14 @@ public void performDuringNoSnapshot(IndexMetadata indexMetadata, ClusterState cu listener.delegateFailureAndWrap((l, response) -> l.onResponse(null)) ); return; - } else if (dataStream.getWriteIndex().getName().equals(indexName)) { + } else if (isFailureStoreWriteIndex || dataStream.getWriteIndex().getName().equals(indexName)) { String errorMessage = String.format( Locale.ROOT, - "index [%s] is the write index for data stream [%s]. " + "index [%s] is the%s write index for data stream [%s]. " + "stopping execution of lifecycle [%s] as a data stream's write index cannot be deleted. manually rolling over the" + " index will resume the execution of the policy as the index will not be the data stream's write index anymore", indexName, + isFailureStoreWriteIndex ? " failure store" : "", dataStream.getName(), policyName ); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/ReplaceDataStreamBackingIndexStep.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/ReplaceDataStreamBackingIndexStep.java index 9de08c8693a12..3962768e94212 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/ReplaceDataStreamBackingIndexStep.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/ReplaceDataStreamBackingIndexStep.java @@ -86,14 +86,15 @@ public ClusterState performAction(Index index, ClusterState clusterState) { throw new IllegalStateException(errorMessage); } - assert dataStream.getWriteIndex() != null : dataStream.getName() + " has no write index"; - if (dataStream.getWriteIndex().equals(index)) { + boolean isFailureStoreWriteIndex = index.equals(dataStream.getFailureStoreWriteIndex()); + if (isFailureStoreWriteIndex || dataStream.getWriteIndex().equals(index)) { String errorMessage = String.format( Locale.ROOT, - "index [%s] is the write index for data stream [%s], pausing " + "index [%s] is the%s write index for data stream [%s], pausing " + "ILM execution of lifecycle [%s] until this index is no longer the write index for the data stream via manual or " + "automated rollover", originalIndex, + isFailureStoreWriteIndex ? " failure store" : "", dataStream.getName(), policyName ); @@ -114,8 +115,10 @@ public ClusterState performAction(Index index, ClusterState clusterState) { throw new IllegalStateException(errorMessage); } - Metadata.Builder newMetaData = Metadata.builder(clusterState.getMetadata()) - .put(dataStream.replaceBackingIndex(index, targetIndexMetadata.getIndex())); + DataStream updatedDataStream = dataStream.isFailureStoreIndex(originalIndex) + ? dataStream.replaceFailureStoreIndex(index, targetIndexMetadata.getIndex()) + : dataStream.replaceBackingIndex(index, targetIndexMetadata.getIndex()); + Metadata.Builder newMetaData = Metadata.builder(clusterState.getMetadata()).put(updatedDataStream); return ClusterState.builder(clusterState).metadata(newMetaData).build(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/RolloverStep.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/RolloverStep.java index 26300f646d617..3e6c00eeadba4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/RolloverStep.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/RolloverStep.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.indices.rollover.RolloverRequest; import org.elasticsearch.action.support.ActiveShardCount; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateObserver; @@ -57,13 +58,16 @@ public void performAction( IndexAbstraction indexAbstraction = currentClusterState.metadata().getIndicesLookup().get(indexName); assert indexAbstraction != null : "expected the index " + indexName + " to exist in the lookup but it didn't"; final String rolloverTarget; + final boolean targetFailureStore; DataStream dataStream = indexAbstraction.getParentDataStream(); if (dataStream != null) { - assert dataStream.getWriteIndex() != null : "datastream " + dataStream.getName() + " has no write index"; - if (dataStream.getWriteIndex().equals(indexMetadata.getIndex()) == false) { + boolean isFailureStoreWriteIndex = indexMetadata.getIndex().equals(dataStream.getFailureStoreWriteIndex()); + targetFailureStore = dataStream.isFailureStoreIndex(indexMetadata.getIndex().getName()); + if (isFailureStoreWriteIndex == false && dataStream.getWriteIndex().equals(indexMetadata.getIndex()) == false) { logger.warn( - "index [{}] is not the write index for data stream [{}]. skipping rollover for policy [{}]", + "index [{}] is not the {}write index for data stream [{}]. skipping rollover for policy [{}]", indexName, + targetFailureStore ? "failure store " : "", dataStream.getName(), indexMetadata.getLifecyclePolicyName() ); @@ -115,10 +119,18 @@ public void performAction( } rolloverTarget = rolloverAlias; + targetFailureStore = false; } // Calling rollover with no conditions will always roll over the index RolloverRequest rolloverRequest = new RolloverRequest(rolloverTarget, null).masterNodeTimeout(TimeValue.MAX_VALUE); + if (targetFailureStore) { + rolloverRequest.setIndicesOptions( + IndicesOptions.builder(rolloverRequest.indicesOptions()) + .failureStoreOptions(opts -> opts.includeFailureIndices(true).includeRegularIndices(false)) + .build() + ); + } // We don't wait for active shards when we perform the rollover because the // {@link org.elasticsearch.xpack.core.ilm.WaitForActiveShardsStep} step will do so rolloverRequest.setWaitForActiveShards(ActiveShardCount.NONE); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/WaitForActiveShardsStep.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/WaitForActiveShardsStep.java index b6cf8b0bdd663..71c99d7f21848 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/WaitForActiveShardsStep.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/WaitForActiveShardsStep.java @@ -30,7 +30,7 @@ import static org.elasticsearch.cluster.metadata.IndexMetadata.parseIndexNameCounter; /** - * After we performed the index rollover we wait for the the configured number of shards for the rolled over index (ie. newly created + * After we performed the index rollover we wait for the configured number of shards for the rolled over index (ie. newly created * index) to become available. */ public class WaitForActiveShardsStep extends ClusterStateWaitStep { @@ -84,10 +84,17 @@ public Result isConditionMet(Index index, ClusterState clusterState) { if (dataStream != null) { IndexAbstraction dataStreamAbstraction = metadata.getIndicesLookup().get(dataStream.getName()); assert dataStreamAbstraction != null : dataStream.getName() + " datastream is not present in the metadata indices lookup"; - if (dataStreamAbstraction.getWriteIndex() == null) { + // Determine which write index we care about right now: + final Index rolledIndex; + if (dataStream.isFailureStoreIndex(index.getName())) { + rolledIndex = dataStream.getFailureStoreWriteIndex(); + } else { + rolledIndex = dataStream.getWriteIndex(); + } + if (rolledIndex == null) { return getErrorResultOnNullMetadata(getKey(), index); } - IndexMetadata rolledIndexMeta = metadata.index(dataStreamAbstraction.getWriteIndex()); + IndexMetadata rolledIndexMeta = metadata.index(rolledIndex); rolledIndexName = rolledIndexMeta.getIndex().getName(); waitForActiveShardsSettingValue = rolledIndexMeta.getSettings().get(IndexMetadata.SETTING_WAIT_FOR_ACTIVE_SHARDS.getKey()); } else { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/WaitForRolloverReadyStep.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/WaitForRolloverReadyStep.java index acb36bd015e4b..7b751994222b1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/WaitForRolloverReadyStep.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/WaitForRolloverReadyStep.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.indices.rollover.RolloverConditions; import org.elasticsearch.action.admin.indices.rollover.RolloverRequest; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.DataStream; import org.elasticsearch.cluster.metadata.IndexAbstraction; @@ -83,13 +84,16 @@ public void evaluateCondition(Metadata metadata, Index index, Listener listener, IndexAbstraction indexAbstraction = metadata.getIndicesLookup().get(index.getName()); assert indexAbstraction != null : "invalid cluster metadata. index [" + index.getName() + "] was not found"; final String rolloverTarget; + final boolean targetFailureStore; DataStream dataStream = indexAbstraction.getParentDataStream(); if (dataStream != null) { - assert dataStream.getWriteIndex() != null : "datastream " + dataStream.getName() + " has no write index"; - if (dataStream.getWriteIndex().equals(index) == false) { + targetFailureStore = dataStream.isFailureStoreIndex(index.getName()); + boolean isFailureStoreWriteIndex = index.equals(dataStream.getFailureStoreWriteIndex()); + if (isFailureStoreWriteIndex == false && dataStream.getWriteIndex().equals(index) == false) { logger.warn( - "index [{}] is not the write index for data stream [{}]. skipping rollover for policy [{}]", + "index [{}] is not the {}write index for data stream [{}]. skipping rollover for policy [{}]", index.getName(), + targetFailureStore ? "failure store " : "", dataStream.getName(), metadata.index(index).getLifecyclePolicyName() ); @@ -194,12 +198,18 @@ public void evaluateCondition(Metadata metadata, Index index, Listener listener, } rolloverTarget = rolloverAlias; + targetFailureStore = false; } // if we should only rollover if not empty, *and* if neither an explicit min_docs nor an explicit min_primary_shard_docs // has been specified on this policy, then inject a default min_docs: 1 condition so that we do not rollover empty indices boolean rolloverOnlyIfHasDocuments = LifecycleSettings.LIFECYCLE_ROLLOVER_ONLY_IF_HAS_DOCUMENTS_SETTING.get(metadata.settings()); - RolloverRequest rolloverRequest = createRolloverRequest(rolloverTarget, masterTimeout, rolloverOnlyIfHasDocuments); + RolloverRequest rolloverRequest = createRolloverRequest( + rolloverTarget, + masterTimeout, + rolloverOnlyIfHasDocuments, + targetFailureStore + ); getClient().admin().indices().rolloverIndex(rolloverRequest, ActionListener.wrap(response -> { final var conditionStatus = response.getConditionStatus(); @@ -226,10 +236,22 @@ public void evaluateCondition(Metadata metadata, Index index, Listener listener, * @return A RolloverRequest suitable for passing to {@code rolloverIndex(...) }. */ // visible for testing - RolloverRequest createRolloverRequest(String rolloverTarget, TimeValue masterTimeout, boolean rolloverOnlyIfHasDocuments) { + RolloverRequest createRolloverRequest( + String rolloverTarget, + TimeValue masterTimeout, + boolean rolloverOnlyIfHasDocuments, + boolean targetFailureStore + ) { RolloverRequest rolloverRequest = new RolloverRequest(rolloverTarget, null).masterNodeTimeout(masterTimeout); rolloverRequest.dryRun(true); rolloverRequest.setConditions(applyDefaultConditions(conditions, rolloverOnlyIfHasDocuments)); + if (targetFailureStore) { + rolloverRequest.setIndicesOptions( + IndicesOptions.builder(rolloverRequest.indicesOptions()) + .failureStoreOptions(opts -> opts.includeFailureIndices(true).includeRegularIndices(false)) + .build() + ); + } return rolloverRequest; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CheckNoDataStreamWriteIndexStepTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CheckNoDataStreamWriteIndexStepTests.java index 33d571fbe8599..e0957239e33a8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CheckNoDataStreamWriteIndexStepTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/CheckNoDataStreamWriteIndexStepTests.java @@ -65,29 +65,45 @@ public void testStepCompleteIfIndexIsNotPartOfDataStream() { public void testStepIncompleteIfIndexIsTheDataStreamWriteIndex() { String dataStreamName = randomAlphaOfLength(10); - String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1); + long ts = System.currentTimeMillis(); + String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts); + String failureIndexName = DataStream.getDefaultFailureStoreName(dataStreamName, 1, ts); String policyName = "test-ilm-policy"; IndexMetadata indexMetadata = IndexMetadata.builder(indexName) .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); + IndexMetadata failureIndexMetadata = IndexMetadata.builder(failureIndexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); ClusterState clusterState = ClusterState.builder(emptyClusterState()) .metadata( - Metadata.builder().put(indexMetadata, true).put(newInstance(dataStreamName, List.of(indexMetadata.getIndex()))).build() + Metadata.builder() + .put(indexMetadata, true) + .put(failureIndexMetadata, true) + .put(newInstance(dataStreamName, List.of(indexMetadata.getIndex()), List.of(failureIndexMetadata.getIndex()))) + .build() ) .build(); - ClusterStateWaitStep.Result result = createRandomInstance().isConditionMet(indexMetadata.getIndex(), clusterState); + boolean useFailureStore = randomBoolean(); + IndexMetadata indexToOperateOn = useFailureStore ? failureIndexMetadata : indexMetadata; + String expectedIndexName = indexToOperateOn.getIndex().getName(); + ClusterStateWaitStep.Result result = createRandomInstance().isConditionMet(indexToOperateOn.getIndex(), clusterState); assertThat(result.isComplete(), is(false)); SingleMessageFieldInfo info = (SingleMessageFieldInfo) result.getInfomationContext(); assertThat( info.getMessage(), is( "index [" - + indexName - + "] is the write index for data stream [" + + expectedIndexName + + "] is the " + + (useFailureStore ? "failure store " : "") + + "write index for data stream [" + dataStreamName + "], " + "pausing ILM execution of lifecycle [" @@ -100,33 +116,51 @@ public void testStepIncompleteIfIndexIsTheDataStreamWriteIndex() { public void testStepCompleteIfPartOfDataStreamButNotWriteIndex() { String dataStreamName = randomAlphaOfLength(10); - String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1); + long ts = System.currentTimeMillis(); + String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts); + String failureIndexName = DataStream.getDefaultFailureStoreName(dataStreamName, 1, ts); String policyName = "test-ilm-policy"; IndexMetadata indexMetadata = IndexMetadata.builder(indexName) .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); + IndexMetadata failureIndexMetadata = IndexMetadata.builder(failureIndexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); - String writeIndexName = DataStream.getDefaultBackingIndexName(dataStreamName, 2); + String writeIndexName = DataStream.getDefaultBackingIndexName(dataStreamName, 2, ts); + String failureStoreWriteIndexName = DataStream.getDefaultFailureStoreName(dataStreamName, 2, ts); IndexMetadata writeIndexMetadata = IndexMetadata.builder(writeIndexName) .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); + IndexMetadata failureStoreWriteIndexMetadata = IndexMetadata.builder(failureStoreWriteIndexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); List backingIndices = List.of(indexMetadata.getIndex(), writeIndexMetadata.getIndex()); + List failureIndices = List.of(failureIndexMetadata.getIndex(), failureStoreWriteIndexMetadata.getIndex()); ClusterState clusterState = ClusterState.builder(emptyClusterState()) .metadata( Metadata.builder() .put(indexMetadata, true) .put(writeIndexMetadata, true) - .put(newInstance(dataStreamName, backingIndices)) + .put(failureIndexMetadata, true) + .put(failureStoreWriteIndexMetadata, true) + .put(newInstance(dataStreamName, backingIndices, failureIndices)) .build() ) .build(); - ClusterStateWaitStep.Result result = createRandomInstance().isConditionMet(indexMetadata.getIndex(), clusterState); + boolean useFailureStore = randomBoolean(); + IndexMetadata indexToOperateOn = useFailureStore ? failureIndexMetadata : indexMetadata; + ClusterStateWaitStep.Result result = createRandomInstance().isConditionMet(indexToOperateOn.getIndex(), clusterState); assertThat(result.isComplete(), is(true)); assertThat(result.getInfomationContext(), is(nullValue())); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/DeleteStepTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/DeleteStepTests.java index 7445e82da3ecf..af4dc67d5dcbd 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/DeleteStepTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/DeleteStepTests.java @@ -8,6 +8,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest; +import org.elasticsearch.action.datastreams.DeleteDataStreamAction; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.cluster.ClusterState; @@ -130,10 +131,11 @@ public void testPerformActionCallsFailureListenerIfIndexIsTheDataStreamWriteInde String policyName = "test-ilm-policy"; String dataStreamName = randomAlphaOfLength(10); + long ts = System.currentTimeMillis(); IndexMetadata index1; { - String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1); + String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts); index1 = IndexMetadata.builder(indexName) .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) .numberOfShards(randomIntBetween(1, 5)) @@ -142,25 +144,258 @@ public void testPerformActionCallsFailureListenerIfIndexIsTheDataStreamWriteInde } IndexMetadata sourceIndexMetadata; { - - String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 2); + String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 2, ts); sourceIndexMetadata = IndexMetadata.builder(indexName) .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); } + IndexMetadata failureIndex1; + { + String indexName = DataStream.getDefaultFailureStoreName(dataStreamName, 1, ts); + failureIndex1 = IndexMetadata.builder(indexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + } + IndexMetadata failureSourceIndexMetadata; + { + String indexName = DataStream.getDefaultFailureStoreName(dataStreamName, 2, ts); + failureSourceIndexMetadata = IndexMetadata.builder(indexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + } DataStream dataStream = DataStreamTestHelper.newInstance( dataStreamName, - List.of(index1.getIndex(), sourceIndexMetadata.getIndex()) + List.of(index1.getIndex(), sourceIndexMetadata.getIndex()), + List.of(failureIndex1.getIndex(), failureSourceIndexMetadata.getIndex()) + ); + ClusterState clusterState = ClusterState.builder(emptyClusterState()) + .metadata( + Metadata.builder() + .put(index1, false) + .put(sourceIndexMetadata, false) + .put(failureIndex1, false) + .put(failureSourceIndexMetadata, false) + .put(dataStream) + .build() + ) + .build(); + + AtomicBoolean listenerCalled = new AtomicBoolean(false); + final boolean useFailureStore = randomBoolean(); + final IndexMetadata indexToOperateOn = useFailureStore ? failureSourceIndexMetadata : sourceIndexMetadata; + createRandomInstance().performDuringNoSnapshot(indexToOperateOn, clusterState, new ActionListener<>() { + @Override + public void onResponse(Void complete) { + listenerCalled.set(true); + fail("unexpected listener callback"); + } + + @Override + public void onFailure(Exception e) { + listenerCalled.set(true); + assertThat( + e.getMessage(), + is( + "index [" + + indexToOperateOn.getIndex().getName() + + "] is the " + + (useFailureStore ? "failure store " : "") + + "write index for data stream [" + + dataStreamName + + "]. stopping execution of lifecycle [test-ilm-policy] as a data stream's write index cannot be deleted. " + + "manually rolling over the index will resume the execution of the policy as the index will not be the " + + "data stream's write index anymore" + ) + ); + } + }); + + assertThat(listenerCalled.get(), is(true)); + } + + public void testDeleteWorksIfWriteIndexIsTheOnlyIndexInDataStream() throws Exception { + String policyName = "test-ilm-policy"; + String dataStreamName = randomAlphaOfLength(10); + long ts = System.currentTimeMillis(); + + // Single backing index + IndexMetadata index1; + { + String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts); + index1 = IndexMetadata.builder(indexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + } + + DataStream dataStream = DataStreamTestHelper.newInstance(dataStreamName, List.of(index1.getIndex()), List.of()); + + ClusterState clusterState = ClusterState.builder(emptyClusterState()) + .metadata(Metadata.builder().put(index1, false).put(dataStream).build()) + .build(); + + Mockito.doAnswer(invocation -> { + DeleteDataStreamAction.Request request = (DeleteDataStreamAction.Request) invocation.getArguments()[1]; + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + assertNotNull(request); + assertEquals(1, request.getNames().length); + assertEquals(dataStreamName, request.getNames()[0]); + listener.onResponse(null); + return null; + }).when(client).execute(any(), any(), any()); + + // Try on the normal data stream - It should delete the data stream + DeleteStep step = createRandomInstance(); + PlainActionFuture.get(f -> step.performAction(index1, clusterState, null, f)); + + Mockito.verify(client, Mockito.only()).execute(any(), any(), any()); + Mockito.verify(adminClient, Mockito.never()).indices(); + Mockito.verify(indicesClient, Mockito.never()).delete(any(), any()); + } + + public void testDeleteWorksIfWriteIndexIsTheOnlyIndexInDataStreamWithFailureStore() throws Exception { + String policyName = "test-ilm-policy"; + String dataStreamName = randomAlphaOfLength(10); + long ts = System.currentTimeMillis(); + + // Single backing index + IndexMetadata index1; + { + String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts); + index1 = IndexMetadata.builder(indexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + } + + // Multiple failure indices + IndexMetadata failureIndex1; + { + String indexName = DataStream.getDefaultFailureStoreName(dataStreamName, 1, ts); + failureIndex1 = IndexMetadata.builder(indexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + } + IndexMetadata failureSourceIndexMetadata; + { + String indexName = DataStream.getDefaultFailureStoreName(dataStreamName, 2, ts); + failureSourceIndexMetadata = IndexMetadata.builder(indexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + } + + DataStream dataStreamWithFailureIndices = DataStreamTestHelper.newInstance( + dataStreamName, + List.of(index1.getIndex()), + List.of(failureIndex1.getIndex(), failureSourceIndexMetadata.getIndex()) ); + ClusterState clusterState = ClusterState.builder(emptyClusterState()) - .metadata(Metadata.builder().put(index1, false).put(sourceIndexMetadata, false).put(dataStream).build()) + .metadata( + Metadata.builder() + .put(index1, false) + .put(failureIndex1, false) + .put(failureSourceIndexMetadata, false) + .put(dataStreamWithFailureIndices) + .build() + ) + .build(); + + Mockito.doAnswer(invocation -> { + DeleteDataStreamAction.Request request = (DeleteDataStreamAction.Request) invocation.getArguments()[1]; + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + assertNotNull(request); + assertEquals(1, request.getNames().length); + assertEquals(dataStreamName, request.getNames()[0]); + listener.onResponse(null); + return null; + }).when(client).execute(any(), any(), any()); + + // Again, the deletion should work since the data stream would be fully deleted anyway if the failure store were disabled. + DeleteStep step = createRandomInstance(); + PlainActionFuture.get(f -> step.performAction(index1, clusterState, null, f)); + + Mockito.verify(client, Mockito.only()).execute(any(), any(), any()); + Mockito.verify(adminClient, Mockito.never()).indices(); + Mockito.verify(indicesClient, Mockito.never()).delete(any(), any()); + } + + public void testDeletingFailureStoreWriteIndexOnDataStreamWithSingleBackingIndex() { + doThrow( + new IllegalStateException( + "the client must not be called in this test as we should fail in the step validation phase before we call the delete API" + ) + ).when(indicesClient).delete(any(DeleteIndexRequest.class), anyActionListener()); + + String policyName = "test-ilm-policy"; + String dataStreamName = randomAlphaOfLength(10); + long ts = System.currentTimeMillis(); + + // Single backing index + IndexMetadata index1; + { + String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts); + index1 = IndexMetadata.builder(indexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + } + + // Multiple failure indices + IndexMetadata failureIndex1; + { + String indexName = DataStream.getDefaultFailureStoreName(dataStreamName, 1, ts); + failureIndex1 = IndexMetadata.builder(indexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + } + IndexMetadata failureSourceIndexMetadata; + { + String indexName = DataStream.getDefaultFailureStoreName(dataStreamName, 2, ts); + failureSourceIndexMetadata = IndexMetadata.builder(indexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + } + + DataStream dataStreamWithFailureIndices = DataStreamTestHelper.newInstance( + dataStreamName, + List.of(index1.getIndex()), + List.of(failureIndex1.getIndex(), failureSourceIndexMetadata.getIndex()) + ); + + ClusterState clusterState = ClusterState.builder(emptyClusterState()) + .metadata( + Metadata.builder() + .put(index1, false) + .put(failureIndex1, false) + .put(failureSourceIndexMetadata, false) + .put(dataStreamWithFailureIndices) + .build() + ) .build(); AtomicBoolean listenerCalled = new AtomicBoolean(false); - createRandomInstance().performDuringNoSnapshot(sourceIndexMetadata, clusterState, new ActionListener<>() { + createRandomInstance().performDuringNoSnapshot(failureSourceIndexMetadata, clusterState, new ActionListener<>() { @Override public void onResponse(Void complete) { listenerCalled.set(true); @@ -174,8 +409,8 @@ public void onFailure(Exception e) { e.getMessage(), is( "index [" - + sourceIndexMetadata.getIndex().getName() - + "] is the write index for data stream [" + + failureSourceIndexMetadata.getIndex().getName() + + "] is the failure store write index for data stream [" + dataStreamName + "]. stopping execution of lifecycle [test-ilm-policy] as a data stream's write index cannot be deleted. " + "manually rolling over the index will resume the execution of the policy as the index will not be the " diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/ReplaceDataStreamBackingIndexStepTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/ReplaceDataStreamBackingIndexStepTests.java index 2a49be703574b..a3318e68305c6 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/ReplaceDataStreamBackingIndexStepTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/ReplaceDataStreamBackingIndexStepTests.java @@ -68,55 +68,85 @@ public void testPerformActionThrowsExceptionIfIndexIsNotPartOfDataStream() { public void testPerformActionThrowsExceptionIfIndexIsTheDataStreamWriteIndex() { String dataStreamName = randomAlphaOfLength(10); - String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1); + long ts = System.currentTimeMillis(); + String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts); + String failureIndexName = DataStream.getDefaultFailureStoreName(dataStreamName, 1, ts); String policyName = "test-ilm-policy"; IndexMetadata sourceIndexMetadata = IndexMetadata.builder(indexName) .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); + IndexMetadata failureSourceIndexMetadata = IndexMetadata.builder(failureIndexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); ClusterState clusterState = ClusterState.builder(emptyClusterState()) .metadata( Metadata.builder() .put(sourceIndexMetadata, true) - .put(newInstance(dataStreamName, List.of(sourceIndexMetadata.getIndex()))) + .put(failureSourceIndexMetadata, true) + .put( + newInstance(dataStreamName, List.of(sourceIndexMetadata.getIndex()), List.of(failureSourceIndexMetadata.getIndex())) + ) .build() ) .build(); - expectThrows(IllegalStateException.class, () -> createRandomInstance().performAction(sourceIndexMetadata.getIndex(), clusterState)); + boolean useFailureStore = randomBoolean(); + IndexMetadata indexToOperateOn = useFailureStore ? failureSourceIndexMetadata : sourceIndexMetadata; + expectThrows(IllegalStateException.class, () -> createRandomInstance().performAction(indexToOperateOn.getIndex(), clusterState)); } public void testPerformActionThrowsExceptionIfTargetIndexIsMissing() { String dataStreamName = randomAlphaOfLength(10); - String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1); + long ts = System.currentTimeMillis(); + String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts); + String failureIndexName = DataStream.getDefaultFailureStoreName(dataStreamName, 1, ts); String policyName = "test-ilm-policy"; IndexMetadata sourceIndexMetadata = IndexMetadata.builder(indexName) .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); + IndexMetadata failureSourceIndexMetadata = IndexMetadata.builder(failureIndexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); - String writeIndexName = DataStream.getDefaultBackingIndexName(dataStreamName, 2); + String writeIndexName = DataStream.getDefaultBackingIndexName(dataStreamName, 2, ts); + String failureWriteIndexName = DataStream.getDefaultFailureStoreName(dataStreamName, 2, ts); IndexMetadata writeIndexMetadata = IndexMetadata.builder(writeIndexName) .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); + IndexMetadata failureWriteIndexMetadata = IndexMetadata.builder(failureWriteIndexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); List backingIndices = List.of(sourceIndexMetadata.getIndex(), writeIndexMetadata.getIndex()); + List failureIndices = List.of(failureSourceIndexMetadata.getIndex(), failureWriteIndexMetadata.getIndex()); ClusterState clusterState = ClusterState.builder(emptyClusterState()) .metadata( Metadata.builder() .put(sourceIndexMetadata, true) .put(writeIndexMetadata, true) - .put(newInstance(dataStreamName, backingIndices)) + .put(failureSourceIndexMetadata, true) + .put(failureWriteIndexMetadata, true) + .put(newInstance(dataStreamName, backingIndices, failureIndices)) .build() ) .build(); - expectThrows(IllegalStateException.class, () -> createRandomInstance().performAction(sourceIndexMetadata.getIndex(), clusterState)); + boolean useFailureStore = randomBoolean(); + IndexMetadata indexToOperateOn = useFailureStore ? failureSourceIndexMetadata : sourceIndexMetadata; + expectThrows(IllegalStateException.class, () -> createRandomInstance().performAction(indexToOperateOn.getIndex(), clusterState)); } public void testPerformActionIsNoOpIfIndexIsMissing() { @@ -129,23 +159,39 @@ public void testPerformActionIsNoOpIfIndexIsMissing() { public void testPerformAction() { String dataStreamName = randomAlphaOfLength(10); - String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1); + long ts = System.currentTimeMillis(); + String indexName = DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts); + String failureIndexName = DataStream.getDefaultFailureStoreName(dataStreamName, 1, ts); String policyName = "test-ilm-policy"; IndexMetadata sourceIndexMetadata = IndexMetadata.builder(indexName) .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); + IndexMetadata failureSourceIndexMetadata = IndexMetadata.builder(failureIndexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); - String writeIndexName = DataStream.getDefaultBackingIndexName(dataStreamName, 2); + String writeIndexName = DataStream.getDefaultBackingIndexName(dataStreamName, 2, ts); + String failureWriteIndexName = DataStream.getDefaultFailureStoreName(dataStreamName, 2, ts); IndexMetadata writeIndexMetadata = IndexMetadata.builder(writeIndexName) .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); + IndexMetadata failureWriteIndexMetadata = IndexMetadata.builder(failureWriteIndexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + + boolean useFailureStore = randomBoolean(); + String indexNameToUse = useFailureStore ? failureIndexName : indexName; String indexPrefix = "test-prefix-"; - String targetIndex = indexPrefix + indexName; + String targetIndex = indexPrefix + indexNameToUse; IndexMetadata targetIndexMetadata = IndexMetadata.builder(targetIndex) .settings(settings(IndexVersion.current())) @@ -154,12 +200,15 @@ public void testPerformAction() { .build(); List backingIndices = List.of(sourceIndexMetadata.getIndex(), writeIndexMetadata.getIndex()); + List failureIndices = List.of(failureSourceIndexMetadata.getIndex(), failureWriteIndexMetadata.getIndex()); ClusterState clusterState = ClusterState.builder(emptyClusterState()) .metadata( Metadata.builder() .put(sourceIndexMetadata, true) .put(writeIndexMetadata, true) - .put(newInstance(dataStreamName, backingIndices)) + .put(failureSourceIndexMetadata, true) + .put(failureWriteIndexMetadata, true) + .put(newInstance(dataStreamName, backingIndices, failureIndices)) .put(targetIndexMetadata, true) .build() ) @@ -168,12 +217,16 @@ public void testPerformAction() { ReplaceDataStreamBackingIndexStep replaceSourceIndexStep = new ReplaceDataStreamBackingIndexStep( randomStepKey(), randomStepKey(), - (index, state) -> indexPrefix + index + (index, state) -> indexPrefix + indexNameToUse ); - ClusterState newState = replaceSourceIndexStep.performAction(sourceIndexMetadata.getIndex(), clusterState); + IndexMetadata indexToOperateOn = useFailureStore ? failureSourceIndexMetadata : sourceIndexMetadata; + ClusterState newState = replaceSourceIndexStep.performAction(indexToOperateOn.getIndex(), clusterState); DataStream updatedDataStream = newState.metadata().dataStreams().get(dataStreamName); - assertThat(updatedDataStream.getIndices().size(), is(2)); - assertThat(updatedDataStream.getIndices().get(0), is(targetIndexMetadata.getIndex())); + DataStream.DataStreamIndices resultIndices = useFailureStore + ? updatedDataStream.getFailureIndices() + : updatedDataStream.getBackingIndices(); + assertThat(resultIndices.getIndices().size(), is(2)); + assertThat(resultIndices.getIndices().get(0), is(targetIndexMetadata.getIndex())); } /** @@ -181,23 +234,38 @@ public void testPerformAction() { */ public void testPerformActionSameOriginalTargetError() { String dataStreamName = randomAlphaOfLength(10); - String writeIndexName = DataStream.getDefaultBackingIndexName(dataStreamName, 2); + long ts = System.currentTimeMillis(); + String writeIndexName = DataStream.getDefaultBackingIndexName(dataStreamName, 2, ts); + String failureWriteIndexName = DataStream.getDefaultFailureStoreName(dataStreamName, 2, ts); String indexName = writeIndexName; + String failureIndexName = failureWriteIndexName; String policyName = "test-ilm-policy"; IndexMetadata sourceIndexMetadata = IndexMetadata.builder(indexName) .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); + IndexMetadata failureSourceIndexMetadata = IndexMetadata.builder(failureIndexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); IndexMetadata writeIndexMetadata = IndexMetadata.builder(writeIndexName) .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); + IndexMetadata failureWriteIndexMetadata = IndexMetadata.builder(failureWriteIndexName) + .settings(settings(IndexVersion.current()).put(LifecycleSettings.LIFECYCLE_NAME, policyName)) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); String indexPrefix = "test-prefix-"; - String targetIndex = indexPrefix + indexName; + boolean useFailureStore = randomBoolean(); + String indexNameToUse = useFailureStore ? failureIndexName : indexName; + String targetIndex = indexPrefix + indexNameToUse; IndexMetadata targetIndexMetadata = IndexMetadata.builder(targetIndex) .settings(settings(IndexVersion.current())) @@ -206,12 +274,15 @@ public void testPerformActionSameOriginalTargetError() { .build(); List backingIndices = List.of(writeIndexMetadata.getIndex()); + List failureIndices = List.of(failureWriteIndexMetadata.getIndex()); ClusterState clusterState = ClusterState.builder(emptyClusterState()) .metadata( Metadata.builder() .put(sourceIndexMetadata, true) .put(writeIndexMetadata, true) - .put(newInstance(dataStreamName, backingIndices)) + .put(failureSourceIndexMetadata, true) + .put(failureWriteIndexMetadata, true) + .put(newInstance(dataStreamName, backingIndices, failureIndices)) .put(targetIndexMetadata, true) .build() ) @@ -222,14 +293,17 @@ public void testPerformActionSameOriginalTargetError() { randomStepKey(), (index, state) -> indexPrefix + index ); + IndexMetadata indexToOperateOn = useFailureStore ? failureSourceIndexMetadata : sourceIndexMetadata; IllegalStateException ex = expectThrows( IllegalStateException.class, - () -> replaceSourceIndexStep.performAction(sourceIndexMetadata.getIndex(), clusterState) + () -> replaceSourceIndexStep.performAction(indexToOperateOn.getIndex(), clusterState) ); assertEquals( "index [" - + writeIndexName - + "] is the write index for data stream [" + + indexNameToUse + + "] is the " + + (useFailureStore ? "failure store " : "") + + "write index for data stream [" + dataStreamName + "], pausing ILM execution of lifecycle [test-ilm-policy] until this index is no longer the write index for the data " + "stream via manual or automated rollover", diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/RolloverStepTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/RolloverStepTests.java index 1fcfc1fb287c4..f25a862362540 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/RolloverStepTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/RolloverStepTests.java @@ -96,7 +96,13 @@ public void testPerformAction() throws Exception { public void testPerformActionOnDataStream() throws Exception { String dataStreamName = "test-datastream"; - IndexMetadata indexMetadata = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 1)) + long ts = System.currentTimeMillis(); + IndexMetadata indexMetadata = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts)) + .settings(settings(IndexVersion.current())) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + IndexMetadata failureIndexMetadata = IndexMetadata.builder(DataStream.getDefaultFailureStoreName(dataStreamName, 1, ts)) .settings(settings(IndexVersion.current())) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) @@ -107,9 +113,16 @@ public void testPerformActionOnDataStream() throws Exception { mockClientRolloverCall(dataStreamName); ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT) - .metadata(Metadata.builder().put(newInstance(dataStreamName, List.of(indexMetadata.getIndex()))).put(indexMetadata, true)) + .metadata( + Metadata.builder() + .put(newInstance(dataStreamName, List.of(indexMetadata.getIndex()), List.of(failureIndexMetadata.getIndex()))) + .put(indexMetadata, true) + .put(failureIndexMetadata, true) + ) .build(); - PlainActionFuture.get(f -> step.performAction(indexMetadata, clusterState, null, f)); + boolean useFailureStore = randomBoolean(); + IndexMetadata indexToOperateOn = useFailureStore ? failureIndexMetadata : indexMetadata; + PlainActionFuture.get(f -> step.performAction(indexToOperateOn, clusterState, null, f)); Mockito.verify(client, Mockito.only()).admin(); Mockito.verify(adminClient, Mockito.only()).indices(); @@ -118,13 +131,24 @@ public void testPerformActionOnDataStream() throws Exception { public void testSkipRolloverIfDataStreamIsAlreadyRolledOver() throws Exception { String dataStreamName = "test-datastream"; - IndexMetadata firstGenerationIndex = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 1)) + long ts = System.currentTimeMillis(); + IndexMetadata firstGenerationIndex = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts)) + .settings(settings(IndexVersion.current())) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + IndexMetadata failureFirstGenerationIndex = IndexMetadata.builder(DataStream.getDefaultFailureStoreName(dataStreamName, 1, ts)) .settings(settings(IndexVersion.current())) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); - IndexMetadata writeIndex = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 2)) + IndexMetadata writeIndex = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 2, ts)) + .settings(settings(IndexVersion.current())) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + IndexMetadata failureWriteIndex = IndexMetadata.builder(DataStream.getDefaultFailureStoreName(dataStreamName, 2, ts)) .settings(settings(IndexVersion.current())) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) @@ -136,10 +160,20 @@ public void testSkipRolloverIfDataStreamIsAlreadyRolledOver() throws Exception { Metadata.builder() .put(firstGenerationIndex, true) .put(writeIndex, true) - .put(newInstance(dataStreamName, List.of(firstGenerationIndex.getIndex(), writeIndex.getIndex()))) + .put(failureFirstGenerationIndex, true) + .put(failureWriteIndex, true) + .put( + newInstance( + dataStreamName, + List.of(firstGenerationIndex.getIndex(), writeIndex.getIndex()), + List.of(failureFirstGenerationIndex.getIndex(), failureWriteIndex.getIndex()) + ) + ) ) .build(); - PlainActionFuture.get(f -> step.performAction(firstGenerationIndex, clusterState, null, f)); + boolean useFailureStore = randomBoolean(); + IndexMetadata indexToOperateOn = useFailureStore ? failureFirstGenerationIndex : firstGenerationIndex; + PlainActionFuture.get(f -> step.performAction(indexToOperateOn, clusterState, null, f)); verifyNoMoreInteractions(client); verifyNoMoreInteractions(adminClient); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/WaitForActiveShardsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/WaitForActiveShardsTests.java index d9fd2d8a2247e..f5f36781e011b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/WaitForActiveShardsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/WaitForActiveShardsTests.java @@ -170,13 +170,24 @@ public void testResultEvaluatedOnOnlyIndexTheAliasPointsToIfWriteIndexIsNull() { public void testResultEvaluatedOnDataStream() throws IOException { String dataStreamName = "test-datastream"; - IndexMetadata originalIndexMeta = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 1)) + long ts = System.currentTimeMillis(); + IndexMetadata originalIndexMeta = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts)) + .settings(settings(IndexVersion.current())) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + IndexMetadata failureOriginalIndexMeta = IndexMetadata.builder(DataStream.getDefaultFailureStoreName(dataStreamName, 1, ts)) .settings(settings(IndexVersion.current())) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); - IndexMetadata rolledIndexMeta = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 2)) + IndexMetadata rolledIndexMeta = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 2, ts)) + .settings(settings(IndexVersion.current()).put("index.write.wait_for_active_shards", "3")) + .numberOfShards(1) + .numberOfReplicas(3) + .build(); + IndexMetadata failureRolledIndexMeta = IndexMetadata.builder(DataStream.getDefaultFailureStoreName(dataStreamName, 2, ts)) .settings(settings(IndexVersion.current()).put("index.write.wait_for_active_shards", "3")) .numberOfShards(1) .numberOfReplicas(3) @@ -186,28 +197,53 @@ public void testResultEvaluatedOnDataStream() throws IOException { ShardRoutingRoleStrategy.NO_SHARD_CREATION, rolledIndexMeta.getIndex() ); + IndexRoutingTable.Builder failureRoutingTable = new IndexRoutingTable.Builder( + ShardRoutingRoleStrategy.NO_SHARD_CREATION, + failureRolledIndexMeta.getIndex() + ); routingTable.addShard( TestShardRouting.newShardRouting(rolledIndexMeta.getIndex().getName(), 0, "node", null, true, ShardRoutingState.STARTED) ); routingTable.addShard( TestShardRouting.newShardRouting(rolledIndexMeta.getIndex().getName(), 0, "node2", null, false, ShardRoutingState.STARTED) ); + failureRoutingTable.addShard( + TestShardRouting.newShardRouting(failureRolledIndexMeta.getIndex().getName(), 0, "node", null, true, ShardRoutingState.STARTED) + ); + failureRoutingTable.addShard( + TestShardRouting.newShardRouting( + failureRolledIndexMeta.getIndex().getName(), + 0, + "node2", + null, + false, + ShardRoutingState.STARTED + ) + ); ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT) .metadata( Metadata.builder() .put( - DataStreamTestHelper.newInstance(dataStreamName, List.of(originalIndexMeta.getIndex(), rolledIndexMeta.getIndex())) + DataStreamTestHelper.newInstance( + dataStreamName, + List.of(originalIndexMeta.getIndex(), rolledIndexMeta.getIndex()), + List.of(failureOriginalIndexMeta.getIndex(), failureRolledIndexMeta.getIndex()) + ) ) .put(originalIndexMeta, true) .put(rolledIndexMeta, true) + .put(failureOriginalIndexMeta, true) + .put(failureRolledIndexMeta, true) ) - .routingTable(RoutingTable.builder().add(routingTable.build()).build()) + .routingTable(RoutingTable.builder().add(routingTable.build()).add(failureRoutingTable.build()).build()) .build(); WaitForActiveShardsStep waitForActiveShardsStep = createRandomInstance(); - ClusterStateWaitStep.Result result = waitForActiveShardsStep.isConditionMet(originalIndexMeta.getIndex(), clusterState); + boolean useFailureStore = randomBoolean(); + IndexMetadata indexToOperateOn = useFailureStore ? failureOriginalIndexMeta : originalIndexMeta; + ClusterStateWaitStep.Result result = waitForActiveShardsStep.isConditionMet(indexToOperateOn.getIndex(), clusterState); assertThat(result.isComplete(), is(false)); XContentBuilder expected = new WaitForActiveShardsStep.ActiveShardsInfo(2, "3", false).toXContent( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/WaitForRolloverReadyStepTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/WaitForRolloverReadyStepTests.java index 2d39d093d149e..15958e9396d81 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/WaitForRolloverReadyStepTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ilm/WaitForRolloverReadyStepTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.index.IndexVersion; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.ToXContentObject; +import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import java.util.Collections; @@ -47,6 +48,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -254,7 +256,14 @@ public void onFailure(Exception e) { public void testEvaluateConditionOnDataStreamTarget() { String dataStreamName = "test-datastream"; - IndexMetadata indexMetadata = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 1)) + long ts = System.currentTimeMillis(); + boolean failureStoreIndex = randomBoolean(); + IndexMetadata indexMetadata = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts)) + .settings(settings(IndexVersion.current())) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + IndexMetadata failureStoreMetadata = IndexMetadata.builder(DataStream.getDefaultFailureStoreName(dataStreamName, 1, ts)) .settings(settings(IndexVersion.current())) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) @@ -267,9 +276,17 @@ public void testEvaluateConditionOnDataStreamTarget() { SetOnce conditionsMet = new SetOnce<>(); Metadata metadata = Metadata.builder() .put(indexMetadata, true) - .put(DataStreamTestHelper.newInstance(dataStreamName, List.of(indexMetadata.getIndex()))) + .put(failureStoreMetadata, true) + .put( + DataStreamTestHelper.newInstance( + dataStreamName, + List.of(indexMetadata.getIndex()), + List.of(failureStoreMetadata.getIndex()) + ) + ) .build(); - step.evaluateCondition(metadata, indexMetadata.getIndex(), new AsyncWaitStep.Listener() { + IndexMetadata indexToOperateOn = failureStoreIndex ? failureStoreMetadata : indexMetadata; + step.evaluateCondition(metadata, indexToOperateOn.getIndex(), new AsyncWaitStep.Listener() { @Override public void onResponse(boolean complete, ToXContentObject infomationContext) { @@ -286,18 +303,38 @@ public void onFailure(Exception e) { verify(client, Mockito.only()).admin(); verify(adminClient, Mockito.only()).indices(); - verify(indicesClient, Mockito.only()).rolloverIndex(Mockito.any(), Mockito.any()); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(RolloverRequest.class); + verify(indicesClient, Mockito.only()).rolloverIndex(requestCaptor.capture(), Mockito.any()); + + RolloverRequest request = requestCaptor.getValue(); + assertThat(request.indicesOptions().failureStoreOptions().includeFailureIndices(), equalTo(failureStoreIndex)); + assertThat(request.indicesOptions().failureStoreOptions().includeRegularIndices(), not(equalTo(failureStoreIndex))); } public void testSkipRolloverIfDataStreamIsAlreadyRolledOver() { String dataStreamName = "test-datastream"; - IndexMetadata firstGenerationIndex = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 1)) + long ts = System.currentTimeMillis(); + boolean failureStoreIndex = randomBoolean(); + IndexMetadata firstGenerationIndex = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 1, ts)) .settings(settings(IndexVersion.current())) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) .build(); - IndexMetadata writeIndex = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 2)) + IndexMetadata writeIndex = IndexMetadata.builder(DataStream.getDefaultBackingIndexName(dataStreamName, 2, ts)) + .settings(settings(IndexVersion.current())) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + + IndexMetadata firstGenerationFailureIndex = IndexMetadata.builder(DataStream.getDefaultFailureStoreName(dataStreamName, 1, ts)) + .settings(settings(IndexVersion.current())) + .numberOfShards(randomIntBetween(1, 5)) + .numberOfReplicas(randomIntBetween(0, 5)) + .build(); + + IndexMetadata writeFailureIndex = IndexMetadata.builder(DataStream.getDefaultFailureStoreName(dataStreamName, 2, ts)) .settings(settings(IndexVersion.current())) .numberOfShards(randomIntBetween(1, 5)) .numberOfReplicas(randomIntBetween(0, 5)) @@ -308,9 +345,18 @@ public void testSkipRolloverIfDataStreamIsAlreadyRolledOver() { Metadata metadata = Metadata.builder() .put(firstGenerationIndex, true) .put(writeIndex, true) - .put(DataStreamTestHelper.newInstance(dataStreamName, List.of(firstGenerationIndex.getIndex(), writeIndex.getIndex()))) + .put(firstGenerationFailureIndex, true) + .put(writeFailureIndex, true) + .put( + DataStreamTestHelper.newInstance( + dataStreamName, + List.of(firstGenerationIndex.getIndex(), writeIndex.getIndex()), + List.of(firstGenerationFailureIndex.getIndex(), writeFailureIndex.getIndex()) + ) + ) .build(); - step.evaluateCondition(metadata, firstGenerationIndex.getIndex(), new AsyncWaitStep.Listener() { + IndexMetadata indexToOperateOn = failureStoreIndex ? firstGenerationFailureIndex : firstGenerationIndex; + step.evaluateCondition(metadata, indexToOperateOn.getIndex(), new AsyncWaitStep.Listener() { @Override public void onResponse(boolean complete, ToXContentObject infomationContext) { @@ -665,7 +711,7 @@ public void testCreateRolloverRequestRolloverOnlyIfHasDocuments() { String rolloverTarget = randomAlphaOfLength(5); TimeValue masterTimeout = randomPositiveTimeValue(); - RolloverRequest request = step.createRolloverRequest(rolloverTarget, masterTimeout, rolloverOnlyIfHasDocuments); + RolloverRequest request = step.createRolloverRequest(rolloverTarget, masterTimeout, rolloverOnlyIfHasDocuments, false); assertThat(request.getRolloverTarget(), is(rolloverTarget)); assertThat(request.masterNodeTimeout(), is(masterTimeout)); @@ -704,7 +750,7 @@ public void testCreateRolloverRequestRolloverBeyondMaximumPrimaryShardDocCount() c.getMinDocs(), c.getMinPrimaryShardDocs() ); - RolloverRequest request = step.createRolloverRequest(rolloverTarget, masterTimeout, true); + RolloverRequest request = step.createRolloverRequest(rolloverTarget, masterTimeout, true, false); assertThat(request.getRolloverTarget(), is(rolloverTarget)); assertThat(request.masterNodeTimeout(), is(masterTimeout)); assertThat(request.isDryRun(), is(true)); // it's always a dry_run @@ -725,7 +771,7 @@ public void testCreateRolloverRequestRolloverBeyondMaximumPrimaryShardDocCount() c.getMinDocs(), c.getMinPrimaryShardDocs() ); - request = step.createRolloverRequest(rolloverTarget, masterTimeout, true); + request = step.createRolloverRequest(rolloverTarget, masterTimeout, true, false); assertThat(request.getRolloverTarget(), is(rolloverTarget)); assertThat(request.masterNodeTimeout(), is(masterTimeout)); assertThat(request.isDryRun(), is(true)); // it's always a dry_run @@ -747,7 +793,7 @@ public void testCreateRolloverRequestRolloverBeyondMaximumPrimaryShardDocCount() c.getMinDocs(), c.getMinPrimaryShardDocs() ); - request = step.createRolloverRequest(rolloverTarget, masterTimeout, true); + request = step.createRolloverRequest(rolloverTarget, masterTimeout, true, false); assertThat(request.getRolloverTarget(), is(rolloverTarget)); assertThat(request.masterNodeTimeout(), is(masterTimeout)); assertThat(request.isDryRun(), is(true)); // it's always a dry_run From 1558bb8f54f6054bf9a12151094a1bafb43e9153 Mon Sep 17 00:00:00 2001 From: Mark Tozzi Date: Wed, 5 Jun 2024 17:02:35 -0400 Subject: [PATCH 24/30] Remove unused core expressions (#109406) Removes the leaf expressions for arithmetic operations from esql-core. These aren't used in ESQL, and we don't want to refer to them by accident. Earlier work in https://github.com/elastic/elasticsearch/pull/109216 migrated the remaining references to these classes to refer to their esql versions. --- .../predicate/operator/arithmetic/Add.java | 45 ------------- .../DateTimeArithmeticOperation.java | 39 ----------- .../predicate/operator/arithmetic/Div.java | 53 --------------- .../predicate/operator/arithmetic/Mod.java | 34 ---------- .../predicate/operator/arithmetic/Mul.java | 66 ------------------- .../predicate/operator/arithmetic/Sub.java | 36 ---------- 6 files changed, 273 deletions(-) delete mode 100644 x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Add.java delete mode 100644 x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java delete mode 100644 x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Div.java delete mode 100644 x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Mod.java delete mode 100644 x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Mul.java delete mode 100644 x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Sub.java diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Add.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Add.java deleted file mode 100644 index 5b16b478f6519..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Add.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic; - -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.tree.NodeInfo; -import org.elasticsearch.xpack.esql.core.tree.Source; - -/** - * Addition function ({@code a + b}). - */ -public class Add extends DateTimeArithmeticOperation implements BinaryComparisonInversible { - public Add(Source source, Expression left, Expression right) { - super(source, left, right, DefaultBinaryArithmeticOperation.ADD); - } - - @Override - protected NodeInfo info() { - return NodeInfo.create(this, Add::new, left(), right()); - } - - @Override - protected Add replaceChildren(Expression left, Expression right) { - return new Add(source(), left, right); - } - - @Override - public Add swapLeftAndRight() { - return new Add(source(), right(), left()); - } - - @Override - public ArithmeticOperationFactory binaryComparisonInverse() { - return Sub::new; - } - - @Override - protected boolean isCommutative() { - return true; - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java deleted file mode 100644 index 9e08cea749a34..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic; - -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.core.type.DataType; - -abstract class DateTimeArithmeticOperation extends ArithmeticOperation { - - DateTimeArithmeticOperation(Source source, Expression left, Expression right, BinaryArithmeticOperation operation) { - super(source, left, right, operation); - } - - @Override - protected TypeResolution resolveType() { - if (childrenResolved() == false) { - return new TypeResolution("Unresolved children"); - } - - // arithmetic operation can work on numbers in QL - - DataType l = left().dataType(); - DataType r = right().dataType(); - - // 1. both are numbers - if (l.isNumeric() && r.isNumeric()) { - return TypeResolution.TYPE_RESOLVED; - } - - // fall-back to default checks - return super.resolveType(); - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Div.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Div.java deleted file mode 100644 index 5f4c660479579..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Div.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic; - -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.tree.NodeInfo; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; - -/** - * Division function ({@code a / b}). - */ -public class Div extends ArithmeticOperation implements BinaryComparisonInversible { - - private DataType dataType; - - public Div(Source source, Expression left, Expression right) { - this(source, left, right, null); - } - - public Div(Source source, Expression left, Expression right, DataType dataType) { - super(source, left, right, DefaultBinaryArithmeticOperation.DIV); - this.dataType = dataType; - } - - @Override - protected NodeInfo
info() { - return NodeInfo.create(this, Div::new, left(), right(), dataType); - } - - @Override - protected Div replaceChildren(Expression newLeft, Expression newRight) { - return new Div(source(), newLeft, newRight, dataType); - } - - @Override - public DataType dataType() { - if (dataType == null) { - dataType = DataTypeConverter.commonType(left().dataType(), right().dataType()); - } - return dataType; - } - - @Override - public ArithmeticOperationFactory binaryComparisonInverse() { - return Mul::new; - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Mod.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Mod.java deleted file mode 100644 index dea7d4e02e0b3..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Mod.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic; - -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.tree.NodeInfo; -import org.elasticsearch.xpack.esql.core.tree.Source; - -/** - * Modulo - * function ({@code a % b}). - * - * Note this operator is also registered as a function (needed for ODBC/SQL) purposes. - */ -public class Mod extends ArithmeticOperation { - - public Mod(Source source, Expression left, Expression right) { - super(source, left, right, DefaultBinaryArithmeticOperation.MOD); - } - - @Override - protected NodeInfo info() { - return NodeInfo.create(this, Mod::new, left(), right()); - } - - @Override - protected Mod replaceChildren(Expression newLeft, Expression newRight) { - return new Mod(source(), newLeft, newRight); - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Mul.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Mul.java deleted file mode 100644 index db46ecf81ea1d..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Mul.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic; - -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.tree.NodeInfo; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.core.type.DataType; - -import static org.elasticsearch.common.logging.LoggerMessageFormat.format; - -/** - * Multiplication function ({@code a * b}). - */ -public class Mul extends ArithmeticOperation implements BinaryComparisonInversible { - - public Mul(Source source, Expression left, Expression right) { - super(source, left, right, DefaultBinaryArithmeticOperation.MUL); - } - - @Override - protected TypeResolution resolveType() { - if (childrenResolved() == false) { - return new TypeResolution("Unresolved children"); - } - - DataType l = left().dataType(); - DataType r = right().dataType(); - - // 1. both are numbers - if (DataType.isNullOrNumeric(l) && DataType.isNullOrNumeric(r)) { - return TypeResolution.TYPE_RESOLVED; - } - - return new TypeResolution(format(null, "[{}] has arguments with incompatible types [{}] and [{}]", symbol(), l, r)); - } - - @Override - protected NodeInfo info() { - return NodeInfo.create(this, Mul::new, left(), right()); - } - - @Override - protected Mul replaceChildren(Expression newLeft, Expression newRight) { - return new Mul(source(), newLeft, newRight); - } - - @Override - public Mul swapLeftAndRight() { - return new Mul(source(), right(), left()); - } - - @Override - public ArithmeticOperationFactory binaryComparisonInverse() { - return Div::new; - } - - @Override - protected boolean isCommutative() { - return true; - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Sub.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Sub.java deleted file mode 100644 index 8a345986e5fba..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Sub.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic; - -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.tree.NodeInfo; -import org.elasticsearch.xpack.esql.core.tree.Source; - -/** - * Subtraction function ({@code a - b}). - */ -public class Sub extends DateTimeArithmeticOperation implements BinaryComparisonInversible { - - public Sub(Source source, Expression left, Expression right) { - super(source, left, right, DefaultBinaryArithmeticOperation.SUB); - } - - @Override - protected NodeInfo info() { - return NodeInfo.create(this, Sub::new, left(), right()); - } - - @Override - protected Sub replaceChildren(Expression newLeft, Expression newRight) { - return new Sub(source(), newLeft, newRight); - } - - @Override - public ArithmeticOperationFactory binaryComparisonInverse() { - return Add::new; - } -} From 5af24ae221069518fe72ac880edec741f91ebc78 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Thu, 6 Jun 2024 07:14:11 +1000 Subject: [PATCH 25/30] Mute org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAppendTests testEvaluateBlockWithoutNulls {TestCase=, } #109409 --- muted-tests.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index 3d64f87144bd3..32128da9c3712 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -19,8 +19,7 @@ tests: method: "testGuessIsDayFirstFromLocale" - class: "org.elasticsearch.test.rest.ClientYamlTestSuiteIT" issue: "https://github.com/elastic/elasticsearch/issues/108857" - method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale\ - \ dependent mappings / dates}" + method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale dependent mappings / dates}" - class: "org.elasticsearch.upgrades.SearchStatesIT" issue: "https://github.com/elastic/elasticsearch/issues/108991" method: "testCanMatch" @@ -29,8 +28,7 @@ tests: method: "testTrainedModelInference" - class: "org.elasticsearch.xpack.security.CoreWithSecurityClientYamlTestSuiteIT" issue: "https://github.com/elastic/elasticsearch/issues/109188" - method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale\ - \ dependent mappings / dates}" + method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale dependent mappings / dates}" - class: "org.elasticsearch.xpack.esql.qa.mixed.EsqlClientYamlIT" issue: "https://github.com/elastic/elasticsearch/issues/109189" method: "test {p0=esql/70_locale/Date format with Italian locale}" @@ -45,8 +43,7 @@ tests: method: "testTimestampFieldTypeExposedByAllIndicesServices" - class: "org.elasticsearch.analysis.common.CommonAnalysisClientYamlTestSuiteIT" issue: "https://github.com/elastic/elasticsearch/issues/109318" - method: "test {yaml=analysis-common/50_char_filters/pattern_replace error handling\ - \ (too complex pattern)}" + method: "test {yaml=analysis-common/50_char_filters/pattern_replace error handling (too complex pattern)}" - class: "org.elasticsearch.xpack.ml.integration.ClassificationHousePricingIT" issue: "https://github.com/elastic/elasticsearch/issues/101598" method: "testFeatureImportanceValues" @@ -59,6 +56,9 @@ tests: - class: "org.elasticsearch.xpack.inference.InferenceCrudIT" issue: "https://github.com/elastic/elasticsearch/issues/109391" method: "testDeleteEndpointWhileReferencedByPipeline" +- class: org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAppendTests + method: testEvaluateBlockWithoutNulls {TestCase=, } + issue: https://github.com/elastic/elasticsearch/issues/109409 # Examples: # From 703c7ad5c61f435be6cc2c026e69d5add72a254d Mon Sep 17 00:00:00 2001 From: Mark Tozzi Date: Wed, 5 Jun 2024 17:22:50 -0400 Subject: [PATCH 26/30] [ESQL] Rough index of key classes and packages (#109405) First draft of some top-level developer docs for ESQL. At the moment this is just a collection of links to key classes and packages, but we can add to it as we find useful. --- .../xpack/esql/package-info.java | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/package-info.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/package-info.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/package-info.java new file mode 100644 index 0000000000000..863476ba55686 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/package-info.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +/** + * ES|QL Overview and Documentation Links + * + *

Major Components

+ *
    + *
  • {@link org.elasticsearch.compute} - The compute engine drives query execution + *
      + *
    • {@link org.elasticsearch.compute.data.Block} - fundamental unit of data. Operations vectorize over blocks.
    • + *
    • {@link org.elasticsearch.compute.data.Page} - Data is broken up into pages (which are collections of blocks) to + * manage size in memory
    • + *
    + *
  • + *
  • {@link org.elasticsearch.xpack.esql.core} - Core Utility Classes + *
      + *
    • {@link org.elasticsearch.xpack.esql.core.type.DataType} - ES|QL is a typed language, and all the supported data types + * are listed in this collection.
    • + *
    • {@link org.elasticsearch.xpack.esql.core.expression.Expression} - Expression is the basis for all functions in ES|QL, + * but see also {@link org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper}
    • + *
    + *
  • + *
  • org.elasticsearch.compute.gen - ES|QL generates code for evaluators, which are type-specific implementations of + * functions, designed to run over a {@link org.elasticsearch.compute.data.Block}
  • + *
  • {@link org.elasticsearch.xpack.esql.session.EsqlSession} - manages state across a query
  • + *
  • {@link org.elasticsearch.xpack.esql.expression.function.scalar} - Guide to writing scalar functions
  • + *
  • {@link org.elasticsearch.xpack.esql.analysis.Analyzer} - The first step in query processing
  • + *
  • {@link org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer} - Coordinator level logical optimizations
  • + *
  • {@link org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer} - Data node level logical optimizations
  • + *
  • {@link org.elasticsearch.xpack.esql.action.RestEsqlQueryAction} - REST API entry point
  • + *
+ */ + +package org.elasticsearch.xpack.esql; From 84bc73bcfe07e2439030dd26806ff6f293d42954 Mon Sep 17 00:00:00 2001 From: Rene Groeschke Date: Thu, 6 Jun 2024 07:28:58 +0200 Subject: [PATCH 27/30] More cc tweaks to esql build script (#109356) --- x-pack/plugin/esql/build.gradle | 34 +++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/esql/build.gradle b/x-pack/plugin/esql/build.gradle index 669f38bd44ecc..faa5a118a90cd 100644 --- a/x-pack/plugin/esql/build.gradle +++ b/x-pack/plugin/esql/build.gradle @@ -47,24 +47,34 @@ tasks.named("compileJava").configure { exclude { it.file.toString().startsWith("${projectDir}/src/main/generated-src/generated") } } +interface Injected { + @Inject FileSystemOperations getFs() +} + tasks.named("test").configure { if (BuildParams.isCi() == false) { systemProperty 'generateDocs', true + def injected = project.objects.newInstance(Injected) doFirst { - project.delete( - files("${projectDir}/build/testrun/test/temp/esql/functions") - ) + injected.fs.delete { + it.delete("build/testrun/test/temp/esql/functions") + } } + File functionsFolder = file("build/testrun/test/temp/esql/functions") + File signatureFolder = file("build/testrun/test/temp/esql/functions/signature") + File typesFolder = file("build/testrun/test/temp/esql/functions/types") + def functionsDocFolder = file("${rootDir}/docs/reference/esql/functions") + def effectiveProjectDir = projectDir + doLast { - List signatures = file("${projectDir}/build/testrun/test/temp/esql/functions/signature").list().findAll {it.endsWith("svg")} - List types = file("${projectDir}/build/testrun/test/temp/esql/functions/types").list().findAll {it.endsWith("asciidoc")} + List types = typesFolder.list().findAll {it.endsWith("asciidoc")} int count = types == null ? 0 : types.size() Closure readExample = line -> { line.replaceAll(/read-example::([^\[]+)\[tag=([^,\]]+)(, ?json)?\]/, { String file = it[1] String tag = it[2] boolean isJson = it[3] - String allExamples = new File("${projectDir}/qa/testFixtures/src/main/resources/${file}").text + String allExamples = new File("${effectiveProjectDir}/qa/testFixtures/src/main/resources/${file}").text .replaceAll(System.lineSeparator(), "\n") int start = allExamples.indexOf("tag::${tag}[]") int end = allExamples.indexOf("end::${tag}[]", start) @@ -85,9 +95,9 @@ tasks.named("test").configure { logger.quiet("ESQL Docs: No function signatures created. Skipping sync.") } else if (count == 1) { logger.quiet("ESQL Docs: Only files related to $types, patching them into place") - project.sync { - from "${projectDir}/build/testrun/test/temp/esql/functions" - into "${rootDir}/docs/reference/esql/functions" + injected.fs.sync { + from functionsFolder + into functionsDocFolder include '**/*.asciidoc', '**/*.svg', '**/*.md', '**/*.json' preserve { include '/*.asciidoc', '**/*.asciidoc', '**/*.md', '**/*.json', '**/*.svg', 'README.md' @@ -95,9 +105,9 @@ tasks.named("test").configure { filter readExample } } else { - project.sync { - from "${projectDir}/build/testrun/test/temp/esql/functions" - into "${rootDir}/docs/reference/esql/functions" + injected.fs.sync { + from functionsFolder + into functionsDocFolder include '**/*.asciidoc', '**/*.svg', '**/*.md', '**/*.json' preserve { include '/*.asciidoc', 'README.md' From fad8c191fc92d680279a9f415dd25e66edc2b86b Mon Sep 17 00:00:00 2001 From: Rene Groeschke Date: Thu, 6 Jun 2024 07:30:21 +0200 Subject: [PATCH 28/30] Make BWC build logic configuration cache compatible (#109354) More refactoring towards configuration cache compatibility --- .../main/groovy/elasticsearch.bwc-test.gradle | 7 +- .../gradle/internal/BwcSetupExtension.java | 72 ++++++++++++++----- .../gradle/internal/InternalBwcGitPlugin.java | 8 +-- .../InternalDistributionBwcSetupPlugin.java | 42 +++++++++-- .../org/elasticsearch/gradle/LoggedExec.java | 6 +- 5 files changed, 99 insertions(+), 36 deletions(-) diff --git a/build-tools-internal/src/main/groovy/elasticsearch.bwc-test.gradle b/build-tools-internal/src/main/groovy/elasticsearch.bwc-test.gradle index 5512b06d0ab8b..ff9b6fe7a526d 100644 --- a/build-tools-internal/src/main/groovy/elasticsearch.bwc-test.gradle +++ b/build-tools-internal/src/main/groovy/elasticsearch.bwc-test.gradle @@ -33,7 +33,8 @@ tasks.register("bwcTest") { plugins.withType(ElasticsearchTestBasePlugin) { tasks.withType(Test).matching { it.name ==~ /v[0-9\.]+#.*/ }.configureEach { - onlyIf("BWC tests enabled") { project.bwc_tests_enabled } + boolean bwcEnabled = project.bwc_tests_enabled + onlyIf("BWC tests enabled") { bwcEnabled } nonInputProperties.systemProperty 'tests.bwc', 'true' } } @@ -50,5 +51,5 @@ plugins.withType(InternalJavaRestTestPlugin) { } } -tasks.matching { it.name.equals("check") }.configureEach {dependsOn(bwcTestSnapshots) } -tasks.matching { it.name.equals("test") }.configureEach {enabled = false} +tasks.matching { it.name.equals("check") }.configureEach { dependsOn(bwcTestSnapshots) } +tasks.matching { it.name.equals("test") }.configureEach { enabled = false } diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcSetupExtension.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcSetupExtension.java index 3d6d37575eca9..7010ed92d4c57 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcSetupExtension.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcSetupExtension.java @@ -20,6 +20,9 @@ import org.gradle.api.model.ObjectFactory; import org.gradle.api.provider.Property; import org.gradle.api.provider.Provider; +import org.gradle.api.provider.ProviderFactory; +import org.gradle.api.provider.ValueSource; +import org.gradle.api.provider.ValueSourceParameters; import org.gradle.api.tasks.TaskProvider; import org.gradle.jvm.toolchain.JavaLanguageVersion; import org.gradle.jvm.toolchain.JavaToolchainService; @@ -41,6 +44,7 @@ public class BwcSetupExtension { private static final Version BUILD_TOOL_MINIMUM_VERSION = Version.fromString("7.14.0"); private final Project project; private final ObjectFactory objectFactory; + private final ProviderFactory providerFactory; private final JavaToolchainService toolChainService; private final Provider unreleasedVersionInfo; @@ -49,12 +53,14 @@ public class BwcSetupExtension { public BwcSetupExtension( Project project, ObjectFactory objectFactory, + ProviderFactory providerFactory, JavaToolchainService toolChainService, Provider unreleasedVersionInfo, Provider checkoutDir ) { this.project = project; this.objectFactory = objectFactory; + this.providerFactory = providerFactory; this.toolChainService = toolChainService; this.unreleasedVersionInfo = unreleasedVersionInfo; this.checkoutDir = checkoutDir; @@ -65,11 +71,26 @@ TaskProvider bwcTask(String name, Action configuration) } TaskProvider bwcTask(String name, Action configuration, boolean useUniqueUserHome) { - return createRunBwcGradleTask(project, name, configuration, useUniqueUserHome); + return createRunBwcGradleTask( + project, + checkoutDir, + providerFactory, + unreleasedVersionInfo, + objectFactory, + toolChainService, + name, + configuration, + useUniqueUserHome + ); } - private TaskProvider createRunBwcGradleTask( + private static TaskProvider createRunBwcGradleTask( Project project, + Provider checkoutDir, + ProviderFactory providerFactory, + Provider unreleasedVersionInfo, + ObjectFactory objectFactory, + JavaToolchainService toolChainService, String name, Action configAction, boolean useUniqueUserHome @@ -78,10 +99,10 @@ private TaskProvider createRunBwcGradleTask( loggedExec.dependsOn("checkoutBwcBranch"); loggedExec.getWorkingDir().set(checkoutDir.get()); - loggedExec.getEnvironment().put("JAVA_HOME", unreleasedVersionInfo.zip(checkoutDir, (version, checkoutDir) -> { - String minimumCompilerVersion = readFromFile(new File(checkoutDir, minimumCompilerVersionPath(version.version()))); - return getJavaHome(Integer.parseInt(minimumCompilerVersion)); - })); + loggedExec.getNonTrackedEnvironment().put("JAVA_HOME", providerFactory.of(JavaHomeValueSource.class, spec -> { + spec.getParameters().getVersion().set(unreleasedVersionInfo.map(it -> it.version())); + spec.getParameters().getCheckoutDir().set(checkoutDir); + }).flatMap(s -> getJavaHome(objectFactory, toolChainService, Integer.parseInt(s)))); if (BuildParams.isCi() && OS.current() != OS.WINDOWS) { // TODO: Disabled for now until we can figure out why files are getting corrupted @@ -137,10 +158,13 @@ private TaskProvider createRunBwcGradleTask( }); } - private String minimumCompilerVersionPath(Version bwcVersion) { - return (bwcVersion.onOrAfter(BUILD_TOOL_MINIMUM_VERSION)) - ? "build-tools-internal/" + MINIMUM_COMPILER_VERSION_PATH - : "buildSrc/" + MINIMUM_COMPILER_VERSION_PATH; + /** A convenience method for getting java home for a version of java and requiring that version for the given task to execute */ + private static Provider getJavaHome(ObjectFactory objectFactory, JavaToolchainService toolChainService, final int version) { + Property value = objectFactory.property(JavaLanguageVersion.class).value(JavaLanguageVersion.of(version)); + return toolChainService.launcherFor(javaToolchainSpec -> { + javaToolchainSpec.getLanguageVersion().value(value); + javaToolchainSpec.getVendor().set(JvmVendorSpec.ORACLE); + }).map(launcher -> launcher.getMetadata().getInstallationPath().getAsFile().getAbsolutePath()); } private static String readFromFile(File file) { @@ -151,13 +175,25 @@ private static String readFromFile(File file) { } } - /** A convenience method for getting java home for a version of java and requiring that version for the given task to execute */ - public String getJavaHome(final int version) { - Property value = objectFactory.property(JavaLanguageVersion.class).value(JavaLanguageVersion.of(version)); - return toolChainService.launcherFor(javaToolchainSpec -> { - javaToolchainSpec.getLanguageVersion().value(value); - javaToolchainSpec.getVendor().set(JvmVendorSpec.ORACLE); - }).get().getMetadata().getInstallationPath().getAsFile().getAbsolutePath(); - } + public static abstract class JavaHomeValueSource implements ValueSource { + + private String minimumCompilerVersionPath(Version bwcVersion) { + return (bwcVersion.onOrAfter(BUILD_TOOL_MINIMUM_VERSION)) + ? "build-tools-internal/" + MINIMUM_COMPILER_VERSION_PATH + : "buildSrc/" + MINIMUM_COMPILER_VERSION_PATH; + } + @Override + public String obtain() { + return readFromFile( + new File(getParameters().getCheckoutDir().get(), minimumCompilerVersionPath(getParameters().getVersion().get())) + ); + } + + public interface Params extends ValueSourceParameters { + Property getVersion(); + + Property getCheckoutDir(); + } + } } diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalBwcGitPlugin.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalBwcGitPlugin.java index 71c76b2045007..7add1e615f577 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalBwcGitPlugin.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalBwcGitPlugin.java @@ -93,13 +93,6 @@ public void execute(Task task) { String remoteRepo = remote.get(); // for testing only we can override the base remote url String remoteRepoUrl = providerFactory.systemProperty("testRemoteRepo") - .orElse( - providerFactory.provider( - () -> addRemote.getExtensions().getExtraProperties().has("remote") - ? addRemote.getExtensions().getExtraProperties().get("remote").toString() - : null - ) - ) .getOrElse("https://github.com/" + remoteRepo + "/" + rootProjectName); spec.commandLine("git", "remote", "add", remoteRepo, remoteRepoUrl); }); @@ -213,6 +206,7 @@ private String maybeAlignedRefSpec(Logger logger, String defaultRefSpec) { private void writeFile(File file, String content) { try { + file.getParentFile().mkdirs(); Files.writeString(file.toPath(), content, CREATE, TRUNCATE_EXISTING); } catch (IOException e) { throw new UncheckedIOException(e); diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalDistributionBwcSetupPlugin.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalDistributionBwcSetupPlugin.java index f727dc165a8a9..a2247adcf7b9e 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalDistributionBwcSetupPlugin.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/InternalDistributionBwcSetupPlugin.java @@ -16,6 +16,7 @@ import org.gradle.api.Plugin; import org.gradle.api.Project; import org.gradle.api.Task; +import org.gradle.api.file.ProjectLayout; import org.gradle.api.model.ObjectFactory; import org.gradle.api.plugins.JvmToolchainsPlugin; import org.gradle.api.provider.Provider; @@ -63,15 +64,39 @@ public void apply(Project project) { project.getPlugins().apply(JvmToolchainsPlugin.class); toolChainService = project.getExtensions().getByType(JavaToolchainService.class); BuildParams.getBwcVersions().forPreviousUnreleased((BwcVersions.UnreleasedVersionInfo unreleasedVersion) -> { - configureBwcProject(project.project(unreleasedVersion.gradleProjectPath()), unreleasedVersion); + configureBwcProject( + project.project(unreleasedVersion.gradleProjectPath()), + unreleasedVersion, + providerFactory, + objectFactory, + toolChainService + ); }); } - private void configureBwcProject(Project project, BwcVersions.UnreleasedVersionInfo versionInfo) { + private static void configureBwcProject( + Project project, + BwcVersions.UnreleasedVersionInfo versionInfo, + ProviderFactory providerFactory, + ObjectFactory objectFactory, + JavaToolchainService toolChainService + ) { + ProjectLayout layout = project.getLayout(); Provider versionInfoProvider = providerFactory.provider(() -> versionInfo); - Provider checkoutDir = versionInfoProvider.map(info -> new File(project.getBuildDir(), "bwc/checkout-" + info.branch())); + Provider checkoutDir = versionInfoProvider.map( + info -> new File(layout.getBuildDirectory().get().getAsFile(), "bwc/checkout-" + info.branch()) + ); BwcSetupExtension bwcSetupExtension = project.getExtensions() - .create("bwcSetup", BwcSetupExtension.class, project, objectFactory, toolChainService, versionInfoProvider, checkoutDir); + .create( + "bwcSetup", + BwcSetupExtension.class, + project, + objectFactory, + providerFactory, + toolChainService, + versionInfoProvider, + checkoutDir + ); BwcGitExtension gitExtension = project.getPlugins().apply(InternalBwcGitPlugin.class).getGitExtension(); Provider bwcVersion = versionInfoProvider.map(info -> info.version()); gitExtension.setBwcVersion(versionInfoProvider.map(info -> info.version())); @@ -157,7 +182,7 @@ private void configureBwcProject(Project project, BwcVersions.UnreleasedVersionI } } - private void registerBwcDistributionArtifacts(Project bwcProject, DistributionProject distributionProject) { + private static void registerBwcDistributionArtifacts(Project bwcProject, DistributionProject distributionProject) { String projectName = distributionProject.name; String buildBwcTask = buildBwcTaskName(projectName); @@ -174,7 +199,11 @@ private void registerBwcDistributionArtifacts(Project bwcProject, DistributionPr } } - private void registerDistributionArchiveArtifact(Project bwcProject, DistributionProject distributionProject, String buildBwcTask) { + private static void registerDistributionArchiveArtifact( + Project bwcProject, + DistributionProject distributionProject, + String buildBwcTask + ) { File distFile = distributionProject.expectedBuildArtifact.distFile; String artifactFileName = distFile.getName(); String artifactName = artifactFileName.contains("oss") ? "elasticsearch-oss" : "elasticsearch"; @@ -363,5 +392,4 @@ private static class DistributionProjectArtifact { this.expandedDistDir = expandedDistDir; } } - } diff --git a/build-tools/src/main/java/org/elasticsearch/gradle/LoggedExec.java b/build-tools/src/main/java/org/elasticsearch/gradle/LoggedExec.java index 4fda91d332118..6087482db278d 100644 --- a/build-tools/src/main/java/org/elasticsearch/gradle/LoggedExec.java +++ b/build-tools/src/main/java/org/elasticsearch/gradle/LoggedExec.java @@ -65,6 +65,9 @@ public abstract class LoggedExec extends DefaultTask implements FileSystemOperat @Optional abstract public MapProperty getEnvironment(); + @Internal + abstract public MapProperty getNonTrackedEnvironment(); + @Input abstract public Property getExecutable(); @@ -139,7 +142,8 @@ public void run() { execSpec.setStandardOutput(finalOutputStream); execSpec.setErrorOutput(finalOutputStream); execSpec.setExecutable(getExecutable().get()); - execSpec.setEnvironment(getEnvironment().get()); + execSpec.environment(getEnvironment().get()); + execSpec.environment(getNonTrackedEnvironment().get()); if (getArgs().isPresent()) { execSpec.setArgs(getArgs().get()); } From 13afe0fda9c50a0adc5d17b2f1a994168ce6baf3 Mon Sep 17 00:00:00 2001 From: Mikhail Berezovskiy Date: Wed, 5 Jun 2024 23:22:40 -0700 Subject: [PATCH 29/30] Change UnassignedInfo class to record (#109363) --- .../ClusterAllocationExplainIT.java | 18 +- .../admin/indices/create/ShrinkIndexIT.java | 4 +- .../cluster/routing/AllocationIdIT.java | 2 +- .../cluster/routing/PrimaryAllocationIT.java | 4 +- .../cluster/routing/ShardRoutingRoleIT.java | 4 +- .../gateway/GatewayIndexStateIT.java | 8 +- .../RemoveCorruptedShardDataCommandIT.java | 4 +- .../index/store/CorruptedFileIT.java | 2 +- .../index/store/CorruptedTranslogIT.java | 4 +- .../indices/cluster/ShardLockFailureIT.java | 2 +- .../SharedClusterSnapshotRestoreIT.java | 8 +- .../ClusterAllocationExplanation.java | 12 +- .../cluster/health/ClusterShardHealth.java | 4 +- .../cluster/routing/IndexRoutingTable.java | 18 +- .../cluster/routing/RoutingNodes.java | 48 ++-- .../cluster/routing/ShardRouting.java | 4 +- .../cluster/routing/UnassignedInfo.java | 207 ++++++------------ .../routing/allocation/AllocationService.java | 28 +-- .../allocation/ShardChangesObserver.java | 2 +- .../allocator/BalancedShardsAllocator.java | 20 +- .../allocator/DesiredBalanceComputer.java | 35 ++- .../allocator/DesiredBalanceReconciler.java | 20 +- ...AllocateEmptyPrimaryAllocationCommand.java | 10 +- .../decider/MaxRetryAllocationDecider.java | 2 +- .../RestoreInProgressAllocationDecider.java | 2 +- ...rdsAvailabilityHealthIndicatorService.java | 10 +- .../gateway/ReplicaShardAllocator.java | 12 +- .../rest/action/cat/RestShardsAction.java | 8 +- .../snapshots/RestoreService.java | 6 +- .../ClusterAllocationExplainActionTests.java | 6 +- .../cluster/reroute/ClusterRerouteTests.java | 2 +- .../health/ClusterStateHealthTests.java | 4 +- .../MetadataIndexStateServiceTests.java | 2 +- .../DelayedAllocationServiceTests.java | 10 +- .../cluster/routing/ShardRoutingTests.java | 2 +- .../cluster/routing/UnassignedInfoTests.java | 78 +++---- .../MaxRetryAllocationDeciderTests.java | 36 +-- .../TrackFailedAllocationNodesTests.java | 9 +- .../DesiredBalanceComputerTests.java | 22 +- .../DesiredBalanceReconcilerTests.java | 30 +-- .../DesiredBalanceShardsAllocatorTests.java | 6 +- .../decider/DiskThresholdDeciderTests.java | 6 +- ...storeInProgressAllocationDeciderTests.java | 18 +- .../gateway/PrimaryShardAllocatorTests.java | 6 +- .../gateway/ReplicaShardAllocatorTests.java | 4 +- .../snapshots/SnapshotResiliencyTests.java | 2 +- .../cluster/ESAllocationTestCase.java | 4 +- .../xpack/ccr/CcrRepositoryIT.java | 4 +- .../ClusterStateApplierOrderingTests.java | 2 +- .../SearchableSnapshotAllocator.java | 4 +- .../TransportGetShutdownStatusAction.java | 2 +- 51 files changed, 348 insertions(+), 419 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainIT.java index 26b33acfcbe98..897f10b031dcb 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainIT.java @@ -100,10 +100,10 @@ public void testUnassignedPrimaryWithExistingIndex() throws Exception { // verify unassigned info assertNotNull(unassignedInfo); - assertEquals(Reason.NODE_LEFT, unassignedInfo.getReason()); + assertEquals(Reason.NODE_LEFT, unassignedInfo.reason()); assertTrue( - unassignedInfo.getLastAllocationStatus() == AllocationStatus.FETCHING_SHARD_DATA - || unassignedInfo.getLastAllocationStatus() == AllocationStatus.NO_VALID_SHARD_COPY + unassignedInfo.lastAllocationStatus() == AllocationStatus.FETCHING_SHARD_DATA + || unassignedInfo.lastAllocationStatus() == AllocationStatus.NO_VALID_SHARD_COPY ); // verify cluster info @@ -190,8 +190,8 @@ public void testUnassignedReplicaDelayedAllocation() throws Exception { // verify unassigned info assertNotNull(unassignedInfo); - assertEquals(Reason.NODE_LEFT, unassignedInfo.getReason()); - assertEquals(AllocationStatus.NO_ATTEMPT, unassignedInfo.getLastAllocationStatus()); + assertEquals(Reason.NODE_LEFT, unassignedInfo.reason()); + assertEquals(AllocationStatus.NO_ATTEMPT, unassignedInfo.lastAllocationStatus()); // verify cluster info verifyClusterInfo(clusterInfo, includeDiskInfo, 2); @@ -320,8 +320,8 @@ public void testUnassignedReplicaWithPriorCopy() throws Exception { // verify unassigned info assertNotNull(unassignedInfo); - assertEquals(Reason.NODE_LEFT, unassignedInfo.getReason()); - assertEquals(AllocationStatus.NO_ATTEMPT, unassignedInfo.getLastAllocationStatus()); + assertEquals(Reason.NODE_LEFT, unassignedInfo.reason()); + assertEquals(AllocationStatus.NO_ATTEMPT, unassignedInfo.lastAllocationStatus()); // verify cluster info verifyClusterInfo(clusterInfo, includeDiskInfo, 3); @@ -432,8 +432,8 @@ public void testAllocationFilteringOnIndexCreation() throws Exception { // verify unassigned info assertNotNull(unassignedInfo); - assertEquals(Reason.INDEX_CREATED, unassignedInfo.getReason()); - assertEquals(AllocationStatus.DECIDERS_NO, unassignedInfo.getLastAllocationStatus()); + assertEquals(Reason.INDEX_CREATED, unassignedInfo.reason()); + assertEquals(AllocationStatus.DECIDERS_NO, unassignedInfo.lastAllocationStatus()); // verify cluster info verifyClusterInfo(clusterInfo, includeDiskInfo, 2); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/indices/create/ShrinkIndexIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/indices/create/ShrinkIndexIT.java index aa4fee3a3f94d..61e5c1bfcc811 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/indices/create/ShrinkIndexIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/indices/create/ShrinkIndexIT.java @@ -386,9 +386,9 @@ public void testCreateShrinkIndexFails() throws Exception { assertTrue(routingTables.index("target").shard(0).shard(0).unassigned()); assertEquals( UnassignedInfo.Reason.ALLOCATION_FAILED, - routingTables.index("target").shard(0).shard(0).unassignedInfo().getReason() + routingTables.index("target").shard(0).shard(0).unassignedInfo().reason() ); - assertEquals(1, routingTables.index("target").shard(0).shard(0).unassignedInfo().getNumFailedAllocations()); + assertEquals(1, routingTables.index("target").shard(0).shard(0).unassignedInfo().failedAllocations()); }); // now relocate them all to the right node updateIndexSettings(Settings.builder().put("index.routing.allocation.require._name", mergeNode), "source"); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/AllocationIdIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/AllocationIdIT.java index 784a6e8f419c8..a25de555ce267 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/AllocationIdIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/AllocationIdIT.java @@ -113,7 +113,7 @@ public void testFailedRecoveryOnAllocateStalePrimaryRequiresAnotherAllocateStale final ClusterState state = clusterAdmin().prepareState().get().getState(); final ShardRouting shardRouting = state.routingTable().index(indexName).shard(shardId.id()).primaryShard(); assertThat(shardRouting.state(), equalTo(ShardRoutingState.UNASSIGNED)); - assertThat(shardRouting.unassignedInfo().getReason(), equalTo(UnassignedInfo.Reason.ALLOCATION_FAILED)); + assertThat(shardRouting.unassignedInfo().reason(), equalTo(UnassignedInfo.Reason.ALLOCATION_FAILED)); }); internalCluster().stopNode(node1); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/PrimaryAllocationIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/PrimaryAllocationIT.java index e7a7a6f2ba727..72594fef8c6ee 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/PrimaryAllocationIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/PrimaryAllocationIT.java @@ -202,7 +202,7 @@ public void testFailedAllocationOfStalePrimaryToDataNodeWithNoData() throws Exce .shard(0) .primaryShard() .unassignedInfo() - .getReason(), + .reason(), equalTo(UnassignedInfo.Reason.NODE_LEFT) ); @@ -227,7 +227,7 @@ public void testFailedAllocationOfStalePrimaryToDataNodeWithNoData() throws Exce .shard(0) .primaryShard() .unassignedInfo() - .getReason(), + .reason(), equalTo(UnassignedInfo.Reason.NODE_LEFT) ); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/ShardRoutingRoleIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/ShardRoutingRoleIT.java index 006c9e2394f3c..76311387115d2 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/ShardRoutingRoleIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/ShardRoutingRoleIT.java @@ -114,7 +114,7 @@ public Collection createAllocationDeciders(Settings settings, @Override public Decision canForceAllocatePrimary(ShardRouting shardRouting, RoutingNode node, RoutingAllocation allocation) { // once a primary is cancelled it _stays_ cancelled - if (shardRouting.unassignedInfo().getReason() == UnassignedInfo.Reason.REROUTE_CANCELLED) { + if (shardRouting.unassignedInfo().reason() == UnassignedInfo.Reason.REROUTE_CANCELLED) { return Decision.NO; } return super.canForceAllocatePrimary(shardRouting, node, allocation); @@ -450,7 +450,7 @@ public AllocationCommand getCancelPrimaryCommand() { shardRouting.role().isPromotableToPrimary() ? UnassignedInfo.Reason.REROUTE_CANCELLED : UnassignedInfo.Reason.UNPROMOTABLE_REPLICA, - shardRouting.unassignedInfo().getReason() + shardRouting.unassignedInfo().reason() ); } } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/gateway/GatewayIndexStateIT.java b/server/src/internalClusterTest/java/org/elasticsearch/gateway/GatewayIndexStateIT.java index d1827bf49410f..e05bda69d2c9c 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/gateway/GatewayIndexStateIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/gateway/GatewayIndexStateIT.java @@ -403,9 +403,9 @@ public void testRecoverBrokenIndexMetadata() throws Exception { assertTrue(shardRoutingTable.primaryShard().unassigned()); assertEquals( UnassignedInfo.AllocationStatus.DECIDERS_NO, - shardRoutingTable.primaryShard().unassignedInfo().getLastAllocationStatus() + shardRoutingTable.primaryShard().unassignedInfo().lastAllocationStatus() ); - assertThat(shardRoutingTable.primaryShard().unassignedInfo().getNumFailedAllocations(), greaterThan(0)); + assertThat(shardRoutingTable.primaryShard().unassignedInfo().failedAllocations(), greaterThan(0)); } }, 60, TimeUnit.SECONDS); indicesAdmin().prepareClose("test").get(); @@ -472,9 +472,9 @@ public void testRecoverMissingAnalyzer() throws Exception { assertTrue(shardRoutingTable.primaryShard().unassigned()); assertEquals( UnassignedInfo.AllocationStatus.DECIDERS_NO, - shardRoutingTable.primaryShard().unassignedInfo().getLastAllocationStatus() + shardRoutingTable.primaryShard().unassignedInfo().lastAllocationStatus() ); - assertThat(shardRoutingTable.primaryShard().unassignedInfo().getNumFailedAllocations(), greaterThan(0)); + assertThat(shardRoutingTable.primaryShard().unassignedInfo().failedAllocations(), greaterThan(0)); } }, 60, TimeUnit.SECONDS); indicesAdmin().prepareClose("test").get(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/shard/RemoveCorruptedShardDataCommandIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/shard/RemoveCorruptedShardDataCommandIT.java index f43aaf0bacad4..ef4616fdd0b40 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/shard/RemoveCorruptedShardDataCommandIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/shard/RemoveCorruptedShardDataCommandIT.java @@ -311,8 +311,8 @@ public Settings onNodeStopped(String nodeName) throws Exception { // all shards should be failed due to a corrupted translog assertBusy(() -> { final UnassignedInfo unassignedInfo = getClusterAllocationExplanation(client(), indexName, 0, true).getUnassignedInfo(); - assertThat(unassignedInfo.getReason(), equalTo(UnassignedInfo.Reason.ALLOCATION_FAILED)); - assertThat(ExceptionsHelper.unwrap(unassignedInfo.getFailure(), TranslogCorruptedException.class), not(nullValue())); + assertThat(unassignedInfo.reason(), equalTo(UnassignedInfo.Reason.ALLOCATION_FAILED)); + assertThat(ExceptionsHelper.unwrap(unassignedInfo.failure(), TranslogCorruptedException.class), not(nullValue())); }); // have to shut down primary node - otherwise node lock is present diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/store/CorruptedFileIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/store/CorruptedFileIT.java index a9d19473164bf..7e3df8d8e1cbc 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/store/CorruptedFileIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/store/CorruptedFileIT.java @@ -465,7 +465,7 @@ public void onTimeout(TimeValue timeout) { final var replicaShards = indexRoutingTable.shard(shardId).replicaShards(); if (replicaShards.isEmpty() || replicaShards.stream() - .anyMatch(sr -> sr.unassigned() == false || sr.unassignedInfo().getNumFailedAllocations() < maxRetries)) { + .anyMatch(sr -> sr.unassigned() == false || sr.unassignedInfo().failedAllocations() < maxRetries)) { return false; } } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/store/CorruptedTranslogIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/store/CorruptedTranslogIT.java index ac5a10d246cfc..0c0ece4bf5227 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/store/CorruptedTranslogIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/store/CorruptedTranslogIT.java @@ -83,8 +83,8 @@ public void onAllNodesStopped() throws Exception { final var description = Strings.toString(allocationExplainResponse); final var unassignedInfo = allocationExplainResponse.getUnassignedInfo(); assertThat(description, unassignedInfo, not(nullValue())); - assertThat(description, unassignedInfo.getReason(), equalTo(UnassignedInfo.Reason.ALLOCATION_FAILED)); - var failure = unassignedInfo.getFailure(); + assertThat(description, unassignedInfo.reason(), equalTo(UnassignedInfo.Reason.ALLOCATION_FAILED)); + var failure = unassignedInfo.failure(); assertNotNull(failure); final Throwable cause = ExceptionsHelper.unwrap(failure, TranslogCorruptedException.class); if (cause != null) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/cluster/ShardLockFailureIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/cluster/ShardLockFailureIT.java index 59e7a67687921..874ba7b42690c 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/cluster/ShardLockFailureIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/cluster/ShardLockFailureIT.java @@ -61,7 +61,7 @@ public void testShardLockFailure() throws Exception { .routingTable() .shardRoutingTable(shardId) .allShards() - .noneMatch(sr -> sr.unassigned() && sr.unassignedInfo().getNumFailedAllocations() > 0) + .noneMatch(sr -> sr.unassigned() && sr.unassignedInfo().failedAllocations() > 0) ); } catch (IndexNotFoundException e) { // ok diff --git a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SharedClusterSnapshotRestoreIT.java b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SharedClusterSnapshotRestoreIT.java index 005604b92a723..6e19cf60cf5b9 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SharedClusterSnapshotRestoreIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SharedClusterSnapshotRestoreIT.java @@ -579,8 +579,8 @@ public void testUnrestorableFilesDuringRestore() throws Exception { .build(); Consumer checkUnassignedInfo = unassignedInfo -> { - assertThat(unassignedInfo.getReason(), equalTo(UnassignedInfo.Reason.ALLOCATION_FAILED)); - assertThat(unassignedInfo.getNumFailedAllocations(), anyOf(equalTo(maxRetries), equalTo(1))); + assertThat(unassignedInfo.reason(), equalTo(UnassignedInfo.Reason.ALLOCATION_FAILED)); + assertThat(unassignedInfo.failedAllocations(), anyOf(equalTo(maxRetries), equalTo(1))); }; unrestorableUseCase(indexName, createIndexSettings, repositorySettings, Settings.EMPTY, checkUnassignedInfo, () -> {}); @@ -605,7 +605,7 @@ public void testUnrestorableIndexDuringRestore() throws Exception { Settings.EMPTY, Settings.EMPTY, restoreIndexSettings, - unassignedInfo -> assertThat(unassignedInfo.getReason(), equalTo(UnassignedInfo.Reason.NEW_INDEX_RESTORED)), + unassignedInfo -> assertThat(unassignedInfo.reason(), equalTo(UnassignedInfo.Reason.NEW_INDEX_RESTORED)), fixupAction ); } @@ -670,7 +670,7 @@ private void unrestorableUseCase( if (shard.primary()) { assertThat(shard.state(), equalTo(ShardRoutingState.UNASSIGNED)); assertThat(shard.recoverySource().getType(), equalTo(RecoverySource.Type.SNAPSHOT)); - assertThat(shard.unassignedInfo().getLastAllocationStatus(), equalTo(UnassignedInfo.AllocationStatus.DECIDERS_NO)); + assertThat(shard.unassignedInfo().lastAllocationStatus(), equalTo(UnassignedInfo.AllocationStatus.DECIDERS_NO)); checkUnassignedInfo.accept(shard.unassignedInfo()); } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanation.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanation.java index d22bae9c5a4b1..1e5f9d5d613d2 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanation.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanation.java @@ -226,16 +226,16 @@ private Iterator getShardAllocationDecisionChunked(ToXCont private static XContentBuilder unassignedInfoToXContent(UnassignedInfo unassignedInfo, XContentBuilder builder) throws IOException { builder.startObject("unassigned_info"); - builder.field("reason", unassignedInfo.getReason()); - builder.field("at", UnassignedInfo.DATE_TIME_FORMATTER.format(Instant.ofEpochMilli(unassignedInfo.getUnassignedTimeInMillis()))); - if (unassignedInfo.getNumFailedAllocations() > 0) { - builder.field("failed_allocation_attempts", unassignedInfo.getNumFailedAllocations()); + builder.field("reason", unassignedInfo.reason()); + builder.field("at", UnassignedInfo.DATE_TIME_FORMATTER.format(Instant.ofEpochMilli(unassignedInfo.unassignedTimeMillis()))); + if (unassignedInfo.failedAllocations() > 0) { + builder.field("failed_allocation_attempts", unassignedInfo.failedAllocations()); } - String details = unassignedInfo.getDetails(); + String details = unassignedInfo.details(); if (details != null) { builder.field("details", details); } - builder.field("last_allocation_status", AllocationDecision.fromAllocationStatus(unassignedInfo.getLastAllocationStatus())); + builder.field("last_allocation_status", AllocationDecision.fromAllocationStatus(unassignedInfo.lastAllocationStatus())); builder.endObject(); return builder; } diff --git a/server/src/main/java/org/elasticsearch/cluster/health/ClusterShardHealth.java b/server/src/main/java/org/elasticsearch/cluster/health/ClusterShardHealth.java index 785b0db5cc807..adb5a7caf2f45 100644 --- a/server/src/main/java/org/elasticsearch/cluster/health/ClusterShardHealth.java +++ b/server/src/main/java/org/elasticsearch/cluster/health/ClusterShardHealth.java @@ -167,8 +167,8 @@ public static ClusterHealthStatus getInactivePrimaryHealth(final ShardRouting sh assert shardRouting.recoverySource() != null : "cannot invoke on a shard that has no recovery source" + shardRouting; final UnassignedInfo unassignedInfo = shardRouting.unassignedInfo(); RecoverySource.Type recoveryType = shardRouting.recoverySource().getType(); - if (unassignedInfo.getLastAllocationStatus() != AllocationStatus.DECIDERS_NO - && unassignedInfo.getNumFailedAllocations() == 0 + if (unassignedInfo.lastAllocationStatus() != AllocationStatus.DECIDERS_NO + && unassignedInfo.failedAllocations() == 0 && (recoveryType == RecoverySource.Type.EMPTY_STORE || recoveryType == RecoverySource.Type.LOCAL_SHARDS || recoveryType == RecoverySource.Type.SNAPSHOT)) { diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/IndexRoutingTable.java b/server/src/main/java/org/elasticsearch/cluster/routing/IndexRoutingTable.java index 6679f17a0427b..d62dd91d7e87b 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/IndexRoutingTable.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/IndexRoutingTable.java @@ -574,15 +574,15 @@ private static UnassignedInfo withLastAllocatedNodeId(UnassignedInfo unassignedI return previousNodes == null || previousNodes.size() <= shardCopy ? unassignedInfo : new UnassignedInfo( - unassignedInfo.getReason(), - unassignedInfo.getMessage(), - unassignedInfo.getFailure(), - unassignedInfo.getNumFailedAllocations(), - unassignedInfo.getUnassignedTimeInNanos(), - unassignedInfo.getUnassignedTimeInMillis(), - unassignedInfo.isDelayed(), - unassignedInfo.getLastAllocationStatus(), - unassignedInfo.getFailedNodeIds(), + unassignedInfo.reason(), + unassignedInfo.message(), + unassignedInfo.failure(), + unassignedInfo.failedAllocations(), + unassignedInfo.unassignedTimeNanos(), + unassignedInfo.unassignedTimeMillis(), + unassignedInfo.delayed(), + unassignedInfo.lastAllocationStatus(), + unassignedInfo.failedNodeIds(), previousNodes.get(shardCopy) ); } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/RoutingNodes.java b/server/src/main/java/org/elasticsearch/cluster/routing/RoutingNodes.java index cf8e0608ecbd4..0b3cadb6e187c 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/RoutingNodes.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/RoutingNodes.java @@ -563,8 +563,8 @@ assert getByAllocationId(failedShard.shardId(), failedShard.allocationId().getId "primary failed while replica initializing", null, 0, - unassignedInfo.getUnassignedTimeInNanos(), - unassignedInfo.getUnassignedTimeInMillis(), + unassignedInfo.unassignedTimeNanos(), + unassignedInfo.unassignedTimeMillis(), false, AllocationStatus.NO_ATTEMPT, Collections.emptySet(), @@ -644,11 +644,11 @@ private void unassignPrimaryAndPromoteActiveReplicaIfExists( unpromotableReplica, new UnassignedInfo( UnassignedInfo.Reason.UNPROMOTABLE_REPLICA, - unassignedInfo.getMessage(), - unassignedInfo.getFailure(), + unassignedInfo.message(), + unassignedInfo.failure(), 0, - unassignedInfo.getUnassignedTimeInNanos(), - unassignedInfo.getUnassignedTimeInMillis(), + unassignedInfo.unassignedTimeNanos(), + unassignedInfo.unassignedTimeMillis(), false, // TODO debatable, but do we want to delay reassignment of unpromotable replicas tho? AllocationStatus.NO_ATTEMPT, Set.of(), @@ -970,18 +970,18 @@ public void ignoreShard(ShardRouting shard, AllocationStatus allocationStatus, R ignoredPrimaries++; UnassignedInfo currInfo = shard.unassignedInfo(); assert currInfo != null; - if (allocationStatus.equals(currInfo.getLastAllocationStatus()) == false) { + if (allocationStatus.equals(currInfo.lastAllocationStatus()) == false) { UnassignedInfo newInfo = new UnassignedInfo( - currInfo.getReason(), - currInfo.getMessage(), - currInfo.getFailure(), - currInfo.getNumFailedAllocations(), - currInfo.getUnassignedTimeInNanos(), - currInfo.getUnassignedTimeInMillis(), - currInfo.isDelayed(), + currInfo.reason(), + currInfo.message(), + currInfo.failure(), + currInfo.failedAllocations(), + currInfo.unassignedTimeNanos(), + currInfo.unassignedTimeMillis(), + currInfo.delayed(), allocationStatus, - currInfo.getFailedNodeIds(), - currInfo.getLastAllocatedNodeId() + currInfo.failedNodeIds(), + currInfo.lastAllocatedNodeId() ); ShardRouting updatedShard = shard.updateUnassigned(newInfo, shard.recoverySource()); changes.unassignedInfoUpdated(shard, newInfo); @@ -1283,16 +1283,16 @@ public void resetFailedCounter(RoutingChangesObserver routingChangesObserver) { UnassignedInfo unassignedInfo = shardRouting.unassignedInfo(); unassignedIterator.updateUnassigned( new UnassignedInfo( - unassignedInfo.getNumFailedAllocations() > 0 ? UnassignedInfo.Reason.MANUAL_ALLOCATION : unassignedInfo.getReason(), - unassignedInfo.getMessage(), - unassignedInfo.getFailure(), + unassignedInfo.failedAllocations() > 0 ? UnassignedInfo.Reason.MANUAL_ALLOCATION : unassignedInfo.reason(), + unassignedInfo.message(), + unassignedInfo.failure(), 0, - unassignedInfo.getUnassignedTimeInNanos(), - unassignedInfo.getUnassignedTimeInMillis(), - unassignedInfo.isDelayed(), - unassignedInfo.getLastAllocationStatus(), + unassignedInfo.unassignedTimeNanos(), + unassignedInfo.unassignedTimeMillis(), + unassignedInfo.delayed(), + unassignedInfo.lastAllocationStatus(), Collections.emptySet(), - unassignedInfo.getLastAllocatedNodeId() + unassignedInfo.lastAllocatedNodeId() ), shardRouting.recoverySource(), routingChangesObserver diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/ShardRouting.java b/server/src/main/java/org/elasticsearch/cluster/routing/ShardRouting.java index 95882e26773e5..523dc0efd450b 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/ShardRouting.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/ShardRouting.java @@ -342,7 +342,7 @@ public ShardRouting(ShardId shardId, StreamInput in) throws IOException { } else { recoverySource = null; } - unassignedInfo = in.readOptionalWriteable(UnassignedInfo::new); + unassignedInfo = in.readOptionalWriteable(UnassignedInfo::fromStreamInput); if (in.getTransportVersion().onOrAfter(RELOCATION_FAILURE_INFO_VERSION)) { relocationFailureInfo = RelocationFailureInfo.readFrom(in); } else { @@ -410,7 +410,7 @@ public void writeTo(StreamOutput out) throws IOException { public ShardRouting updateUnassigned(UnassignedInfo unassignedInfo, RecoverySource recoverySource) { assert this.unassignedInfo != null : "can only update unassigned info if it is already set"; - assert this.unassignedInfo.isDelayed() || (unassignedInfo.isDelayed() == false) : "cannot transition from non-delayed to delayed"; + assert this.unassignedInfo.delayed() || (unassignedInfo.delayed() == false) : "cannot transition from non-delayed to delayed"; return new ShardRouting( shardId, currentNodeId, diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/UnassignedInfo.java b/server/src/main/java/org/elasticsearch/cluster/routing/UnassignedInfo.java index bde667df3821a..9423e32be6846 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/UnassignedInfo.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/UnassignedInfo.java @@ -41,9 +41,40 @@ import static org.elasticsearch.cluster.routing.allocation.ExistingShardsAllocator.EXISTING_SHARDS_ALLOCATOR_SETTING; /** - * Holds additional information as to why the shard is in unassigned state. + * Holds additional information as to why the shard is in an unassigned state. + * + * @param reason why the shard is unassigned. + * @param message optional details explaining the reasons. + * @param failure additional failure exception details if exists. + * @param failedAllocations number of previously failed allocations of this shard. + * @param delayed true if allocation of this shard is delayed due to {@link #INDEX_DELAYED_NODE_LEFT_TIMEOUT_SETTING}. + * @param unassignedTimeMillis The timestamp in milliseconds when the shard became unassigned, based on System.currentTimeMillis(). + * Note, we use timestamp here since we want to make sure its preserved across node serializations. + * @param unassignedTimeNanos The timestamp in nanoseconds when the shard became unassigned, based on System.nanoTime(). + * Used to calculate the delay for delayed shard allocation. + * ONLY EXPOSED FOR TESTS! + * @param lastAllocationStatus status for the last allocation attempt for this shard. + * @param failedNodeIds A set of nodeIds that failed to complete allocations for this shard. + * {@link org.elasticsearch.gateway.ReplicaShardAllocator} uses this bset to avoid repeatedly canceling ongoing + * recoveries for copies on those nodes, although they can perform noop recoveries. This set will be discarded when a + * shard moves to started. And if a shard is failed while started (i.e., from started to unassigned), the currently + * assigned node won't be added to this set. + * @see org.elasticsearch.gateway.ReplicaShardAllocator#processExistingRecoveries + * @see org.elasticsearch.cluster.routing.allocation.AllocationService#applyFailedShards(ClusterState, List, List) + * @param lastAllocatedNodeId ID of the node this shard was last allocated to, or null if unavailable. */ -public final class UnassignedInfo implements ToXContentFragment, Writeable { +public record UnassignedInfo( + Reason reason, + @Nullable String message, + @Nullable Exception failure, + int failedAllocations, + long unassignedTimeNanos, + long unassignedTimeMillis, + boolean delayed, + AllocationStatus lastAllocationStatus, + Set failedNodeIds, + @Nullable String lastAllocatedNodeId +) implements ToXContentFragment, Writeable { /** * The version that the {@code lastAllocatedNode} field was added in. Used to adapt streaming of this class as appropriate for the @@ -218,17 +249,6 @@ public String value() { } } - private final Reason reason; - private final long unassignedTimeMillis; // used for display and log messages, in milliseconds - private final long unassignedTimeNanos; // in nanoseconds, used to calculate delay for delayed shard allocation - private final boolean delayed; // if allocation of this shard is delayed due to INDEX_DELAYED_NODE_LEFT_TIMEOUT_SETTING - private final String message; - private final Exception failure; - private final int failedAllocations; - private final Set failedNodeIds; - private final AllocationStatus lastAllocationStatus; // result of the last allocation attempt for this shard - private final String lastAllocatedNodeId; - /** * creates an UnassignedInfo object based on **current** time * @@ -261,28 +281,10 @@ public UnassignedInfo(Reason reason, String message) { * @param failedNodeIds a set of nodeIds that failed to complete allocations for this shard * @param lastAllocatedNodeId the ID of the node this shard was last allocated to */ - public UnassignedInfo( - Reason reason, - @Nullable String message, - @Nullable Exception failure, - int failedAllocations, - long unassignedTimeNanos, - long unassignedTimeMillis, - boolean delayed, - AllocationStatus lastAllocationStatus, - Set failedNodeIds, - @Nullable String lastAllocatedNodeId - ) { - this.reason = Objects.requireNonNull(reason); - this.unassignedTimeMillis = unassignedTimeMillis; - this.unassignedTimeNanos = unassignedTimeNanos; - this.delayed = delayed; - this.message = message; - this.failure = failure; - this.failedAllocations = failedAllocations; - this.lastAllocationStatus = Objects.requireNonNull(lastAllocationStatus); - this.failedNodeIds = Set.copyOf(failedNodeIds); - this.lastAllocatedNodeId = lastAllocatedNodeId; + public UnassignedInfo { + Objects.requireNonNull(reason); + Objects.requireNonNull(lastAllocationStatus); + failedNodeIds = Set.copyOf(failedNodeIds); assert (failedAllocations > 0) == (reason == Reason.ALLOCATION_FAILED) : "failedAllocations: " + failedAllocations + " for reason " + reason; assert (message == null && failure != null) == false : "provide a message if a failure exception is provided"; @@ -294,24 +296,37 @@ public UnassignedInfo( : "last allocated node ID must be set if the shard is unassigned due to a node restarting"; } - public UnassignedInfo(StreamInput in) throws IOException { + public static UnassignedInfo fromStreamInput(StreamInput in) throws IOException { // Because Reason.NODE_RESTARTING is new and can't be sent by older versions, there's no need to vary the deserialization behavior - this.reason = Reason.values()[(int) in.readByte()]; - this.unassignedTimeMillis = in.readLong(); + var reason = Reason.values()[(int) in.readByte()]; + var unassignedTimeMillis = in.readLong(); // As System.nanoTime() cannot be compared across different JVMs, reset it to now. // This means that in master fail-over situations, elapsed delay time is forgotten. - this.unassignedTimeNanos = System.nanoTime(); - this.delayed = in.readBoolean(); - this.message = in.readOptionalString(); - this.failure = in.readException(); - this.failedAllocations = in.readVInt(); - this.lastAllocationStatus = AllocationStatus.readFrom(in); - this.failedNodeIds = in.readCollectionAsImmutableSet(StreamInput::readString); + var unassignedTimeNanos = System.nanoTime(); + var delayed = in.readBoolean(); + var message = in.readOptionalString(); + var failure = in.readException(); + var failedAllocations = in.readVInt(); + var lastAllocationStatus = AllocationStatus.readFrom(in); + var failedNodeIds = in.readCollectionAsImmutableSet(StreamInput::readString); + String lastAllocatedNodeId; if (in.getTransportVersion().onOrAfter(VERSION_LAST_ALLOCATED_NODE_ADDED)) { - this.lastAllocatedNodeId = in.readOptionalString(); + lastAllocatedNodeId = in.readOptionalString(); } else { - this.lastAllocatedNodeId = null; + lastAllocatedNodeId = null; } + return new UnassignedInfo( + reason, + message, + failure, + failedAllocations, + unassignedTimeNanos, + unassignedTimeMillis, + delayed, + lastAllocationStatus, + failedNodeIds, + lastAllocatedNodeId + ); } public void writeTo(StreamOutput out) throws IOException { @@ -335,107 +350,25 @@ public void writeTo(StreamOutput out) throws IOException { } } - /** - * Returns the number of previously failed allocations of this shard. - */ - public int getNumFailedAllocations() { - return failedAllocations; - } - - /** - * Returns true if allocation of this shard is delayed due to {@link #INDEX_DELAYED_NODE_LEFT_TIMEOUT_SETTING} - */ - public boolean isDelayed() { - return delayed; - } - - /** - * The reason why the shard is unassigned. - */ - public Reason getReason() { - return this.reason; - } - - /** - * The timestamp in milliseconds when the shard became unassigned, based on System.currentTimeMillis(). - * Note, we use timestamp here since we want to make sure its preserved across node serializations. - */ - public long getUnassignedTimeInMillis() { - return this.unassignedTimeMillis; - } - - /** - * The timestamp in nanoseconds when the shard became unassigned, based on System.nanoTime(). - * Used to calculate the delay for delayed shard allocation. - * ONLY EXPOSED FOR TESTS! - */ - public long getUnassignedTimeInNanos() { - return this.unassignedTimeNanos; - } - - /** - * Returns optional details explaining the reasons. - */ - @Nullable - public String getMessage() { - return this.message; - } - - /** - * Returns additional failure exception details if exists. - */ - @Nullable - public Exception getFailure() { - return failure; - } - /** * Builds a string representation of the message and the failure if exists. */ @Nullable - public String getDetails() { + public String details() { if (message == null) { return null; } return message + (failure == null ? "" : ", failure " + ExceptionsHelper.stackTrace(failure)); } - /** - * Gets the ID of the node this shard was last allocated to, or null if unavailable. - */ - @Nullable - public String getLastAllocatedNodeId() { - return lastAllocatedNodeId; - } - - /** - * Get the status for the last allocation attempt for this shard. - */ - public AllocationStatus getLastAllocationStatus() { - return lastAllocationStatus; - } - - /** - * A set of nodeIds that failed to complete allocations for this shard. {@link org.elasticsearch.gateway.ReplicaShardAllocator} - * uses this set to avoid repeatedly canceling ongoing recoveries for copies on those nodes although they can perform noop recoveries. - * This set will be discarded when a shard moves to started. And if a shard is failed while started (i.e., from started to unassigned), - * the currently assigned node won't be added to this set. - * - * @see org.elasticsearch.gateway.ReplicaShardAllocator#processExistingRecoveries - * @see org.elasticsearch.cluster.routing.allocation.AllocationService#applyFailedShards(ClusterState, List, List) - */ - public Set getFailedNodeIds() { - return failedNodeIds; - } - /** * Calculates the delay left based on current time (in nanoseconds) and the delay defined by the index settings. - * Only relevant if shard is effectively delayed (see {@link #isDelayed()}) + * Only relevant if shard is effectively delayed (see {@link #delayed()}) * Returns 0 if delay is negative * * @return calculated delay in nanoseconds */ - public long getRemainingDelay(final long nanoTimeNow, final Settings indexSettings, final NodesShutdownMetadata nodesShutdownMetadata) { + public long remainingDelay(final long nanoTimeNow, final Settings indexSettings, final NodesShutdownMetadata nodesShutdownMetadata) { final long indexLevelDelay = INDEX_DELAYED_NODE_LEFT_TIMEOUT_SETTING.get(indexSettings).nanos(); long delayTimeoutNanos = Optional.ofNullable(lastAllocatedNodeId) // If the node wasn't restarting when this became unassigned, use default delay @@ -455,7 +388,7 @@ public long getRemainingDelay(final long nanoTimeNow, final Settings indexSettin public static int getNumberOfDelayedUnassigned(ClusterState state) { int count = 0; for (ShardRouting shard : state.getRoutingNodes().unassigned()) { - if (shard.unassignedInfo().isDelayed()) { + if (shard.unassignedInfo().delayed()) { count++; } } @@ -472,10 +405,10 @@ public static long findNextDelayedAllocation(long currentNanoTime, ClusterState long nextDelayNanos = Long.MAX_VALUE; for (ShardRouting shard : state.getRoutingNodes().unassigned()) { UnassignedInfo unassignedInfo = shard.unassignedInfo(); - if (unassignedInfo.isDelayed()) { + if (unassignedInfo.delayed()) { Settings indexSettings = metadata.index(shard.index()).getSettings(); // calculate next time to schedule - final long newComputedLeftDelayNanos = unassignedInfo.getRemainingDelay( + final long newComputedLeftDelayNanos = unassignedInfo.remainingDelay( currentNanoTime, indexSettings, metadata.nodeShutdowns() @@ -502,7 +435,7 @@ public String shortSummary() { if (lastAllocatedNodeId != null) { sb.append(", last_node[").append(lastAllocatedNodeId).append("]"); } - String details = getDetails(); + String details = details(); if (details != null) { sb.append(", details[").append(details).append("]"); } @@ -530,7 +463,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (lastAllocatedNodeId != null) { builder.field("last_node", lastAllocatedNodeId); } - String details = getDetails(); + String details = details(); if (details != null) { builder.field("details", details); } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java index e19e266cc2672..436399a02005f 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java @@ -215,11 +215,11 @@ public ClusterState applyFailedShards( failedShard ); } - int failedAllocations = failedShard.unassignedInfo() != null ? failedShard.unassignedInfo().getNumFailedAllocations() : 0; + int failedAllocations = failedShard.unassignedInfo() != null ? failedShard.unassignedInfo().failedAllocations() : 0; final Set failedNodeIds; if (failedShard.unassignedInfo() != null) { - failedNodeIds = Sets.newHashSetWithExpectedSize(failedShard.unassignedInfo().getFailedNodeIds().size() + 1); - failedNodeIds.addAll(failedShard.unassignedInfo().getFailedNodeIds()); + failedNodeIds = Sets.newHashSetWithExpectedSize(failedShard.unassignedInfo().failedNodeIds().size() + 1); + failedNodeIds.addAll(failedShard.unassignedInfo().failedNodeIds()); failedNodeIds.add(failedShard.currentNodeId()); } else { failedNodeIds = Collections.emptySet(); @@ -425,8 +425,8 @@ default void removeDelayMarkers(RoutingAllocation allocation) { while (unassignedIterator.hasNext()) { ShardRouting shardRouting = unassignedIterator.next(); UnassignedInfo unassignedInfo = shardRouting.unassignedInfo(); - if (unassignedInfo.isDelayed()) { - final long newComputedLeftDelayNanos = unassignedInfo.getRemainingDelay( + if (unassignedInfo.delayed()) { + final long newComputedLeftDelayNanos = unassignedInfo.remainingDelay( allocation.getCurrentNanoTime(), metadata.getIndexSafe(shardRouting.index()).getSettings(), metadata.nodeShutdowns() @@ -434,16 +434,16 @@ default void removeDelayMarkers(RoutingAllocation allocation) { if (newComputedLeftDelayNanos == 0) { unassignedIterator.updateUnassigned( new UnassignedInfo( - unassignedInfo.getReason(), - unassignedInfo.getMessage(), - unassignedInfo.getFailure(), - unassignedInfo.getNumFailedAllocations(), - unassignedInfo.getUnassignedTimeInNanos(), - unassignedInfo.getUnassignedTimeInMillis(), + unassignedInfo.reason(), + unassignedInfo.message(), + unassignedInfo.failure(), + unassignedInfo.failedAllocations(), + unassignedInfo.unassignedTimeNanos(), + unassignedInfo.unassignedTimeMillis(), false, - unassignedInfo.getLastAllocationStatus(), - unassignedInfo.getFailedNodeIds(), - unassignedInfo.getLastAllocatedNodeId() + unassignedInfo.lastAllocationStatus(), + unassignedInfo.failedNodeIds(), + unassignedInfo.lastAllocatedNodeId() ), shardRouting.recoverySource(), allocation.changes() diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/ShardChangesObserver.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/ShardChangesObserver.java index 1b5d1875bc1d3..f265ab7f62db2 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/ShardChangesObserver.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/ShardChangesObserver.java @@ -36,7 +36,7 @@ public void relocationStarted(ShardRouting startedShard, ShardRouting targetRelo @Override public void shardFailed(ShardRouting failedShard, UnassignedInfo unassignedInfo) { - logger.debug("{} has failed on [{}]: {}", shardIdentifier(failedShard), failedShard.currentNodeId(), unassignedInfo.getReason()); + logger.debug("{} has failed on [{}]: {}", shardIdentifier(failedShard), failedShard.currentNodeId(), unassignedInfo.reason()); } @Override diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java index 57f28e4ea021c..2fca8895b011c 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java @@ -220,19 +220,19 @@ private void failAllocationOfNewPrimaries(RoutingAllocation allocation) { while (unassignedIterator.hasNext()) { final ShardRouting shardRouting = unassignedIterator.next(); final UnassignedInfo unassignedInfo = shardRouting.unassignedInfo(); - if (shardRouting.primary() && unassignedInfo.getLastAllocationStatus() == AllocationStatus.NO_ATTEMPT) { + if (shardRouting.primary() && unassignedInfo.lastAllocationStatus() == AllocationStatus.NO_ATTEMPT) { unassignedIterator.updateUnassigned( new UnassignedInfo( - unassignedInfo.getReason(), - unassignedInfo.getMessage(), - unassignedInfo.getFailure(), - unassignedInfo.getNumFailedAllocations(), - unassignedInfo.getUnassignedTimeInNanos(), - unassignedInfo.getUnassignedTimeInMillis(), - unassignedInfo.isDelayed(), + unassignedInfo.reason(), + unassignedInfo.message(), + unassignedInfo.failure(), + unassignedInfo.failedAllocations(), + unassignedInfo.unassignedTimeNanos(), + unassignedInfo.unassignedTimeMillis(), + unassignedInfo.delayed(), AllocationStatus.DECIDERS_NO, - unassignedInfo.getFailedNodeIds(), - unassignedInfo.getLastAllocatedNodeId() + unassignedInfo.failedNodeIds(), + unassignedInfo.lastAllocatedNodeId() ), shardRouting.recoverySource(), allocation.changes() diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputer.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputer.java index f0fd108dd31fd..7c04d518eb2f6 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputer.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputer.java @@ -120,7 +120,7 @@ public DesiredBalance compute( for (final var iterator = unassigned.iterator(); iterator.hasNext();) { final var shardRouting = iterator.next(); if (shardRouting.primary() == primary) { - var lastAllocatedNodeId = shardRouting.unassignedInfo().getLastAllocatedNodeId(); + var lastAllocatedNodeId = shardRouting.unassignedInfo().lastAllocatedNodeId(); if (knownNodeIds.contains(lastAllocatedNodeId) || ignoredShards.contains(discardAllocationStatus(shardRouting)) == false) { shardRoutings.computeIfAbsent(shardRouting.shardId(), ShardRoutings::new).unassigned().add(shardRouting); @@ -154,7 +154,7 @@ public DesiredBalance compute( // preserving last known shard location as a starting point to avoid unnecessary relocations for (ShardRouting shardRouting : routings.unassigned()) { - var lastAllocatedNodeId = shardRouting.unassignedInfo().getLastAllocatedNodeId(); + var lastAllocatedNodeId = shardRouting.unassignedInfo().lastAllocatedNodeId(); if (knownNodeIds.contains(lastAllocatedNodeId)) { targetNodes.add(lastAllocatedNodeId); } @@ -346,19 +346,18 @@ public DesiredBalance compute( for (var shard : routingNodes.unassigned().ignored()) { var info = shard.unassignedInfo(); assert info != null - && (info.getLastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_NO - || info.getLastAllocationStatus() == UnassignedInfo.AllocationStatus.NO_ATTEMPT - || info.getLastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_THROTTLED) - : "Unexpected stats in: " + info; + && (info.lastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_NO + || info.lastAllocationStatus() == UnassignedInfo.AllocationStatus.NO_ATTEMPT + || info.lastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_THROTTLED) : "Unexpected stats in: " + info; - if (hasChanges == false && info.getLastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_THROTTLED) { + if (hasChanges == false && info.lastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_THROTTLED) { // Simulation could not progress due to missing information in any of the deciders. // Currently, this could happen if `HasFrozenCacheAllocationDecider` is still fetching the data. // Progress would be made after the followup reroute call. hasChanges = true; } - var ignored = shard.unassignedInfo().getLastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_NO ? 0 : 1; + var ignored = shard.unassignedInfo().lastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_NO ? 0 : 1; assignments.compute( shard.shardId(), (key, oldValue) -> oldValue == null @@ -400,20 +399,20 @@ private static ShardRouting discardAllocationStatus(ShardRouting shardRouting) { } private static UnassignedInfo discardAllocationStatus(UnassignedInfo info) { - if (info.getLastAllocationStatus() == UnassignedInfo.AllocationStatus.NO_ATTEMPT) { + if (info.lastAllocationStatus() == UnassignedInfo.AllocationStatus.NO_ATTEMPT) { return info; } return new UnassignedInfo( - info.getReason(), - info.getMessage(), - info.getFailure(), - info.getNumFailedAllocations(), - info.getUnassignedTimeInNanos(), - info.getUnassignedTimeInMillis(), - info.isDelayed(), + info.reason(), + info.message(), + info.failure(), + info.failedAllocations(), + info.unassignedTimeNanos(), + info.unassignedTimeMillis(), + info.delayed(), UnassignedInfo.AllocationStatus.NO_ATTEMPT, - info.getFailedNodeIds(), - info.getLastAllocatedNodeId() + info.failedNodeIds(), + info.lastAllocatedNodeId() ); } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java index f70d0b8929252..24e7abca45d2d 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java @@ -226,19 +226,19 @@ private void failAllocationOfNewPrimaries(RoutingAllocation allocation) { while (unassignedIterator.hasNext()) { final ShardRouting shardRouting = unassignedIterator.next(); final UnassignedInfo unassignedInfo = shardRouting.unassignedInfo(); - if (shardRouting.primary() && unassignedInfo.getLastAllocationStatus() == AllocationStatus.NO_ATTEMPT) { + if (shardRouting.primary() && unassignedInfo.lastAllocationStatus() == AllocationStatus.NO_ATTEMPT) { unassignedIterator.updateUnassigned( new UnassignedInfo( - unassignedInfo.getReason(), - unassignedInfo.getMessage(), - unassignedInfo.getFailure(), - unassignedInfo.getNumFailedAllocations(), - unassignedInfo.getUnassignedTimeInNanos(), - unassignedInfo.getUnassignedTimeInMillis(), - unassignedInfo.isDelayed(), + unassignedInfo.reason(), + unassignedInfo.message(), + unassignedInfo.failure(), + unassignedInfo.failedAllocations(), + unassignedInfo.unassignedTimeNanos(), + unassignedInfo.unassignedTimeMillis(), + unassignedInfo.delayed(), AllocationStatus.DECIDERS_NO, - unassignedInfo.getFailedNodeIds(), - unassignedInfo.getLastAllocatedNodeId() + unassignedInfo.failedNodeIds(), + unassignedInfo.lastAllocatedNodeId() ), shardRouting.recoverySource(), allocation.changes() diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/command/AllocateEmptyPrimaryAllocationCommand.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/command/AllocateEmptyPrimaryAllocationCommand.java index e2fdec43d8e12..2b006988a2ae4 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/command/AllocateEmptyPrimaryAllocationCommand.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/command/AllocateEmptyPrimaryAllocationCommand.java @@ -126,20 +126,20 @@ public RerouteExplanation execute(RoutingAllocation allocation, boolean explain) } UnassignedInfo unassignedInfoToUpdate = null; - if (shardRouting.unassignedInfo().getReason() != UnassignedInfo.Reason.FORCED_EMPTY_PRIMARY) { + if (shardRouting.unassignedInfo().reason() != UnassignedInfo.Reason.FORCED_EMPTY_PRIMARY) { String unassignedInfoMessage = "force empty allocation from previous reason " - + shardRouting.unassignedInfo().getReason() + + shardRouting.unassignedInfo().reason() + ", " - + shardRouting.unassignedInfo().getMessage(); + + shardRouting.unassignedInfo().message(); unassignedInfoToUpdate = new UnassignedInfo( UnassignedInfo.Reason.FORCED_EMPTY_PRIMARY, unassignedInfoMessage, - shardRouting.unassignedInfo().getFailure(), + shardRouting.unassignedInfo().failure(), 0, System.nanoTime(), System.currentTimeMillis(), false, - shardRouting.unassignedInfo().getLastAllocationStatus(), + shardRouting.unassignedInfo().lastAllocationStatus(), Collections.emptySet(), null ); diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/MaxRetryAllocationDecider.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/MaxRetryAllocationDecider.java index f37039608d7bd..1f7d1fe0143c3 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/MaxRetryAllocationDecider.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/MaxRetryAllocationDecider.java @@ -50,7 +50,7 @@ public Decision canAllocate(ShardRouting shardRouting, RoutingAllocation allocat final int maxRetries = SETTING_ALLOCATION_MAX_RETRY.get(allocation.metadata().getIndexSafe(shardRouting.index()).getSettings()); final var unassignedInfo = shardRouting.unassignedInfo(); - final int numFailedAllocations = unassignedInfo == null ? 0 : unassignedInfo.getNumFailedAllocations(); + final int numFailedAllocations = unassignedInfo == null ? 0 : unassignedInfo.failedAllocations(); if (numFailedAllocations > 0) { final var decision = numFailedAllocations >= maxRetries ? Decision.NO : Decision.YES; return allocation.debugDecision() ? debugDecision(decision, unassignedInfo, numFailedAllocations, maxRetries) : decision; diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDecider.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDecider.java index 7b08a4d94512e..7adfc2c17d4aa 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDecider.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDecider.java @@ -55,7 +55,7 @@ public Decision canAllocate(final ShardRouting shardRouting, final RoutingAlloca + "to restore the snapshot again or use the reroute API to force the allocation of an empty primary shard. Details: [%s]", source.snapshot(), shardRouting.getIndexName(), - shardRouting.unassignedInfo().getDetails() + shardRouting.unassignedInfo().details() ); } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/shards/ShardsAvailabilityHealthIndicatorService.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/shards/ShardsAvailabilityHealthIndicatorService.java index 309848635a440..8fb91d89417e0 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/shards/ShardsAvailabilityHealthIndicatorService.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/shards/ShardsAvailabilityHealthIndicatorService.java @@ -536,15 +536,15 @@ static boolean isNewlyCreatedAndInitializingReplica(ShardRouting routing, Cluste private static boolean isUnassignedDueToTimelyRestart(ShardRouting routing, NodesShutdownMetadata shutdowns) { var info = routing.unassignedInfo(); - if (info == null || info.getReason() != UnassignedInfo.Reason.NODE_RESTARTING) { + if (info == null || info.reason() != UnassignedInfo.Reason.NODE_RESTARTING) { return false; } - var shutdown = shutdowns.get(info.getLastAllocatedNodeId(), SingleNodeShutdownMetadata.Type.RESTART); + var shutdown = shutdowns.get(info.lastAllocatedNodeId(), SingleNodeShutdownMetadata.Type.RESTART); if (shutdown == null) { return false; } var now = System.nanoTime(); - var restartingAllocationDelayExpiration = info.getUnassignedTimeInNanos() + shutdown.getAllocationDelay().nanos(); + var restartingAllocationDelayExpiration = info.unassignedTimeNanos() + shutdown.getAllocationDelay().nanos(); return now - restartingAllocationDelayExpiration <= 0; } @@ -567,10 +567,10 @@ private static boolean isUnassignedDueToNewInitialization(ShardRouting routing, List diagnoseUnassignedShardRouting(ShardRouting shardRouting, ClusterState state) { List diagnosisDefs = new ArrayList<>(); LOGGER.trace("Diagnosing unassigned shard [{}] due to reason [{}]", shardRouting.shardId(), shardRouting.unassignedInfo()); - switch (shardRouting.unassignedInfo().getLastAllocationStatus()) { + switch (shardRouting.unassignedInfo().lastAllocationStatus()) { case NO_VALID_SHARD_COPY -> diagnosisDefs.add(ACTION_RESTORE_FROM_SNAPSHOT); case NO_ATTEMPT -> { - if (shardRouting.unassignedInfo().isDelayed()) { + if (shardRouting.unassignedInfo().delayed()) { diagnosisDefs.add(DIAGNOSIS_WAIT_FOR_OR_FIX_DELAYED_SHARDS); } else { diagnosisDefs.addAll(explainAllocationsAndDiagnoseDeciders(shardRouting, state)); diff --git a/server/src/main/java/org/elasticsearch/gateway/ReplicaShardAllocator.java b/server/src/main/java/org/elasticsearch/gateway/ReplicaShardAllocator.java index fa9636dc89d69..d07d2498d6534 100644 --- a/server/src/main/java/org/elasticsearch/gateway/ReplicaShardAllocator.java +++ b/server/src/main/java/org/elasticsearch/gateway/ReplicaShardAllocator.java @@ -65,7 +65,7 @@ public void processExistingRecoveries(RoutingAllocation allocation, Predicate failedNodeIds = shard.unassignedInfo() == null ? Collections.emptySet() - : shard.unassignedInfo().getFailedNodeIds(); + : shard.unassignedInfo().failedNodeIds(); UnassignedInfo unassignedInfo = new UnassignedInfo( UnassignedInfo.Reason.REALLOCATED_REPLICA, "existing allocation of replica to [" @@ -138,7 +138,7 @@ private static boolean isResponsibleFor(final ShardRouting shard) { return shard.primary() == false // must be a replica && shard.unassigned() // must be unassigned // if we are allocating a replica because of index creation, no need to go and find a copy, there isn't one... - && shard.unassignedInfo().getReason() != UnassignedInfo.Reason.INDEX_CREATED; + && shard.unassignedInfo().reason() != UnassignedInfo.Reason.INDEX_CREATED; } @Override @@ -234,7 +234,7 @@ public AllocateUnassignedDecision makeAllocationDecision( // we found a match return AllocateUnassignedDecision.yes(nodeWithHighestMatch.node(), null, nodeDecisions, true); } - } else if (matchingNodes.hasAnyData() == false && unassignedShard.unassignedInfo().isDelayed()) { + } else if (matchingNodes.hasAnyData() == false && unassignedShard.unassignedInfo().delayed()) { // if we didn't manage to find *any* data (regardless of matching sizes), and the replica is // unassigned due to a node leaving, so we delay allocation of this replica to see if the // node with the shard copy will rejoin so we can re-use the copy it has @@ -262,7 +262,7 @@ public static AllocateUnassignedDecision delayedDecision( Metadata metadata = allocation.metadata(); IndexMetadata indexMetadata = metadata.index(unassignedShard.index()); totalDelayMillis = INDEX_DELAYED_NODE_LEFT_TIMEOUT_SETTING.get(indexMetadata.getSettings()).getMillis(); - long remainingDelayNanos = unassignedInfo.getRemainingDelay( + long remainingDelayNanos = unassignedInfo.remainingDelay( System.nanoTime(), indexMetadata.getSettings(), metadata.nodeShutdowns() @@ -357,7 +357,7 @@ private MatchingNodes findMatchingNodes( DiscoveryNode discoNode = nodeStoreEntry.getKey(); if (noMatchFailedNodes && shard.unassignedInfo() != null - && shard.unassignedInfo().getFailedNodeIds().contains(discoNode.getId())) { + && shard.unassignedInfo().failedNodeIds().contains(discoNode.getId())) { continue; } TransportNodesListShardStoreMetadata.StoreFilesMetadata storeFilesMetadata = nodeStoreEntry.getValue().storeFilesMetadata(); diff --git a/server/src/main/java/org/elasticsearch/rest/action/cat/RestShardsAction.java b/server/src/main/java/org/elasticsearch/rest/action/cat/RestShardsAction.java index 664f9b63dee2a..d9a34fe36c860 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/cat/RestShardsAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/cat/RestShardsAction.java @@ -324,13 +324,13 @@ Table buildTable(RestRequest request, ClusterStateResponse state, IndicesStatsRe table.addCell(commitStats == null ? null : commitStats.getUserData().get(Engine.SYNC_COMMIT_ID)); if (shard.unassignedInfo() != null) { - table.addCell(shard.unassignedInfo().getReason()); - Instant unassignedTime = Instant.ofEpochMilli(shard.unassignedInfo().getUnassignedTimeInMillis()); + table.addCell(shard.unassignedInfo().reason()); + Instant unassignedTime = Instant.ofEpochMilli(shard.unassignedInfo().unassignedTimeMillis()); table.addCell(UnassignedInfo.DATE_TIME_FORMATTER.format(unassignedTime)); table.addCell( - TimeValue.timeValueMillis(Math.max(0, System.currentTimeMillis() - shard.unassignedInfo().getUnassignedTimeInMillis())) + TimeValue.timeValueMillis(Math.max(0, System.currentTimeMillis() - shard.unassignedInfo().unassignedTimeMillis())) ); - table.addCell(shard.unassignedInfo().getDetails()); + table.addCell(shard.unassignedInfo().details()); } else { table.addCell(null); table.addCell(null); diff --git a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java index 74b8a3e12dad5..453d0b3201560 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java @@ -795,13 +795,13 @@ public void shardFailed(ShardRouting failedShard, UnassignedInfo unassignedInfo) // mark restore entry for this shard as failed when it's due to a file corruption. There is no need wait on retries // to restore this shard on another node if the snapshot files are corrupt. In case where a node just left or crashed, // however, we only want to acknowledge the restore operation once it has been successfully restored on another node. - if (unassignedInfo.getFailure() != null && Lucene.isCorruptionException(unassignedInfo.getFailure().getCause())) { + if (unassignedInfo.failure() != null && Lucene.isCorruptionException(unassignedInfo.failure().getCause())) { changes(recoverySource).put( failedShard.shardId(), new ShardRestoreStatus( failedShard.currentNodeId(), RestoreInProgress.State.FAILURE, - unassignedInfo.getFailure().getCause().getMessage() + unassignedInfo.failure().getCause().getMessage() ) ); } @@ -829,7 +829,7 @@ public void shardInitialized(ShardRouting unassignedShard, ShardRouting initiali public void unassignedInfoUpdated(ShardRouting unassignedShard, UnassignedInfo newUnassignedInfo) { RecoverySource recoverySource = unassignedShard.recoverySource(); if (recoverySource.getType() == RecoverySource.Type.SNAPSHOT) { - if (newUnassignedInfo.getLastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_NO) { + if (newUnassignedInfo.lastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_NO) { String reason = "shard could not be allocated to any of the nodes"; changes(recoverySource).put( unassignedShard.shardId(), diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java index f9483bd23f216..eb1a64ef66bbd 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java @@ -103,9 +103,9 @@ public ShardAllocationDecision decideShardAllocation(ShardRouting shard, Routing """ ,"unassigned_info": {"reason": "%s", "at": "%s", "last_allocation_status": "%s"} """, - shard.unassignedInfo().getReason(), - UnassignedInfo.DATE_TIME_FORMATTER.format(Instant.ofEpochMilli(shard.unassignedInfo().getUnassignedTimeInMillis())), - AllocationDecision.fromAllocationStatus(shard.unassignedInfo().getLastAllocationStatus()) + shard.unassignedInfo().reason(), + UnassignedInfo.DATE_TIME_FORMATTER.format(Instant.ofEpochMilli(shard.unassignedInfo().unassignedTimeMillis())), + AllocationDecision.fromAllocationStatus(shard.unassignedInfo().lastAllocationStatus()) ) : "", cae.getCurrentNode().getId(), diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteTests.java index a6d380bc7683c..19c268100d4a0 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteTests.java @@ -162,7 +162,7 @@ public void testClusterStateUpdateTask() { private void assertStateAndFailedAllocations(IndexRoutingTable indexRoutingTable, ShardRoutingState state, int failedAllocations) { assertThat(indexRoutingTable.size(), equalTo(1)); assertThat(indexRoutingTable.shard(0).shard(0).state(), equalTo(state)); - assertThat(indexRoutingTable.shard(0).shard(0).unassignedInfo().getNumFailedAllocations(), equalTo(failedAllocations)); + assertThat(indexRoutingTable.shard(0).shard(0).unassignedInfo().failedAllocations(), equalTo(failedAllocations)); } private ClusterState createInitialClusterState(AllocationService service) { diff --git a/server/src/test/java/org/elasticsearch/cluster/health/ClusterStateHealthTests.java b/server/src/test/java/org/elasticsearch/cluster/health/ClusterStateHealthTests.java index 05e345bf4b52b..96ff00488a1d2 100644 --- a/server/src/test/java/org/elasticsearch/cluster/health/ClusterStateHealthTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/health/ClusterStateHealthTests.java @@ -559,10 +559,10 @@ private boolean primaryInactiveDueToRecovery(final String indexName, final Clust && primaryShard.recoverySource().getType() == RecoverySource.Type.EXISTING_STORE) { return false; } - if (primaryShard.unassignedInfo().getNumFailedAllocations() > 0) { + if (primaryShard.unassignedInfo().failedAllocations() > 0) { return false; } - if (primaryShard.unassignedInfo().getLastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_NO) { + if (primaryShard.unassignedInfo().lastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_NO) { return false; } } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexStateServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexStateServiceTests.java index 6e24735eba454..e034971482bcf 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexStateServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexStateServiceTests.java @@ -457,7 +457,7 @@ private static void assertIsClosed(final String indexName, final ClusterState cl assertThat( RoutingNodesHelper.asStream(shardRoutingTable) .map(ShardRouting::unassignedInfo) - .map(UnassignedInfo::getReason) + .map(UnassignedInfo::reason) .allMatch(info -> info == UnassignedInfo.Reason.INDEX_CLOSED), is(true) ); diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/DelayedAllocationServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/DelayedAllocationServiceTests.java index aacf9f803dde0..171fd397d65f3 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/DelayedAllocationServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/DelayedAllocationServiceTests.java @@ -109,7 +109,7 @@ public void testNoDelayedUnassigned() { assertThat(unassignedShards.size(), equalTo(0)); } else { assertThat(unassignedShards.size(), equalTo(1)); - assertThat(unassignedShards.get(0).unassignedInfo().isDelayed(), equalTo(false)); + assertThat(unassignedShards.get(0).unassignedInfo().delayed(), equalTo(false)); } delayedAllocationService.clusterChanged(new ClusterChangedEvent("test", newState, prevState)); @@ -169,7 +169,7 @@ public void testDelayedUnassignedScheduleReroute() throws Exception { // make sure the replica is marked as delayed (i.e. not reallocated) assertEquals(1, UnassignedInfo.getNumberOfDelayedUnassigned(stateWithDelayedShard)); ShardRouting delayedShard = stateWithDelayedShard.getRoutingNodes().unassigned().iterator().next(); - assertEquals(baseTimestampNanos, delayedShard.unassignedInfo().getUnassignedTimeInNanos()); + assertEquals(baseTimestampNanos, delayedShard.unassignedInfo().unassignedTimeNanos()); // mock ClusterService.submitStateUpdateTask() method CountDownLatch latch = new CountDownLatch(1); @@ -318,8 +318,8 @@ public void testDelayedUnassignedScheduleRerouteAfterDelayedReroute() throws Exc final ClusterState stateWithDelayedShards = clusterState; assertEquals(2, UnassignedInfo.getNumberOfDelayedUnassigned(stateWithDelayedShards)); RoutingNodes.UnassignedShards.UnassignedIterator iter = stateWithDelayedShards.getRoutingNodes().unassigned().iterator(); - assertEquals(baseTimestampNanos, iter.next().unassignedInfo().getUnassignedTimeInNanos()); - assertEquals(baseTimestampNanos, iter.next().unassignedInfo().getUnassignedTimeInNanos()); + assertEquals(baseTimestampNanos, iter.next().unassignedInfo().unassignedTimeNanos()); + assertEquals(baseTimestampNanos, iter.next().unassignedInfo().unassignedTimeNanos()); // mock ClusterService.submitStateUpdateTask() method CountDownLatch latch1 = new CountDownLatch(1); @@ -491,7 +491,7 @@ public void testDelayedUnassignedScheduleRerouteRescheduledOnShorterDelay() { // make sure the replica is marked as delayed (i.e. not reallocated) assertEquals(1, UnassignedInfo.getNumberOfDelayedUnassigned(stateWithDelayedShard)); ShardRouting delayedShard = stateWithDelayedShard.getRoutingNodes().unassigned().iterator().next(); - assertEquals(nodeLeftTimestampNanos, delayedShard.unassignedInfo().getUnassignedTimeInNanos()); + assertEquals(nodeLeftTimestampNanos, delayedShard.unassignedInfo().unassignedTimeNanos()); assertNull(delayedAllocationService.delayedRerouteTask.get()); long delayUntilClusterChangeEvent = TimeValue.timeValueNanos(randomInt((int) shorterDelaySetting.nanos() - 1)).nanos(); diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/ShardRoutingTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/ShardRoutingTests.java index e6466b9237d3a..33695883aebc3 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/ShardRoutingTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/ShardRoutingTests.java @@ -401,7 +401,7 @@ public void testEqualsIgnoringVersion() { .withUnassignedInfo( otherRouting.unassignedInfo() == null ? new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, "test") - : new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, otherRouting.unassignedInfo().getMessage() + "_1") + : new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, otherRouting.unassignedInfo().message() + "_1") ) .build(); } diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/UnassignedInfoTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/UnassignedInfoTests.java index 1d0b01a10da78..eb39d56346eb2 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/UnassignedInfoTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/UnassignedInfoTests.java @@ -137,14 +137,14 @@ public void testSerialization() throws Exception { meta.writeTo(out); out.close(); - UnassignedInfo read = new UnassignedInfo(out.bytes().streamInput()); - assertThat(read.getReason(), equalTo(meta.getReason())); - assertThat(read.getUnassignedTimeInMillis(), equalTo(meta.getUnassignedTimeInMillis())); - assertThat(read.getMessage(), equalTo(meta.getMessage())); - assertThat(read.getDetails(), equalTo(meta.getDetails())); - assertThat(read.getNumFailedAllocations(), equalTo(meta.getNumFailedAllocations())); - assertThat(read.getFailedNodeIds(), equalTo(meta.getFailedNodeIds())); - assertThat(read.getLastAllocatedNodeId(), equalTo(meta.getLastAllocatedNodeId())); + UnassignedInfo read = UnassignedInfo.fromStreamInput(out.bytes().streamInput()); + assertThat(read.reason(), equalTo(meta.reason())); + assertThat(read.unassignedTimeMillis(), equalTo(meta.unassignedTimeMillis())); + assertThat(read.message(), equalTo(meta.message())); + assertThat(read.details(), equalTo(meta.details())); + assertThat(read.failedAllocations(), equalTo(meta.failedAllocations())); + assertThat(read.failedNodeIds(), equalTo(meta.failedNodeIds())); + assertThat(read.lastAllocatedNodeId(), equalTo(meta.lastAllocatedNodeId())); } public void testIndexCreated() { @@ -161,7 +161,7 @@ public void testIndexCreated() { .routingTable(RoutingTable.builder(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY).addAsNew(metadata.index("test")).build()) .build(); for (ShardRouting shard : shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED)) { - assertThat(shard.unassignedInfo().getReason(), equalTo(UnassignedInfo.Reason.INDEX_CREATED)); + assertThat(shard.unassignedInfo().reason(), equalTo(UnassignedInfo.Reason.INDEX_CREATED)); } } @@ -181,7 +181,7 @@ public void testClusterRecovered() { ) .build(); for (ShardRouting shard : shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED)) { - assertThat(shard.unassignedInfo().getReason(), equalTo(UnassignedInfo.Reason.CLUSTER_RECOVERED)); + assertThat(shard.unassignedInfo().reason(), equalTo(UnassignedInfo.Reason.CLUSTER_RECOVERED)); } } @@ -296,8 +296,8 @@ private void assertLastAllocatedNodeIdsAssigned( for (int shardCopy = 0; shardCopy < shardRoutingTable.size(); shardCopy++) { final var shard = shardRoutingTable.shard(shardCopy); assertTrue(shard.unassigned()); - assertThat(shard.unassignedInfo().getReason(), equalTo(expectedUnassignedReason)); - final var lastAllocatedNodeId = shard.unassignedInfo().getLastAllocatedNodeId(); + assertThat(shard.unassignedInfo().reason(), equalTo(expectedUnassignedReason)); + final var lastAllocatedNodeId = shard.unassignedInfo().lastAllocatedNodeId(); if (lastAllocatedNodeId == null) { // restoring an index may change the number of shards/replicas so no guarantee that lastAllocatedNodeId is populated assertTrue(shardCountChanged); @@ -309,7 +309,7 @@ private void assertLastAllocatedNodeIdsAssigned( if (shardCountChanged == false) { assertNotNull(previousShardRoutingTable); assertThat( - shardRoutingTable.primaryShard().unassignedInfo().getLastAllocatedNodeId(), + shardRoutingTable.primaryShard().unassignedInfo().lastAllocatedNodeId(), equalTo(previousShardRoutingTable.primaryShard().currentNodeId()) ); } @@ -335,7 +335,7 @@ public void testIndexReopened() { ) .build(); for (ShardRouting shard : shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED)) { - assertThat(shard.unassignedInfo().getReason(), equalTo(UnassignedInfo.Reason.INDEX_REOPENED)); + assertThat(shard.unassignedInfo().reason(), equalTo(UnassignedInfo.Reason.INDEX_REOPENED)); } } @@ -366,7 +366,7 @@ public void testNewIndexRestored() { ) .build(); for (ShardRouting shard : shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED)) { - assertThat(shard.unassignedInfo().getReason(), equalTo(UnassignedInfo.Reason.NEW_INDEX_RESTORED)); + assertThat(shard.unassignedInfo().reason(), equalTo(UnassignedInfo.Reason.NEW_INDEX_RESTORED)); } } @@ -471,7 +471,7 @@ public void testDanglingIndexImported() { ) .build(); for (ShardRouting shard : shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED)) { - assertThat(shard.unassignedInfo().getReason(), equalTo(UnassignedInfo.Reason.DANGLING_INDEX_IMPORTED)); + assertThat(shard.unassignedInfo().reason(), equalTo(UnassignedInfo.Reason.DANGLING_INDEX_IMPORTED)); } } @@ -501,7 +501,7 @@ public void testReplicaAdded() { assertThat(shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).size(), equalTo(1)); assertThat(shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo(), notNullValue()); assertThat( - shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().getReason(), + shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().reason(), equalTo(UnassignedInfo.Reason.REPLICA_ADDED) ); } @@ -551,11 +551,11 @@ public void testNodeLeave() { assertThat(shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).size(), equalTo(1)); assertThat(shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo(), notNullValue()); assertThat( - shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().getReason(), + shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().reason(), equalTo(UnassignedInfo.Reason.NODE_LEFT) ); assertThat( - shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().getUnassignedTimeInMillis(), + shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().unassignedTimeMillis(), greaterThan(0L) ); } @@ -593,19 +593,19 @@ public void testFailedShard() { assertThat(shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).size(), equalTo(1)); assertThat(shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo(), notNullValue()); assertThat( - shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().getReason(), + shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().reason(), equalTo(UnassignedInfo.Reason.ALLOCATION_FAILED) ); assertThat( - shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().getMessage(), + shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().message(), equalTo("failed shard on node [" + shardToFail.currentNodeId() + "]: test fail") ); assertThat( - shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().getDetails(), + shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().details(), equalTo("failed shard on node [" + shardToFail.currentNodeId() + "]: test fail") ); assertThat( - shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().getUnassignedTimeInMillis(), + shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).get(0).unassignedInfo().unassignedTimeMillis(), greaterThan(0L) ); } @@ -768,14 +768,14 @@ private void checkRemainingDelayCalculation( final Settings indexSettings = Settings.builder() .put(UnassignedInfo.INDEX_DELAYED_NODE_LEFT_TIMEOUT_SETTING.getKey(), indexLevelTimeoutSetting) .build(); - long delay = unassignedInfo.getRemainingDelay(baseTime, indexSettings, nodeShutdowns); + long delay = unassignedInfo.remainingDelay(baseTime, indexSettings, nodeShutdowns); assertThat(delay, equalTo(totalDelayNanos)); long delta1 = randomLongBetween(1, (totalDelayNanos - 1)); - delay = unassignedInfo.getRemainingDelay(baseTime + delta1, indexSettings, nodeShutdowns); + delay = unassignedInfo.remainingDelay(baseTime + delta1, indexSettings, nodeShutdowns); assertThat(delay, equalTo(totalDelayNanos - delta1)); - delay = unassignedInfo.getRemainingDelay(baseTime + totalDelayNanos, indexSettings, nodeShutdowns); + delay = unassignedInfo.remainingDelay(baseTime + totalDelayNanos, indexSettings, nodeShutdowns); assertThat(delay, equalTo(0L)); - delay = unassignedInfo.getRemainingDelay(baseTime + totalDelayNanos + randomIntBetween(1, 20), indexSettings, nodeShutdowns); + delay = unassignedInfo.remainingDelay(baseTime + totalDelayNanos + randomIntBetween(1, 20), indexSettings, nodeShutdowns); assertThat(delay, equalTo(0L)); } @@ -918,25 +918,25 @@ public void testSummaryContainsImportantFields() { var info = randomUnassignedInfo(randomBoolean() ? randomIdentifier() : null); var summary = info.shortSummary(); - assertThat("reason", summary, containsString("[reason=" + info.getReason() + ']')); + assertThat("reason", summary, containsString("[reason=" + info.reason() + ']')); assertThat( "delay", summary, - containsString("at[" + UnassignedInfo.DATE_TIME_FORMATTER.format(Instant.ofEpochMilli(info.getUnassignedTimeInMillis())) + ']') + containsString("at[" + UnassignedInfo.DATE_TIME_FORMATTER.format(Instant.ofEpochMilli(info.unassignedTimeMillis())) + ']') ); - if (info.getNumFailedAllocations() > 0) { - assertThat("failed_allocations", summary, containsString("failed_attempts[" + info.getNumFailedAllocations() + ']')); + if (info.failedAllocations() > 0) { + assertThat("failed_allocations", summary, containsString("failed_attempts[" + info.failedAllocations() + ']')); } - if (info.getFailedNodeIds().isEmpty() == false) { - assertThat("failed_nodes", summary, containsString("failed_nodes[" + info.getFailedNodeIds() + ']')); + if (info.failedNodeIds().isEmpty() == false) { + assertThat("failed_nodes", summary, containsString("failed_nodes[" + info.failedNodeIds() + ']')); } - assertThat("delayed", summary, containsString("delayed=" + info.isDelayed())); - if (info.getLastAllocatedNodeId() != null) { - assertThat("last_node", summary, containsString("last_node[" + info.getLastAllocatedNodeId() + ']')); + assertThat("delayed", summary, containsString("delayed=" + info.delayed())); + if (info.lastAllocatedNodeId() != null) { + assertThat("last_node", summary, containsString("last_node[" + info.lastAllocatedNodeId() + ']')); } - if (info.getMessage() != null) { - assertThat("details", summary, containsString("details[" + info.getMessage() + ']')); + if (info.message() != null) { + assertThat("details", summary, containsString("details[" + info.message() + ']')); } - assertThat("allocation_status", summary, containsString("allocation_status[" + info.getLastAllocationStatus().value() + ']')); + assertThat("allocation_status", summary, containsString("allocation_status[" + info.lastAllocationStatus().value() + ']')); } } diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java index d28c1875100bb..e863aca526da7 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java @@ -91,8 +91,8 @@ public void testSingleRetryOnIgnore() { routingTable = newState.routingTable(); assertEquals(routingTable.index("idx").size(), 1); assertEquals(routingTable.index("idx").shard(0).shard(0).state(), INITIALIZING); - assertEquals(routingTable.index("idx").shard(0).shard(0).unassignedInfo().getNumFailedAllocations(), i + 1); - assertThat(routingTable.index("idx").shard(0).shard(0).unassignedInfo().getMessage(), containsString("boom" + i)); + assertEquals(routingTable.index("idx").shard(0).shard(0).unassignedInfo().failedAllocations(), i + 1); + assertThat(routingTable.index("idx").shard(0).shard(0).unassignedInfo().message(), containsString("boom" + i)); } // now we go and check that we are actually stick to unassigned on the next failure ClusterState newState = applyShardFailure(clusterState, routingTable.index("idx").shard(0).shard(0), "boom"); @@ -100,9 +100,9 @@ public void testSingleRetryOnIgnore() { clusterState = newState; routingTable = newState.routingTable(); assertEquals(routingTable.index("idx").size(), 1); - assertEquals(routingTable.index("idx").shard(0).shard(0).unassignedInfo().getNumFailedAllocations(), retries); + assertEquals(routingTable.index("idx").shard(0).shard(0).unassignedInfo().failedAllocations(), retries); assertEquals(routingTable.index("idx").shard(0).shard(0).state(), UNASSIGNED); - assertThat(routingTable.index("idx").shard(0).shard(0).unassignedInfo().getMessage(), containsString("boom")); + assertThat(routingTable.index("idx").shard(0).shard(0).unassignedInfo().message(), containsString("boom")); // manual resetting of retry count newState = strategy.reroute(clusterState, new AllocationCommands(), false, true, false, ActionListener.noop()).clusterState(); @@ -112,9 +112,9 @@ public void testSingleRetryOnIgnore() { clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build(); assertEquals(routingTable.index("idx").size(), 1); - assertEquals(0, routingTable.index("idx").shard(0).shard(0).unassignedInfo().getNumFailedAllocations()); + assertEquals(0, routingTable.index("idx").shard(0).shard(0).unassignedInfo().failedAllocations()); assertEquals(INITIALIZING, routingTable.index("idx").shard(0).shard(0).state()); - assertThat(routingTable.index("idx").shard(0).shard(0).unassignedInfo().getMessage(), containsString("boom")); + assertThat(routingTable.index("idx").shard(0).shard(0).unassignedInfo().message(), containsString("boom")); // again fail it N-1 times for (int i = 0; i < retries - 1; i++) { @@ -123,9 +123,9 @@ public void testSingleRetryOnIgnore() { clusterState = newState; routingTable = newState.routingTable(); assertEquals(routingTable.index("idx").size(), 1); - assertEquals(i + 1, routingTable.index("idx").shard(0).shard(0).unassignedInfo().getNumFailedAllocations()); + assertEquals(i + 1, routingTable.index("idx").shard(0).shard(0).unassignedInfo().failedAllocations()); assertEquals(INITIALIZING, routingTable.index("idx").shard(0).shard(0).state()); - assertThat(routingTable.index("idx").shard(0).shard(0).unassignedInfo().getMessage(), containsString("boom")); + assertThat(routingTable.index("idx").shard(0).shard(0).unassignedInfo().message(), containsString("boom")); } // now we go and check that we are actually stick to unassigned on the next failure @@ -134,9 +134,9 @@ public void testSingleRetryOnIgnore() { clusterState = newState; routingTable = newState.routingTable(); assertEquals(routingTable.index("idx").size(), 1); - assertEquals(retries, routingTable.index("idx").shard(0).shard(0).unassignedInfo().getNumFailedAllocations()); + assertEquals(retries, routingTable.index("idx").shard(0).shard(0).unassignedInfo().failedAllocations()); assertEquals(UNASSIGNED, routingTable.index("idx").shard(0).shard(0).state()); - assertThat(routingTable.index("idx").shard(0).shard(0).unassignedInfo().getMessage(), containsString("boom")); + assertThat(routingTable.index("idx").shard(0).shard(0).unassignedInfo().message(), containsString("boom")); } public void testFailedAllocation() { @@ -152,8 +152,8 @@ public void testFailedAllocation() { assertEquals(routingTable.index("idx").size(), 1); ShardRouting unassignedPrimary = routingTable.index("idx").shard(0).shard(0); assertEquals(unassignedPrimary.state(), INITIALIZING); - assertEquals(unassignedPrimary.unassignedInfo().getNumFailedAllocations(), i + 1); - assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("boom" + i)); + assertEquals(unassignedPrimary.unassignedInfo().failedAllocations(), i + 1); + assertThat(unassignedPrimary.unassignedInfo().message(), containsString("boom" + i)); // MaxRetryAllocationDecider#canForceAllocatePrimary should return YES decisions because canAllocate returns YES here assertEquals( Decision.Type.YES, @@ -168,9 +168,9 @@ public void testFailedAllocation() { routingTable = newState.routingTable(); assertEquals(routingTable.index("idx").size(), 1); ShardRouting unassignedPrimary = routingTable.index("idx").shard(0).shard(0); - assertEquals(unassignedPrimary.unassignedInfo().getNumFailedAllocations(), retries); + assertEquals(unassignedPrimary.unassignedInfo().failedAllocations(), retries); assertEquals(unassignedPrimary.state(), UNASSIGNED); - assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("boom")); + assertThat(unassignedPrimary.unassignedInfo().message(), containsString("boom")); // MaxRetryAllocationDecider#canForceAllocatePrimary should return a NO decision because canAllocate returns NO here final var allocation = newRoutingAllocation(clusterState); allocation.debugDecision(true); @@ -211,9 +211,9 @@ public void testFailedAllocation() { // good we are initializing and we are maintaining failure information assertEquals(routingTable.index("idx").size(), 1); ShardRouting unassignedPrimary = routingTable.index("idx").shard(0).shard(0); - assertEquals(unassignedPrimary.unassignedInfo().getNumFailedAllocations(), retries); + assertEquals(unassignedPrimary.unassignedInfo().failedAllocations(), retries); assertEquals(unassignedPrimary.state(), INITIALIZING); - assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("boom")); + assertThat(unassignedPrimary.unassignedInfo().message(), containsString("boom")); // bumped up the max retry count, so canForceAllocatePrimary should return a YES decision assertEquals( Decision.Type.YES, @@ -236,9 +236,9 @@ public void testFailedAllocation() { routingTable = newState.routingTable(); assertEquals(routingTable.index("idx").size(), 1); unassignedPrimary = routingTable.index("idx").shard(0).shard(0); - assertEquals(unassignedPrimary.unassignedInfo().getNumFailedAllocations(), 1); + assertEquals(unassignedPrimary.unassignedInfo().failedAllocations(), 1); assertEquals(unassignedPrimary.state(), UNASSIGNED); - assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("ZOOOMG")); + assertThat(unassignedPrimary.unassignedInfo().message(), containsString("ZOOOMG")); // Counter reset, so MaxRetryAllocationDecider#canForceAllocatePrimary should return a YES decision assertEquals( Decision.Type.YES, diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/TrackFailedAllocationNodesTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/TrackFailedAllocationNodesTests.java index 438ec85c4b997..84eead8d51dc2 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/TrackFailedAllocationNodesTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/TrackFailedAllocationNodesTests.java @@ -59,17 +59,14 @@ public void testTrackFailedNodes() { List.of(new FailedShard(clusterState.routingTable().index("idx").shard(0).shard(0), null, null, randomBoolean())), List.of() ); - assertThat( - clusterState.routingTable().index("idx").shard(0).shard(0).unassignedInfo().getFailedNodeIds(), - equalTo(failedNodeIds) - ); + assertThat(clusterState.routingTable().index("idx").shard(0).shard(0).unassignedInfo().failedNodeIds(), equalTo(failedNodeIds)); } // reroute with retryFailed=true should discard the failedNodes assertThat(clusterState.routingTable().index("idx").shard(0).shard(0).state(), equalTo(ShardRoutingState.UNASSIGNED)); clusterState = allocationService.reroute(clusterState, new AllocationCommands(), false, true, false, ActionListener.noop()) .clusterState(); - assertThat(clusterState.routingTable().index("idx").shard(0).shard(0).unassignedInfo().getFailedNodeIds(), empty()); + assertThat(clusterState.routingTable().index("idx").shard(0).shard(0).unassignedInfo().failedNodeIds(), empty()); // do not track the failed nodes while shard is started clusterState = startInitializingShardsAndReroute(allocationService, clusterState); @@ -79,6 +76,6 @@ public void testTrackFailedNodes() { List.of(new FailedShard(clusterState.routingTable().index("idx").shard(0).primaryShard(), null, null, false)), List.of() ); - assertThat(clusterState.routingTable().index("idx").shard(0).primaryShard().unassignedInfo().getFailedNodeIds(), empty()); + assertThat(clusterState.routingTable().index("idx").shard(0).primaryShard().unassignedInfo().failedNodeIds(), empty()); } } diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputerTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputerTests.java index 2fc84c3f32e79..6c3a4157bb4ba 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputerTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputerTests.java @@ -167,7 +167,7 @@ public void testIgnoresOutOfScopePrimaries() { .replicaShards() .get(0) .unassignedInfo() - .getLastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_NO ? 1 : 2 + .lastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_NO ? 1 : 2 ), new ShardId(index, 1), new ShardAssignment(Set.of("node-0", "node-1"), 2, 0, 0) @@ -198,7 +198,7 @@ public void testIgnoresOutOfScopeReplicas() { Set.of("node-0"), 2, 1, - originalReplicaShard.unassignedInfo().getLastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_NO ? 0 : 1 + originalReplicaShard.unassignedInfo().lastAllocationStatus() == UnassignedInfo.AllocationStatus.DECIDERS_NO ? 0 : 1 ), new ShardId(index, 1), new ShardAssignment(Set.of("node-0", "node-1"), 2, 0, 0) @@ -1301,20 +1301,20 @@ private static ShardRouting mutateAllocationStatus(ShardRouting shardRouting) { var unassignedInfo = shardRouting.unassignedInfo(); return shardRouting.updateUnassigned( new UnassignedInfo( - unassignedInfo.getReason(), - unassignedInfo.getMessage(), - unassignedInfo.getFailure(), - unassignedInfo.getNumFailedAllocations(), - unassignedInfo.getUnassignedTimeInNanos(), - unassignedInfo.getUnassignedTimeInMillis(), - unassignedInfo.isDelayed(), + unassignedInfo.reason(), + unassignedInfo.message(), + unassignedInfo.failure(), + unassignedInfo.failedAllocations(), + unassignedInfo.unassignedTimeNanos(), + unassignedInfo.unassignedTimeMillis(), + unassignedInfo.delayed(), randomFrom( UnassignedInfo.AllocationStatus.DECIDERS_NO, UnassignedInfo.AllocationStatus.NO_ATTEMPT, UnassignedInfo.AllocationStatus.DECIDERS_THROTTLED ), - unassignedInfo.getFailedNodeIds(), - unassignedInfo.getLastAllocatedNodeId() + unassignedInfo.failedNodeIds(), + unassignedInfo.lastAllocatedNodeId() ), shardRouting.recoverySource() ); diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java index 4ce5e78e308b2..0de27aea5b08f 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java @@ -132,19 +132,19 @@ public void testFailsNewPrimariesIfNoDataNodes() { final var shardRouting = unassigned.next(); if (shardRouting.primary() && shardRouting.shardId().id() == 1) { final var unassignedInfo = shardRouting.unassignedInfo(); - assertThat(unassignedInfo.getLastAllocationStatus(), equalTo(UnassignedInfo.AllocationStatus.NO_ATTEMPT)); + assertThat(unassignedInfo.lastAllocationStatus(), equalTo(UnassignedInfo.AllocationStatus.NO_ATTEMPT)); unassigned.updateUnassigned( new UnassignedInfo( - unassignedInfo.getReason(), - unassignedInfo.getMessage(), - unassignedInfo.getFailure(), - unassignedInfo.getNumFailedAllocations(), - unassignedInfo.getUnassignedTimeInNanos(), - unassignedInfo.getUnassignedTimeInMillis(), - unassignedInfo.isDelayed(), + unassignedInfo.reason(), + unassignedInfo.message(), + unassignedInfo.failure(), + unassignedInfo.failedAllocations(), + unassignedInfo.unassignedTimeNanos(), + unassignedInfo.unassignedTimeMillis(), + unassignedInfo.delayed(), UnassignedInfo.AllocationStatus.DECIDERS_THROTTLED, - unassignedInfo.getFailedNodeIds(), - unassignedInfo.getLastAllocatedNodeId() + unassignedInfo.failedNodeIds(), + unassignedInfo.lastAllocatedNodeId() ), shardRouting.recoverySource(), new RoutingChangesObserver.DelegatingRoutingChangesObserver() @@ -164,7 +164,7 @@ public void testFailsNewPrimariesIfNoDataNodes() { for (ShardRouting shardRouting : routingAllocation.routingNodes().unassigned()) { assertTrue(shardRouting.toString(), shardRouting.unassigned()); assertThat( - shardRouting.unassignedInfo().getLastAllocationStatus(), + shardRouting.unassignedInfo().lastAllocationStatus(), equalTo( shardRouting.primary() && shardRouting.shardId().id() == 1 ? UnassignedInfo.AllocationStatus.DECIDERS_THROTTLED @@ -190,7 +190,7 @@ public void testFailsNewPrimariesIfNoDataNodes() { for (ShardRouting shardRouting : routingAllocation.routingNodes().unassigned()) { assertTrue(shardRouting.toString(), shardRouting.unassigned()); assertThat( - shardRouting.unassignedInfo().getLastAllocationStatus(), + shardRouting.unassignedInfo().lastAllocationStatus(), equalTo( // we only update primaries, and only if currently NO_ATTEMPT shardRouting.primary() @@ -677,7 +677,7 @@ public Decision canAllocate(ShardRouting shardRouting, RoutingNode node, Routing .replicaShards() .stream() .allMatch( - shardRouting -> shardRouting.unassignedInfo().getLastAllocationStatus() == UnassignedInfo.AllocationStatus.NO_ATTEMPT + shardRouting -> shardRouting.unassignedInfo().lastAllocationStatus() == UnassignedInfo.AllocationStatus.NO_ATTEMPT ) ); } @@ -724,7 +724,7 @@ public Decision canAllocate(ShardRouting shardRouting, RoutingNode node, Routing nonYesDecision == Decision.NO ? UnassignedInfo.AllocationStatus.DECIDERS_NO : UnassignedInfo.AllocationStatus.DECIDERS_THROTTLED, - redState.routingTable().shardRoutingTable("index-0", 0).primaryShard().unassignedInfo().getLastAllocationStatus() + redState.routingTable().shardRoutingTable("index-0", 0).primaryShard().unassignedInfo().lastAllocationStatus() ); assignPrimary.set(true); @@ -733,7 +733,7 @@ public Decision canAllocate(ShardRouting shardRouting, RoutingNode node, Routing startInitializingShardsAndReroute(allocationService, redState) ); for (final var shardRouting : yellowState.routingTable().shardRoutingTable("index-0", 0).replicaShards()) { - assertEquals(UnassignedInfo.AllocationStatus.NO_ATTEMPT, shardRouting.unassignedInfo().getLastAllocationStatus()); + assertEquals(UnassignedInfo.AllocationStatus.NO_ATTEMPT, shardRouting.unassignedInfo().lastAllocationStatus()); } } diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocatorTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocatorTests.java index 53ac77de6fc88..e5b3393723ab1 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocatorTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocatorTests.java @@ -100,7 +100,7 @@ public void testGatewayAllocatorStillFetching() { var shardRouting = routingTable.shardRoutingTable("test-index", 0).primaryShard(); assertFalse(shardRouting.assignedToNode()); assertThat( - shardRouting.unassignedInfo().getLastAllocationStatus(), + shardRouting.unassignedInfo().lastAllocationStatus(), equalTo(UnassignedInfo.AllocationStatus.FETCHING_SHARD_DATA) ); } @@ -111,7 +111,7 @@ public void testGatewayAllocatorDoesNothing() { testAllocate((allocation, unassignedAllocationHandler) -> {}, routingTable -> { var shardRouting = routingTable.shardRoutingTable("test-index", 0).primaryShard(); assertTrue(shardRouting.assignedToNode());// assigned by a followup reconciliation - assertThat(shardRouting.unassignedInfo().getLastAllocationStatus(), equalTo(UnassignedInfo.AllocationStatus.NO_ATTEMPT)); + assertThat(shardRouting.unassignedInfo().lastAllocationStatus(), equalTo(UnassignedInfo.AllocationStatus.NO_ATTEMPT)); }); } @@ -328,7 +328,7 @@ protected long currentNanoTime() { var unassigned = reconciledState.getRoutingNodes().unassigned(); assertThat(unassigned.size(), equalTo(1)); var unassignedShard = unassigned.iterator().next(); - assertThat(unassignedShard.unassignedInfo().isDelayed(), equalTo(true)); + assertThat(unassignedShard.unassignedInfo().delayed(), equalTo(true)); } finally { clusterService.close(); diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java index 716e7c80a6cde..d5cf73cacb782 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java @@ -1190,13 +1190,13 @@ private void doTestDiskThresholdWithSnapshotShardSizes(boolean testMaxHeadroom) assertThat( shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).stream() .map(ShardRouting::unassignedInfo) - .allMatch(unassignedInfo -> Reason.NEW_INDEX_RESTORED.equals(unassignedInfo.getReason())), + .allMatch(unassignedInfo -> Reason.NEW_INDEX_RESTORED.equals(unassignedInfo.reason())), is(true) ); assertThat( shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).stream() .map(ShardRouting::unassignedInfo) - .allMatch(unassignedInfo -> AllocationStatus.NO_ATTEMPT.equals(unassignedInfo.getLastAllocationStatus())), + .allMatch(unassignedInfo -> AllocationStatus.NO_ATTEMPT.equals(unassignedInfo.lastAllocationStatus())), is(true) ); assertThat(shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).size(), equalTo(1)); @@ -1218,7 +1218,7 @@ private void doTestDiskThresholdWithSnapshotShardSizes(boolean testMaxHeadroom) assertThat( shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).stream() .map(ShardRouting::unassignedInfo) - .allMatch(unassignedInfo -> AllocationStatus.FETCHING_SHARD_DATA.equals(unassignedInfo.getLastAllocationStatus())), + .allMatch(unassignedInfo -> AllocationStatus.FETCHING_SHARD_DATA.equals(unassignedInfo.lastAllocationStatus())), is(true) ); assertThat(shardsWithState(clusterState.getRoutingNodes(), UNASSIGNED).size(), equalTo(1)); diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDeciderTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDeciderTests.java index ea156ee48a656..ab14345cb53c4 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDeciderTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDeciderTests.java @@ -111,16 +111,16 @@ public void testCanAllocatePrimaryExistingInRestoreInProgress() { UnassignedInfo currentInfo = primary.unassignedInfo(); UnassignedInfo newInfo = new UnassignedInfo( - currentInfo.getReason(), - currentInfo.getMessage(), + currentInfo.reason(), + currentInfo.message(), new IOException("i/o failure"), - currentInfo.getNumFailedAllocations(), - currentInfo.getUnassignedTimeInNanos(), - currentInfo.getUnassignedTimeInMillis(), - currentInfo.isDelayed(), - currentInfo.getLastAllocationStatus(), - currentInfo.getFailedNodeIds(), - currentInfo.getLastAllocatedNodeId() + currentInfo.failedAllocations(), + currentInfo.unassignedTimeNanos(), + currentInfo.unassignedTimeMillis(), + currentInfo.delayed(), + currentInfo.lastAllocationStatus(), + currentInfo.failedNodeIds(), + currentInfo.lastAllocatedNodeId() ); primary = primary.updateUnassigned(newInfo, primary.recoverySource()); diff --git a/server/src/test/java/org/elasticsearch/gateway/PrimaryShardAllocatorTests.java b/server/src/test/java/org/elasticsearch/gateway/PrimaryShardAllocatorTests.java index f6b310abac770..a74a00792d701 100644 --- a/server/src/test/java/org/elasticsearch/gateway/PrimaryShardAllocatorTests.java +++ b/server/src/test/java/org/elasticsearch/gateway/PrimaryShardAllocatorTests.java @@ -286,7 +286,7 @@ public void testDontAllocateOnNoOrThrottleForceAllocationDecision() { List ignored = allocation.routingNodes().unassigned().ignored(); assertEquals(ignored.size(), 1); assertEquals( - ignored.get(0).unassignedInfo().getLastAllocationStatus(), + ignored.get(0).unassignedInfo().lastAllocationStatus(), forceDecisionNo ? AllocationStatus.DECIDERS_NO : AllocationStatus.DECIDERS_THROTTLED ); assertTrue(shardsWithState(allocation.routingNodes(), ShardRoutingState.INITIALIZING).isEmpty()); @@ -314,7 +314,7 @@ public void testDontForceAllocateOnThrottleDecision() { assertThat(allocation.routingNodesChanged(), equalTo(true)); List ignored = allocation.routingNodes().unassigned().ignored(); assertEquals(ignored.size(), 1); - assertEquals(ignored.get(0).unassignedInfo().getLastAllocationStatus(), AllocationStatus.DECIDERS_THROTTLED); + assertEquals(ignored.get(0).unassignedInfo().lastAllocationStatus(), AllocationStatus.DECIDERS_THROTTLED); assertTrue(shardsWithState(allocation.routingNodes(), ShardRoutingState.INITIALIZING).isEmpty()); } @@ -454,7 +454,7 @@ public void testRestoreDoesNotAssignIfShardSizeNotAvailable() { assertThat(allocation.routingNodesChanged(), equalTo(true)); assertThat(allocation.routingNodes().unassigned().ignored().isEmpty(), equalTo(false)); ShardRouting ignoredRouting = allocation.routingNodes().unassigned().ignored().get(0); - assertThat(ignoredRouting.unassignedInfo().getLastAllocationStatus(), equalTo(AllocationStatus.FETCHING_SHARD_DATA)); + assertThat(ignoredRouting.unassignedInfo().lastAllocationStatus(), equalTo(AllocationStatus.FETCHING_SHARD_DATA)); assertClusterHealthStatus(allocation, ClusterHealthStatus.YELLOW); } diff --git a/server/src/test/java/org/elasticsearch/gateway/ReplicaShardAllocatorTests.java b/server/src/test/java/org/elasticsearch/gateway/ReplicaShardAllocatorTests.java index e1cba6f1746e4..9582037975318 100644 --- a/server/src/test/java/org/elasticsearch/gateway/ReplicaShardAllocatorTests.java +++ b/server/src/test/java/org/elasticsearch/gateway/ReplicaShardAllocatorTests.java @@ -254,8 +254,8 @@ private void runNoopRetentionLeaseTest(boolean isRelevantShard) { List unassignedShards = shardsWithState(allocation.routingNodes(), ShardRoutingState.UNASSIGNED); assertThat(unassignedShards, hasSize(1)); assertThat(unassignedShards.get(0).shardId(), equalTo(shardId)); - assertThat(unassignedShards.get(0).unassignedInfo().getNumFailedAllocations(), equalTo(0)); - assertThat(unassignedShards.get(0).unassignedInfo().getFailedNodeIds(), equalTo(failedNodeIds)); + assertThat(unassignedShards.get(0).unassignedInfo().failedAllocations(), equalTo(0)); + assertThat(unassignedShards.get(0).unassignedInfo().failedNodeIds(), equalTo(failedNodeIds)); } else { assertThat(allocation.routingNodesChanged(), equalTo(false)); assertThat(shardsWithState(allocation.routingNodes(), ShardRoutingState.UNASSIGNED).size(), equalTo(0)); diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index fa700dc5d78f7..697b40671ee8b 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -1035,7 +1035,7 @@ public void run() { .routingTable() .shardRoutingTable(shardToRelocate.shardId()) .primaryShard(); - if (shardRouting.unassigned() && shardRouting.unassignedInfo().getReason() == UnassignedInfo.Reason.NODE_LEFT) { + if (shardRouting.unassigned() && shardRouting.unassignedInfo().reason() == UnassignedInfo.Reason.NODE_LEFT) { if (masterNodeCount > 1) { scheduleNow(() -> testClusterNodes.stopNode(masterNode)); } diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java b/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java index 7848f0ef4a625..f3fac694f9980 100644 --- a/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java @@ -423,10 +423,10 @@ public void allocateUnassigned( RoutingAllocation allocation, UnassignedAllocationHandler unassignedAllocationHandler ) { - if (shardRouting.primary() || shardRouting.unassignedInfo().getReason() == UnassignedInfo.Reason.INDEX_CREATED) { + if (shardRouting.primary() || shardRouting.unassignedInfo().reason() == UnassignedInfo.Reason.INDEX_CREATED) { return; } - if (shardRouting.unassignedInfo().isDelayed()) { + if (shardRouting.unassignedInfo().delayed()) { unassignedAllocationHandler.removeAndIgnore(UnassignedInfo.AllocationStatus.DELAYED_ALLOCATION, allocation.changes()); } } diff --git a/x-pack/plugin/ccr/src/internalClusterTest/java/org/elasticsearch/xpack/ccr/CcrRepositoryIT.java b/x-pack/plugin/ccr/src/internalClusterTest/java/org/elasticsearch/xpack/ccr/CcrRepositoryIT.java index dff3ff935595f..90bbc29a11b41 100644 --- a/x-pack/plugin/ccr/src/internalClusterTest/java/org/elasticsearch/xpack/ccr/CcrRepositoryIT.java +++ b/x-pack/plugin/ccr/src/internalClusterTest/java/org/elasticsearch/xpack/ccr/CcrRepositoryIT.java @@ -553,7 +553,7 @@ public void testCcrRepositoryFetchesSnapshotShardSizeFromIndexShardStoreStats() if (RestoreInProgress.get(event.state()).isEmpty() == false && event.state().routingTable().hasIndex(followerIndex)) { final IndexRoutingTable indexRoutingTable = event.state().routingTable().index(followerIndex); for (ShardRouting shardRouting : indexRoutingTable.shardsWithState(ShardRoutingState.UNASSIGNED)) { - if (shardRouting.unassignedInfo().getLastAllocationStatus() == AllocationStatus.FETCHING_SHARD_DATA) { + if (shardRouting.unassignedInfo().lastAllocationStatus() == AllocationStatus.FETCHING_SHARD_DATA) { try { assertBusy(() -> { final Long snapshotShardSize = snapshotsInfoService.snapshotShardSizes().getShardSize(shardRouting); @@ -644,7 +644,7 @@ public void testCcrRepositoryFailsToFetchSnapshotShardSizes() throws Exception { assertBusy(() -> { List sizes = indexRoutingTable.shardsWithState(ShardRoutingState.UNASSIGNED) .stream() - .filter(shard -> shard.unassignedInfo().getLastAllocationStatus() == AllocationStatus.FETCHING_SHARD_DATA) + .filter(shard -> shard.unassignedInfo().lastAllocationStatus() == AllocationStatus.FETCHING_SHARD_DATA) .sorted(Comparator.comparingInt(ShardRouting::getId)) .map(shard -> snapshotsInfoService.snapshotShardSizes().getShardSize(shard)) .filter(Objects::nonNull) diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/ClusterStateApplierOrderingTests.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/ClusterStateApplierOrderingTests.java index ffedcb8f9ebd3..ee19fc07e45cb 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/ClusterStateApplierOrderingTests.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/ClusterStateApplierOrderingTests.java @@ -96,7 +96,7 @@ public Settings onNodeStopped(String nodeName) { for (RoutingNode routingNode : event.state().getRoutingNodes()) { for (ShardRouting shardRouting : routingNode) { if (shardRouting.unassignedInfo() != null) { - unassignedReasons.add(shardRouting.unassignedInfo().getReason()); + unassignedReasons.add(shardRouting.unassignedInfo().reason()); } } } diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/allocation/SearchableSnapshotAllocator.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/allocation/SearchableSnapshotAllocator.java index ee018578ce143..b05f7e4844908 100644 --- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/allocation/SearchableSnapshotAllocator.java +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/allocation/SearchableSnapshotAllocator.java @@ -332,8 +332,8 @@ private AllocateUnassignedDecision decideAllocation(RoutingAllocation allocation } private static boolean isDelayedDueToNodeRestart(RoutingAllocation allocation, ShardRouting shardRouting) { - if (shardRouting.unassignedInfo().isDelayed()) { - String lastAllocatedNodeId = shardRouting.unassignedInfo().getLastAllocatedNodeId(); + if (shardRouting.unassignedInfo().delayed()) { + String lastAllocatedNodeId = shardRouting.unassignedInfo().lastAllocatedNodeId(); if (lastAllocatedNodeId != null) { return allocation.metadata().nodeShutdowns().contains(lastAllocatedNodeId, SingleNodeShutdownMetadata.Type.RESTART); } diff --git a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportGetShutdownStatusAction.java b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportGetShutdownStatusAction.java index 9e8c54ba594ea..69043c606ef15 100644 --- a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportGetShutdownStatusAction.java +++ b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportGetShutdownStatusAction.java @@ -219,7 +219,7 @@ static ShutdownShardMigrationStatus shardMigrationStatus( .unassigned() .stream() .peek(s -> cancellableTask.ensureNotCancelled()) - .filter(s -> Objects.equals(s.unassignedInfo().getLastAllocatedNodeId(), nodeId)) + .filter(s -> Objects.equals(s.unassignedInfo().lastAllocatedNodeId(), nodeId)) .filter(s -> s.primary() || hasShardCopyOnAnotherNode(currentState, s, shuttingDownNodes) == false) .toList(); From 4a1d7426d7027930faf117a2c04df507a9b941db Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 6 Jun 2024 11:20:53 +0300 Subject: [PATCH 30/30] Adding RankFeature implementation (#108538) --- docs/changelog/108538.yaml | 5 + .../search/rank/FieldBasedRerankerIT.java | 811 ++++++++++++ server/src/main/java/module-info.java | 1 + .../org/elasticsearch/TransportVersions.java | 2 +- ...ngleCoordinatorSearchProgressListener.java | 18 + .../action/search/FetchSearchPhase.java | 68 +- .../action/search/RankFeaturePhase.java | 175 ++- .../action/search/SearchPhase.java | 34 + .../action/search/SearchPhaseController.java | 4 +- .../action/search/SearchProgressListener.java | 32 + .../action/search/SearchRequest.java | 2 +- .../search/SearchTransportAPMMetrics.java | 1 + .../action/search/SearchTransportService.java | 30 + .../elasticsearch/node/NodeConstruction.java | 1 + .../node/NodeServiceProvider.java | 3 + .../search/DefaultSearchContext.java | 13 + .../elasticsearch/search/SearchModule.java | 5 + .../search/SearchPhaseResult.java | 16 + .../elasticsearch/search/SearchService.java | 39 + .../search/fetch/FetchPhase.java | 1 - .../internal/FilteredSearchContext.java | 11 + .../search/internal/SearchContext.java | 5 + .../search/query/QueryPhase.java | 55 +- .../search/rank/RankBuilder.java | 23 +- .../search/rank/RankSearchContext.java | 19 +- ...ankFeaturePhaseRankCoordinatorContext.java | 96 ++ .../RankFeaturePhaseRankShardContext.java | 39 + .../search/rank/feature/RankFeatureDoc.java | 54 + .../rank/feature/RankFeatureResult.java | 70 + .../rank/feature/RankFeatureShardPhase.java | 99 ++ .../rank/feature/RankFeatureShardRequest.java | 101 ++ .../rank/feature/RankFeatureShardResult.java | 68 + .../action/search/RankFeaturePhaseTests.java | 1170 +++++++++++++++++ .../search/DefaultSearchContextTests.java | 4 +- .../search/SearchServiceTests.java | 753 ++++++++++- .../rank/RankFeatureShardPhaseTests.java | 409 ++++++ .../snapshots/SnapshotResiliencyTests.java | 2 + .../java/org/elasticsearch/node/MockNode.java | 4 + .../search/MockSearchService.java | 3 + .../search/rank/TestRankBuilder.java | 19 +- .../elasticsearch/test/TestSearchContext.java | 11 + .../hamcrest/ElasticsearchAssertions.java | 4 + .../xpack/search/AsyncSearchTask.java | 14 + .../xpack/rank/rrf/RRFRankBuilder.java | 19 +- .../xpack/rank/rrf/RRFRetrieverBuilder.java | 2 +- .../security/authz/PreAuthorizationUtils.java | 1 + 46 files changed, 4207 insertions(+), 109 deletions(-) create mode 100644 docs/changelog/108538.yaml create mode 100644 server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java create mode 100644 server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java create mode 100644 server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java create mode 100644 server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java create mode 100644 server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureResult.java create mode 100644 server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java create mode 100644 server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java create mode 100644 server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardResult.java create mode 100644 server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java create mode 100644 server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java diff --git a/docs/changelog/108538.yaml b/docs/changelog/108538.yaml new file mode 100644 index 0000000000000..10ae49f0c1670 --- /dev/null +++ b/docs/changelog/108538.yaml @@ -0,0 +1,5 @@ +pr: 108538 +summary: Adding RankFeature search phase implementation +area: Search +type: feature +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java new file mode 100644 index 0000000000000..a4e2fda0fd3c9 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java @@ -0,0 +1,811 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank; + +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchPhaseController; +import org.elasticsearch.action.search.SearchPhaseExecutionException; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureShardResult; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.index.query.QueryBuilders.boolQuery; +import static org.elasticsearch.index.query.QueryBuilders.constantScoreQuery; +import static org.elasticsearch.index.query.QueryBuilders.matchQuery; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasId; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasRank; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.hamcrest.Matchers.equalTo; + +@ESIntegTestCase.ClusterScope(minNumDataNodes = 3) +public class FieldBasedRerankerIT extends ESIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return List.of(FieldBasedRerankerPlugin.class); + } + + public void testFieldBasedReranker() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + assertNoFailuresAndResponse( + prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField)) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10), + response -> { + assertHitCount(response, 5L); + int rank = 1; + for (SearchHit searchHit : response.getHits().getHits()) { + assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1)))); + assertEquals(searchHit.getScore(), (0.5f - ((rank - 1) * 0.1f)), 1e-5f); + assertThat(searchHit, hasRank(rank)); + assertNotNull(searchHit.getFields().get(searchField)); + rank++; + } + } + ); + assertNoOpenContext(indexName); + } + + public void testFieldBasedRerankerPagination() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + assertResponse( + prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField)) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(2) + .setFrom(2), + response -> { + assertHitCount(response, 5L); + int rank = 3; + for (SearchHit searchHit : response.getHits().getHits()) { + assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1)))); + assertEquals(searchHit.getScore(), (0.5f - ((rank - 1) * 0.1f)), 1e-5f); + assertThat(searchHit, hasRank(rank)); + assertNotNull(searchHit.getFields().get(searchField)); + rank++; + } + } + ); + assertNoOpenContext(indexName); + } + + public void testFieldBasedRerankerPaginationOutsideOfBounds() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + assertNoFailuresAndResponse( + prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField)) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(2) + .setFrom(10), + response -> { + assertHitCount(response, 5L); + assertEquals(0, response.getHits().getHits().length); + } + ); + assertNoOpenContext(indexName); + } + + public void testNotAllShardsArePresentInFetchPhase() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 10).build()); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A").setRouting("A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B").setRouting("B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C").setRouting("C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D").setRouting("C"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E").setRouting("C") + ); + + assertNoFailuresAndResponse( + prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(0.1f)) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(0.3f)) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(0.3f)) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(0.3f)) + ) + .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField)) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(2), + response -> { + assertHitCount(response, 4L); + assertEquals(2, response.getHits().getHits().length); + int rank = 1; + for (SearchHit searchHit : response.getHits().getHits()) { + assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1)))); + assertEquals(searchHit.getScore(), (0.5f - ((rank - 1) * 0.1f)), 1e-5f); + assertThat(searchHit, hasRank(rank)); + assertNotNull(searchHit.getFields().get(searchField)); + rank++; + } + } + ); + assertNoOpenContext(indexName); + } + + public void testFieldBasedRerankerNoMatchingDocs() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + assertNoFailuresAndResponse( + prepareSearch().setQuery(boolQuery().should(constantScoreQuery(matchQuery(searchField, "F")).boost(randomFloat()))) + .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField)) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10), + response -> { + assertHitCount(response, 0L); + } + ); + assertNoOpenContext(indexName); + } + + public void testQueryPhaseShardThrowingRankBuilderAllContextsAreClosedAllShardsFail() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + // this test is irrespective of the number of shards, as we will always reach QueryPhaseRankShardContext#combineQueryPhaseResults + // even with no results. So, when we get back to the coordinator, all shards will have failed, and the whole response + // will be marked as a failure + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + expectThrows( + SearchPhaseExecutionException.class, + () -> prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder( + new ThrowingRankBuilder( + rankWindowSize, + rankFeatureField, + ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_QUERY_PHASE_SHARD_CONTEXT.name() + ) + ) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10) + .get() + ); + assertNoOpenContext(indexName); + } + + public void testQueryPhaseCoordinatorThrowingRankBuilderAllContextsAreClosedAllShardsFail() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + // when we throw on the coordinator, the onPhaseFailure handler will be invoked, which in turn will mark the whole + // search request as a failure (i.e. no partial results) + expectThrows( + SearchPhaseExecutionException.class, + () -> prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder( + new ThrowingRankBuilder( + rankWindowSize, + rankFeatureField, + ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_QUERY_PHASE_COORDINATOR_CONTEXT.name() + ) + ) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10) + .get() + ); + assertNoOpenContext(indexName); + } + + public void testRankFeaturePhaseShardThrowingRankBuilderAllContextsAreClosedPartialFailures() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 10).build()); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + // we have 10 shards and 5 documents, so when the exception is thrown we know that not all shards will report failures + assertResponse( + prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder( + new ThrowingRankBuilder( + rankWindowSize, + rankFeatureField, + ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT.name() + ) + ) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10), + response -> { + assertTrue(response.getFailedShards() > 0); + assertTrue( + Arrays.stream(response.getShardFailures()) + .allMatch(failure -> failure.getCause().getMessage().contains("rfs - simulated failure")) + ); + assertHitCount(response, 5); + assertTrue(response.getHits().getHits().length == 0); + } + ); + assertNoOpenContext(indexName); + } + + public void testRankFeaturePhaseShardThrowingRankBuilderAllContextsAreClosedAllShardsFail() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + // we have 1 shard and 5 documents, so when the exception is thrown we know that all shards will have failed + createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).build()); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + expectThrows( + SearchPhaseExecutionException.class, + () -> prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder( + new ThrowingRankBuilder( + rankWindowSize, + rankFeatureField, + ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT.name() + ) + ) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10) + .get() + ); + assertNoOpenContext(indexName); + } + + public void testRankFeaturePhaseCoordinatorThrowingRankBuilderAllContextsAreClosedAllShardsFail() throws Exception { + final String indexName = "test_index"; + final String rankFeatureField = "rankFeatureField"; + final String searchField = "searchField"; + final int rankWindowSize = 10; + + createIndex(indexName); + indexRandom( + true, + prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"), + prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"), + prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"), + prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"), + prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E") + ); + + expectThrows( + SearchPhaseExecutionException.class, + () -> prepareSearch().setQuery( + boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat())) + .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat())) + ) + .setRankBuilder( + new ThrowingRankBuilder( + rankWindowSize, + rankFeatureField, + ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name() + ) + ) + .addFetchField(searchField) + .setTrackTotalHits(true) + .setAllowPartialSearchResults(true) + .setSize(10) + .get() + ); + assertNoOpenContext(indexName); + } + + private void assertNoOpenContext(final String indexName) throws Exception { + assertBusy( + () -> assertThat(indicesAdmin().prepareStats(indexName).get().getTotal().getSearch().getOpenContexts(), equalTo(0L)), + 1, + TimeUnit.SECONDS + ); + } + + public static class FieldBasedRankBuilder extends RankBuilder { + + public static final ParseField FIELD_FIELD = new ParseField("field"); + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "field-based-rank", + args -> { + int rankWindowSize = args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[0]; + String field = (String) args[1]; + if (field == null || field.isEmpty()) { + throw new IllegalArgumentException("Field cannot be null or empty"); + } + return new FieldBasedRankBuilder(rankWindowSize, field); + } + ); + + static { + PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); + PARSER.declareString(constructorArg(), FIELD_FIELD); + } + + protected final String field; + + public static FieldBasedRankBuilder fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + public FieldBasedRankBuilder(final int rankWindowSize, final String field) { + super(rankWindowSize); + this.field = field; + } + + public FieldBasedRankBuilder(StreamInput in) throws IOException { + super(in); + this.field = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(FIELD_FIELD.getPreferredName(), field); + } + + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, rankWindowSize()) { + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + Map rankDocs = new HashMap<>(); + rankResults.forEach(topDocs -> { + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + rankDocs.compute(scoreDoc.doc, (key, value) -> { + if (value == null) { + return new RankFeatureDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex); + } else { + value.score = Math.max(scoreDoc.score, rankDocs.get(scoreDoc.doc).score); + return value; + } + }); + } + }); + RankFeatureDoc[] sortedResults = rankDocs.values().toArray(RankFeatureDoc[]::new); + Arrays.sort(sortedResults, (o1, o2) -> Float.compare(o2.score, o1.score)); + return new RankFeatureShardResult(sortedResults); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(rankWindowSize()) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + List rankDocs = new ArrayList<>(); + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + RankFeatureShardResult shardResult = (RankFeatureShardResult) querySearchResult.getRankShardResult(); + for (RankFeatureDoc frd : shardResult.rankFeatureDocs) { + frd.shardIndex = i; + rankDocs.add(frd); + } + } + // no support for sort field atm + // should pass needed info to make use of org.elasticsearch.action.search.SearchPhaseController.sortDocs? + rankDocs.sort(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()); + RankFeatureDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(RankFeatureDoc[]::new); + + assert topDocStats.fetchHits == 0; + topDocStats.fetchHits = topResults.length; + + return topResults; + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(field) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + try { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId); + rankFeatureDocs[i].featureData(hits.getHits()[i].field(field).getValue().toString()); + } + return new RankFeatureShardResult(rankFeatureDocs); + } catch (Exception ex) { + throw ex; + } + } + }; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize()) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + float[] scores = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + scores[i] = Float.parseFloat(featureDocs[i].featureData); + } + scoreListener.onResponse(scores); + } + }; + } + + @Override + protected boolean doEquals(RankBuilder other) { + return other instanceof FieldBasedRankBuilder && Objects.equals(field, ((FieldBasedRankBuilder) other).field); + } + + @Override + protected int doHashCode() { + return Objects.hash(field); + } + + @Override + public String getWriteableName() { + return "field-based-rank"; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.RANK_FEATURE_PHASE_ADDED; + } + } + + public static class ThrowingRankBuilder extends FieldBasedRankBuilder { + + public enum ThrowingRankBuilderType { + THROWING_QUERY_PHASE_SHARD_CONTEXT, + THROWING_QUERY_PHASE_COORDINATOR_CONTEXT, + THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT, + THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT; + } + + protected final ThrowingRankBuilderType throwingRankBuilderType; + + public static final ParseField FIELD_FIELD = new ParseField("field"); + public static final ParseField THROWING_TYPE_FIELD = new ParseField("throwing-type"); + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("throwing-rank", args -> { + int rankWindowSize = args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[0]; + String field = (String) args[1]; + if (field == null || field.isEmpty()) { + throw new IllegalArgumentException("Field cannot be null or empty"); + } + String throwingType = (String) args[2]; + return new ThrowingRankBuilder(rankWindowSize, field, throwingType); + }); + + static { + PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); + PARSER.declareString(constructorArg(), FIELD_FIELD); + PARSER.declareString(constructorArg(), THROWING_TYPE_FIELD); + } + + public static FieldBasedRankBuilder fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + public ThrowingRankBuilder(final int rankWindowSize, final String field, final String throwingType) { + super(rankWindowSize, field); + this.throwingRankBuilderType = ThrowingRankBuilderType.valueOf(throwingType); + } + + public ThrowingRankBuilder(StreamInput in) throws IOException { + super(in); + this.throwingRankBuilderType = in.readEnum(ThrowingRankBuilderType.class); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + super.doWriteTo(out); + out.writeEnum(throwingRankBuilderType); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + super.doXContent(builder, params); + builder.field(THROWING_TYPE_FIELD.getPreferredName(), throwingRankBuilderType); + } + + @Override + public String getWriteableName() { + return "throwing-rank"; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_QUERY_PHASE_SHARD_CONTEXT) + return new QueryPhaseRankShardContext(queries, rankWindowSize()) { + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + throw new UnsupportedOperationException("qps - simulated failure"); + } + }; + else { + return super.buildQueryPhaseShardContext(queries, from); + } + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_QUERY_PHASE_COORDINATOR_CONTEXT) + return new QueryPhaseRankCoordinatorContext(rankWindowSize()) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + throw new UnsupportedOperationException("qpc - simulated failure"); + } + }; + else { + return super.buildQueryPhaseCoordinatorContext(size, from); + } + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT) + return new RankFeaturePhaseRankShardContext(field) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + throw new UnsupportedOperationException("rfs - simulated failure"); + } + }; + else { + return super.buildRankFeaturePhaseShardContext(); + } + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT) + return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize()) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + throw new UnsupportedOperationException("rfc - simulated failure"); + } + }; + else { + return super.buildRankFeaturePhaseCoordinatorContext(size, from); + } + } + } + + public static class FieldBasedRerankerPlugin extends Plugin implements SearchPlugin { + + private static final String FIELD_BASED_RANK_BUILDER_NAME = "field-based-rank"; + private static final String THROWING_RANK_BUILDER_NAME = "throwing-rank"; + + @Override + public List getNamedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry(RankBuilder.class, FIELD_BASED_RANK_BUILDER_NAME, FieldBasedRankBuilder::new), + new NamedWriteableRegistry.Entry(RankBuilder.class, THROWING_RANK_BUILDER_NAME, ThrowingRankBuilder::new), + new NamedWriteableRegistry.Entry(RankShardResult.class, "rank_feature_shard", RankFeatureShardResult::new) + ); + } + + @Override + public List getNamedXContent() { + return List.of( + new NamedXContentRegistry.Entry( + RankBuilder.class, + new ParseField(FIELD_BASED_RANK_BUILDER_NAME), + FieldBasedRankBuilder::fromXContent + ), + new NamedXContentRegistry.Entry( + RankBuilder.class, + new ParseField(THROWING_RANK_BUILDER_NAME), + ThrowingRankBuilder::fromXContent + ) + ); + } + } +} diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index d8682500c49d6..2f08129b4080d 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -362,6 +362,7 @@ exports org.elasticsearch.search.query; exports org.elasticsearch.search.rank; exports org.elasticsearch.search.rank.context; + exports org.elasticsearch.search.rank.feature; exports org.elasticsearch.search.rescore; exports org.elasticsearch.search.retriever; exports org.elasticsearch.search.runtime; diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index e8a33217b937d..72771855ff622 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -184,7 +184,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_GOOGLE_AI_STUDIO_EMBEDDINGS_ADDED = def(8_675_00_0); public static final TransportVersion ADD_MISTRAL_EMBEDDINGS_INFERENCE = def(8_676_00_0); public static final TransportVersion ML_CHUNK_INFERENCE_OPTION = def(8_677_00_0); - + public static final TransportVersion RANK_FEATURE_PHASE_ADDED = def(8_678_00_0); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/action/search/CCSSingleCoordinatorSearchProgressListener.java b/server/src/main/java/org/elasticsearch/action/search/CCSSingleCoordinatorSearchProgressListener.java index 3b594c94db9a7..0504d0cde8986 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CCSSingleCoordinatorSearchProgressListener.java +++ b/server/src/main/java/org/elasticsearch/action/search/CCSSingleCoordinatorSearchProgressListener.java @@ -260,6 +260,24 @@ public void onFinalReduce(List shards, TotalHits totalHits, Interna } } + /** + * Executed when a shard returns a rank feature result. + * + * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}. + */ + @Override + public void onRankFeatureResult(int shardIndex) {} + + /** + * Executed when a shard reports a rank feature failure. + * + * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}. + * @param shardTarget The last shard target that thrown an exception. + * @param exc The cause of the failure. + */ + @Override + public void onRankFeatureFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {} + /** * Executed when a shard returns a fetch result. * diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index f804ab31faf8e..2308f5fcc8085 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -17,8 +17,6 @@ import org.elasticsearch.search.fetch.FetchSearchResult; import org.elasticsearch.search.fetch.ShardFetchSearchRequest; import org.elasticsearch.search.internal.ShardSearchContextId; -import org.elasticsearch.search.query.QuerySearchResult; -import org.elasticsearch.transport.Transport; import java.util.List; import java.util.function.BiFunction; @@ -29,7 +27,7 @@ */ final class FetchSearchPhase extends SearchPhase { private final ArraySearchPhaseResults fetchResults; - private final AtomicArray queryResults; + private final AtomicArray searchPhaseShardResults; private final BiFunction, SearchPhase> nextPhaseFactory; private final SearchPhaseContext context; private final Logger logger; @@ -74,7 +72,7 @@ final class FetchSearchPhase extends SearchPhase { } this.fetchResults = new ArraySearchPhaseResults<>(resultConsumer.getNumShards()); context.addReleasable(fetchResults); - this.queryResults = resultConsumer.getAtomicArray(); + this.searchPhaseShardResults = resultConsumer.getAtomicArray(); this.aggregatedDfs = aggregatedDfs; this.nextPhaseFactory = nextPhaseFactory; this.context = context; @@ -103,19 +101,20 @@ private void innerRun() { final int numShards = context.getNumShards(); // Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase. - final boolean queryAndFetchOptimization = queryResults.length() == 1 + final boolean queryAndFetchOptimization = searchPhaseShardResults.length() == 1 && context.getRequest().hasKnnSearch() == false - && reducedQueryPhase.rankCoordinatorContext() == null; + && reducedQueryPhase.queryPhaseRankCoordinatorContext() == null; if (queryAndFetchOptimization) { assert assertConsistentWithQueryAndFetchOptimization(); // query AND fetch optimization - moveToNextPhase(queryResults); + moveToNextPhase(searchPhaseShardResults); } else { ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // no docs to fetch -- sidestep everything and return if (scoreDocs.length == 0) { // we have to release contexts here to free up resources - queryResults.asList().stream().map(SearchPhaseResult::queryResult).forEach(this::releaseIrrelevantSearchContext); + searchPhaseShardResults.asList() + .forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context)); moveToNextPhase(fetchResults.getAtomicArray()); } else { final ScoreDoc[] lastEmittedDocPerShard = context.getRequest().scroll() != null @@ -130,19 +129,19 @@ private void innerRun() { ); for (int i = 0; i < docIdsToLoad.length; i++) { List entry = docIdsToLoad[i]; - SearchPhaseResult queryResult = queryResults.get(i); + SearchPhaseResult shardPhaseResult = searchPhaseShardResults.get(i); if (entry == null) { // no results for this shard ID - if (queryResult != null) { + if (shardPhaseResult != null) { // if we got some hits from this shard we have to release the context there // we do this as we go since it will free up resources and passing on the request on the // transport layer is cheap. - releaseIrrelevantSearchContext(queryResult.queryResult()); + releaseIrrelevantSearchContext(shardPhaseResult, context); progressListener.notifyFetchResult(i); } // in any case we count down this result since we don't talk to this shard anymore counter.countDown(); } else { - executeFetch(queryResult, counter, entry, (lastEmittedDocPerShard != null) ? lastEmittedDocPerShard[i] : null); + executeFetch(shardPhaseResult, counter, entry, (lastEmittedDocPerShard != null) ? lastEmittedDocPerShard[i] : null); } } } @@ -150,31 +149,33 @@ private void innerRun() { } private boolean assertConsistentWithQueryAndFetchOptimization() { - var phaseResults = queryResults.asList(); + var phaseResults = searchPhaseShardResults.asList(); assert phaseResults.isEmpty() || phaseResults.get(0).fetchResult() != null : "phaseResults empty [" + phaseResults.isEmpty() + "], single result: " + phaseResults.get(0).fetchResult(); return true; } private void executeFetch( - SearchPhaseResult queryResult, + SearchPhaseResult shardPhaseResult, final CountedCollector counter, final List entry, ScoreDoc lastEmittedDocForShard ) { - final SearchShardTarget shardTarget = queryResult.getSearchShardTarget(); - final int shardIndex = queryResult.getShardIndex(); - final ShardSearchContextId contextId = queryResult.queryResult().getContextId(); + final SearchShardTarget shardTarget = shardPhaseResult.getSearchShardTarget(); + final int shardIndex = shardPhaseResult.getShardIndex(); + final ShardSearchContextId contextId = shardPhaseResult.queryResult() != null + ? shardPhaseResult.queryResult().getContextId() + : shardPhaseResult.rankFeatureResult().getContextId(); context.getSearchTransport() .sendExecuteFetch( context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()), new ShardFetchSearchRequest( - context.getOriginalIndices(queryResult.getShardIndex()), + context.getOriginalIndices(shardPhaseResult.getShardIndex()), contextId, - queryResult.getShardSearchRequest(), + shardPhaseResult.getShardSearchRequest(), entry, lastEmittedDocForShard, - queryResult.getRescoreDocIds(), + shardPhaseResult.getRescoreDocIds(), aggregatedDfs ), context.getTask(), @@ -199,40 +200,17 @@ public void onFailure(Exception e) { // the search context might not be cleared on the node where the fetch was executed for example // because the action was rejected by the thread pool. in this case we need to send a dedicated // request to clear the search context. - releaseIrrelevantSearchContext(queryResult.queryResult()); + releaseIrrelevantSearchContext(shardPhaseResult, context); } } } ); } - /** - * Releases shard targets that are not used in the docsIdsToLoad. - */ - private void releaseIrrelevantSearchContext(QuerySearchResult queryResult) { - // we only release search context that we did not fetch from, if we are not scrolling - // or using a PIT and if it has at least one hit that didn't make it to the global topDocs - if (queryResult.hasSearchContext() - && context.getRequest().scroll() == null - && (context.isPartOfPointInTime(queryResult.getContextId()) == false)) { - try { - SearchShardTarget shardTarget = queryResult.getSearchShardTarget(); - Transport.Connection connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()); - context.sendReleaseSearchContext( - queryResult.getContextId(), - connection, - context.getOriginalIndices(queryResult.getShardIndex()) - ); - } catch (Exception e) { - logger.trace("failed to release context", e); - } - } - } - private void moveToNextPhase(AtomicArray fetchResultsArr) { var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr); context.addReleasable(resp::decRef); fetchResults.close(); - context.executeNextPhase(this, nextPhaseFactory.apply(resp, queryResults)); + context.executeNextPhase(this, nextPhaseFactory.apply(resp, searchPhaseShardResults)); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index 767597625edc6..291982dd9bdd3 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -7,23 +7,39 @@ */ package org.elasticsearch.action.search; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.dfs.AggregatedDfs; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; + +import java.util.List; /** * This search phase is responsible for executing any re-ranking needed for the given search request, iff that is applicable. - * It starts by retrieving {code num_shards * window_size} results from the query phase and reduces them to a global list of + * It starts by retrieving {@code num_shards * window_size} results from the query phase and reduces them to a global list of * the top {@code window_size} results. It then reaches out to the shards to extract the needed feature data, * and finally passes all this information to the appropriate {@code RankFeatureRankCoordinatorContext} which is responsible for reranking * the results. If no rank query is specified, it proceeds directly to the next phase (FetchSearchPhase) by first reducing the results. */ -public final class RankFeaturePhase extends SearchPhase { +public class RankFeaturePhase extends SearchPhase { + private static final Logger logger = LogManager.getLogger(RankFeaturePhase.class); private final SearchPhaseContext context; - private final SearchPhaseResults queryPhaseResults; - + final SearchPhaseResults queryPhaseResults; + final SearchPhaseResults rankPhaseResults; private final AggregatedDfs aggregatedDfs; + private final SearchProgressListener progressListener; RankFeaturePhase(SearchPhaseResults queryPhaseResults, AggregatedDfs aggregatedDfs, SearchPhaseContext context) { super("rank-feature"); @@ -38,6 +54,9 @@ public final class RankFeaturePhase extends SearchPhase { this.context = context; this.queryPhaseResults = queryPhaseResults; this.aggregatedDfs = aggregatedDfs; + this.rankPhaseResults = new ArraySearchPhaseResults<>(context.getNumShards()); + context.addReleasable(rankPhaseResults); + this.progressListener = context.getTask().getProgressListener(); } @Override @@ -59,16 +78,154 @@ public void onFailure(Exception e) { }); } - private void innerRun() throws Exception { - // other than running reduce, this is currently close to a no-op + void innerRun() throws Exception { + // if the RankBuilder specifies a QueryPhaseCoordinatorContext, it will be called as part of the reduce call + // to operate on the first `window_size * num_shards` results and merge them appropriately. SearchPhaseController.ReducedQueryPhase reducedQueryPhase = queryPhaseResults.reduce(); - moveToNextPhase(queryPhaseResults, reducedQueryPhase); + RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source()); + if (rankFeaturePhaseRankCoordinatorContext != null) { + ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size + final List[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs); + final CountedCollector rankRequestCounter = new CountedCollector<>( + rankPhaseResults, + context.getNumShards(), + () -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase), + context + ); + + // we send out a request to each shard in order to fetch the needed feature info + for (int i = 0; i < docIdsToLoad.length; i++) { + List entry = docIdsToLoad[i]; + SearchPhaseResult queryResult = queryPhaseResults.getAtomicArray().get(i); + if (entry == null || entry.isEmpty()) { + if (queryResult != null) { + releaseIrrelevantSearchContext(queryResult, context); + progressListener.notifyRankFeatureResult(i); + } + rankRequestCounter.countDown(); + } else { + executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry); + } + } + } else { + moveToNextPhase(queryPhaseResults, reducedQueryPhase); + } + } + + private RankFeaturePhaseRankCoordinatorContext coordinatorContext(SearchSourceBuilder source) { + return source == null || source.rankBuilder() == null + ? null + : context.getRequest() + .source() + .rankBuilder() + .buildRankFeaturePhaseCoordinatorContext(context.getRequest().source().size(), context.getRequest().source().from()); } - private void moveToNextPhase( - SearchPhaseResults phaseResults, + private void executeRankFeatureShardPhase( + SearchPhaseResult queryResult, + final CountedCollector rankRequestCounter, + final List entry + ) { + final SearchShardTarget shardTarget = queryResult.queryResult().getSearchShardTarget(); + final ShardSearchContextId contextId = queryResult.queryResult().getContextId(); + final int shardIndex = queryResult.getShardIndex(); + context.getSearchTransport() + .sendExecuteRankFeature( + context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()), + new RankFeatureShardRequest( + context.getOriginalIndices(queryResult.getShardIndex()), + queryResult.getContextId(), + queryResult.getShardSearchRequest(), + entry + ), + context.getTask(), + new SearchActionListener<>(shardTarget, shardIndex) { + @Override + protected void innerOnResponse(RankFeatureResult response) { + try { + progressListener.notifyRankFeatureResult(shardIndex); + rankRequestCounter.onResult(response); + } catch (Exception e) { + context.onPhaseFailure(RankFeaturePhase.this, "", e); + } + } + + @Override + public void onFailure(Exception e) { + try { + logger.debug(() -> "[" + contextId + "] Failed to execute rank phase", e); + progressListener.notifyRankFeatureFailure(shardIndex, shardTarget, e); + rankRequestCounter.onFailure(shardIndex, shardTarget, e); + } finally { + releaseIrrelevantSearchContext(queryResult, context); + } + } + } + ); + } + + private void onPhaseDone( + RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext, SearchPhaseController.ReducedQueryPhase reducedQueryPhase ) { + assert rankFeaturePhaseRankCoordinatorContext != null; + ThreadedActionListener rankResultListener = new ThreadedActionListener<>(context, new ActionListener<>() { + @Override + public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) { + RankFeatureDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(docsWithUpdatedScores); + SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = newReducedQueryPhaseResults( + reducedQueryPhase, + topResults + ); + moveToNextPhase(rankPhaseResults, reducedRankFeaturePhase); + } + + @Override + public void onFailure(Exception e) { + context.onPhaseFailure(RankFeaturePhase.this, "Computing updated ranks for results failed", e); + } + }); + rankFeaturePhaseRankCoordinatorContext.rankGlobalResults( + rankPhaseResults.getAtomicArray().asList().stream().map(SearchPhaseResult::rankFeatureResult).toList(), + rankResultListener + ); + } + + private SearchPhaseController.ReducedQueryPhase newReducedQueryPhaseResults( + SearchPhaseController.ReducedQueryPhase reducedQueryPhase, + ScoreDoc[] scoreDocs + ) { + + return new SearchPhaseController.ReducedQueryPhase( + reducedQueryPhase.totalHits(), + reducedQueryPhase.fetchHits(), + maxScore(scoreDocs), + reducedQueryPhase.timedOut(), + reducedQueryPhase.terminatedEarly(), + reducedQueryPhase.suggest(), + reducedQueryPhase.aggregations(), + reducedQueryPhase.profileBuilder(), + new SearchPhaseController.SortedTopDocs(scoreDocs, false, null, null, null, 0), + reducedQueryPhase.sortValueFormats(), + reducedQueryPhase.queryPhaseRankCoordinatorContext(), + reducedQueryPhase.numReducePhases(), + reducedQueryPhase.size(), + reducedQueryPhase.from(), + reducedQueryPhase.isEmptyResult() + ); + } + + private float maxScore(ScoreDoc[] scoreDocs) { + float maxScore = Float.NaN; + for (ScoreDoc scoreDoc : scoreDocs) { + if (Float.isNaN(maxScore) || scoreDoc.score > maxScore) { + maxScore = scoreDoc.score; + } + } + return maxScore; + } + + void moveToNextPhase(SearchPhaseResults phaseResults, SearchPhaseController.ReducedQueryPhase reducedQueryPhase) { context.executeNextPhase(this, new FetchSearchPhase(phaseResults, aggregatedDfs, context, reducedQueryPhase)); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java index 9d3eadcc42bf9..5ed449667fe57 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java @@ -9,6 +9,9 @@ import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.core.CheckedRunnable; +import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.transport.Transport; import java.io.IOException; import java.io.UncheckedIOException; @@ -62,4 +65,35 @@ static void doCheckNoMissingShards(String phaseName, SearchRequest request, Grou } } } + + /** + * Releases shard targets that are not used in the docsIdsToLoad. + */ + protected void releaseIrrelevantSearchContext(SearchPhaseResult searchPhaseResult, SearchPhaseContext context) { + // we only release search context that we did not fetch from, if we are not scrolling + // or using a PIT and if it has at least one hit that didn't make it to the global topDocs + if (searchPhaseResult == null) { + return; + } + // phaseResult.getContextId() is the same for query & rank feature results + SearchPhaseResult phaseResult = searchPhaseResult.queryResult() != null + ? searchPhaseResult.queryResult() + : searchPhaseResult.rankFeatureResult(); + if (phaseResult != null + && phaseResult.hasSearchContext() + && context.getRequest().scroll() == null + && (context.isPartOfPointInTime(phaseResult.getContextId()) == false)) { + try { + SearchShardTarget shardTarget = phaseResult.getSearchShardTarget(); + Transport.Connection connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()); + context.sendReleaseSearchContext( + phaseResult.getContextId(), + connection, + context.getOriginalIndices(phaseResult.getShardIndex()) + ); + } catch (Exception e) { + context.getLogger().trace("failed to release context", e); + } + } + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index 1b894dfe3d8bd..1d3859b9038fe 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -456,7 +456,7 @@ private static SearchHits getHits( : "not enough hits fetched. index [" + index + "] length: " + fetchResult.hits().getHits().length; SearchHit searchHit = fetchResult.hits().getHits()[index]; searchHit.shard(fetchResult.getSearchShardTarget()); - if (reducedQueryPhase.rankCoordinatorContext != null) { + if (reducedQueryPhase.queryPhaseRankCoordinatorContext != null) { assert shardDoc instanceof RankDoc; searchHit.setRank(((RankDoc) shardDoc).rank); searchHit.score(shardDoc.score); @@ -747,7 +747,7 @@ public record ReducedQueryPhase( // sort value formats used to sort / format the result DocValueFormat[] sortValueFormats, // the rank context if ranking is used - QueryPhaseRankCoordinatorContext rankCoordinatorContext, + QueryPhaseRankCoordinatorContext queryPhaseRankCoordinatorContext, // the number of reduces phases int numReducePhases, // the size of the top hits to return diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java index f5d280a01257c..3b5e03cb5ac4a 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java @@ -88,6 +88,22 @@ protected void onPartialReduce(List shards, TotalHits totalHits, In */ protected void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {} + /** + * Executed when a shard returns a rank feature result. + * + * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}. + */ + protected void onRankFeatureResult(int shardIndex) {} + + /** + * Executed when a shard reports a rank feature failure. + * + * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}. + * @param shardTarget The last shard target that thrown an exception. + * @param exc The cause of the failure. + */ + protected void onRankFeatureFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {} + /** * Executed when a shard returns a fetch result. * @@ -160,6 +176,22 @@ protected final void notifyFinalReduce(List shards, TotalHits total } } + final void notifyRankFeatureResult(int shardIndex) { + try { + onRankFeatureResult(shardIndex); + } catch (Exception e) { + logger.warn(() -> "[" + shards.get(shardIndex) + "] Failed to execute progress listener on rank-feature result", e); + } + } + + final void notifyRankFeatureFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) { + try { + onRankFeatureFailure(shardIndex, shardTarget, exc); + } catch (Exception e) { + logger.warn(() -> "[" + shards.get(shardIndex) + "] Failed to execute progress listener on rank-feature failure", e); + } + } + final void notifyFetchResult(int shardIndex) { try { onFetchResult(shardIndex); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java index 4e3fdbc9633b9..3e4f6dfec9fdb 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java @@ -407,7 +407,7 @@ public ActionRequestValidationException validate() { ); } int queryCount = source.subSearches().size() + source.knnSearch().size(); - if (queryCount < 2) { + if (source.rankBuilder().isCompoundBuilder() && queryCount < 2) { validationException = addValidationError( "[rank] requires a minimum of [2] result sets using a combination of sub searches and/or knn searches", validationException diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportAPMMetrics.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportAPMMetrics.java index 93b8e22d0d7cd..9f8896f169350 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportAPMMetrics.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportAPMMetrics.java @@ -19,6 +19,7 @@ public class SearchTransportAPMMetrics { public static final String DFS_ACTION_METRIC = "dfs_query_then_fetch/shard_dfs_phase"; public static final String QUERY_ID_ACTION_METRIC = "dfs_query_then_fetch/shard_query_phase"; public static final String QUERY_ACTION_METRIC = "query_then_fetch/shard_query_phase"; + public static final String RANK_SHARD_FEATURE_ACTION_METRIC = "rank/shard_feature_phase"; public static final String FREE_CONTEXT_ACTION_METRIC = "shard_release_context"; public static final String FETCH_ID_ACTION_METRIC = "shard_fetch_phase"; public static final String QUERY_SCROLL_ACTION_METRIC = "scroll/shard_query_phase"; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index 66c395cf51d96..d627da9b0e33b 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -39,6 +39,8 @@ import org.elasticsearch.search.query.QuerySearchRequest; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.query.ScrollQuerySearchResult; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.RemoteClusterService; @@ -70,6 +72,7 @@ import static org.elasticsearch.action.search.SearchTransportAPMMetrics.QUERY_FETCH_SCROLL_ACTION_METRIC; import static org.elasticsearch.action.search.SearchTransportAPMMetrics.QUERY_ID_ACTION_METRIC; import static org.elasticsearch.action.search.SearchTransportAPMMetrics.QUERY_SCROLL_ACTION_METRIC; +import static org.elasticsearch.action.search.SearchTransportAPMMetrics.RANK_SHARD_FEATURE_ACTION_METRIC; /** * An encapsulation of {@link org.elasticsearch.search.SearchService} operations exposed through @@ -96,6 +99,8 @@ public class SearchTransportService { public static final String FETCH_ID_SCROLL_ACTION_NAME = "indices:data/read/search[phase/fetch/id/scroll]"; public static final String FETCH_ID_ACTION_NAME = "indices:data/read/search[phase/fetch/id]"; + public static final String RANK_FEATURE_SHARD_ACTION_NAME = "indices:data/read/search[phase/rank/feature]"; + /** * The Can-Match phase. It is executed to pre-filter shards that a search request hits. It rewrites the query on * the shard and checks whether the result of the rewrite matches no documents, in which case the shard can be @@ -250,6 +255,21 @@ public void sendExecuteScrollQuery( ); } + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + transportService.sendChildRequest( + connection, + RANK_FEATURE_SHARD_ACTION_NAME, + request, + task, + new ConnectionCountingHandler<>(listener, RankFeatureResult::new, connection) + ); + } + public void sendExecuteScrollFetch( Transport.Connection connection, final InternalScrollSearchRequest request, @@ -539,6 +559,16 @@ public static void registerRequestHandler( ); TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, true, ScrollQueryFetchSearchResult::new); + final TransportRequestHandler rankShardFeatureRequest = (request, channel, task) -> searchService + .executeRankFeaturePhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel)); + transportService.registerRequestHandler( + RANK_FEATURE_SHARD_ACTION_NAME, + EsExecutors.DIRECT_EXECUTOR_SERVICE, + RankFeatureShardRequest::new, + instrumentedHandler(RANK_SHARD_FEATURE_ACTION_METRIC, transportService, searchTransportMetrics, rankShardFeatureRequest) + ); + TransportActionProxy.registerProxyAction(transportService, RANK_FEATURE_SHARD_ACTION_NAME, true, RankFeatureResult::new); + final TransportRequestHandler shardFetchRequestHandler = (request, channel, task) -> searchService .executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel)); transportService.registerRequestHandler( diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index e0ca0f7a48cdd..fd2aabce8e952 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -1044,6 +1044,7 @@ record PluginServiceInstances( threadPool, scriptService, bigArrays, + searchModule.getRankFeatureShardPhase(), searchModule.getFetchPhase(), responseCollectorService, circuitBreakerService, diff --git a/server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java b/server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java index ab90ca42bca98..914dd51d0c6b2 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java +++ b/server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java @@ -33,6 +33,7 @@ import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.fetch.FetchPhase; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.threadpool.ThreadPool; @@ -116,6 +117,7 @@ SearchService newSearchService( ThreadPool threadPool, ScriptService scriptService, BigArrays bigArrays, + RankFeatureShardPhase rankFeatureShardPhase, FetchPhase fetchPhase, ResponseCollectorService responseCollectorService, CircuitBreakerService circuitBreakerService, @@ -128,6 +130,7 @@ SearchService newSearchService( threadPool, scriptService, bigArrays, + rankFeatureShardPhase, fetchPhase, responseCollectorService, circuitBreakerService, diff --git a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java index 9bacf19a9169d..4f16d3a5720fb 100644 --- a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java @@ -70,6 +70,7 @@ import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.sort.SortAndFormats; @@ -102,6 +103,7 @@ final class DefaultSearchContext extends SearchContext { private final ContextIndexSearcher searcher; private DfsSearchResult dfsResult; private QuerySearchResult queryResult; + private RankFeatureResult rankFeatureResult; private FetchSearchResult fetchResult; private final float queryBoost; private final boolean lowLevelCancellation; @@ -308,6 +310,17 @@ static boolean isParallelCollectionSupportedForResults( return false; } + @Override + public void addRankFeatureResult() { + this.rankFeatureResult = new RankFeatureResult(this.readerContext.id(), this.shardTarget, this.request); + addReleasable(rankFeatureResult::decRef); + } + + @Override + public RankFeatureResult rankFeatureResult() { + return rankFeatureResult; + } + @Override public void addFetchResult() { this.fetchResult = new FetchSearchResult(this.readerContext.id(), this.shardTarget); diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index 8d5fa0a7ac155..d93ff91a6ffe4 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -226,6 +226,7 @@ import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.search.fetch.subphase.highlight.PlainHighlighter; import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.retriever.KnnRetrieverBuilder; @@ -1252,6 +1253,10 @@ private void registerQuery(QuerySpec spec) { ); } + public RankFeatureShardPhase getRankFeatureShardPhase() { + return new RankFeatureShardPhase(); + } + public FetchPhase getFetchPhase() { return new FetchPhase(fetchSubPhases); } diff --git a/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java b/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java index 254cd7d3370b5..450b98b22f39c 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java +++ b/server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java @@ -15,6 +15,7 @@ import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.transport.TransportResponse; import java.io.IOException; @@ -43,6 +44,14 @@ protected SearchPhaseResult(StreamInput in) throws IOException { super(in); } + /** + * Specifies whether the specific search phase results are associated with an opened SearchContext on the shards that + * executed the request. + */ + public boolean hasSearchContext() { + return false; + } + /** * Returns the search context ID that is used to reference the search context on the executing node * or null if no context was created. @@ -81,6 +90,13 @@ public QuerySearchResult queryResult() { return null; } + /** + * Returns the rank feature result iff it's included in this response otherwise null + */ + public RankFeatureResult rankFeatureResult() { + return null; + } + /** * Returns the fetch result iff it's included in this response otherwise null */ diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 41796967c3870..3f9dd7895f6a7 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -112,6 +112,9 @@ import org.elasticsearch.search.query.QuerySearchRequest; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.query.ScrollQuerySearchResult; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; +import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.searchafter.SearchAfterBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; @@ -151,6 +154,7 @@ import static org.elasticsearch.core.TimeValue.timeValueMillis; import static org.elasticsearch.core.TimeValue.timeValueMinutes; import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; +import static org.elasticsearch.search.rank.feature.RankFeatureShardPhase.EMPTY_RESULT; public class SearchService extends AbstractLifecycleComponent implements IndexEventListener { private static final Logger logger = LogManager.getLogger(SearchService.class); @@ -276,6 +280,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv private final DfsPhase dfsPhase = new DfsPhase(); private final FetchPhase fetchPhase; + private final RankFeatureShardPhase rankFeatureShardPhase; private volatile boolean enableSearchWorkerThreads; private volatile boolean enableQueryPhaseParallelCollection; @@ -314,6 +319,7 @@ public SearchService( ThreadPool threadPool, ScriptService scriptService, BigArrays bigArrays, + RankFeatureShardPhase rankFeatureShardPhase, FetchPhase fetchPhase, ResponseCollectorService responseCollectorService, CircuitBreakerService circuitBreakerService, @@ -327,6 +333,7 @@ public SearchService( this.scriptService = scriptService; this.responseCollectorService = responseCollectorService; this.bigArrays = bigArrays; + this.rankFeatureShardPhase = rankFeatureShardPhase; this.fetchPhase = fetchPhase; this.multiBucketConsumerService = new MultiBucketConsumerService( clusterService, @@ -713,6 +720,32 @@ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchSh } } + public void executeRankFeaturePhase(RankFeatureShardRequest request, SearchShardTask task, ActionListener listener) { + final ReaderContext readerContext = findReaderContext(request.contextId(), request); + final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest()); + final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest)); + runAsync(getExecutor(readerContext.indexShard()), () -> { + try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, ResultsType.RANK_FEATURE, false)) { + int[] docIds = request.getDocIds(); + if (docIds == null || docIds.length == 0) { + searchContext.rankFeatureResult().shardResult(EMPTY_RESULT); + searchContext.rankFeatureResult().incRef(); + return searchContext.rankFeatureResult(); + } + rankFeatureShardPhase.prepareForFetch(searchContext, request); + fetchPhase.execute(searchContext, docIds); + rankFeatureShardPhase.processFetch(searchContext); + var rankFeatureResult = searchContext.rankFeatureResult(); + rankFeatureResult.incRef(); + return rankFeatureResult; + } catch (Exception e) { + assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e); + // we handle the failure in the failure listener below + throw e; + } + }, wrapFailureListener(listener, readerContext, markAsUsed)); + } + private QueryFetchSearchResult executeFetchPhase(ReaderContext reader, SearchContext context, long afterQueryTime) { try ( Releasable scope = tracer.withScope(context.getTask()); @@ -1559,6 +1592,12 @@ void addResultsObject(SearchContext context) { context.addQueryResult(); } }, + RANK_FEATURE { + @Override + void addResultsObject(SearchContext context) { + context.addRankFeatureResult(); + } + }, FETCH { @Override void addResultsObject(SearchContext context) { diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java index 0c54e8ff89589..4ba191794413d 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java @@ -98,7 +98,6 @@ public Source getSource(LeafReaderContext ctx, int doc) { } private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Profiler profiler) { - FetchContext fetchContext = new FetchContext(context); SourceLoader sourceLoader = context.newSourceLoader(); diff --git a/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java index d5c3c00c00ce1..e32397e25d773 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java @@ -35,6 +35,7 @@ import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; @@ -374,6 +375,16 @@ public float getMaxScore() { return in.getMaxScore(); } + @Override + public void addRankFeatureResult() { + in.addRankFeatureResult(); + } + + @Override + public RankFeatureResult rankFeatureResult() { + return in.rankFeatureResult(); + } + @Override public FetchSearchResult fetchResult() { return in.fetchResult(); diff --git a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java index 35f96ee2dc102..9bc622034184c 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java @@ -42,6 +42,7 @@ import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; @@ -332,6 +333,10 @@ public Query rewrittenQuery() { public abstract float getMaxScore(); + public abstract void addRankFeatureResult(); + + public abstract RankFeatureResult rankFeatureResult(); + public abstract FetchPhase fetchPhase(); public abstract FetchSearchResult fetchResult(); diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java index 828c6d2b4f3e8..0d2610aa34282 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java @@ -87,35 +87,38 @@ static void executeRank(SearchContext searchContext) throws QueryPhaseExecutionE boolean searchTimedOut = querySearchResult.searchTimedOut(); long serviceTimeEWMA = querySearchResult.serviceTimeEWMA(); int nodeQueueSize = querySearchResult.nodeQueueSize(); - - // run each of the rank queries - for (Query rankQuery : queryPhaseRankShardContext.queries()) { - // if a search timeout occurs, exit with partial results - if (searchTimedOut) { - break; - } - try ( - RankSearchContext rankSearchContext = new RankSearchContext( - searchContext, - rankQuery, - queryPhaseRankShardContext.rankWindowSize() - ) - ) { - QueryPhase.addCollectorsAndSearch(rankSearchContext); - QuerySearchResult rrfQuerySearchResult = rankSearchContext.queryResult(); - rrfRankResults.add(rrfQuerySearchResult.topDocs().topDocs); - serviceTimeEWMA += rrfQuerySearchResult.serviceTimeEWMA(); - nodeQueueSize = Math.max(nodeQueueSize, rrfQuerySearchResult.nodeQueueSize()); - searchTimedOut = rrfQuerySearchResult.searchTimedOut(); + try { + // run each of the rank queries + for (Query rankQuery : queryPhaseRankShardContext.queries()) { + // if a search timeout occurs, exit with partial results + if (searchTimedOut) { + break; + } + try ( + RankSearchContext rankSearchContext = new RankSearchContext( + searchContext, + rankQuery, + queryPhaseRankShardContext.rankWindowSize() + ) + ) { + QueryPhase.addCollectorsAndSearch(rankSearchContext); + QuerySearchResult rrfQuerySearchResult = rankSearchContext.queryResult(); + rrfRankResults.add(rrfQuerySearchResult.topDocs().topDocs); + serviceTimeEWMA += rrfQuerySearchResult.serviceTimeEWMA(); + nodeQueueSize = Math.max(nodeQueueSize, rrfQuerySearchResult.nodeQueueSize()); + searchTimedOut = rrfQuerySearchResult.searchTimedOut(); + } } - } - querySearchResult.setRankShardResult(queryPhaseRankShardContext.combineQueryPhaseResults(rrfRankResults)); + querySearchResult.setRankShardResult(queryPhaseRankShardContext.combineQueryPhaseResults(rrfRankResults)); - // record values relevant to all queries - querySearchResult.searchTimedOut(searchTimedOut); - querySearchResult.serviceTimeEWMA(serviceTimeEWMA); - querySearchResult.nodeQueueSize(nodeQueueSize); + // record values relevant to all queries + querySearchResult.searchTimedOut(searchTimedOut); + querySearchResult.serviceTimeEWMA(serviceTimeEWMA); + querySearchResult.nodeQueueSize(nodeQueueSize); + } catch (Exception e) { + throw new QueryPhaseExecutionException(searchContext.shardTarget(), "Failed to execute rank query", e); + } } static void executeQuery(SearchContext searchContext) throws QueryPhaseExecutionException { diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java b/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java index 7118c9f49b36d..f496758c3f5c6 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java @@ -16,6 +16,8 @@ import org.elasticsearch.search.SearchService; import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -32,7 +34,7 @@ public abstract class RankBuilder implements VersionedNamedWriteable, ToXContent public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size"); - public static final int DEFAULT_WINDOW_SIZE = SearchService.DEFAULT_SIZE; + public static final int DEFAULT_RANK_WINDOW_SIZE = SearchService.DEFAULT_SIZE; private final int rankWindowSize; @@ -68,6 +70,12 @@ public int rankWindowSize() { return rankWindowSize; } + /** + * Specify whether this rank builder is a compound builder or not. A compound builder is a rank builder that requires + * two or more queries to be executed in order to generate the final result. + */ + public abstract boolean isCompoundBuilder(); + /** * Generates a context used to execute required searches during the query phase on the shard. */ @@ -78,6 +86,19 @@ public int rankWindowSize() { */ public abstract QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from); + /** + * Generates a context used to execute the rank feature phase on the shard. This is responsible for retrieving any needed + * feature data, and passing them back to the coordinator through the appropriate {@link RankShardResult}. + */ + public abstract RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(); + + /** + * Generates a context used to perform global ranking during the RankFeature phase, + * on the coordinator based on all the individual shard results. The output of this will be a `size` ranked list of ordered results, + * which will then be passed to fetch phase. + */ + public abstract RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from); + @Override public final boolean equals(Object obj) { if (this == obj) { diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java b/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java index 1cb5843dfc7da..7f8e99971d61b 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java @@ -43,6 +43,7 @@ import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; @@ -57,14 +58,14 @@ public class RankSearchContext extends SearchContext { private final SearchContext parent; private final Query rankQuery; - private final int windowSize; + private final int rankWindowSize; private final QuerySearchResult querySearchResult; @SuppressWarnings("this-escape") - public RankSearchContext(SearchContext parent, Query rankQuery, int windowSize) { + public RankSearchContext(SearchContext parent, Query rankQuery, int rankWindowSize) { this.parent = parent; this.rankQuery = parent.buildFilteredQuery(rankQuery); - this.windowSize = windowSize; + this.rankWindowSize = rankWindowSize; this.querySearchResult = new QuerySearchResult(parent.readerContext().id(), parent.shardTarget(), parent.request()); this.addReleasable(querySearchResult::decRef); } @@ -182,7 +183,7 @@ public int from() { @Override public int size() { - return windowSize; + return rankWindowSize; } /** @@ -492,6 +493,16 @@ public FetchPhase fetchPhase() { throw new UnsupportedOperationException(); } + @Override + public void addRankFeatureResult() { + throw new UnsupportedOperationException(); + } + + @Override + public RankFeatureResult rankFeatureResult() { + throw new UnsupportedOperationException(); + } + @Override public FetchSearchResult fetchResult() { throw new UnsupportedOperationException(); diff --git a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java new file mode 100644 index 0000000000000..b8951a4779166 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.context; + +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardResult; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; + +import static org.elasticsearch.search.SearchService.DEFAULT_FROM; +import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; + +/** + * {@code RankFeaturePhaseRankCoordinatorContext} is a base class that runs on the coordinating node and is responsible for retrieving + * {@code window_size} total results from all shards, rank them, and then produce a final paginated response of [from, from+size] results. + */ +public abstract class RankFeaturePhaseRankCoordinatorContext { + + protected final int size; + protected final int from; + protected final int rankWindowSize; + + public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) { + this.size = size < 0 ? DEFAULT_SIZE : size; + this.from = from < 0 ? DEFAULT_FROM : from; + this.rankWindowSize = rankWindowSize; + } + + /** + * Computes the updated scores for a list of features (i.e. document-based data). We also pass along an ActionListener + * that should be called with the new scores, and will continue execution to the next phase + */ + protected abstract void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener); + + /** + * This method is responsible for ranking the global results based on the provided rank feature results from each shard. + *

+ * We first start by extracting ordered feature data through a {@code List} + * from the provided rankSearchResults, and then compute the updated score for each of the documents. + * Once all the scores have been computed, we sort the results, perform any pagination needed, and then call the `onFinish` consumer + * with the final array of {@link ScoreDoc} results. + * + * @param rankSearchResults a list of rank feature results from each shard + * @param rankListener a rankListener to handle the global ranking result + */ + public void rankGlobalResults(List rankSearchResults, ActionListener rankListener) { + // extract feature data from each shard rank-feature phase result + RankFeatureDoc[] featureDocs = extractFeatureDocs(rankSearchResults); + + // generate the final `topResults` paginated results, and pass them to fetch phase through the `rankListener` + computeScores(featureDocs, rankListener.delegateFailureAndWrap((listener, scores) -> { + for (int i = 0; i < featureDocs.length; i++) { + featureDocs[i].score = scores[i]; + } + listener.onResponse(featureDocs); + })); + } + + /** + * Ranks the provided {@link RankFeatureDoc} array and paginates the results based on the `from` and `size` parameters. + */ + public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) { + Arrays.sort(rankFeatureDocs, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()); + RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, rankFeatureDocs.length - from))]; + for (int rank = 0; rank < topResults.length; ++rank) { + topResults[rank] = rankFeatureDocs[from + rank]; + topResults[rank].rank = from + rank + 1; + } + return topResults; + } + + private RankFeatureDoc[] extractFeatureDocs(List rankSearchResults) { + List docFeatures = new ArrayList<>(); + for (RankFeatureResult rankFeatureResult : rankSearchResults) { + RankFeatureShardResult shardResult = rankFeatureResult.shardResult(); + for (RankFeatureDoc rankFeatureDoc : shardResult.rankFeatureDocs) { + if (rankFeatureDoc.featureData != null) { + docFeatures.add(rankFeatureDoc); + } + } + } + return docFeatures.toArray(new RankFeatureDoc[0]); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java new file mode 100644 index 0000000000000..5d3f30bce757a --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.context; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.rank.RankShardResult; + +/** + * {@link RankFeaturePhaseRankShardContext} is a base class used to execute the RankFeature phase on each shard. + * In this class, we can fetch the feature data for a given set of documents and pass them back to the coordinator + * through the {@link RankShardResult}. + */ +public abstract class RankFeaturePhaseRankShardContext { + + protected final String field; + + public RankFeaturePhaseRankShardContext(final String field) { + this.field = field; + } + + public String getField() { + return field; + } + + /** + * This is used to fetch the feature data for a given set of documents, using the {@link org.elasticsearch.search.fetch.FetchPhase} + * and the {@link org.elasticsearch.search.fetch.subphase.FetchFieldsPhase} subphase. + * The feature data is then stored in a {@link org.elasticsearch.search.rank.feature.RankFeatureDoc} and passed back to the coordinator. + */ + @Nullable + public abstract RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId); +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java new file mode 100644 index 0000000000000..8eb3f2fc8339b --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.feature; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.search.rank.RankDoc; + +import java.io.IOException; +import java.util.Objects; + +/** + * A {@link RankDoc} that contains field data to be used later by the reranker on the coordinator node. + */ +public class RankFeatureDoc extends RankDoc { + + // todo: update to support more than 1 fields; and not restrict to string data + public String featureData; + + public RankFeatureDoc(int doc, float score, int shardIndex) { + super(doc, score, shardIndex); + } + + public RankFeatureDoc(StreamInput in) throws IOException { + super(in); + featureData = in.readOptionalString(); + } + + public void featureData(String featureData) { + this.featureData = featureData; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeOptionalString(featureData); + } + + @Override + protected boolean doEquals(RankDoc rd) { + RankFeatureDoc other = (RankFeatureDoc) rd; + return Objects.equals(this.featureData, other.featureData); + } + + @Override + protected int doHashCode() { + return Objects.hashCode(featureData); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureResult.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureResult.java new file mode 100644 index 0000000000000..1e16d18cda367 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureResult.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.feature; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.search.internal.ShardSearchRequest; + +import java.io.IOException; + +/** + * The result of a rank feature search phase. + * Each instance holds a {@code RankFeatureShardResult} along with the references associated with it. + */ +public class RankFeatureResult extends SearchPhaseResult { + + private RankFeatureShardResult rankShardResult; + + public RankFeatureResult() {} + + public RankFeatureResult(ShardSearchContextId id, SearchShardTarget shardTarget, ShardSearchRequest request) { + this.contextId = id; + setSearchShardTarget(shardTarget); + setShardSearchRequest(request); + } + + public RankFeatureResult(StreamInput in) throws IOException { + super(in); + contextId = new ShardSearchContextId(in); + rankShardResult = in.readOptionalWriteable(RankFeatureShardResult::new); + setShardSearchRequest(in.readOptionalWriteable(ShardSearchRequest::new)); + setSearchShardTarget(in.readOptionalWriteable(SearchShardTarget::new)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + assert hasReferences(); + contextId.writeTo(out); + out.writeOptionalWriteable(rankShardResult); + out.writeOptionalWriteable(getShardSearchRequest()); + out.writeOptionalWriteable(getSearchShardTarget()); + } + + @Override + public RankFeatureResult rankFeatureResult() { + return this; + } + + public void shardResult(RankFeatureShardResult shardResult) { + this.rankShardResult = shardResult; + } + + public RankFeatureShardResult shardResult() { + return rankShardResult; + } + + @Override + public boolean hasSearchContext() { + return rankShardResult != null; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java new file mode 100644 index 0000000000000..727ed4e938cca --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.feature; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.search.SearchContextSourcePrinter; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.fetch.FetchSearchResult; +import org.elasticsearch.search.fetch.StoredFieldsContext; +import org.elasticsearch.search.fetch.subphase.FetchFieldsContext; +import org.elasticsearch.search.fetch.subphase.FieldAndFormat; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; +import org.elasticsearch.tasks.TaskCancelledException; + +import java.util.Arrays; +import java.util.Collections; + +/** + * The {@code RankFeatureShardPhase} executes the rank feature phase on the shard, iff there is a {@code RankBuilder} that requires it. + * This phase is responsible for reading field data for a set of docids. To do this, it reuses the {@code FetchPhase} to read the required + * fields for all requested documents using the `FetchFieldPhase` sub-phase. + */ +public final class RankFeatureShardPhase { + + private static final Logger logger = LogManager.getLogger(RankFeatureShardPhase.class); + + public static final RankFeatureShardResult EMPTY_RESULT = new RankFeatureShardResult(new RankFeatureDoc[0]); + + public RankFeatureShardPhase() {} + + public void prepareForFetch(SearchContext searchContext, RankFeatureShardRequest request) { + if (logger.isTraceEnabled()) { + logger.trace("{}", new SearchContextSourcePrinter(searchContext)); + } + + if (searchContext.isCancelled()) { + throw new TaskCancelledException("cancelled"); + } + + RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = shardContext(searchContext); + if (rankFeaturePhaseRankShardContext != null) { + assert rankFeaturePhaseRankShardContext.getField() != null : "field must not be null"; + searchContext.fetchFieldsContext( + new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(rankFeaturePhaseRankShardContext.getField(), null))) + ); + searchContext.storedFieldsContext(StoredFieldsContext.fromList(Collections.singletonList(StoredFieldsContext._NONE_))); + searchContext.addFetchResult(); + Arrays.sort(request.getDocIds()); + } + } + + public void processFetch(SearchContext searchContext) { + if (logger.isTraceEnabled()) { + logger.trace("{}", new SearchContextSourcePrinter(searchContext)); + } + + if (searchContext.isCancelled()) { + throw new TaskCancelledException("cancelled"); + } + + RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = searchContext.request().source().rankBuilder() != null + ? searchContext.request().source().rankBuilder().buildRankFeaturePhaseShardContext() + : null; + if (rankFeaturePhaseRankShardContext != null) { + // TODO: here we populate the profile part of the fetchResult as well + // we need to see what info we want to include on the overall profiling section. This is something that is per-shard + // so most likely we will still care about the `FetchFieldPhase` profiling info as we could potentially + // operate on `rank_window_size` instead of just `size` results, so this could be much more expensive. + FetchSearchResult fetchSearchResult = searchContext.fetchResult(); + if (fetchSearchResult == null || fetchSearchResult.hits() == null) { + return; + } + // this cannot be null; as we have either already checked for it, or we would have thrown in + // FetchSearchResult#shardResult() + SearchHits hits = fetchSearchResult.hits(); + RankFeatureShardResult featureRankShardResult = (RankFeatureShardResult) rankFeaturePhaseRankShardContext + .buildRankFeatureShardResult(hits, searchContext.shardTarget().getShardId().id()); + // save the result in the search context + // need to add profiling info as well available from fetch + if (featureRankShardResult != null) { + searchContext.rankFeatureResult().shardResult(featureRankShardResult); + } + } + } + + private RankFeaturePhaseRankShardContext shardContext(SearchContext searchContext) { + return searchContext.request().source() != null && searchContext.request().source().rankBuilder() != null + ? searchContext.request().source().rankBuilder().buildRankFeaturePhaseShardContext() + : null; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java new file mode 100644 index 0000000000000..d487fb63a0102 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.feature; + +import org.elasticsearch.action.IndicesRequest; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.search.SearchShardTask; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.transport.TransportRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +/** + * Shard level request for extracting all needed feature for a global reranker + */ + +public class RankFeatureShardRequest extends TransportRequest implements IndicesRequest { + + private final OriginalIndices originalIndices; + private final ShardSearchRequest shardSearchRequest; + + private final ShardSearchContextId contextId; + + private final int[] docIds; + + public RankFeatureShardRequest( + OriginalIndices originalIndices, + ShardSearchContextId contextId, + ShardSearchRequest shardSearchRequest, + List docIds + ) { + this.originalIndices = originalIndices; + this.shardSearchRequest = shardSearchRequest; + this.docIds = docIds.stream().flatMapToInt(IntStream::of).toArray(); + this.contextId = contextId; + } + + public RankFeatureShardRequest(StreamInput in) throws IOException { + super(in); + originalIndices = OriginalIndices.readOriginalIndices(in); + shardSearchRequest = in.readOptionalWriteable(ShardSearchRequest::new); + docIds = in.readIntArray(); + contextId = in.readOptionalWriteable(ShardSearchContextId::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + OriginalIndices.writeOriginalIndices(originalIndices, out); + out.writeOptionalWriteable(shardSearchRequest); + out.writeIntArray(docIds); + out.writeOptionalWriteable(contextId); + } + + @Override + public String[] indices() { + if (originalIndices == null) { + return null; + } + return originalIndices.indices(); + } + + @Override + public IndicesOptions indicesOptions() { + if (originalIndices == null) { + return null; + } + return originalIndices.indicesOptions(); + } + + public ShardSearchRequest getShardSearchRequest() { + return shardSearchRequest; + } + + public int[] getDocIds() { + return docIds; + } + + public ShardSearchContextId contextId() { + return contextId; + } + + @Override + public SearchShardTask createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new SearchShardTask(id, type, action, getDescription(), parentTaskId, headers); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardResult.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardResult.java new file mode 100644 index 0000000000000..e06b963621c60 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardResult.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank.feature; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.search.rank.RankShardResult; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +/** + * The result set of {@link RankFeatureDoc} docs for the shard. + */ +public class RankFeatureShardResult implements RankShardResult { + + public final RankFeatureDoc[] rankFeatureDocs; + + public RankFeatureShardResult(RankFeatureDoc[] rankFeatureDocs) { + this.rankFeatureDocs = Objects.requireNonNull(rankFeatureDocs); + } + + public RankFeatureShardResult(StreamInput in) throws IOException { + rankFeatureDocs = in.readArray(RankFeatureDoc::new, RankFeatureDoc[]::new); + } + + @Override + public String getWriteableName() { + return "rank_feature_shard"; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.RANK_FEATURE_PHASE_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeArray(rankFeatureDocs); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RankFeatureShardResult that = (RankFeatureShardResult) o; + return Arrays.equals(rankFeatureDocs, that.rankFeatureDocs); + } + + @Override + public int hashCode() { + return 31 * Arrays.hashCode(rankFeatureDocs); + } + + @Override + public String toString() { + return this.getClass().getSimpleName() + "{rankFeatureDocs=" + Arrays.toString(rankFeatureDocs) + '}'; + } +} diff --git a/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java new file mode 100644 index 0000000000000..9716749562eae --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java @@ -0,0 +1,1170 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ +package org.elasticsearch.action.search; + +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.tests.store.MockDirectoryWrapper; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.search.rank.RankBuilder; +import org.elasticsearch.search.rank.RankShardResult; +import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; +import org.elasticsearch.search.rank.feature.RankFeatureShardResult; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.InternalAggregationTestCase; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +public class RankFeaturePhaseTests extends ESTestCase { + + private static final int DEFAULT_RANK_WINDOW_SIZE = 10; + private static final int DEFAULT_FROM = 0; + private static final int DEFAULT_SIZE = 10; + private static final String DEFAULT_FIELD = "some_field"; + + private final RankBuilder DEFAULT_RANK_BUILDER = rankBuilder( + DEFAULT_RANK_WINDOW_SIZE, + defaultQueryPhaseRankShardContext(new ArrayList<>(), DEFAULT_RANK_WINDOW_SIZE), + defaultQueryPhaseRankCoordinatorContext(DEFAULT_RANK_WINDOW_SIZE), + defaultRankFeaturePhaseRankShardContext(DEFAULT_FIELD), + defaultRankFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, DEFAULT_RANK_WINDOW_SIZE) + ); + + private record ExpectedRankFeatureDoc(int doc, int rank, float score, String featureData) {} + + public void testRankFeaturePhaseWith1Shard() { + // request params used within SearchSourceBuilder and *RankContext classes + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(DEFAULT_RANK_BUILDER); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123); + QuerySearchResult queryResult = new QuerySearchResult(ctx, shard1Target, null); + try { + queryResult.setShardIndex(shard1Target.getShardId().getId()); + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 2 results, with doc ids 1 and 2 + int totalHits = randomIntBetween(2, 100); + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F), new ScoreDoc(2, 9.0F) }; + populateQuerySearchResult(queryResult, totalHits, shard1Docs); + results.consumeResult(queryResult, () -> {}); + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + // make sure to match the context id generated above, otherwise we throw + if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1, 2 })) { + RankFeatureResult rankFeatureResult = new RankFeatureResult(); + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard1Target, + totalHits, + shard1Docs + ); + listener.onResponse(rankFeatureResult); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + } + }; + } finally { + queryResult.decRef(); + } + + RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone); + try { + rankFeaturePhase.run(); + + mockSearchPhaseContext.assertNoFailure(); + assertTrue(mockSearchPhaseContext.failures.isEmpty()); + assertTrue(phaseDone.get()); + assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); + + SearchPhaseResults rankPhaseResults = rankFeaturePhase.rankPhaseResults; + assertNotNull(rankPhaseResults.getAtomicArray()); + assertEquals(1, rankPhaseResults.getAtomicArray().length()); + assertEquals(1, rankPhaseResults.getSuccessfulResults().count()); + + SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); + List expectedShardResults = List.of( + new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"), + new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2") + ); + List expectedFinalResults = new ArrayList<>(expectedShardResults); + assertShardResults(shard1Result, expectedShardResults); + assertFinalResults(finalResults[0], expectedFinalResults); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + public void testRankFeaturePhaseWithMultipleShardsOneEmpty() { + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(DEFAULT_RANK_BUILDER); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 1), null); + SearchShardTarget shard3Target = new SearchShardTarget("node2", new ShardId("test", "na", 2), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(3); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 2 results, with doc ids 1 and 2 found on shards 0 and 1 respectively + final ShardSearchContextId ctxShard1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + final ShardSearchContextId ctxShard2 = new ShardSearchContextId(UUIDs.base64UUID(), 456); + final ShardSearchContextId ctxShard3 = new ShardSearchContextId(UUIDs.base64UUID(), 789); + + QuerySearchResult queryResultShard1 = new QuerySearchResult(ctxShard1, shard1Target, null); + QuerySearchResult queryResultShard2 = new QuerySearchResult(ctxShard2, shard2Target, null); + QuerySearchResult queryResultShard3 = new QuerySearchResult(ctxShard3, shard2Target, null); + try { + queryResultShard1.setShardIndex(shard1Target.getShardId().getId()); + queryResultShard2.setShardIndex(shard2Target.getShardId().getId()); + queryResultShard3.setShardIndex(shard3Target.getShardId().getId()); + + final int shard1Results = randomIntBetween(1, 100); + final int shard2Results = randomIntBetween(1, 100); + final int shard3Results = 0; + + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F) }; + populateQuerySearchResult(queryResultShard1, shard1Results, shard1Docs); + final ScoreDoc[] shard2Docs = new ScoreDoc[] { new ScoreDoc(2, 9.0F) }; + populateQuerySearchResult(queryResultShard2, shard2Results, shard2Docs); + final ScoreDoc[] shard3Docs = new ScoreDoc[0]; + populateQuerySearchResult(queryResultShard3, shard3Results, shard3Docs); + + results.consumeResult(queryResultShard2, () -> {}); + results.consumeResult(queryResultShard3, () -> {}); + results.consumeResult(queryResultShard1, () -> {}); + + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + // make sure to match the context id generated above, otherwise we throw + // first shard + RankFeatureResult rankFeatureResult = new RankFeatureResult(); + if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1 })) { + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard1Target, + shard1Results, + shard1Docs + ); + listener.onResponse(rankFeatureResult); + } else if (request.contextId().getId() == 456 && Arrays.equals(request.getDocIds(), new int[] { 2 })) { + // second shard + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard2Target, + shard2Results, + shard2Docs + ); + listener.onResponse(rankFeatureResult); + } else if (request.contextId().getId() == 789) { + listener.onResponse(rankFeatureResult); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + } + }; + } finally { + queryResultShard1.decRef(); + queryResultShard2.decRef(); + queryResultShard3.decRef(); + } + RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone); + try { + rankFeaturePhase.run(); + mockSearchPhaseContext.assertNoFailure(); + assertTrue(mockSearchPhaseContext.failures.isEmpty()); + assertTrue(phaseDone.get()); + SearchPhaseResults rankPhaseResults = rankFeaturePhase.rankPhaseResults; + assertNotNull(rankPhaseResults.getAtomicArray()); + assertEquals(3, rankPhaseResults.getAtomicArray().length()); + // one result is null + assertEquals(2, rankPhaseResults.getSuccessfulResults().count()); + + SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); + List expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1")); + assertShardResults(shard1Result, expectedShard1Results); + + SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1); + List expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, "ranked_2")); + assertShardResults(shard2Result, expectedShard2Results); + + SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2); + assertNull(shard3Result); + + List expectedFinalResults = List.of( + new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"), + new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2") + ); + assertFinalResults(finalResults[0], expectedFinalResults); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + public void testRankFeaturePhaseNoNeedForFetchingFieldData() { + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // build the appropriate RankBuilder; using a null rankFeaturePhaseRankShardContext + // and non-field based rankFeaturePhaseRankCoordinatorContext + RankBuilder rankBuilder = rankBuilder( + DEFAULT_RANK_WINDOW_SIZE, + defaultQueryPhaseRankShardContext(Collections.emptyList(), DEFAULT_RANK_WINDOW_SIZE), + negatingScoresQueryFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, DEFAULT_RANK_WINDOW_SIZE), + null, + null + ); + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(rankBuilder); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 2 results, with doc ids 1 and 2 + final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123); + QuerySearchResult queryResult = new QuerySearchResult(ctx, shard1Target, null); + + try { + queryResult.setShardIndex(shard1Target.getShardId().getId()); + int totalHits = randomIntBetween(2, 100); + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F), new ScoreDoc(2, 9.0F) }; + populateQuerySearchResult(queryResult, totalHits, shard1Docs); + results.consumeResult(queryResult, () -> {}); + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + // make sure to match the context id generated above, otherwise we throw + if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1, 2 })) { + listener.onFailure(new UnsupportedOperationException("should not have reached here")); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + } + }; + } finally { + queryResult.decRef(); + } + // override the RankFeaturePhase to skip moving to next phase + RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone); + try { + rankFeaturePhase.run(); + mockSearchPhaseContext.assertNoFailure(); + assertTrue(mockSearchPhaseContext.failures.isEmpty()); + assertTrue(phaseDone.get()); + + // in this case there was no additional "RankFeature" results on shards, so we shortcut directly to queryPhaseResults + SearchPhaseResults rankPhaseResults = rankFeaturePhase.queryPhaseResults; + assertNotNull(rankPhaseResults.getAtomicArray()); + assertEquals(1, rankPhaseResults.getAtomicArray().length()); + assertEquals(1, rankPhaseResults.getSuccessfulResults().count()); + + SearchPhaseResult shardResult = rankPhaseResults.getAtomicArray().get(0); + assertTrue(shardResult instanceof QuerySearchResult); + QuerySearchResult rankResult = (QuerySearchResult) shardResult; + assertNull(rankResult.rankFeatureResult()); + assertNotNull(rankResult.queryResult()); + + List expectedFinalResults = List.of( + new ExpectedRankFeatureDoc(2, 1, -9.0F, null), + new ExpectedRankFeatureDoc(1, 2, -10.0F, null) + ); + assertFinalResults(finalResults[0], expectedFinalResults); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + public void testRankFeaturePhaseOneShardFails() { + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(DEFAULT_RANK_BUILDER); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 1), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 2 results, with doc ids 1 and 2 found on shards 0 and 1 respectively + final ShardSearchContextId ctxShard1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + final ShardSearchContextId ctxShard2 = new ShardSearchContextId(UUIDs.base64UUID(), 456); + + QuerySearchResult queryResultShard1 = new QuerySearchResult(ctxShard1, shard1Target, null); + QuerySearchResult queryResultShard2 = new QuerySearchResult(ctxShard2, shard2Target, null); + try { + queryResultShard1.setShardIndex(shard1Target.getShardId().getId()); + queryResultShard2.setShardIndex(shard2Target.getShardId().getId()); + + final int shard1Results = randomIntBetween(1, 100); + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F) }; + populateQuerySearchResult(queryResultShard1, shard1Results, shard1Docs); + + final int shard2Results = randomIntBetween(1, 100); + final ScoreDoc[] shard2Docs = new ScoreDoc[] { new ScoreDoc(2, 9.0F) }; + populateQuerySearchResult(queryResultShard2, shard2Results, shard2Docs); + + results.consumeResult(queryResultShard2, () -> {}); + results.consumeResult(queryResultShard1, () -> {}); + + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + // make sure to match the context id generated above, otherwise we throw + // first shard + if (request.contextId().getId() == 456 && Arrays.equals(request.getDocIds(), new int[] { 2 })) { + RankFeatureResult rankFeatureResult = new RankFeatureResult(); + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard2Target, + shard2Results, + shard2Docs + ); + listener.onResponse(rankFeatureResult); + + } else if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1 })) { + // other shard; this one throws an exception + listener.onFailure(new IllegalArgumentException("simulated failure")); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + } + }; + } finally { + queryResultShard1.decRef(); + queryResultShard2.decRef(); + } + RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone); + try { + rankFeaturePhase.run(); + + mockSearchPhaseContext.assertNoFailure(); + assertEquals(1, mockSearchPhaseContext.failures.size()); + assertTrue(mockSearchPhaseContext.failures.get(0).getCause().getMessage().contains("simulated failure")); + assertTrue(phaseDone.get()); + + SearchPhaseResults rankPhaseResults = rankFeaturePhase.rankPhaseResults; + assertNotNull(rankPhaseResults.getAtomicArray()); + assertEquals(2, rankPhaseResults.getAtomicArray().length()); + // one shard failed + assertEquals(1, rankPhaseResults.getSuccessfulResults().count()); + + SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); + assertNull(shard1Result); + + SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1); + List expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, "ranked_2")); + List expectedFinalResults = new ArrayList<>(expectedShard2Results); + assertShardResults(shard2Result, expectedShard2Results); + assertFinalResults(finalResults[0], expectedFinalResults); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + public void testRankFeaturePhaseExceptionThrownOnPhase() { + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(DEFAULT_RANK_BUILDER); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 2 results, with doc ids 1 and 2 + final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123); + QuerySearchResult queryResult = new QuerySearchResult(ctx, shard1Target, null); + try { + queryResult.setShardIndex(shard1Target.getShardId().getId()); + int totalHits = randomIntBetween(2, 100); + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F), new ScoreDoc(2, 9.0F) }; + populateQuerySearchResult(queryResult, totalHits, shard1Docs); + results.consumeResult(queryResult, () -> {}); + + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + // make sure to match the context id generated above, otherwise we throw + if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1, 2 })) { + RankFeatureResult rankFeatureResult = new RankFeatureResult(); + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard1Target, + totalHits, + shard1Docs + ); + listener.onResponse(rankFeatureResult); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + } + }; + } finally { + queryResult.decRef(); + } + // override the RankFeaturePhase to raise an exception + RankFeaturePhase rankFeaturePhase = new RankFeaturePhase(results, null, mockSearchPhaseContext) { + @Override + void innerRun() { + throw new IllegalArgumentException("simulated failure"); + } + + @Override + public void moveToNextPhase( + SearchPhaseResults phaseResults, + SearchPhaseController.ReducedQueryPhase reducedQueryPhase + ) { + // this is called after the RankFeaturePhaseCoordinatorContext has been executed + phaseDone.set(true); + finalResults[0] = reducedQueryPhase.sortedTopDocs().scoreDocs(); + logger.debug("Skipping moving to next phase"); + } + }; + assertEquals("rank-feature", rankFeaturePhase.getName()); + try { + rankFeaturePhase.run(); + assertNotNull(mockSearchPhaseContext.phaseFailure.get()); + assertTrue(mockSearchPhaseContext.phaseFailure.get().getMessage().contains("simulated failure")); + assertTrue(mockSearchPhaseContext.failures.isEmpty()); + assertFalse(phaseDone.get()); + assertTrue(rankFeaturePhase.rankPhaseResults.getAtomicArray().asList().isEmpty()); + assertNull(finalResults[0][0]); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + public void testRankFeatureWithPagination() { + // request params used within SearchSourceBuilder and *RankContext classes + final int from = 1; + final int size = 1; + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // build the appropriate RankBuilder + RankBuilder rankBuilder = rankBuilder( + DEFAULT_RANK_WINDOW_SIZE, + defaultQueryPhaseRankShardContext(Collections.emptyList(), DEFAULT_RANK_WINDOW_SIZE), + defaultQueryPhaseRankCoordinatorContext(DEFAULT_RANK_WINDOW_SIZE), + defaultRankFeaturePhaseRankShardContext(DEFAULT_FIELD), + defaultRankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) + ); + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(rankBuilder); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 1), null); + SearchShardTarget shard3Target = new SearchShardTarget("node2", new ShardId("test", "na", 2), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(3); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 4 results, with doc ids 1 and (11, 2, 200) found on shards 0 and 1 respectively + final ShardSearchContextId ctxShard1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + final ShardSearchContextId ctxShard2 = new ShardSearchContextId(UUIDs.base64UUID(), 456); + final ShardSearchContextId ctxShard3 = new ShardSearchContextId(UUIDs.base64UUID(), 789); + + QuerySearchResult queryResultShard1 = new QuerySearchResult(ctxShard1, shard1Target, null); + QuerySearchResult queryResultShard2 = new QuerySearchResult(ctxShard2, shard2Target, null); + QuerySearchResult queryResultShard3 = new QuerySearchResult(ctxShard3, shard2Target, null); + + try { + queryResultShard1.setShardIndex(shard1Target.getShardId().getId()); + queryResultShard2.setShardIndex(shard2Target.getShardId().getId()); + queryResultShard3.setShardIndex(shard3Target.getShardId().getId()); + + final int shard1Results = randomIntBetween(1, 100); + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F) }; + populateQuerySearchResult(queryResultShard1, shard1Results, shard1Docs); + + final int shard2Results = randomIntBetween(1, 100); + final ScoreDoc[] shard2Docs = new ScoreDoc[] { + new ScoreDoc(11, 100.0F, -1), + new ScoreDoc(2, 9.0F), + new ScoreDoc(200, 1F, -1) }; + populateQuerySearchResult(queryResultShard2, shard2Results, shard2Docs); + + final int shard3Results = 0; + final ScoreDoc[] shard3Docs = new ScoreDoc[0]; + populateQuerySearchResult(queryResultShard3, shard3Results, shard3Docs); + + results.consumeResult(queryResultShard2, () -> {}); + results.consumeResult(queryResultShard3, () -> {}); + results.consumeResult(queryResultShard1, () -> {}); + + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + + RankFeatureResult rankFeatureResult = new RankFeatureResult(); + // make sure to match the context id generated above, otherwise we throw + // first shard + if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1 })) { + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard1Target, + shard1Results, + shard1Docs + ); + listener.onResponse(rankFeatureResult); + } else if (request.contextId().getId() == 456 && Arrays.equals(request.getDocIds(), new int[] { 11, 2, 200 })) { + // second shard + + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard2Target, + shard2Results, + shard2Docs + ); + listener.onResponse(rankFeatureResult); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + + } + }; + } finally { + queryResultShard1.decRef(); + queryResultShard2.decRef(); + queryResultShard3.decRef(); + } + RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone); + try { + rankFeaturePhase.run(); + + mockSearchPhaseContext.assertNoFailure(); + assertTrue(mockSearchPhaseContext.failures.isEmpty()); + assertTrue(phaseDone.get()); + SearchPhaseResults rankPhaseResults = rankFeaturePhase.rankPhaseResults; + assertNotNull(rankPhaseResults.getAtomicArray()); + assertEquals(3, rankPhaseResults.getAtomicArray().length()); + // one result is null + assertEquals(2, rankPhaseResults.getSuccessfulResults().count()); + + SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); + List expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1")); + assertShardResults(shard1Result, expectedShard1Results); + + SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1); + List expectedShard2Results = List.of( + new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11"), + new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2"), + new ExpectedRankFeatureDoc(200, 3, 101.0F, "ranked_200") + + ); + assertShardResults(shard2Result, expectedShard2Results); + + SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2); + assertNull(shard3Result); + + List expectedFinalResults = List.of(new ExpectedRankFeatureDoc(1, 2, 110.0F, "ranked_1")); + assertFinalResults(finalResults[0], expectedFinalResults); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + public void testRankFeatureCollectOnlyRankWindowSizeFeatures() { + // request params used within SearchSourceBuilder and *RankContext classes + final int rankWindowSize = 2; + AtomicBoolean phaseDone = new AtomicBoolean(false); + final ScoreDoc[][] finalResults = new ScoreDoc[1][1]; + + // build the appropriate RankBuilder + RankBuilder rankBuilder = rankBuilder( + rankWindowSize, + defaultQueryPhaseRankShardContext(Collections.emptyList(), rankWindowSize), + defaultQueryPhaseRankCoordinatorContext(rankWindowSize), + defaultRankFeaturePhaseRankShardContext(DEFAULT_FIELD), + defaultRankFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, rankWindowSize) + ); + // create a SearchSource to attach to the request + SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(rankBuilder); + + SearchPhaseController controller = searchPhaseController(); + SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null); + SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 1), null); + SearchShardTarget shard3Target = new SearchShardTarget("node2", new ShardId("test", "na", 2), null); + + MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(3); + mockSearchPhaseContext.getRequest().source(searchSourceBuilder); + try (SearchPhaseResults results = searchPhaseResults(controller, mockSearchPhaseContext)) { + // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase + // here we have 3 results, with doc ids 1, and (11, 2) found on shards 0 and 1 respectively + final ShardSearchContextId ctxShard1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + final ShardSearchContextId ctxShard2 = new ShardSearchContextId(UUIDs.base64UUID(), 456); + final ShardSearchContextId ctxShard3 = new ShardSearchContextId(UUIDs.base64UUID(), 789); + + QuerySearchResult queryResultShard1 = new QuerySearchResult(ctxShard1, shard1Target, null); + QuerySearchResult queryResultShard2 = new QuerySearchResult(ctxShard2, shard2Target, null); + QuerySearchResult queryResultShard3 = new QuerySearchResult(ctxShard3, shard2Target, null); + + try { + queryResultShard1.setShardIndex(shard1Target.getShardId().getId()); + queryResultShard2.setShardIndex(shard2Target.getShardId().getId()); + queryResultShard3.setShardIndex(shard3Target.getShardId().getId()); + + final int shard1Results = randomIntBetween(1, 100); + final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F) }; + populateQuerySearchResult(queryResultShard1, shard1Results, shard1Docs); + + final int shard2Results = randomIntBetween(1, 100); + final ScoreDoc[] shard2Docs = new ScoreDoc[] { new ScoreDoc(11, 100.0F), new ScoreDoc(2, 9.0F) }; + populateQuerySearchResult(queryResultShard2, shard2Results, shard2Docs); + + final int shard3Results = 0; + final ScoreDoc[] shard3Docs = new ScoreDoc[0]; + populateQuerySearchResult(queryResultShard3, shard3Results, shard3Docs); + + results.consumeResult(queryResultShard2, () -> {}); + results.consumeResult(queryResultShard3, () -> {}); + results.consumeResult(queryResultShard1, () -> {}); + + // do not make an actual http request, but rather generate the response + // as if we would have read it from the RankFeatureShardPhase + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { + @Override + public void sendExecuteRankFeature( + Transport.Connection connection, + final RankFeatureShardRequest request, + SearchTask task, + final SearchActionListener listener + ) { + RankFeatureResult rankFeatureResult = new RankFeatureResult(); + // make sure to match the context id generated above, otherwise we throw + // first shard + if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1 })) { + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard1Target, + shard1Results, + shard1Docs + ); + listener.onResponse(rankFeatureResult); + } else if (request.contextId().getId() == 456 && Arrays.equals(request.getDocIds(), new int[] { 11 })) { + // second shard + buildRankFeatureResult( + mockSearchPhaseContext.getRequest().source().rankBuilder(), + rankFeatureResult, + shard2Target, + shard2Results, + new ScoreDoc[] { shard2Docs[0] } + ); + listener.onResponse(rankFeatureResult); + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + } + }; + } finally { + queryResultShard1.decRef(); + queryResultShard2.decRef(); + queryResultShard3.decRef(); + } + RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone); + try { + rankFeaturePhase.run(); + mockSearchPhaseContext.assertNoFailure(); + assertTrue(mockSearchPhaseContext.failures.isEmpty()); + assertTrue(phaseDone.get()); + SearchPhaseResults rankPhaseResults = rankFeaturePhase.rankPhaseResults; + assertNotNull(rankPhaseResults.getAtomicArray()); + assertEquals(3, rankPhaseResults.getAtomicArray().length()); + // one result is null + assertEquals(2, rankPhaseResults.getSuccessfulResults().count()); + + SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0); + List expectedShardResults = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1")); + assertShardResults(shard1Result, expectedShardResults); + + SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1); + List expectedShard2Results = List.of(new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11")); + assertShardResults(shard2Result, expectedShard2Results); + + SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2); + assertNull(shard3Result); + + List expectedFinalResults = List.of( + new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11"), + new ExpectedRankFeatureDoc(1, 2, 110.0F, "ranked_1") + ); + assertFinalResults(finalResults[0], expectedFinalResults); + } finally { + rankFeaturePhase.rankPhaseResults.close(); + } + } finally { + if (mockSearchPhaseContext.searchResponse.get() != null) { + mockSearchPhaseContext.searchResponse.get().decRef(); + } + } + } + + private RankFeaturePhaseRankCoordinatorContext defaultRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize) { + + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + // no-op + // this one is handled directly in rankGlobalResults to create a RankFeatureDoc + // and avoid modifying in-place the ScoreDoc's rank + } + + @Override + public void rankGlobalResults(List rankSearchResults, ActionListener rankListener) { + List features = new ArrayList<>(); + for (RankFeatureResult rankFeatureResult : rankSearchResults) { + RankFeatureShardResult shardResult = rankFeatureResult.shardResult(); + features.addAll(Arrays.stream(shardResult.rankFeatureDocs).toList()); + } + rankListener.onResponse(features.toArray(new RankFeatureDoc[0])); + } + + @Override + public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) { + Arrays.sort(rankFeatureDocs, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()); + RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, rankFeatureDocs.length - from))]; + // perform pagination + for (int rank = 0; rank < topResults.length; ++rank) { + RankFeatureDoc rfd = rankFeatureDocs[from + rank]; + topResults[rank] = new RankFeatureDoc(rfd.doc, rfd.score, rfd.shardIndex); + topResults[rank].rank = from + rank + 1; + } + return topResults; + } + }; + } + + private QueryPhaseRankCoordinatorContext negatingScoresQueryFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) { + return new QueryPhaseRankCoordinatorContext(rankWindowSize) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List rankSearchResults, + SearchPhaseController.TopDocsStats topDocsStats + ) { + List docScores = new ArrayList<>(); + for (QuerySearchResult phaseResults : rankSearchResults) { + docScores.addAll(Arrays.asList(phaseResults.topDocs().topDocs.scoreDocs)); + } + ScoreDoc[] sortedDocs = docScores.toArray(new ScoreDoc[0]); + // negating scores + Arrays.stream(sortedDocs).forEach(doc -> doc.score *= -1); + + Arrays.sort(sortedDocs, Comparator.comparing((ScoreDoc doc) -> doc.score).reversed()); + sortedDocs = Arrays.stream(sortedDocs).limit(rankWindowSize).toArray(ScoreDoc[]::new); + RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, sortedDocs.length - from))]; + // perform pagination + for (int rank = 0; rank < topResults.length; ++rank) { + ScoreDoc base = sortedDocs[from + rank]; + topResults[rank] = new RankFeatureDoc(base.doc, base.score, base.shardIndex); + topResults[rank].rank = from + rank + 1; + } + topDocsStats.fetchHits = topResults.length; + return topResults; + } + }; + } + + private RankFeaturePhaseRankShardContext defaultRankFeaturePhaseRankShardContext(String field) { + return new RankFeaturePhaseRankShardContext(field) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].score += 100f; + rankFeatureDocs[i].featureData("ranked_" + hit.docId()); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + }; + } + + private QueryPhaseRankCoordinatorContext defaultQueryPhaseRankCoordinatorContext(int rankWindowSize) { + return new QueryPhaseRankCoordinatorContext(rankWindowSize) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + List rankDocs = new ArrayList<>(); + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + RankFeatureShardResult shardResult = (RankFeatureShardResult) querySearchResult.getRankShardResult(); + for (RankFeatureDoc frd : shardResult.rankFeatureDocs) { + frd.shardIndex = i; + rankDocs.add(frd); + } + } + rankDocs.sort(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()); + RankFeatureDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(RankFeatureDoc[]::new); + topDocStats.fetchHits = topResults.length; + return topResults; + } + }; + } + + private QueryPhaseRankShardContext defaultQueryPhaseRankShardContext(List queries, int rankWindowSize) { + return new QueryPhaseRankShardContext(queries, rankWindowSize) { + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + throw new UnsupportedOperationException( + "shard-level QueryPhase context should not be accessed as part of the RankFeature phase" + ); + } + }; + } + + private SearchPhaseController searchPhaseController() { + return new SearchPhaseController((task, request) -> InternalAggregationTestCase.emptyReduceContextBuilder()); + } + + private RankBuilder rankBuilder( + int rankWindowSize, + QueryPhaseRankShardContext queryPhaseRankShardContext, + QueryPhaseRankCoordinatorContext queryPhaseRankCoordinatorContext, + RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext, + RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext + ) { + return new RankBuilder(rankWindowSize) { + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + // no-op + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + // no-op + } + + @Override + public boolean isCompoundBuilder() { + return true; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return queryPhaseRankShardContext; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return queryPhaseRankCoordinatorContext; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return rankFeaturePhaseRankShardContext; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return rankFeaturePhaseRankCoordinatorContext; + } + + @Override + protected boolean doEquals(RankBuilder other) { + return other != null && other.rankWindowSize() == rankWindowSize; + } + + @Override + protected int doHashCode() { + return 0; + } + + @Override + public String getWriteableName() { + return "test-rank-builder"; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_12_0; + } + }; + } + + private SearchSourceBuilder searchSourceWithRankBuilder(RankBuilder rankBuilder) { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(rankBuilder); + return searchSourceBuilder; + } + + private SearchPhaseResults searchPhaseResults( + SearchPhaseController controller, + MockSearchPhaseContext mockSearchPhaseContext + ) { + return controller.newSearchPhaseResults( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + () -> false, + SearchProgressListener.NOOP, + mockSearchPhaseContext.getRequest(), + mockSearchPhaseContext.numShards, + exc -> {} + ); + } + + private void buildRankFeatureResult( + RankBuilder shardRankBuilder, + RankFeatureResult rankFeatureResult, + SearchShardTarget shardTarget, + int totalHits, + ScoreDoc[] scoreDocs + ) { + rankFeatureResult.setSearchShardTarget(shardTarget); + // these are the SearchHits generated by the FetchFieldPhase processor + SearchHit[] searchHits = new SearchHit[scoreDocs.length]; + float maxScore = Float.MIN_VALUE; + for (int i = 0; i < searchHits.length; i++) { + searchHits[i] = SearchHit.unpooled(scoreDocs[i].doc); + searchHits[i].shard(shardTarget); + searchHits[i].score(scoreDocs[i].score); + searchHits[i].setDocumentField(DEFAULT_FIELD, new DocumentField(DEFAULT_FIELD, Collections.singletonList(scoreDocs[i].doc))); + if (scoreDocs[i].score > maxScore) { + maxScore = scoreDocs[i].score; + } + } + SearchHits hits = null; + try { + hits = SearchHits.unpooled(searchHits, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), maxScore); + // construct the appropriate RankFeatureDoc objects based on the rank builder + RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = shardRankBuilder.buildRankFeaturePhaseShardContext(); + RankFeatureShardResult rankShardResult = (RankFeatureShardResult) rankFeaturePhaseRankShardContext.buildRankFeatureShardResult( + hits, + shardTarget.getShardId().id() + ); + rankFeatureResult.shardResult(rankShardResult); + } finally { + if (hits != null) { + hits.decRef(); + } + } + } + + private void populateQuerySearchResult(QuerySearchResult queryResult, int totalHits, ScoreDoc[] scoreDocs) { + // this would have been populated during the QueryPhase by the appropriate QueryPhaseShardContext + float maxScore = Float.MIN_VALUE; + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + if (scoreDocs[i].score > maxScore) { + maxScore = scoreDocs[i].score; + } + rankFeatureDocs[i] = new RankFeatureDoc(scoreDocs[i].doc, scoreDocs[i].score, scoreDocs[i].shardIndex); + } + queryResult.setRankShardResult(new RankFeatureShardResult(rankFeatureDocs)); + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs), + maxScore + + ), + new DocValueFormat[0] + ); + queryResult.size(totalHits); + } + + private RankFeaturePhase rankFeaturePhase( + SearchPhaseResults results, + MockSearchPhaseContext mockSearchPhaseContext, + ScoreDoc[][] finalResults, + AtomicBoolean phaseDone + ) { + // override the RankFeaturePhase to skip moving to next phase + return new RankFeaturePhase(results, null, mockSearchPhaseContext) { + @Override + public void moveToNextPhase( + SearchPhaseResults phaseResults, + SearchPhaseController.ReducedQueryPhase reducedQueryPhase + ) { + // this is called after the RankFeaturePhaseCoordinatorContext has been executed + phaseDone.set(true); + finalResults[0] = reducedQueryPhase.sortedTopDocs().scoreDocs(); + logger.debug("Skipping moving to next phase"); + } + }; + } + + private void assertRankFeatureResults(RankFeatureShardResult rankFeatureShardResult, List expectedResults) { + assertEquals(expectedResults.size(), rankFeatureShardResult.rankFeatureDocs.length); + for (int i = 0; i < expectedResults.size(); i++) { + ExpectedRankFeatureDoc expected = expectedResults.get(i); + RankFeatureDoc actual = rankFeatureShardResult.rankFeatureDocs[i]; + assertEquals(expected.doc, actual.doc); + assertEquals(expected.rank, actual.rank); + assertEquals(expected.score, actual.score, 10E-5); + assertEquals(expected.featureData, actual.featureData); + } + } + + private void assertFinalResults(ScoreDoc[] finalResults, List expectedResults) { + assertEquals(expectedResults.size(), finalResults.length); + for (int i = 0; i < expectedResults.size(); i++) { + ExpectedRankFeatureDoc expected = expectedResults.get(i); + RankFeatureDoc actual = (RankFeatureDoc) finalResults[i]; + assertEquals(expected.doc, actual.doc); + assertEquals(expected.rank, actual.rank); + assertEquals(expected.score, actual.score, 10E-5); + } + } + + private void assertShardResults(SearchPhaseResult shardResult, List expectedShardResults) { + assertTrue(shardResult instanceof RankFeatureResult); + RankFeatureResult rankResult = (RankFeatureResult) shardResult; + assertNotNull(rankResult.rankFeatureResult()); + assertNull(rankResult.queryResult()); + assertNotNull(rankResult.rankFeatureResult().shardResult()); + RankFeatureShardResult rankFeatureShardResult = rankResult.rankFeatureResult().shardResult(); + assertRankFeatureResults(rankFeatureShardResult, expectedShardResults); + } +} diff --git a/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java b/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java index 59acb227385f6..4d58471f4817a 100644 --- a/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java +++ b/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java @@ -644,8 +644,8 @@ public void testIsParallelCollectionSupportedForResults() { ToLongFunction fieldCardinality = name -> -1; for (var resultsType : SearchService.ResultsType.values()) { switch (resultsType) { - case NONE, FETCH -> assertFalse( - "NONE and FETCH phases do not support parallel collection.", + case NONE, RANK_FEATURE, FETCH -> assertFalse( + "NONE, RANK_FEATURE, and FETCH phases do not support parallel collection.", DefaultSearchContext.isParallelCollectionSupportedForResults( resultsType, searchSourceBuilderOrNull, diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java index d2c6c55634ec6..2af20a6ffef4a 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java @@ -13,6 +13,8 @@ import org.apache.lucene.index.Term; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHitCountCollectorManager; import org.apache.lucene.store.AlreadyClosedException; import org.apache.lucene.util.SetOnce; @@ -27,6 +29,7 @@ import org.elasticsearch.action.search.ClearScrollRequest; import org.elasticsearch.action.search.ClosePointInTimeRequest; import org.elasticsearch.action.search.OpenPointInTimeRequest; +import org.elasticsearch.action.search.SearchPhaseController; import org.elasticsearch.action.search.SearchPhaseExecutionException; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; @@ -92,6 +95,7 @@ import org.elasticsearch.search.dfs.AggregatedDfs; import org.elasticsearch.search.fetch.FetchSearchResult; import org.elasticsearch.search.fetch.ShardFetchRequest; +import org.elasticsearch.search.fetch.ShardFetchSearchRequest; import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.ContextIndexSearcher; @@ -102,12 +106,26 @@ import org.elasticsearch.search.query.NonCountingTermQuery; import org.elasticsearch.search.query.QuerySearchRequest; import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.search.rank.RankBuilder; +import org.elasticsearch.search.rank.RankShardResult; +import org.elasticsearch.search.rank.TestRankBuilder; +import org.elasticsearch.search.rank.TestRankDoc; +import org.elasticsearch.search.rank.TestRankShardResult; +import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; +import org.elasticsearch.search.rank.feature.RankFeatureShardResult; import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.tasks.TaskCancelHelper; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; @@ -115,8 +133,10 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.Comparator; import java.util.LinkedList; import java.util.List; import java.util.Locale; @@ -136,8 +156,8 @@ import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; import static org.elasticsearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; -import static org.elasticsearch.indices.cluster.AbstractIndicesClusterStateServiceTestCase.awaitIndexShardCloseAsyncTasks; import static org.elasticsearch.indices.cluster.IndicesClusterStateService.AllocatedIndices.IndexRemovalReason.DELETED; +import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.elasticsearch.search.SearchService.QUERY_PHASE_PARALLEL_COLLECTION_ENABLED; import static org.elasticsearch.search.SearchService.SEARCH_WORKER_THREADS_ENABLED; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; @@ -371,7 +391,7 @@ public void testSearchWhileIndexDeleted() throws InterruptedException { -1, null ), - new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()), + new SearchShardTask(123L, "", "", "", null, emptyMap()), result.delegateFailure((l, r) -> { r.incRef(); l.onResponse(r); @@ -387,7 +407,7 @@ public void testSearchWhileIndexDeleted() throws InterruptedException { null/* not a scroll */ ); PlainActionFuture listener = new PlainActionFuture<>(); - service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()), listener); + service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, emptyMap()), listener); listener.get(); if (useScroll) { // have to free context since this test does not remove the index from IndicesService. @@ -422,6 +442,711 @@ public void testSearchWhileIndexDeleted() throws InterruptedException { assertEquals(0, totalStats.getFetchCurrent()); } + public void testRankFeaturePhaseSearchPhases() throws InterruptedException, ExecutionException { + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 3; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + final SearchService service = getInstanceFromNode(SearchService.class); + + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex(indexName)); + final IndexShard indexShard = indexService.getShard(0); + SearchShardTask searchTask = new SearchShardTask(123L, "", "", "", null, emptyMap()); + + // create a SearchRequest that will return all documents and defines a TestRankBuilder with shard-level only operations + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true) + .source( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .size(DEFAULT_SIZE) + .fetchField(fetchFieldName) + .rankBuilder( + // here we override only the shard-level contexts + new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(TestRankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].score = (numDocs - i) + randomFloat(); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + }; + } + } + ) + ); + + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ); + QuerySearchResult queryResult = null; + RankFeatureResult rankResult = null; + try { + // Execute the query phase and store the result in a SearchPhaseResult container using a PlainActionFuture + PlainActionFuture queryPhaseResults = new PlainActionFuture<>(); + service.executeQueryPhase(request, searchTask, queryPhaseResults); + queryResult = (QuerySearchResult) queryPhaseResults.get(); + + // these are the matched docs from the query phase + final TestRankDoc[] queryRankDocs = ((TestRankShardResult) queryResult.getRankShardResult()).testRankDocs; + + // assume that we have cut down to these from the coordinator node as the top-docs to run the rank feature phase upon + List topRankWindowSizeDocs = randomNonEmptySubsetOf(Arrays.stream(queryRankDocs).map(x -> x.doc).toList()); + + // now we create a RankFeatureShardRequest to extract feature info for the top-docs above + RankFeatureShardRequest rankFeatureShardRequest = new RankFeatureShardRequest( + OriginalIndices.NONE, + queryResult.getContextId(), // use the context from the query phase + request, + topRankWindowSizeDocs + ); + PlainActionFuture rankPhaseResults = new PlainActionFuture<>(); + service.executeRankFeaturePhase(rankFeatureShardRequest, searchTask, rankPhaseResults); + rankResult = rankPhaseResults.get(); + + assertNotNull(rankResult); + assertNotNull(rankResult.rankFeatureResult()); + RankFeatureShardResult rankFeatureShardResult = rankResult.rankFeatureResult().shardResult(); + assertNotNull(rankFeatureShardResult); + + List sortedRankWindowDocs = topRankWindowSizeDocs.stream().sorted().toList(); + assertEquals(sortedRankWindowDocs.size(), rankFeatureShardResult.rankFeatureDocs.length); + for (int i = 0; i < sortedRankWindowDocs.size(); i++) { + assertEquals((long) sortedRankWindowDocs.get(i), rankFeatureShardResult.rankFeatureDocs[i].doc); + assertEquals(rankFeatureShardResult.rankFeatureDocs[i].featureData, "aardvark_" + sortedRankWindowDocs.get(i)); + } + + List globalTopKResults = randomNonEmptySubsetOf( + Arrays.stream(rankFeatureShardResult.rankFeatureDocs).map(x -> x.doc).toList() + ); + + // finally let's create a fetch request to bring back fetch info for the top results + ShardFetchSearchRequest fetchRequest = new ShardFetchSearchRequest( + OriginalIndices.NONE, + rankResult.getContextId(), + request, + globalTopKResults, + null, + rankResult.getRescoreDocIds(), + null + ); + + // execute fetch phase and perform any validations once we retrieve the response + // the difference in how we do assertions here is needed because once the transport service sends back the response + // it decrements the reference to the FetchSearchResult (through the ActionListener#respondAndRelease) and sets hits to null + service.executeFetchPhase(fetchRequest, searchTask, new ActionListener<>() { + @Override + public void onResponse(FetchSearchResult fetchSearchResult) { + assertNotNull(fetchSearchResult); + assertNotNull(fetchSearchResult.hits()); + + int totalHits = fetchSearchResult.hits().getHits().length; + assertEquals(globalTopKResults.size(), totalHits); + for (int i = 0; i < totalHits; i++) { + // rank and score are set by the SearchPhaseController#merge so no need to validate that here + SearchHit hit = fetchSearchResult.hits().getAt(i); + assertNotNull(hit.getFields().get(fetchFieldName)); + assertEquals(hit.getFields().get(fetchFieldName).getValue(), fetchFieldValue + "_" + hit.docId()); + } + } + + @Override + public void onFailure(Exception e) { + throw new AssertionError("No failure should have been raised", e); + } + }); + } catch (Exception ex) { + if (queryResult != null) { + if (queryResult.hasReferences()) { + queryResult.decRef(); + } + service.freeReaderContext(queryResult.getContextId()); + } + if (rankResult != null && rankResult.hasReferences()) { + rankResult.decRef(); + } + throw ex; + } + } + + public void testRankFeaturePhaseUsingClient() { + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 4; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + ElasticsearchAssertions.assertResponse( + client().prepareSearch(indexName) + .setSource( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .size(2) + .from(2) + .fetchField(fetchFieldName) + .rankBuilder( + // here we override only the shard-level contexts + new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + + // no need for more than one queries + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + float[] scores = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + scores[i] = featureDocs[i].score; + } + scoreListener.onResponse(scores); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + List rankDocs = new ArrayList<>(); + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + TestRankShardResult shardResult = (TestRankShardResult) querySearchResult + .getRankShardResult(); + for (TestRankDoc trd : shardResult.testRankDocs) { + trd.shardIndex = i; + rankDocs.add(trd); + } + } + rankDocs.sort(Comparator.comparing((TestRankDoc doc) -> doc.score).reversed()); + TestRankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(TestRankDoc[]::new); + topDocStats.fetchHits = topResults.length; + return topResults; + } + }; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(TestRankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].score = randomFloat(); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + }; + } + } + ) + ), + (response) -> { + SearchHits hits = response.getHits(); + assertEquals(hits.getTotalHits().value, numDocs); + assertEquals(hits.getHits().length, 2); + int index = 0; + for (SearchHit hit : hits.getHits()) { + assertEquals(hit.getRank(), 3 + index); + assertTrue(hit.getScore() >= 0); + assertEquals(hit.getFields().get(fetchFieldName).getValue(), fetchFieldValue + "_" + hit.docId()); + index++; + } + } + ); + } + + public void testRankFeaturePhaseExceptionOnCoordinatingNode() { + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 3; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + expectThrows( + SearchPhaseExecutionException.class, + () -> client().prepareSearch(indexName) + .setSource( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .size(2) + .from(2) + .fetchField(fetchFieldName) + .rankBuilder(new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + + // no need for more than one queries + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + throw new IllegalStateException("should have failed earlier"); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + throw new UnsupportedOperationException("simulated failure"); + } + }; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(TestRankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].score = randomFloat(); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + }; + } + }) + ) + .get() + ); + } + + public void testRankFeaturePhaseExceptionAllShardFail() { + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 3; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + expectThrows( + SearchPhaseExecutionException.class, + () -> client().prepareSearch(indexName) + .setAllowPartialSearchResults(true) + .setSource( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .fetchField(fetchFieldName) + .rankBuilder( + // here we override only the shard-level contexts + new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + + // no need for more than one queries + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + float[] scores = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + scores[i] = featureDocs[i].score; + } + scoreListener.onResponse(scores); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + List rankDocs = new ArrayList<>(); + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + TestRankShardResult shardResult = (TestRankShardResult) querySearchResult + .getRankShardResult(); + for (TestRankDoc trd : shardResult.testRankDocs) { + trd.shardIndex = i; + rankDocs.add(trd); + } + } + rankDocs.sort(Comparator.comparing((TestRankDoc doc) -> doc.score).reversed()); + TestRankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(TestRankDoc[]::new); + topDocStats.fetchHits = topResults.length; + return topResults; + } + }; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(TestRankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + throw new UnsupportedOperationException("simulated failure"); + } + }; + } + } + ) + ) + .get() + ); + } + + public void testRankFeaturePhaseExceptionOneShardFails() { + // if we have only one shard and it fails, it will fallback to context.onPhaseFailure which will eventually clean up all contexts. + // in this test we want to make sure that even if one shard (of many) fails during the RankFeaturePhase, then the appropriate + // context will have been cleaned up. + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 3; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 2).build()); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + assertResponse( + client().prepareSearch(indexName) + .setAllowPartialSearchResults(true) + .setSource( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .fetchField(fetchFieldName) + .rankBuilder( + // here we override only the shard-level contexts + new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + + // no need for more than one queries + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + float[] scores = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + scores[i] = featureDocs[i].score; + } + scoreListener.onResponse(scores); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + List rankDocs = new ArrayList<>(); + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + TestRankShardResult shardResult = (TestRankShardResult) querySearchResult + .getRankShardResult(); + for (TestRankDoc trd : shardResult.testRankDocs) { + trd.shardIndex = i; + rankDocs.add(trd); + } + } + rankDocs.sort(Comparator.comparing((TestRankDoc doc) -> doc.score).reversed()); + TestRankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(TestRankDoc[]::new); + topDocStats.fetchHits = topResults.length; + return topResults; + } + }; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(TestRankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + if (shardId == 0) { + throw new UnsupportedOperationException("simulated failure"); + } else { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].score = randomFloat(); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + } + }; + } + } + ) + ), + (searchResponse) -> { + assertEquals(1, searchResponse.getSuccessfulShards()); + assertEquals("simulated failure", searchResponse.getShardFailures()[0].getCause().getMessage()); + assertNotEquals(0, searchResponse.getHits().getHits().length); + for (SearchHit hit : searchResponse.getHits().getHits()) { + assertEquals(fetchFieldValue + "_" + hit.getId(), hit.getFields().get(fetchFieldName).getValue()); + assertEquals(1, hit.getShard().getShardId().id()); + } + } + ); + } + public void testSearchWhileIndexDeletedDoesNotLeakSearchContext() throws ExecutionException, InterruptedException { createIndex("index"); prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); @@ -457,7 +1182,7 @@ public void testSearchWhileIndexDeletedDoesNotLeakSearchContext() throws Executi -1, null ), - new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()), + new SearchShardTask(123L, "", "", "", null, emptyMap()), result ); @@ -694,7 +1419,7 @@ public void testMaxScriptFieldsSearch() throws IOException { for (int i = 0; i < maxScriptFields; i++) { searchSourceBuilder.scriptField( "field" + i, - new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, Collections.emptyMap()) + new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap()) ); } final ShardSearchRequest request = new ShardSearchRequest( @@ -723,7 +1448,7 @@ public void testMaxScriptFieldsSearch() throws IOException { } searchSourceBuilder.scriptField( "anotherScriptField", - new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, Collections.emptyMap()) + new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap()) ); IllegalArgumentException ex = expectThrows( IllegalArgumentException.class, @@ -752,7 +1477,7 @@ public void testIgnoreScriptfieldIfSizeZero() throws IOException { searchRequest.source(searchSourceBuilder); searchSourceBuilder.scriptField( "field" + 0, - new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, Collections.emptyMap()) + new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap()) ); searchSourceBuilder.size(0); final ShardSearchRequest request = new ShardSearchRequest( @@ -1036,7 +1761,7 @@ public void testCanMatch() throws Exception { ); CountDownLatch latch = new CountDownLatch(1); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); // Because the foo field used in alias filter is unmapped the term query builder rewrite can resolve to a match no docs query, // without acquiring a searcher and that means the wrapper is not called assertEquals(5, numWrapInvocations.get()); @@ -1330,7 +2055,7 @@ public void testMatchNoDocsEmptyResponse() throws InterruptedException { 0, null ); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); { CountDownLatch latch = new CountDownLatch(1); @@ -1705,7 +2430,7 @@ public void testWaitOnRefresh() throws ExecutionException, InterruptedException final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); assertEquals(RestStatus.CREATED, response.status()); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); ShardSearchRequest request = new ShardSearchRequest( OriginalIndices.NONE, searchRequest, @@ -1740,7 +2465,7 @@ public void testWaitOnRefreshFailsWithRefreshesDisabled() { final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); assertEquals(RestStatus.CREATED, response.status()); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); PlainActionFuture future = new PlainActionFuture<>(); ShardSearchRequest request = new ShardSearchRequest( OriginalIndices.NONE, @@ -1778,7 +2503,7 @@ public void testWaitOnRefreshFailsIfCheckpointNotIndexed() { final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); assertEquals(RestStatus.CREATED, response.status()); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); PlainActionFuture future = new PlainActionFuture<>(); ShardSearchRequest request = new ShardSearchRequest( OriginalIndices.NONE, @@ -1815,7 +2540,7 @@ public void testWaitOnRefreshTimeout() { final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); assertEquals(RestStatus.CREATED, response.status()); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); PlainActionFuture future = new PlainActionFuture<>(); ShardSearchRequest request = new ShardSearchRequest( OriginalIndices.NONE, @@ -1901,7 +2626,7 @@ public void testDfsQueryPhaseRewrite() { PlainActionFuture plainActionFuture = new PlainActionFuture<>(); service.executeQueryPhase( new QuerySearchRequest(null, context.id(), request, new AggregatedDfs(Map.of(), Map.of(), 10)), - new SearchShardTask(42L, "", "", "", null, Collections.emptyMap()), + new SearchShardTask(42L, "", "", "", null, emptyMap()), plainActionFuture ); diff --git a/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java new file mode 100644 index 0000000000000..cf464044cd701 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java @@ -0,0 +1,409 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.rank; + +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.FetchSearchResult; +import org.elasticsearch.search.fetch.StoredFieldsContext; +import org.elasticsearch.search.fetch.subphase.FetchFieldsContext; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; +import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; +import org.elasticsearch.search.rank.feature.RankFeatureShardResult; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.TestSearchContext; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +public class RankFeatureShardPhaseTests extends ESTestCase { + + private SearchContext getSearchContext() { + return new TestSearchContext((SearchExecutionContext) null) { + + private FetchSearchResult fetchResult; + private RankFeatureResult rankFeatureResult; + private FetchFieldsContext fetchFieldsContext; + private StoredFieldsContext storedFieldsContext; + + @Override + public FetchSearchResult fetchResult() { + return fetchResult; + } + + @Override + public void addFetchResult() { + this.fetchResult = new FetchSearchResult(); + this.addReleasable(fetchResult::decRef); + } + + @Override + public RankFeatureResult rankFeatureResult() { + return rankFeatureResult; + } + + @Override + public void addRankFeatureResult() { + this.rankFeatureResult = new RankFeatureResult(); + this.addReleasable(rankFeatureResult::decRef); + } + + @Override + public SearchContext fetchFieldsContext(FetchFieldsContext fetchFieldsContext) { + this.fetchFieldsContext = fetchFieldsContext; + return this; + } + + @Override + public FetchFieldsContext fetchFieldsContext() { + return fetchFieldsContext; + } + + @Override + public SearchContext storedFieldsContext(StoredFieldsContext storedFieldsContext) { + this.storedFieldsContext = storedFieldsContext; + return this; + } + + @Override + public StoredFieldsContext storedFieldsContext() { + return storedFieldsContext; + } + + @Override + public boolean isCancelled() { + return false; + } + }; + } + + private RankBuilder getRankBuilder(final String field) { + return new RankBuilder(DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + // no-op + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + // no-op + } + + @Override + public boolean isCompoundBuilder() { + return false; + } + + // no work to be done on the query phase + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return null; + } + + // no work to be done on the query phase + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return null; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(field) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(field).getValue()); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + }; + } + + // no work to be done on the coordinator node for the rank feature phase + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return null; + } + + @Override + protected boolean doEquals(RankBuilder other) { + return false; + } + + @Override + protected int doHashCode() { + return 0; + } + + @Override + public String getWriteableName() { + return "rank_builder_rank_feature_shard_phase_enabled"; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.RANK_FEATURE_PHASE_ADDED; + } + }; + } + + public void testPrepareForFetch() { + + final String fieldName = "some_field"; + int numDocs = randomIntBetween(10, 30); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(getRankBuilder(fieldName)); + + ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + + try (SearchContext searchContext = spy(getSearchContext())) { + when(searchContext.isCancelled()).thenReturn(false); + when(searchContext.request()).thenReturn(searchRequest); + + RankFeatureShardRequest request = mock(RankFeatureShardRequest.class); + when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 }); + + RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase(); + rankFeatureShardPhase.prepareForFetch(searchContext, request); + + assertNotNull(searchContext.fetchFieldsContext()); + assertEquals(searchContext.fetchFieldsContext().fields().size(), 1); + assertEquals(searchContext.fetchFieldsContext().fields().get(0).field, fieldName); + assertNotNull(searchContext.storedFieldsContext()); + assertNull(searchContext.storedFieldsContext().fieldNames()); + assertFalse(searchContext.storedFieldsContext().fetchFields()); + assertNotNull(searchContext.fetchResult()); + } + } + + public void testPrepareForFetchNoRankFeatureContext() { + int numDocs = randomIntBetween(10, 30); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(null); + + ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + + try (SearchContext searchContext = spy(getSearchContext())) { + when(searchContext.isCancelled()).thenReturn(false); + when(searchContext.request()).thenReturn(searchRequest); + + RankFeatureShardRequest request = mock(RankFeatureShardRequest.class); + when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 }); + + RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase(); + rankFeatureShardPhase.prepareForFetch(searchContext, request); + + assertNull(searchContext.fetchFieldsContext()); + assertNull(searchContext.fetchResult()); + } + } + + public void testPrepareForFetchWhileTaskIsCancelled() { + + final String fieldName = "some_field"; + int numDocs = randomIntBetween(10, 30); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(getRankBuilder(fieldName)); + + ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + + try (SearchContext searchContext = spy(getSearchContext())) { + when(searchContext.isCancelled()).thenReturn(true); + when(searchContext.request()).thenReturn(searchRequest); + + RankFeatureShardRequest request = mock(RankFeatureShardRequest.class); + when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 }); + + RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase(); + expectThrows(TaskCancelledException.class, () -> rankFeatureShardPhase.prepareForFetch(searchContext, request)); + } + } + + public void testProcessFetch() { + final String fieldName = "some_field"; + int numDocs = randomIntBetween(10, 30); + Map expectedFieldData = Map.of(4, "doc_4_aardvark", 9, "doc_9_aardvark", numDocs - 1, "last_doc_aardvark"); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(getRankBuilder(fieldName)); + + ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + + SearchShardTarget shardTarget = new SearchShardTarget( + "node_id", + new ShardId(new Index("some_index", UUID.randomUUID().toString()), 0), + null + ); + SearchHits searchHits = null; + try (SearchContext searchContext = spy(getSearchContext())) { + searchContext.addFetchResult(); + SearchHit[] hits = new SearchHit[3]; + hits[0] = SearchHit.unpooled(4); + hits[0].setDocumentField(fieldName, new DocumentField(fieldName, Collections.singletonList(expectedFieldData.get(4)))); + + hits[1] = SearchHit.unpooled(9); + hits[1].setDocumentField(fieldName, new DocumentField(fieldName, Collections.singletonList(expectedFieldData.get(9)))); + + hits[2] = SearchHit.unpooled(numDocs - 1); + hits[2].setDocumentField( + fieldName, + new DocumentField(fieldName, Collections.singletonList(expectedFieldData.get(numDocs - 1))) + ); + searchHits = SearchHits.unpooled(hits, new TotalHits(3, TotalHits.Relation.EQUAL_TO), 1.0f); + searchContext.fetchResult().shardResult(searchHits, null); + when(searchContext.isCancelled()).thenReturn(false); + when(searchContext.request()).thenReturn(searchRequest); + when(searchContext.shardTarget()).thenReturn(shardTarget); + RankFeatureShardRequest request = mock(RankFeatureShardRequest.class); + when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 }); + + RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase(); + // this is called as part of the search context initialization + // with the ResultsType.RANK_FEATURE type + searchContext.addRankFeatureResult(); + rankFeatureShardPhase.processFetch(searchContext); + + assertNotNull(searchContext.rankFeatureResult()); + assertNotNull(searchContext.rankFeatureResult().rankFeatureResult()); + for (RankFeatureDoc rankFeatureDoc : searchContext.rankFeatureResult().rankFeatureResult().shardResult().rankFeatureDocs) { + assertTrue(expectedFieldData.containsKey(rankFeatureDoc.doc)); + assertEquals(rankFeatureDoc.featureData, expectedFieldData.get(rankFeatureDoc.doc)); + } + } finally { + if (searchHits != null) { + searchHits.decRef(); + } + } + } + + public void testProcessFetchEmptyHits() { + final String fieldName = "some_field"; + int numDocs = randomIntBetween(10, 30); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(getRankBuilder(fieldName)); + + ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + + SearchShardTarget shardTarget = new SearchShardTarget( + "node_id", + new ShardId(new Index("some_index", UUID.randomUUID().toString()), 0), + null + ); + + SearchHits searchHits = null; + try (SearchContext searchContext = spy(getSearchContext())) { + searchContext.addFetchResult(); + SearchHit[] hits = new SearchHit[0]; + searchHits = SearchHits.unpooled(hits, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1.0f); + searchContext.fetchResult().shardResult(searchHits, null); + when(searchContext.isCancelled()).thenReturn(false); + when(searchContext.request()).thenReturn(searchRequest); + when(searchContext.shardTarget()).thenReturn(shardTarget); + RankFeatureShardRequest request = mock(RankFeatureShardRequest.class); + when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 }); + + RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase(); + // this is called as part of the search context initialization + // with the ResultsType.RANK_FEATURE type + searchContext.addRankFeatureResult(); + rankFeatureShardPhase.processFetch(searchContext); + + assertNotNull(searchContext.rankFeatureResult()); + assertNotNull(searchContext.rankFeatureResult().rankFeatureResult()); + assertEquals(searchContext.rankFeatureResult().rankFeatureResult().shardResult().rankFeatureDocs.length, 0); + } finally { + if (searchHits != null) { + searchHits.decRef(); + } + } + } + + public void testProcessFetchWhileTaskIsCancelled() { + + final String fieldName = "some_field"; + int numDocs = randomIntBetween(10, 30); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.rankBuilder(getRankBuilder(fieldName)); + + ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + + SearchShardTarget shardTarget = new SearchShardTarget( + "node_id", + new ShardId(new Index("some_index", UUID.randomUUID().toString()), 0), + null + ); + + SearchHits searchHits = null; + try (SearchContext searchContext = spy(getSearchContext())) { + searchContext.addFetchResult(); + SearchHit[] hits = new SearchHit[0]; + searchHits = SearchHits.unpooled(hits, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1.0f); + searchContext.fetchResult().shardResult(searchHits, null); + when(searchContext.isCancelled()).thenReturn(true); + when(searchContext.request()).thenReturn(searchRequest); + when(searchContext.shardTarget()).thenReturn(shardTarget); + RankFeatureShardRequest request = mock(RankFeatureShardRequest.class); + when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 }); + + RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase(); + // this is called as part of the search context initialization + // with the ResultsType.RANK_FEATURE type + searchContext.addRankFeatureResult(); + expectThrows(TaskCancelledException.class, () -> rankFeatureShardPhase.processFetch(searchContext)); + } finally { + if (searchHits != null) { + searchHits.decRef(); + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index 697b40671ee8b..6419759ab5962 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -178,6 +178,7 @@ import org.elasticsearch.search.SearchService; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.FetchPhase; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; import org.elasticsearch.telemetry.TelemetryProvider; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.test.ClusterServiceUtils; @@ -2249,6 +2250,7 @@ public RecyclerBytesStreamOutput newNetworkBytesStream() { threadPool, scriptService, bigArrays, + new RankFeatureShardPhase(), new FetchPhase(Collections.emptyList()), responseCollectorService, new NoneCircuitBreakerService(), diff --git a/test/framework/src/main/java/org/elasticsearch/node/MockNode.java b/test/framework/src/main/java/org/elasticsearch/node/MockNode.java index ef29f9fca4f93..520aff77497ba 100644 --- a/test/framework/src/main/java/org/elasticsearch/node/MockNode.java +++ b/test/framework/src/main/java/org/elasticsearch/node/MockNode.java @@ -40,6 +40,7 @@ import org.elasticsearch.search.MockSearchService; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.fetch.FetchPhase; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.test.ESTestCase; @@ -97,6 +98,7 @@ SearchService newSearchService( ThreadPool threadPool, ScriptService scriptService, BigArrays bigArrays, + RankFeatureShardPhase rankFeatureShardPhase, FetchPhase fetchPhase, ResponseCollectorService responseCollectorService, CircuitBreakerService circuitBreakerService, @@ -111,6 +113,7 @@ SearchService newSearchService( threadPool, scriptService, bigArrays, + rankFeatureShardPhase, fetchPhase, responseCollectorService, circuitBreakerService, @@ -124,6 +127,7 @@ SearchService newSearchService( threadPool, scriptService, bigArrays, + rankFeatureShardPhase, fetchPhase, responseCollectorService, circuitBreakerService, diff --git a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java index aa1889e15d594..747eff1d21708 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java +++ b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java @@ -23,6 +23,7 @@ import org.elasticsearch.search.internal.ReaderContext; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.threadpool.ThreadPool; @@ -81,6 +82,7 @@ public MockSearchService( ThreadPool threadPool, ScriptService scriptService, BigArrays bigArrays, + RankFeatureShardPhase rankFeatureShardPhase, FetchPhase fetchPhase, ResponseCollectorService responseCollectorService, CircuitBreakerService circuitBreakerService, @@ -93,6 +95,7 @@ public MockSearchService( threadPool, scriptService, bigArrays, + rankFeatureShardPhase, fetchPhase, responseCollectorService, circuitBreakerService, diff --git a/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java b/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java index 8e2a2c96a31ab..862c4d2ea3270 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java +++ b/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java @@ -15,6 +15,8 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; @@ -31,7 +33,7 @@ public class TestRankBuilder extends RankBuilder { static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, - args -> new TestRankBuilder(args[0] == null ? DEFAULT_WINDOW_SIZE : (int) args[0]) + args -> new TestRankBuilder(args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[0]) ); static { @@ -74,6 +76,11 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep // do nothing } + @Override + public boolean isCompoundBuilder() { + return true; + } + @Override public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { throw new UnsupportedOperationException(); @@ -84,6 +91,16 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si throw new UnsupportedOperationException(); } + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + throw new UnsupportedOperationException(); + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + throw new UnsupportedOperationException(); + } + @Override protected boolean doEquals(RankBuilder other) { return true; diff --git a/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java b/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java index cba2b41d279bb..fa414cd8121d6 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java @@ -44,6 +44,7 @@ import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; @@ -463,6 +464,16 @@ public float getMaxScore() { return queryResult.getMaxScore(); } + @Override + public void addRankFeatureResult() { + // this space intentionally left blank + } + + @Override + public RankFeatureResult rankFeatureResult() { + return null; + } + @Override public FetchSearchResult fetchResult() { return null; diff --git a/test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java b/test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java index a7f21bd206c62..bf9eba87ee809 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java +++ b/test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java @@ -687,6 +687,10 @@ public static Matcher hasScore(final float score) { return transformedMatch(SearchHit::getScore, equalTo(score)); } + public static Matcher hasRank(final int rank) { + return transformedMatch(SearchHit::getRank, equalTo(rank)); + } + public static T assertBooleanSubQuery(Query query, Class subqueryType, int i) { assertThat(query, instanceOf(BooleanQuery.class)); BooleanQuery q = (BooleanQuery) query; diff --git a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java index 04b0b11ad38d4..c0305f873327d 100644 --- a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java +++ b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java @@ -397,6 +397,14 @@ protected void onQueryResult(int shardIndex, QuerySearchResult queryResult) { } } + @Override + protected void onRankFeatureResult(int shardIndex) { + checkCancellation(); + if (delegate != null) { + delegate.onRankFeatureResult(shardIndex); + } + } + @Override protected void onFetchResult(int shardIndex) { checkCancellation(); @@ -420,6 +428,12 @@ protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exc ); } + @Override + protected void onRankFeatureFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) { + // best effort to cancel expired tasks + checkCancellation(); + } + @Override protected void onFetchFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) { // best effort to cancel expired tasks diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java index 8f3ed15037c08..5c39c6c32fd06 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java @@ -16,6 +16,8 @@ import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; @@ -38,7 +40,7 @@ public class RRFRankBuilder extends RankBuilder { public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant"); static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(RRFRankPlugin.NAME, args -> { - int windowSize = args[0] == null ? DEFAULT_WINDOW_SIZE : (int) args[0]; + int windowSize = args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[0]; int rankConstant = args[1] == null ? DEFAULT_RANK_CONSTANT : (int) args[1]; if (rankConstant < 1) { throw new IllegalArgumentException("[rank_constant] must be greater than [0] for [rrf]"); @@ -94,6 +96,11 @@ public int rankConstant() { return rankConstant; } + @Override + public boolean isCompoundBuilder() { + return true; + } + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { return new RRFQueryPhaseRankShardContext(queries, rankWindowSize(), rankConstant); } @@ -103,6 +110,16 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si return new RRFQueryPhaseRankCoordinatorContext(size, from, rankWindowSize(), rankConstant); } + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return null; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) { + return null; + } + @Override protected boolean doEquals(RankBuilder other) { return Objects.equals(rankConstant, ((RRFRankBuilder) other).rankConstant); diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index 077c933fa9add..e5a7983107278 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -71,7 +71,7 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP } List retrieverBuilders = Collections.emptyList(); - int rankWindowSize = RRFRankBuilder.DEFAULT_WINDOW_SIZE; + int rankWindowSize = RRFRankBuilder.DEFAULT_RANK_WINDOW_SIZE; int rankConstant = RRFRankBuilder.DEFAULT_RANK_CONSTANT; @Override diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java index aeb6bfc8de796..221b7a65e1f8f 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java @@ -45,6 +45,7 @@ public final class PreAuthorizationUtils { SearchTransportService.QUERY_ACTION_NAME, SearchTransportService.QUERY_ID_ACTION_NAME, SearchTransportService.FETCH_ID_ACTION_NAME, + SearchTransportService.RANK_FEATURE_SHARD_ACTION_NAME, SearchTransportService.QUERY_CAN_MATCH_NODE_NAME ) );