Skip to content

Commit

Permalink
[Inference API] Add Azure AI Studio Embeddings and Chat Completion Su…
Browse files Browse the repository at this point in the history
…pport (elastic#108472)

* redo after messy merge commit

* cleanups; refactoring; and added a few tests

* filter xContent ratelimit; reduce boilerplate code

* fix checkstyle issue

* ... and spotlessApply

* set lower rate limit 240; rename back files

* clean lint

* fix code and tests after merge

* change completion temp and top_p to double

* clean lint

* add default max_new_tokens of 64

* constrain top_p temperature to 0.0-2.0 range

* remove Snowflake provider; cleanups
  • Loading branch information
markjhoy committed May 15, 2024
1 parent 172c059 commit e87047f
Show file tree
Hide file tree
Showing 55 changed files with 7,191 additions and 77 deletions.
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ static TransportVersion def(int id) {
public static final TransportVersion JOIN_STATUS_AGE_SERIALIZATION = def(8_656_00_0);
public static final TransportVersion ML_RERANK_DOC_OPTIONAL = def(8_657_00_0);
public static final TransportVersion FAILURE_STORE_FIELD_PARITY = def(8_658_00_0);
public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO = def(8_659_00_0);

/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionTaskSettings;
Expand Down Expand Up @@ -69,106 +73,137 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
new NamedWriteableRegistry.Entry(InferenceResults.class, LegacyTextEmbeddingResults.NAME, LegacyTextEmbeddingResults::new)
);

// Inference results
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, SparseEmbeddingResults.NAME, SparseEmbeddingResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingByteResults.NAME, TextEmbeddingByteResults::new)
);
addInferenceResultsNamedWriteables(namedWriteables);
addChunkedInferenceResultsNamedWriteables(namedWriteables);

// Empty default task settings
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new));

// Default secret settings
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, DefaultSecretSettings.NAME, DefaultSecretSettings::new));

addInternalElserNamedWriteables(namedWriteables);

// Internal TextEmbedding service config
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, ChatCompletionResults.NAME, ChatCompletionResults::new)
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticsearchInternalServiceSettings.NAME,
ElasticsearchInternalServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, RankedDocsResults.NAME, RankedDocsResults::new)
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
MultilingualE5SmallInternalServiceSettings.NAME,
MultilingualE5SmallInternalServiceSettings::new
)
);

// Chunked inference results
addHuggingFaceNamedWriteables(namedWriteables);
addOpenAiNamedWriteables(namedWriteables);
addCohereNamedWriteables(namedWriteables);
addAzureOpenAiNamedWriteables(namedWriteables);
addAzureAiStudioNamedWriteables(namedWriteables);

return namedWriteables;
}

private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ErrorChunkedInferenceResults.NAME,
ErrorChunkedInferenceResults::new
ServiceSettings.class,
AzureAiStudioEmbeddingsServiceSettings.NAME,
AzureAiStudioEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ChunkedSparseEmbeddingResults.NAME,
ChunkedSparseEmbeddingResults::new
TaskSettings.class,
AzureAiStudioEmbeddingsTaskSettings.NAME,
AzureAiStudioEmbeddingsTaskSettings::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ChunkedTextEmbeddingResults.NAME,
ChunkedTextEmbeddingResults::new
ServiceSettings.class,
AzureAiStudioChatCompletionServiceSettings.NAME,
AzureAiStudioChatCompletionServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ChunkedTextEmbeddingFloatResults.NAME,
ChunkedTextEmbeddingFloatResults::new
TaskSettings.class,
AzureAiStudioChatCompletionTaskSettings.NAME,
AzureAiStudioChatCompletionTaskSettings::new
)
);
}

private static void addAzureOpenAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ChunkedTextEmbeddingByteResults.NAME,
ChunkedTextEmbeddingByteResults::new
AzureOpenAiSecretSettings.class,
AzureOpenAiSecretSettings.NAME,
AzureOpenAiSecretSettings::new
)
);

// Empty default task settings
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new));

// Default secret settings
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, DefaultSecretSettings.NAME, DefaultSecretSettings::new));

// Internal ELSER config
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, ElserInternalServiceSettings.NAME, ElserInternalServiceSettings::new)
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
AzureOpenAiEmbeddingsServiceSettings.NAME,
AzureOpenAiEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, ElserMlNodeTaskSettings.NAME, ElserMlNodeTaskSettings::new)
new NamedWriteableRegistry.Entry(
TaskSettings.class,
AzureOpenAiEmbeddingsTaskSettings.NAME,
AzureOpenAiEmbeddingsTaskSettings::new
)
);

