From d53f684383bc2d859866ec7165d75597f248b7a0 Mon Sep 17 00:00:00 2001 From: Max Hniebergall <137079448+maxhniebergall@users.noreply.github.com> Date: Wed, 15 May 2024 15:13:09 -0400 Subject: [PATCH] [ML] Add mixed cluster tests for inference (#108392) * mixed cluster tests are executable * add tests from upgrade tests * [ML] Add mixed cluster tests for existing services * clean up * review improvements * spotless * remove blocked AzureOpenAI mixed IT * improvements from DK review * temp for testing * refactoring and documentation * Revert manual testing configs of "temp for testing" This reverts parts of commit fca46fd2b6253accc010a2e2a8bf05edfff5ea9b. * revert TESTING.asciidoc formatting * Update TESTING.asciidoc to avoid reformatting * add minimum version for tests to match minimum version in services * spotless --- TESTING.asciidoc | 16 +- .../inference/qa/mixed-cluster/build.gradle | 37 +++ .../inference/qa/mixed/BaseMixedTestCase.java | 129 +++++++++ .../qa/mixed/CohereServiceMixedIT.java | 271 ++++++++++++++++++ .../qa/mixed/HuggingFaceServiceMixedIT.java | 147 ++++++++++ .../qa/mixed/MixedClusterSpecTestCase.java | 53 ++++ .../inference/qa/mixed/MixedClustersSpec.java | 25 ++ .../qa/mixed/OpenAIServiceMixedIT.java | 223 ++++++++++++++ 8 files changed, 896 insertions(+), 5 deletions(-) create mode 100644 x-pack/plugin/inference/qa/mixed-cluster/build.gradle create mode 100644 x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/BaseMixedTestCase.java create mode 100644 x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/CohereServiceMixedIT.java create mode 100644 x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java create mode 100644 x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/MixedClusterSpecTestCase.java create mode 100644 x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/MixedClustersSpec.java create mode 100644 x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java diff --git a/TESTING.asciidoc b/TESTING.asciidoc index 96f94755a2758b..2c205f9090ba84 100644 --- a/TESTING.asciidoc +++ b/TESTING.asciidoc @@ -551,13 +551,19 @@ When running `./gradlew check`, minimal bwc checks are also run against compatib ==== BWC Testing against a specific remote/branch -Sometimes a backward compatibility change spans two versions. A common case is a new functionality -that needs a BWC bridge in an unreleased versioned of a release branch (for example, 5.x). -To test the changes, you can instruct Gradle to build the BWC version from another remote/branch combination instead of -pulling the release branch from GitHub. You do so using the `bwc.remote` and `bwc.refspec.BRANCH` system properties: +Sometimes a backward compatibility change spans two versions. +A common case is a new functionality that needs a BWC bridge in an unreleased versioned of a release branch (for example, 5.x). +Another use case, since the introduction of serverless, is to test BWC against main in addition to the other released branches. +To do so, specify the `bwc.refspec` remote and branch to use for the BWC build as `origin/main`. +To test against main, you will also need to create a new version in link:./server/src/main/java/org/elasticsearch/Version.java[Version.java], +increment `elasticsearch` in link:./build-tools-internal/version.properties[version.properties], and hard-code the `project.version` for ml-cpp +in link:./x-pack/plugin/ml/build.gradle[ml/build.gradle]. + +In general, to test the changes, you can instruct Gradle to build the BWC version from another remote/branch combination instead of pulling the release branch from GitHub. +You do so using the `bwc.refspec.{VERSION}` system property: ------------------------------------------------- -./gradlew check -Dbwc.remote=${remote} -Dbwc.refspec.5.x=index_req_bwc_5.x +./gradlew check -Dtests.bwc.refspec.8.15=origin/main ------------------------------------------------- The branch needs to be available on the remote that the BWC makes of the diff --git a/x-pack/plugin/inference/qa/mixed-cluster/build.gradle b/x-pack/plugin/inference/qa/mixed-cluster/build.gradle new file mode 100644 index 00000000000000..1d5369468b054d --- /dev/null +++ b/x-pack/plugin/inference/qa/mixed-cluster/build.gradle @@ -0,0 +1,37 @@ +import org.elasticsearch.gradle.Version +import org.elasticsearch.gradle.VersionProperties +import org.elasticsearch.gradle.util.GradleUtils +import org.elasticsearch.gradle.internal.info.BuildParams +import org.elasticsearch.gradle.testclusters.StandaloneRestIntegTestTask + +apply plugin: 'elasticsearch.internal-java-rest-test' +apply plugin: 'elasticsearch.internal-test-artifact-base' +apply plugin: 'elasticsearch.bwc-test' + +dependencies { + testImplementation project(path: ':x-pack:plugin:inference:qa:inference-service-tests') + compileOnly project(':x-pack:plugin:core') + javaRestTestImplementation(testArtifact(project(xpackModule('core')))) + javaRestTestImplementation project(path: xpackModule('inference')) + clusterPlugins project( + ':x-pack:plugin:inference:qa:test-service-plugin' + ) +} + +// inference is available in 8.11 or later +def supportedVersion = bwcVersion -> { + return bwcVersion.onOrAfter(Version.fromString("8.11.0")); +} + +BuildParams.bwcVersions.withWireCompatible(supportedVersion) { bwcVersion, baseName -> + def javaRestTest = tasks.register("v${bwcVersion}#javaRestTest", StandaloneRestIntegTestTask) { + usesBwcDistribution(bwcVersion) + systemProperty("tests.old_cluster_version", bwcVersion) + maxParallelForks = 1 + } + + tasks.register(bwcTaskName(bwcVersion)) { + dependsOn javaRestTest + } +} + diff --git a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/BaseMixedTestCase.java b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/BaseMixedTestCase.java new file mode 100644 index 00000000000000..2c47578f466e38 --- /dev/null +++ b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/BaseMixedTestCase.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.qa.mixed; + +import org.apache.http.util.EntityUtils; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.hamcrest.Matchers; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public abstract class BaseMixedTestCase extends MixedClusterSpecTestCase { + protected static String getUrl(MockWebServer webServer) { + return Strings.format("http://%s:%s", webServer.getHostName(), webServer.getPort()); + } + + @Override + protected Settings restClientSettings() { + String token = ESRestTestCase.basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); + return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); + } + + protected void delete(String inferenceId, TaskType taskType) throws IOException { + var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, inferenceId)); + var response = ESRestTestCase.client().performRequest(request); + ESRestTestCase.assertOK(response); + } + + protected void delete(String inferenceId) throws IOException { + var request = new Request("DELETE", Strings.format("_inference/%s", inferenceId)); + var response = ESRestTestCase.client().performRequest(request); + ESRestTestCase.assertOK(response); + } + + protected Map getAll() throws IOException { + var request = new Request("GET", "_inference/_all"); + var response = ESRestTestCase.client().performRequest(request); + ESRestTestCase.assertOK(response); + return ESRestTestCase.entityAsMap(response); + } + + protected Map get(String inferenceId) throws IOException { + var endpoint = Strings.format("_inference/%s", inferenceId); + var request = new Request("GET", endpoint); + var response = ESRestTestCase.client().performRequest(request); + ESRestTestCase.assertOK(response); + return ESRestTestCase.entityAsMap(response); + } + + protected Map get(TaskType taskType, String inferenceId) throws IOException { + var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceId); + var request = new Request("GET", endpoint); + var response = ESRestTestCase.client().performRequest(request); + ESRestTestCase.assertOK(response); + return ESRestTestCase.entityAsMap(response); + } + + protected Map inference(String inferenceId, TaskType taskType, String input) throws IOException { + var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceId); + var request = new Request("POST", endpoint); + request.setJsonEntity("{\"input\": [" + '"' + input + '"' + "]}"); + + var response = ESRestTestCase.client().performRequest(request); + ESRestTestCase.assertOK(response); + return ESRestTestCase.entityAsMap(response); + } + + protected Map rerank(String inferenceId, List inputs, String query) throws IOException { + var endpoint = Strings.format("_inference/rerank/%s", inferenceId); + var request = new Request("POST", endpoint); + + StringBuilder body = new StringBuilder("{").append("\"query\":\"").append(query).append("\",").append("\"input\":["); + + for (int i = 0; i < inputs.size(); i++) { + body.append("\"").append(inputs.get(i)).append("\""); + if (i < inputs.size() - 1) { + body.append(","); + } + } + + body.append("]}"); + request.setJsonEntity(body.toString()); + + var response = ESRestTestCase.client().performRequest(request); + ESRestTestCase.assertOK(response); + return ESRestTestCase.entityAsMap(response); + } + + protected void put(String inferenceId, String modelConfig, TaskType taskType) throws IOException { + String endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, inferenceId); + var request = new Request("PUT", endpoint); + request.setJsonEntity(modelConfig); + var response = ESRestTestCase.client().performRequest(request); + logger.warn("PUT response: {}", response.toString()); + System.out.println("PUT response: " + response.toString()); + ESRestTestCase.assertOKAndConsume(response); + } + + protected static void assertOkOrCreated(Response response) throws IOException { + int statusCode = response.getStatusLine().getStatusCode(); + // Once EntityUtils.toString(entity) is called the entity cannot be reused. + // Avoid that call with check here. + if (statusCode == 200 || statusCode == 201) { + return; + } + + String responseStr = EntityUtils.toString(response.getEntity()); + ESTestCase.assertThat( + responseStr, + response.getStatusLine().getStatusCode(), + Matchers.anyOf(Matchers.equalTo(200), Matchers.equalTo(201)) + ); + } +} diff --git a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/CohereServiceMixedIT.java b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/CohereServiceMixedIT.java new file mode 100644 index 00000000000000..69274b46d75c1c --- /dev/null +++ b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/CohereServiceMixedIT.java @@ -0,0 +1,271 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.qa.mixed; + +import org.elasticsearch.Version; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.hamcrest.Matchers; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.qa.mixed.MixedClusterSpecTestCase.bwcVersion; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.oneOf; + +public class CohereServiceMixedIT extends BaseMixedTestCase { + + private static final String COHERE_EMBEDDINGS_ADDED = "8.13.0"; + private static final String COHERE_RERANK_ADDED = "8.14.0"; + private static final String BYTE_ALIAS_FOR_INT8_ADDED = "8.14.0"; + private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0"; + + private static MockWebServer cohereEmbeddingsServer; + private static MockWebServer cohereRerankServer; + + @BeforeClass + public static void startWebServer() throws IOException { + cohereEmbeddingsServer = new MockWebServer(); + cohereEmbeddingsServer.start(); + + cohereRerankServer = new MockWebServer(); + cohereRerankServer.start(); + } + + @AfterClass + public static void shutdown() { + cohereEmbeddingsServer.close(); + cohereRerankServer.close(); + } + + @SuppressWarnings("unchecked") + public void testCohereEmbeddings() throws IOException { + var embeddingsSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_EMBEDDINGS_ADDED)); + assumeTrue("Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED, embeddingsSupported); + assumeTrue( + "Cohere service requires at least " + MINIMUM_SUPPORTED_VERSION, + bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION)) + ); + + final String inferenceIdInt8 = "mixed-cluster-cohere-embeddings-int8"; + final String inferenceIdFloat = "mixed-cluster-cohere-embeddings-float"; + + // queue a response as PUT will call the service + cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte())); + put(inferenceIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING); + + // float model + cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat())); + put(inferenceIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING); + + var configs = (List>) get(TaskType.TEXT_EMBEDDING, inferenceIdInt8).get("endpoints"); + assertEquals("cohere", configs.get(0).get("service")); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0")); + var embeddingType = serviceSettings.get("embedding_type"); + // An upgraded node will report the embedding type as byte, an old node int8 + assertThat(embeddingType, Matchers.is(oneOf("int8", "byte"))); + + configs = (List>) get(TaskType.TEXT_EMBEDDING, inferenceIdFloat).get("endpoints"); + serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("embedding_type", "float")); + + assertEmbeddingInference(inferenceIdInt8, CohereEmbeddingType.BYTE); + assertEmbeddingInference(inferenceIdFloat, CohereEmbeddingType.FLOAT); + + delete(inferenceIdFloat); + delete(inferenceIdInt8); + + } + + void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) throws IOException { + switch (type) { + case INT8: + case BYTE: + cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte())); + break; + case FLOAT: + cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat())); + } + + var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text"); + assertThat(inferenceMap.entrySet(), not(empty())); + } + + @SuppressWarnings("unchecked") + public void testRerank() throws IOException { + var rerankSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_RERANK_ADDED)); + assumeTrue("Cohere rerank service added in " + COHERE_RERANK_ADDED, rerankSupported); + assumeTrue( + "Cohere service requires at least " + MINIMUM_SUPPORTED_VERSION, + bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION)) + ); + + final String inferenceId = "mixed-cluster-rerank"; + + put(inferenceId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK); + assertRerank(inferenceId); + + var configs = (List>) get(TaskType.RERANK, inferenceId).get("endpoints"); + assertThat(configs, hasSize(1)); + assertEquals("cohere", configs.get(0).get("service")); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("model_id", "rerank-english-v3.0")); + var taskSettings = (Map) configs.get(0).get("task_settings"); + assertThat(taskSettings, hasEntry("top_n", 3)); + + assertRerank(inferenceId); + + } + + private void assertRerank(String inferenceId) throws IOException { + cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse())); + var inferenceMap = rerank( + inferenceId, + List.of("luke", "like", "leia", "chewy", "r2d2", "star", "wars"), + "star wars main character" + ); + assertThat(inferenceMap.entrySet(), not(empty())); + } + + private String embeddingConfigByte(String url) { + return embeddingConfigTemplate(url, "byte"); + } + + private String embeddingConfigInt8(String url) { + return embeddingConfigTemplate(url, "int8"); + } + + private String embeddingConfigFloat(String url) { + return embeddingConfigTemplate(url, "float"); + } + + private String embeddingConfigTemplate(String url, String embeddingType) { + return Strings.format(""" + { + "service": "cohere", + "service_settings": { + "url": "%s", + "api_key": "XXXX", + "model_id": "embed-english-light-v3.0", + "embedding_type": "%s" + } + } + """, url, embeddingType); + } + + private String embeddingResponseByte() { + return """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": [ + [ + 12, + 56 + ] + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_bytes" + } + """; + } + + private String embeddingResponseFloat() { + return """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": [ + [ + -0.0018434525, + 0.01777649 + ] + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_floats" + } + """; + } + + private String rerankConfig(String url) { + return Strings.format(""" + { + "service": "cohere", + "service_settings": { + "api_key": "XXXX", + "model_id": "rerank-english-v3.0", + "url": "%s" + }, + "task_settings": { + "return_documents": false, + "top_n": 3 + } + } + """, url); + } + + private String rerankResponse() { + return """ + { + "index": "d0760819-5a73-4d58-b163-3956d3648b62", + "results": [ + { + "index": 2, + "relevance_score": 0.98005307 + }, + { + "index": 3, + "relevance_score": 0.27904198 + }, + { + "index": 0, + "relevance_score": 0.10194652 + } + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "search_units": 1 + } + } + } + """; + } + +} diff --git a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java new file mode 100644 index 00000000000000..a2793f9060d8a8 --- /dev/null +++ b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java @@ -0,0 +1,147 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.qa.mixed; + +import org.elasticsearch.Version; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.not; + +public class HuggingFaceServiceMixedIT extends BaseMixedTestCase { + + private static final String HF_EMBEDDINGS_ADDED = "8.12.0"; + private static final String HF_ELSER_ADDED = "8.12.0"; + private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0"; + + private static MockWebServer embeddingsServer; + private static MockWebServer elserServer; + + @BeforeClass + public static void startWebServer() throws IOException { + embeddingsServer = new MockWebServer(); + embeddingsServer.start(); + + elserServer = new MockWebServer(); + elserServer.start(); + } + + @AfterClass + public static void shutdown() { + embeddingsServer.close(); + elserServer.close(); + } + + @SuppressWarnings("unchecked") + public void testHFEmbeddings() throws IOException { + var embeddingsSupported = bwcVersion.onOrAfter(Version.fromString(HF_EMBEDDINGS_ADDED)); + assumeTrue("Hugging Face embedding service added in " + HF_EMBEDDINGS_ADDED, embeddingsSupported); + assumeTrue( + "HuggingFace service requires at least " + MINIMUM_SUPPORTED_VERSION, + bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION)) + ); + + final String inferenceId = "mixed-cluster-embeddings"; + + embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse())); + put(inferenceId, embeddingConfig(getUrl(embeddingsServer)), TaskType.TEXT_EMBEDDING); + var configs = (List>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints"); + assertThat(configs, hasSize(1)); + assertEquals("hugging_face", configs.get(0).get("service")); + assertEmbeddingInference(inferenceId); + } + + void assertEmbeddingInference(String inferenceId) throws IOException { + embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse())); + var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text"); + assertThat(inferenceMap.entrySet(), not(empty())); + } + + @SuppressWarnings("unchecked") + public void testElser() throws IOException { + var supported = bwcVersion.onOrAfter(Version.fromString(HF_ELSER_ADDED)); + assumeTrue("HF elser service added in " + HF_ELSER_ADDED, supported); + assumeTrue( + "HuggingFace service requires at least " + MINIMUM_SUPPORTED_VERSION, + bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION)) + ); + + final String inferenceId = "mixed-cluster-elser"; + final String upgradedClusterId = "upgraded-cluster-elser"; + + put(inferenceId, elserConfig(getUrl(elserServer)), TaskType.SPARSE_EMBEDDING); + + var configs = (List>) get(TaskType.SPARSE_EMBEDDING, inferenceId).get("endpoints"); + assertThat(configs, hasSize(1)); + assertEquals("hugging_face", configs.get(0).get("service")); + assertElser(inferenceId); + } + + private void assertElser(String inferenceId) throws IOException { + elserServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse())); + var inferenceMap = inference(inferenceId, TaskType.SPARSE_EMBEDDING, "some text"); + assertThat(inferenceMap.entrySet(), not(empty())); + } + + private String embeddingConfig(String url) { + return Strings.format(""" + { + "service": "hugging_face", + "service_settings": { + "url": "%s", + "api_key": "XXXX" + } + } + """, url); + } + + private String embeddingResponse() { + return """ + [ + [ + 0.014539449, + -0.015288644 + ] + ] + """; + } + + private String elserConfig(String url) { + return Strings.format(""" + { + "service": "hugging_face", + "service_settings": { + "api_key": "XXXX", + "url": "%s" + } + } + """, url); + } + + private String elserResponse() { + return """ + [ + { + ".": 0.133155956864357, + "the": 0.6747211217880249 + } + ] + """; + } + +} diff --git a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/MixedClusterSpecTestCase.java b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/MixedClusterSpecTestCase.java new file mode 100644 index 00000000000000..45cd3716f21df9 --- /dev/null +++ b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/MixedClusterSpecTestCase.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.qa.mixed; + +import org.elasticsearch.Version; +import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.test.rest.TestFeatureService; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.ClassRule; + +public abstract class MixedClusterSpecTestCase extends ESRestTestCase { + @ClassRule + public static ElasticsearchCluster cluster = MixedClustersSpec.mixedVersionCluster(); + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + + static final Version bwcVersion = Version.fromString(System.getProperty("tests.old_cluster_version")); + + private static TestFeatureService oldClusterTestFeatureService = null; + + @Before + public void extractOldClusterFeatures() { + if (oldClusterTestFeatureService == null) { + oldClusterTestFeatureService = testFeatureService; + } + } + + protected static boolean oldClusterHasFeature(String featureId) { + assert oldClusterTestFeatureService != null; + return oldClusterTestFeatureService.clusterHasFeature(featureId); + } + + protected static boolean oldClusterHasFeature(NodeFeature feature) { + return oldClusterHasFeature(feature.id()); + } + + @AfterClass + public static void cleanUp() { + oldClusterTestFeatureService = null; + } + +} diff --git a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/MixedClustersSpec.java b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/MixedClustersSpec.java new file mode 100644 index 00000000000000..7802c2e966e019 --- /dev/null +++ b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/MixedClustersSpec.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.qa.mixed; + +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.cluster.util.Version; + +public class MixedClustersSpec { + public static ElasticsearchCluster mixedVersionCluster() { + Version oldVersion = Version.fromString(System.getProperty("tests.old_cluster_version")); + return ElasticsearchCluster.local() + .distribution(DistributionType.DEFAULT) + .withNode(node -> node.version(oldVersion)) + .withNode(node -> node.version(Version.CURRENT)) + .setting("xpack.security.enabled", "false") + .setting("xpack.license.self_generated.type", "trial") + .build(); + } +} diff --git a/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java new file mode 100644 index 00000000000000..33cad6a179281a --- /dev/null +++ b/x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java @@ -0,0 +1,223 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.qa.mixed; + +import org.elasticsearch.Version; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.qa.mixed.MixedClusterSpecTestCase.bwcVersion; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.not; + +public class OpenAIServiceMixedIT extends BaseMixedTestCase { + + private static final String OPEN_AI_EMBEDDINGS_ADDED = "8.12.0"; + private static final String OPEN_AI_EMBEDDINGS_MODEL_SETTING_MOVED = "8.13.0"; + private static final String OPEN_AI_COMPLETIONS_ADDED = "8.14.0"; + private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0"; + + private static MockWebServer openAiEmbeddingsServer; + private static MockWebServer openAiChatCompletionsServer; + + @BeforeClass + public static void startWebServer() throws IOException { + openAiEmbeddingsServer = new MockWebServer(); + openAiEmbeddingsServer.start(); + + openAiChatCompletionsServer = new MockWebServer(); + openAiChatCompletionsServer.start(); + } + + @AfterClass + public static void shutdown() { + openAiEmbeddingsServer.close(); + openAiChatCompletionsServer.close(); + } + + @SuppressWarnings("unchecked") + public void testOpenAiEmbeddings() throws IOException { + var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED)); + assumeTrue("OpenAI embedding service added in " + OPEN_AI_EMBEDDINGS_ADDED, openAiEmbeddingsSupported); + assumeTrue( + "OpenAI service requires at least " + MINIMUM_SUPPORTED_VERSION, + bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION)) + ); + + final String inferenceId = "mixed-cluster-embeddings"; + + String inferenceConfig = oldClusterVersionCompatibleEmbeddingConfig(); + // queue a response as PUT will call the service + openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse())); + put(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING); + + var configs = (List>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints"); + assertThat(configs, hasSize(1)); + assertEquals("openai", configs.get(0).get("service")); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + var taskSettings = (Map) configs.get(0).get("task_settings"); + var modelIdFound = serviceSettings.containsKey("model_id") || taskSettings.containsKey("model_id"); + assertTrue("model_id not found in config: " + configs.toString(), modelIdFound); + + assertEmbeddingInference(inferenceId); + } + + void assertEmbeddingInference(String inferenceId) throws IOException { + openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse())); + var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text"); + assertThat(inferenceMap.entrySet(), not(empty())); + } + + @SuppressWarnings("unchecked") + public void testOpenAiCompletions() throws IOException { + var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED)); + assumeTrue("OpenAI completions service added in " + OPEN_AI_COMPLETIONS_ADDED, openAiEmbeddingsSupported); + assumeTrue( + "OpenAI service requires at least " + MINIMUM_SUPPORTED_VERSION, + bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION)) + ); + + final String inferenceId = "mixed-cluster-completions"; + final String upgradedClusterId = "upgraded-cluster-completions"; + + put(inferenceId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), TaskType.COMPLETION); + + var configsMap = get(TaskType.COMPLETION, inferenceId); + logger.warn("Configs: {}", configsMap); + var configs = (List>) configsMap.get("endpoints"); + assertThat(configs, hasSize(1)); + assertEquals("openai", configs.get(0).get("service")); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("model_id", "gpt-4")); + var taskSettings = (Map) configs.get(0).get("task_settings"); + assertThat(taskSettings.keySet(), empty()); + + assertCompletionInference(inferenceId); + } + + void assertCompletionInference(String inferenceId) throws IOException { + openAiChatCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionsResponse())); + var inferenceMap = inference(inferenceId, TaskType.COMPLETION, "some text"); + assertThat(inferenceMap.entrySet(), not(empty())); + } + + private String oldClusterVersionCompatibleEmbeddingConfig() { + if (getOldClusterTestVersion().before(OPEN_AI_EMBEDDINGS_MODEL_SETTING_MOVED)) { + return embeddingConfigWithModelInTaskSettings(getUrl(openAiEmbeddingsServer)); + } else { + return embeddingConfigWithModelInServiceSettings(getUrl(openAiEmbeddingsServer)); + } + } + + protected static org.elasticsearch.test.cluster.util.Version getOldClusterTestVersion() { + return org.elasticsearch.test.cluster.util.Version.fromString(bwcVersion.toString()); + } + + private String embeddingConfigWithModelInTaskSettings(String url) { + return Strings.format(""" + { + "service": "openai", + "service_settings": { + "api_key": "XXXX", + "url": "%s" + }, + "task_settings": { + "model": "text-embedding-ada-002" + } + } + """, url); + } + + static String embeddingConfigWithModelInServiceSettings(String url) { + return Strings.format(""" + { + "service": "openai", + "service_settings": { + "api_key": "XXXX", + "url": "%s", + "model_id": "text-embedding-ada-002" + } + } + """, url); + } + + private String chatCompletionsConfig(String url) { + return Strings.format(""" + { + "service": "openai", + "service_settings": { + "api_key": "XXXX", + "url": "%s", + "model_id": "gpt-4" + } + } + """, url); + } + + static String embeddingResponse() { + return """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + } + + private String chatCompletionsResponse() { + return """ + { + "id": "some-id", + "object": "chat.completion", + "created": 1705397787, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "some content" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 46, + "completion_tokens": 39, + "total_tokens": 85 + }, + "system_fingerprint": null + } + """; + } + +}