Skip to content

Commit

Permalink
[ML] Add mixed cluster tests for inference (elastic#108392)
Browse files Browse the repository at this point in the history
* mixed cluster tests are executable

* add tests from upgrade tests

* [ML] Add mixed cluster tests for existing services

* clean up

* review improvements

* spotless

* remove blocked AzureOpenAI mixed IT

* improvements from DK review

* temp for testing

* refactoring and documentation

* Revert manual testing configs of "temp for testing"

This reverts parts of commit fca46fd.

* revert TESTING.asciidoc formatting

* Update TESTING.asciidoc to avoid reformatting

* add minimum version for tests to match minimum version in services

* spotless
  • Loading branch information
maxhniebergall authored and parkertimmins committed May 17, 2024
1 parent 3a49909 commit d53f684
Show file tree
Hide file tree
Showing 8 changed files with 896 additions and 5 deletions.
16 changes: 11 additions & 5 deletions TESTING.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -551,13 +551,19 @@ When running `./gradlew check`, minimal bwc checks are also run against compatib

==== BWC Testing against a specific remote/branch

Sometimes a backward compatibility change spans two versions. A common case is a new functionality
that needs a BWC bridge in an unreleased versioned of a release branch (for example, 5.x).
To test the changes, you can instruct Gradle to build the BWC version from another remote/branch combination instead of
pulling the release branch from GitHub. You do so using the `bwc.remote` and `bwc.refspec.BRANCH` system properties:
Sometimes a backward compatibility change spans two versions.
A common case is a new functionality that needs a BWC bridge in an unreleased versioned of a release branch (for example, 5.x).
Another use case, since the introduction of serverless, is to test BWC against main in addition to the other released branches.
To do so, specify the `bwc.refspec` remote and branch to use for the BWC build as `origin/main`.
To test against main, you will also need to create a new version in link:./server/src/main/java/org/elasticsearch/Version.java[Version.java],
increment `elasticsearch` in link:./build-tools-internal/version.properties[version.properties], and hard-code the `project.version` for ml-cpp
in link:./x-pack/plugin/ml/build.gradle[ml/build.gradle].

In general, to test the changes, you can instruct Gradle to build the BWC version from another remote/branch combination instead of pulling the release branch from GitHub.
You do so using the `bwc.refspec.{VERSION}` system property:

-------------------------------------------------
./gradlew check -Dbwc.remote=${remote} -Dbwc.refspec.5.x=index_req_bwc_5.x
./gradlew check -Dtests.bwc.refspec.8.15=origin/main
-------------------------------------------------

The branch needs to be available on the remote that the BWC makes of the
Expand Down
37 changes: 37 additions & 0 deletions x-pack/plugin/inference/qa/mixed-cluster/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import org.elasticsearch.gradle.Version
import org.elasticsearch.gradle.VersionProperties
import org.elasticsearch.gradle.util.GradleUtils
import org.elasticsearch.gradle.internal.info.BuildParams
import org.elasticsearch.gradle.testclusters.StandaloneRestIntegTestTask

apply plugin: 'elasticsearch.internal-java-rest-test'
apply plugin: 'elasticsearch.internal-test-artifact-base'
apply plugin: 'elasticsearch.bwc-test'

dependencies {
testImplementation project(path: ':x-pack:plugin:inference:qa:inference-service-tests')
compileOnly project(':x-pack:plugin:core')
javaRestTestImplementation(testArtifact(project(xpackModule('core'))))
javaRestTestImplementation project(path: xpackModule('inference'))
clusterPlugins project(
':x-pack:plugin:inference:qa:test-service-plugin'
)
}

// inference is available in 8.11 or later
def supportedVersion = bwcVersion -> {
return bwcVersion.onOrAfter(Version.fromString("8.11.0"));
}

BuildParams.bwcVersions.withWireCompatible(supportedVersion) { bwcVersion, baseName ->
def javaRestTest = tasks.register("v${bwcVersion}#javaRestTest", StandaloneRestIntegTestTask) {
usesBwcDistribution(bwcVersion)
systemProperty("tests.old_cluster_version", bwcVersion)
maxParallelForks = 1
}

tasks.register(bwcTaskName(bwcVersion)) {
dependsOn javaRestTest
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.qa.mixed;

import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.hamcrest.Matchers;

import java.io.IOException;
import java.util.List;
import java.util.Map;

public abstract class BaseMixedTestCase extends MixedClusterSpecTestCase {
protected static String getUrl(MockWebServer webServer) {
return Strings.format("http://%s:%s", webServer.getHostName(), webServer.getPort());
}

@Override
protected Settings restClientSettings() {
String token = ESRestTestCase.basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
}

protected void delete(String inferenceId, TaskType taskType) throws IOException {
var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, inferenceId));
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
}

protected void delete(String inferenceId) throws IOException {
var request = new Request("DELETE", Strings.format("_inference/%s", inferenceId));
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
}

protected Map<String, Object> getAll() throws IOException {
var request = new Request("GET", "_inference/_all");
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}

protected Map<String, Object> get(String inferenceId) throws IOException {
var endpoint = Strings.format("_inference/%s", inferenceId);
var request = new Request("GET", endpoint);
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}

protected Map<String, Object> get(TaskType taskType, String inferenceId) throws IOException {
var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceId);
var request = new Request("GET", endpoint);
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}

protected Map<String, Object> inference(String inferenceId, TaskType taskType, String input) throws IOException {
var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceId);
var request = new Request("POST", endpoint);
request.setJsonEntity("{\"input\": [" + '"' + input + '"' + "]}");

var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}

protected Map<String, Object> rerank(String inferenceId, List<String> inputs, String query) throws IOException {
var endpoint = Strings.format("_inference/rerank/%s", inferenceId);
var request = new Request("POST", endpoint);

StringBuilder body = new StringBuilder("{").append("\"query\":\"").append(query).append("\",").append("\"input\":[");

for (int i = 0; i < inputs.size(); i++) {
body.append("\"").append(inputs.get(i)).append("\"");
if (i < inputs.size() - 1) {
body.append(",");
}
}

body.append("]}");
request.setJsonEntity(body.toString());

var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}

protected void put(String inferenceId, String modelConfig, TaskType taskType) throws IOException {
String endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, inferenceId);
var request = new Request("PUT", endpoint);
request.setJsonEntity(modelConfig);
var response = ESRestTestCase.client().performRequest(request);
logger.warn("PUT response: {}", response.toString());
System.out.println("PUT response: " + response.toString());
ESRestTestCase.assertOKAndConsume(response);
}

protected static void assertOkOrCreated(Response response) throws IOException {
int statusCode = response.getStatusLine().getStatusCode();
// Once EntityUtils.toString(entity) is called the entity cannot be reused.
// Avoid that call with check here.
if (statusCode == 200 || statusCode == 201) {
return;
}

String responseStr = EntityUtils.toString(response.getEntity());
ESTestCase.assertThat(
responseStr,
response.getStatusLine().getStatusCode(),
Matchers.anyOf(Matchers.equalTo(200), Matchers.equalTo(201))
);
}
}

0 comments on commit d53f684

Please sign in to comment.