// Internal TextEmbedding service config
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticsearchInternalServiceSettings.NAME,
ElasticsearchInternalServiceSettings::new
AzureOpenAiCompletionServiceSettings.NAME,
AzureOpenAiCompletionServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
MultilingualE5SmallInternalServiceSettings.NAME,
MultilingualE5SmallInternalServiceSettings::new
TaskSettings.class,
AzureOpenAiCompletionTaskSettings.NAME,
AzureOpenAiCompletionTaskSettings::new
)
);
}

// Hugging Face config
private static void addCohereNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
HuggingFaceElserServiceSettings.NAME,
HuggingFaceElserServiceSettings::new
CohereEmbeddingsServiceSettings.NAME,
CohereEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
new NamedWriteableRegistry.Entry(TaskSettings.class, CohereEmbeddingsTaskSettings.NAME, CohereEmbeddingsTaskSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(SecretSettings.class, HuggingFaceElserSecretSettings.NAME, HuggingFaceElserSecretSettings::new)
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereRerankServiceSettings.NAME, CohereRerankServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, CohereRerankTaskSettings.NAME, CohereRerankTaskSettings::new)
);
}

// OpenAI
private static void addOpenAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
Expand All @@ -193,67 +228,86 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
OpenAiChatCompletionTaskSettings::new
)
);
}

// Cohere
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereServiceSettings.NAME, CohereServiceSettings::new)
);
private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
CohereEmbeddingsServiceSettings.NAME,
CohereEmbeddingsServiceSettings::new
HuggingFaceElserServiceSettings.NAME,
HuggingFaceElserServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, CohereEmbeddingsTaskSettings.NAME, CohereEmbeddingsTaskSettings::new)
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CohereRerankServiceSettings.NAME, CohereRerankServiceSettings::new)
new NamedWriteableRegistry.Entry(SecretSettings.class, HuggingFaceElserSecretSettings.NAME, HuggingFaceElserSecretSettings::new)
);
}

private static void addInternalElserNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, CohereRerankTaskSettings.NAME, CohereRerankTaskSettings::new)
new NamedWriteableRegistry.Entry(ServiceSettings.class, ElserInternalServiceSettings.NAME, ElserInternalServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, ElserMlNodeTaskSettings.NAME, ElserMlNodeTaskSettings::new)
);
}

// Azure OpenAI
private static void addChunkedInferenceResultsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
AzureOpenAiSecretSettings.class,
AzureOpenAiSecretSettings.NAME,
AzureOpenAiSecretSettings::new
InferenceServiceResults.class,
ErrorChunkedInferenceResults.NAME,
ErrorChunkedInferenceResults::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
AzureOpenAiEmbeddingsServiceSettings.NAME,
AzureOpenAiEmbeddingsServiceSettings::new
InferenceServiceResults.class,
ChunkedSparseEmbeddingResults.NAME,
ChunkedSparseEmbeddingResults::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
AzureOpenAiEmbeddingsTaskSettings.NAME,
AzureOpenAiEmbeddingsTaskSettings::new
InferenceServiceResults.class,
ChunkedTextEmbeddingResults.NAME,
ChunkedTextEmbeddingResults::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
AzureOpenAiCompletionServiceSettings.NAME,
AzureOpenAiCompletionServiceSettings::new
InferenceServiceResults.class,
ChunkedTextEmbeddingFloatResults.NAME,
ChunkedTextEmbeddingFloatResults::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
AzureOpenAiCompletionTaskSettings.NAME,
AzureOpenAiCompletionTaskSettings::new
InferenceServiceResults.class,
ChunkedTextEmbeddingByteResults.NAME,
ChunkedTextEmbeddingByteResults::new
)
);
}

return namedWriteables;
private static void addInferenceResultsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, SparseEmbeddingResults.NAME, SparseEmbeddingResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, TextEmbeddingByteResults.NAME, TextEmbeddingByteResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, ChatCompletionResults.NAME, ChatCompletionResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, RankedDocsResults.NAME, RankedDocsResults::new)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.elasticsearch.xpack.inference.rest.RestInferenceAction;
import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
Expand Down Expand Up @@ -190,6 +191,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
context -> new CohereService(httpFactory.get(), serviceComponents.get()),
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()),
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.azureaistudio;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.AzureAiStudioRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;

public class AzureAiStudioAction implements ExecutableAction {
protected final Sender sender;
protected final AzureAiStudioRequestManager requestCreator;
protected final String errorMessage;

protected AzureAiStudioAction(Sender sender, AzureAiStudioRequestManager requestCreator, String errorMessage) {
this.sender = sender;
this.requestCreator = requestCreator;
this.errorMessage = errorMessage;
}

@Override
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
try {
ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener);

sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
} catch (ElasticsearchException e) {
listener.onFailure(e);
} catch (Exception e) {
listener.onFailure(createInternalServerError(e, errorMessage));
}
}
}
Loading

0 comments on commit e87047f

Please sign in to comment.