From 89f84327f8167d25fd1c638a80be224bd71426c6 Mon Sep 17 00:00:00 2001 From: lvchzh Date: Thu, 30 May 2024 19:14:07 +0800 Subject: [PATCH 1/4] feature:add Wenxin model client --- models/spring-ai-wenxin/README.md | 0 models/spring-ai-wenxin/pom.xml | 95 ++++ .../ai/wenxin/WenxinAudioSpeechModel.java | 10 + .../ai/wenxin/WenxinAudioSpeechOptions.java | 10 + .../wenxin/WenxinAudioTranscriptionModel.java | 10 + .../WenxinAudioTranscriptionOptions.java | 10 + .../ai/wenxin/WenxinChatModel.java | 261 ++++++++++ .../ai/wenxin/WenxinChatOptions.java | 483 ++++++++++++++++++ .../ai/wenxin/WenxinEmbeddingModel.java | 119 +++++ .../ai/wenxin/WenxinEmbeddingOptions.java | 61 +++ .../ai/wenxin/WenxinImageModel.java | 10 + .../ai/wenxin/WenxinImageOptions.java | 10 + .../ai/wenxin/aot/WenxinRuntimeHints.java | 35 ++ .../ai/wenxin/api/ApiUtils.java | 102 ++++ .../api/CustomResponseErrorHandler.java | 39 ++ .../ai/wenxin/api/WenxinApi.java | 463 +++++++++++++++++ .../ai/wenxin/api/WenxinAudioApi.java | 10 + .../ai/wenxin/api/WenxinImageApi.java | 10 + .../WenxinStreamFunctionCallingHelper.java | 191 +++++++ .../common/WenxinApiClientErrorException.java | 18 + .../wenxin/api/common/WenxinApiException.java | 18 + .../ai/wenxin/audio/speech/Speech.java | 10 + .../ai/wenxin/audio/speech/SpeechMessage.java | 10 + .../ai/wenxin/audio/speech/SpeechModel.java | 10 + .../ai/wenxin/audio/speech/SpeechPrompt.java | 10 + .../wenxin/audio/speech/SpeechResponse.java | 10 + .../audio/speech/StreamingSpeechModel.java | 10 + .../transcription/AudioTranscription.java | 10 + .../AudioTranscriptionPrompt.java | 10 + .../AudioTranscriptionResponse.java | 10 + .../metadata/WenxinChatResponseMetadata.java | 74 +++ .../WenxinImageGenerationMetadata.java | 10 + .../metadata/WenxinImageResponseMetadata.java | 10 + .../ai/wenxin/metadata/WenxinRateLimit.java | 69 +++ .../ai/wenxin/metadata/WenxinUsage.java | 49 ++ .../audio/WenxinAudioSpeechMetadata.java | 10 + .../WenxinAudioSpeechResponseMetadata.java | 10 + .../WenxinAudioTranscriptionMetadata.java | 10 + ...nxinAudioTransriptionResponseMetadata.java | 10 + .../support/WenxinApiResponseHeaders.java | 34 ++ .../WenxinResponseHeaderExtractor.java | 60 +++ .../resources/META-INF/spring/aot.factories | 2 + .../ai/wenxin/api/WenxinApiIT.java | 45 ++ .../ai/wenxin/embedding/EmbeddingIT.java | 30 ++ pom.xml | 2 + spring-ai-bom/pom.xml | 5 + spring-ai-spring-boot-autoconfigure/pom.xml | 7 + .../wenxin/WenxinAudioSpeechProperties.java | 10 + .../WenxinAudioTranscriptionProperties.java | 10 + .../wenxin/WenxinAutoConfiguration.java | 105 ++++ .../wenxin/WenxinChatProperties.java | 44 ++ .../wenxin/WenxinConnectionProperties.java | 21 + .../wenxin/WenxinEmbeddingProperties.java | 51 ++ .../wenxin/WenxinImageProperties.java | 10 + .../wenxin/WenxinParentProperties.java | 40 ++ ...ot.autoconfigure.AutoConfiguration.imports | 1 + .../spring-ai-starter-wenxin/pom.xml | 42 ++ 57 files changed, 2826 insertions(+) create mode 100644 models/spring-ai-wenxin/README.md create mode 100644 models/spring-ai-wenxin/pom.xml create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioSpeechModel.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioSpeechOptions.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioTranscriptionModel.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioTranscriptionOptions.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatModel.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatOptions.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingModel.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingOptions.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinImageModel.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinImageOptions.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/aot/WenxinRuntimeHints.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/ApiUtils.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/CustomResponseErrorHandler.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinApi.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinAudioApi.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinImageApi.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinStreamFunctionCallingHelper.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiClientErrorException.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiException.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/Speech.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechMessage.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechModel.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechPrompt.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechResponse.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/StreamingSpeechModel.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/transcription/AudioTranscription.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/transcription/AudioTranscriptionPrompt.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/transcription/AudioTranscriptionResponse.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinChatResponseMetadata.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinImageGenerationMetadata.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinImageResponseMetadata.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinRateLimit.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinUsage.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioSpeechMetadata.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioSpeechResponseMetadata.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioTranscriptionMetadata.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioTransriptionResponseMetadata.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinApiResponseHeaders.java create mode 100644 models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinResponseHeaderExtractor.java create mode 100644 models/spring-ai-wenxin/src/main/resources/META-INF/spring/aot.factories create mode 100644 models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/api/WenxinApiIT.java create mode 100644 models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/embedding/EmbeddingIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAudioSpeechProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAudioTranscriptionProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAutoConfiguration.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinChatProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinConnectionProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinEmbeddingProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinImageProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinParentProperties.java create mode 100644 spring-ai-spring-boot-starters/spring-ai-starter-wenxin/pom.xml diff --git a/models/spring-ai-wenxin/README.md b/models/spring-ai-wenxin/README.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/models/spring-ai-wenxin/pom.xml b/models/spring-ai-wenxin/pom.xml new file mode 100644 index 00000000000..0dd62468c9d --- /dev/null +++ b/models/spring-ai-wenxin/pom.xml @@ -0,0 +1,95 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-wenxin + jar + Spring AI Wenxin + Wenxin support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + commons-codec + commons-codec + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + + + + org.springframework.boot + spring-boot + + + + io.rest-assured + json-path + + + + + com.github.victools + jsonschema-generator + ${victools.version} + + + + com.github.victools + jsonschema-module-jackson + ${victools.version} + + + + org.springframework + spring-context-support + + + org.springframework.boot + spring-boot-starter-logging + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + + + + maven_central + Maven Central + https://repo.maven.apache.org/maven2/ + + + + diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioSpeechModel.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioSpeechModel.java new file mode 100644 index 00000000000..f7548ceef32 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioSpeechModel.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin; + +/** + * @author lvchzh + * @date 2024年05月27日 上午9:32 + * @description: + */ +public class WenxinAudioSpeechModel { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioSpeechOptions.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioSpeechOptions.java new file mode 100644 index 00000000000..5e4e0b8d987 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioSpeechOptions.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin; + +/** + * @author lvchzh + * @date 2024年05月27日 上午9:33 + * @description: + */ +public class WenxinAudioSpeechOptions { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioTranscriptionModel.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioTranscriptionModel.java new file mode 100644 index 00000000000..4cc21e5451f --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioTranscriptionModel.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:06 + * @description: + */ +public class WenxinAudioTranscriptionModel { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioTranscriptionOptions.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioTranscriptionOptions.java new file mode 100644 index 00000000000..4a408249cdf --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinAudioTranscriptionOptions.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:13 + * @description: + */ +public class WenxinAudioTranscriptionOptions { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatModel.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatModel.java new file mode 100644 index 00000000000..4a77a54e91b --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatModel.java @@ -0,0 +1,261 @@ +package org.springframework.ai.wenxin; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.ai.wenxin.metadata.WenxinChatResponseMetadata; +import org.springframework.ai.wenxin.metadata.support.WenxinResponseHeaderExtractor; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import reactor.core.publisher.Flux; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:26 + * @description: + */ +public class WenxinChatModel extends + AbstractFunctionCallSupport> + implements ChatModel, StreamingChatModel { + + // @formatter:off + private static final Logger logger = LoggerFactory.getLogger(WenxinChatModel.class); + private final RetryTemplate retryTemplate; + private final WenxinApi wenxinApi; + private WenxinChatOptions defaultOptions; + + public WenxinChatModel(WenxinApi wenxinApi) { + this(wenxinApi, + WenxinChatOptions.builder().withModel(WenxinApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build()); + } + + public WenxinChatModel(WenxinApi wenxinApi, WenxinChatOptions options) { + this(wenxinApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public WenxinChatModel(WenxinApi wenxinApi, WenxinChatOptions options, + FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) { + super(functionCallbackContext); + Assert.notNull(wenxinApi, "WenxinApi must not be null"); + Assert.notNull(options, "WenxinChatOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + this.wenxinApi = wenxinApi; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + } + + @Override + public ChatResponse call(Prompt prompt) { + + WenxinApi.ChatCompletionRequest request = createRequest(prompt, false); + + return this.retryTemplate.execute(ctx -> { + + ResponseEntity completionEntity = this.callWithFunctionSupport(request); + + var chatCompletion = completionEntity.getBody(); + + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + RateLimit rateLimits = WenxinResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); + + Generation generation = new Generation(chatCompletion.result(), toMap(chatCompletion.id(), + chatCompletion)); + + List generations = List.of(generation); + + return new ChatResponse(generations, + WenxinChatResponseMetadata.from(chatCompletion).withRateLimit(rateLimits)); + }); + } + + @Override + public ChatOptions getDefaultOptions() { + return WenxinChatOptions.fromOptions(this.defaultOptions); + } + + private Map toMap(String id, WenxinApi.ChatCompletion chatCompletion) { + Map map = new HashMap<>(); + if (chatCompletion.finishReason() != null) { + map.put("finishReason", chatCompletion.finishReason().name()); + } + map.put("id", id); + return map; + } + + @Override + public Flux stream(Prompt prompt) { + WenxinApi.ChatCompletionRequest request = createRequest(prompt, true); + + return this.retryTemplate.execute(ctx -> { + + Flux completionChunks = this.wenxinApi.chatCompletionStream(request); + + return completionChunks.map(chunk -> chunkToChatCompletion(chunk)).map(chatCompletion -> { + try { + chatCompletion = handleFunctionCallOrReturn(request, + ResponseEntity.of(Optional.of(chatCompletion))).getBody(); + + @SuppressWarnings("null") + String id = chatCompletion.id(); + String finish = chatCompletion.finishReason() != null ? chatCompletion.finishReason().name() : + null; + + var generation = new Generation(chatCompletion.result(), Map.of("id", id, "finishReason", finish)); + if (chatCompletion.finishReason() != null) { + generation = generation.withGenerationMetadata( + ChatGenerationMetadata.from(chatCompletion.finishReason().name(), null)); + } + List generations = List.of(generation); + + return new ChatResponse(generations); + } catch (Exception e) { + logger.error("Error processing chat completion", e); + return new ChatResponse(List.of()); + } + }); + }); + } + + private WenxinApi.ChatCompletion chunkToChatCompletion(WenxinApi.ChatCompletionChunk chunk) { + return new WenxinApi.ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.sentenceId(), + chunk.isEnd(), chunk.isTruncated(), chunk.finishReason(), chunk.searchInfo(), chunk.result(), + chunk.needClearHistory(), chunk.flag(), chunk.banRound(), null, chunk.functionCall()); + } + + WenxinApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + + Set functionsForThisRequest = new HashSet<>(); + + List chatCompletionMessages = prompt.getInstructions().stream() + .map(m -> new WenxinApi.ChatCompletionMessage(m.getContent(), + WenxinApi.Role.valueOf(m.getMessageType().name()))).toList(); + WenxinApi.ChatCompletionRequest request = new WenxinApi.ChatCompletionRequest(chatCompletionMessages, stream); + + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { + WenxinChatOptions updateRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, + ChatOptions.class, WenxinChatOptions.class); + + Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updateRuntimeOptions, + IS_RUNTIME_CALL); + + functionsForThisRequest.addAll(promptEnabledFunctions); + + request = ModelOptionsUtils.merge(updateRuntimeOptions, request, + WenxinApi.ChatCompletionRequest.class); + } else { + throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + + prompt.getOptions().getClass().getSimpleName()); + } + } + + if (this.defaultOptions != null) { + + Set defaultEnableFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, + !IS_RUNTIME_CALL); + + functionsForThisRequest.addAll(defaultEnableFunctions); + + request = ModelOptionsUtils.merge(this.defaultOptions, request, WenxinApi.ChatCompletionRequest.class); + } + + if (!CollectionUtils.isEmpty(functionsForThisRequest)) { + request = ModelOptionsUtils.merge( + WenxinChatOptions.builder().withTools(this.getFunctionTools(functionsForThisRequest)).build(), + request, WenxinApi.ChatCompletionRequest.class); + } + + return request; + } + + private List getFunctionTools(Set functionNames) { + return this.resolveFunctionCallbacks(functionNames).stream().map(functioncallback -> { + var function = new WenxinApi.FunctionTool(functioncallback.getName(), functioncallback.getDescription(), + functioncallback.getInputTypeSchema(), null); + return function; + }).toList(); + } + + @Override + protected WenxinApi.ChatCompletionRequest doCreateToolResponseRequest( + WenxinApi.ChatCompletionRequest previousRequest, WenxinApi.ChatCompletionMessage responseMessage, + List conversationHistory) { + + var functionName = responseMessage.functionCall().name(); + String functionArguments = responseMessage.functionCall().arguments(); + if (!this.functionCallbackRegister.containsKey(functionName)) { + throw new IllegalStateException("Function callback not found for function name: " + functionName); + } + + String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); + + conversationHistory.add( + new WenxinApi.ChatCompletionMessage(functionResponse, WenxinApi.Role.FUNCTION, functionName, null)); + + WenxinApi.ChatCompletionRequest newRequest = new WenxinApi.ChatCompletionRequest(conversationHistory, false); + newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, WenxinApi.ChatCompletionRequest.class); + return newRequest; + } + + @Override + protected List doGetUserMessages(WenxinApi.ChatCompletionRequest request) { + return request.messages(); + } + + @Override + protected WenxinApi.ChatCompletionMessage doGetToolResponseMessage( + ResponseEntity chatCompletion) { + return new WenxinApi.ChatCompletionMessage(chatCompletion.getBody().result(), WenxinApi.Role.ASSISTANT, null, + chatCompletion.getBody().functionCall()); + } + + @Override + protected ResponseEntity doChatCompletion(WenxinApi.ChatCompletionRequest request) { + return this.wenxinApi.chatCompletionEntity(request); + } + + @Override + protected Flux> doChatCompletionStream( + WenxinApi.ChatCompletionRequest request) { + return this.wenxinApi.chatCompletionStream(request) + .map(this::chunkToChatCompletion) + .map(Optional::ofNullable) + .map(ResponseEntity::of); + } + + @Override + protected boolean isToolFunctionCall(ResponseEntity chatCompletion) { + var body = chatCompletion.getBody(); + if (body == null) { + return false; + } + return body.finishReason() == WenxinApi.ChatCompletionFinishReason.FUNCTION_CALL; + } + // @formatter:on + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatOptions.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatOptions.java new file mode 100644 index 00000000000..103c3fa231d --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatOptions.java @@ -0,0 +1,483 @@ +package org.springframework.ai.wenxin; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:26 + * @description: + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class WenxinChatOptions implements FunctionCallingOptions, ChatOptions { + + // + private @JsonProperty("model") String model; + + private @JsonProperty("penalty_score") Float penaltyScore; + + private @JsonProperty("max_output_tokens") Integer maxOutputTokens; + + private @JsonProperty("response_format") WenxinApi.ChatCompletionRequest.ResponseFormat responseFormat; + + private @JsonProperty("stop") List stop; + + private @JsonProperty("temperature") Float temperature; + + private @JsonProperty("top_p") Float topP; + + private @JsonProperty("functions") List tools; + + private @JsonProperty("tool_choice") String toolChoice; + + private @JsonProperty("user_id") String userId; + + private @JsonProperty("system") String system; + + private @JsonProperty("disable_search") Boolean disableSearch; + + private @JsonProperty("enable_citation") Boolean enableCitation; + + private @JsonProperty("enable_trace") Boolean enableTrace; + + @NestedConfigurationProperty + @JsonIgnore + private List functionCallbacks = new ArrayList<>(); + + @NestedConfigurationProperty + @JsonIgnore + private Set functions = new HashSet<>(); + + public static Builder builder() { + return new Builder(); + } + + // @formatter:off + public static WenxinChatOptions fromOptions(WenxinChatOptions fromOptions) { + return WenxinChatOptions.builder() + .withModel(fromOptions.getModel()) + .withPenaltyScore(fromOptions.getPenaltyScore()) + .withMaxOutputTokens(fromOptions.getMaxOutputTokens()) + .withResponseFormat(fromOptions.getResponseFormat()) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTools(fromOptions.getTools()) + .withToolChoice(fromOptions.getToolChoice()) + .withUserId(fromOptions.getUserId()) + .withSystem(fromOptions.getSystem()) + .withDisableSearch(fromOptions.getDisableSearch()) + .withEnableCitation(fromOptions.getEnableCitation()) + .withEnableTrace(fromOptions.getEnableTrace()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .build(); + } + // @formatter:on + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public Float getPenaltyScore() { + return penaltyScore; + } + + public void setPenaltyScore(Float penaltyScore) { + this.penaltyScore = penaltyScore; + } + + public Integer getMaxOutputTokens() { + return maxOutputTokens; + } + + public void setMaxOutputTokens(Integer maxOutputTokens) { + this.maxOutputTokens = maxOutputTokens; + } + + public WenxinApi.ChatCompletionRequest.ResponseFormat getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(WenxinApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + public List getStop() { + return stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + @Override + public Float getTemperature() { + return this.temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public Float getTopP() { + return this.topP; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + @Override + @JsonIgnore + public Integer getTopK() { + throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); + } + + @JsonIgnore + public void setTopK(Integer topK) { + throw new UnsupportedOperationException("Unimplemented method 'setTopK'"); + } + + public List getTools() { + return tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + public String getToolChoice() { + return toolChoice; + } + + public void setToolChoice(String toolChoice) { + this.toolChoice = toolChoice; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public String getSystem() { + return system; + } + + public void setSystem(String system) { + this.system = system; + } + + public Boolean getDisableSearch() { + return disableSearch; + } + + public void setDisableSearch(Boolean disableSearch) { + this.disableSearch = disableSearch; + } + + public Boolean getEnableCitation() { + return enableCitation; + } + + public void setEnableCitation(Boolean enableCitation) { + this.enableCitation = enableCitation; + } + + public Boolean getEnableTrace() { + return enableTrace; + } + + public void setEnableTrace(Boolean enableTrace) { + this.enableTrace = enableTrace; + } + + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + this.functionCallbacks = functionCallbacks; + } + + @Override + public Set getFunctions() { + return functions; + } + + @Override + public void setFunctions(Set functionNames) { + this.functions = functionNames; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((penaltyScore == null) ? 0 : penaltyScore.hashCode()); + result = prime * result + ((maxOutputTokens == null) ? 0 : maxOutputTokens.hashCode()); + result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + result = prime * result + ((stop == null) ? 0 : stop.hashCode()); + result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); + result = prime * result + ((topP == null) ? 0 : topP.hashCode()); + result = prime * result + ((tools == null) ? 0 : tools.hashCode()); + result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); + result = prime * result + ((userId == null) ? 0 : userId.hashCode()); + result = prime * result + ((system == null) ? 0 : system.hashCode()); + result = prime * result + ((disableSearch == null) ? 0 : disableSearch.hashCode()); + result = prime * result + ((enableCitation == null) ? 0 : enableCitation.hashCode()); + result = prime * result + ((enableTrace == null) ? 0 : enableTrace.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + WenxinChatOptions other = (WenxinChatOptions) obj; + if (this.model == null) { + if (other.model != null) { + return false; + } + } + else if (!this.model.equals(other.model)) { + return false; + } + if (this.penaltyScore == null) { + if (other.penaltyScore != null) { + return false; + } + } + else if (!this.penaltyScore.equals(other.penaltyScore)) { + return false; + } + if (this.maxOutputTokens == null) { + if (other.maxOutputTokens != null) { + return false; + } + } + else if (!this.maxOutputTokens.equals(other.maxOutputTokens)) { + return false; + } + if (this.responseFormat != other.responseFormat) { + return false; + } + if (this.stop == null) { + if (other.stop != null) { + return false; + } + } + else if (!this.stop.equals(other.stop)) { + return false; + } + if (this.temperature == null) { + if (other.temperature != null) { + return false; + } + } + else if (!this.temperature.equals(other.temperature)) { + return false; + } + if (this.topP == null) { + if (other.topP != null) { + return false; + } + } + else if (!this.topP.equals(other.topP)) { + return false; + } + if (this.tools == null) { + if (other.tools != null) { + return false; + } + } + else if (!this.tools.equals(other.tools)) { + return false; + } + if (this.toolChoice == null) { + if (other.toolChoice != null) { + return false; + } + } + else if (!this.toolChoice.equals(other.toolChoice)) { + return false; + } + if (this.userId == null) { + if (other.userId != null) { + return false; + } + } + else if (!this.userId.equals(other.userId)) { + return false; + } + if (this.system == null) { + if (other.system != null) { + return false; + } + } + else if (!this.system.equals(other.system)) { + return false; + } + if (this.disableSearch == null) { + if (other.disableSearch != null) { + return false; + } + } + else if (!this.disableSearch.equals(other.disableSearch)) { + return false; + } + if (this.enableCitation == null) { + if (other.enableCitation != null) { + return false; + } + } + else if (!this.enableCitation.equals(other.enableCitation)) { + return false; + } + if (this.enableTrace == null) { + if (other.enableTrace != null) { + return false; + } + } + else if (!this.enableTrace.equals(other.enableTrace)) { + return false; + } + return true; + } + + public static class Builder { + + protected WenxinChatOptions options; + + public Builder() { + this.options = new WenxinChatOptions(); + } + + public Builder(WenxinChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withPenaltyScore(Float penaltyScore) { + this.options.penaltyScore = penaltyScore; + return this; + } + + public Builder withMaxOutputTokens(Integer maxOutputTokens) { + this.options.maxOutputTokens = maxOutputTokens; + return this; + } + + public Builder withResponseFormat(WenxinApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Float temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Float topP) { + this.options.topP = topP; + return this; + } + + public Builder withTools(List tools) { + this.options.tools = tools; + return this; + } + + public Builder withToolChoice(String toolChoice) { + this.options.toolChoice = toolChoice; + return this; + } + + public Builder withUserId(String userId) { + this.options.userId = userId; + return this; + } + + public Builder withSystem(String system) { + this.options.system = system; + return this; + } + + public Builder withDisableSearch(Boolean disableSearch) { + this.options.disableSearch = disableSearch; + return this; + } + + public Builder withEnableCitation(Boolean enableCitation) { + this.options.enableCitation = enableCitation; + return this; + } + + public Builder withEnableTrace(Boolean enableTrace) { + this.options.enableTrace = enableTrace; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public WenxinChatOptions build() { + return this.options; + } + + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingModel.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingModel.java new file mode 100644 index 00000000000..830f96eb936 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingModel.java @@ -0,0 +1,119 @@ +package org.springframework.ai.wenxin; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.AbstractEmbeddingModel; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import java.util.List; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:27 + * @description: + */ +public class WenxinEmbeddingModel extends AbstractEmbeddingModel { + + private static final Logger logger = LoggerFactory.getLogger(WenxinEmbeddingModel.class); + + private final WenxinEmbeddingOptions defaultOptions; + + private final RetryTemplate retryTemplate; + + private final WenxinApi wenxinApi; + + private final MetadataMode metadataMode; + + public WenxinEmbeddingModel(WenxinApi wenxinApi) { + this(wenxinApi, MetadataMode.EMBED); + } + + public WenxinEmbeddingModel(WenxinApi wenxinApi, MetadataMode metadataMode) { + this(wenxinApi, metadataMode, + WenxinEmbeddingOptions.builder().withModel(WenxinApi.DEFAULT_EMBEDDING_MODEL).build(), + RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public WenxinEmbeddingModel(WenxinApi wenxinApi, MetadataMode metadataMode, + WenxinEmbeddingOptions wenxinEmbeddingOptions) { + this(wenxinApi, metadataMode, wenxinEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public WenxinEmbeddingModel(WenxinApi wenxinApi, MetadataMode metadataMode, WenxinEmbeddingOptions options, + RetryTemplate retryTemplate) { + Assert.notNull(wenxinApi, "WenxinApi must not be null"); + Assert.notNull(metadataMode, "MetadataMode must not be null"); + Assert.notNull(options, "WenxinEmbeddingOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + + this.wenxinApi = wenxinApi; + this.metadataMode = metadataMode; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + } + + @Override + public List embed(Document document) { + Assert.notNull(document, "Document must not be null"); + return this.embed(document.getFormattedContent(this.metadataMode)); + } + + @SuppressWarnings("unchecked") + @Override + public EmbeddingResponse call(EmbeddingRequest request) { + + return this.retryTemplate.execute(ctx -> { + + WenxinApi.EmbeddingRequest> apiRequest = (this.defaultOptions != null) + ? new WenxinApi.EmbeddingRequest<>(request.getInstructions(), this.defaultOptions.getModel(), + this.defaultOptions.getUserId()) + : new WenxinApi.EmbeddingRequest<>(request.getInstructions(), WenxinApi.DEFAULT_EMBEDDING_MODEL); + + if (request.getOptions() != null && !EmbeddingOptions.EMPTY.equals(request.getOptions())) { + apiRequest = ModelOptionsUtils.merge(request.getOptions(), apiRequest, + WenxinApi.EmbeddingRequest.class); + } + + WenxinApi.EmbeddingList apiEmbeddingResponse = this.wenxinApi.embeddings(apiRequest) + .getBody(); + + if (apiEmbeddingResponse == null) { + logger.warn("No embeddings returned from request: {}", request); + return new EmbeddingResponse(List.of()); + } + + var metadata = generateResponseMetadata(apiEmbeddingResponse.id(), apiEmbeddingResponse.object(), + apiEmbeddingResponse.created(), apiEmbeddingResponse.usage()); + + List embeddings = apiEmbeddingResponse.data() + .stream() + .map(e -> new Embedding(e.embedding(), e.index())) + .toList(); + + return new EmbeddingResponse(embeddings, metadata); + }); + } + + private EmbeddingResponseMetadata generateResponseMetadata(String id, String object, Long created, + WenxinApi.EmbeddingList.Usage usage) { + EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); + metadata.put("id", id); + metadata.put("object", object); + metadata.put("created", created); + metadata.put("prompt-tokens", usage.promptTokens()); + metadata.put("total-tokens", usage.totalTokens()); + return metadata; + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingOptions.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingOptions.java new file mode 100644 index 00000000000..3e708fd0d73 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingOptions.java @@ -0,0 +1,61 @@ +package org.springframework.ai.wenxin; + +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.embedding.EmbeddingOptions; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:28 + * @description: + */ +public class WenxinEmbeddingOptions implements EmbeddingOptions { + + private @JsonProperty("model") String model; + + private @JsonProperty("user_id") String userId; + + public static Builder builder() { + return new Builder(); + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getUserId() { + return this.userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public static class Builder { + + protected WenxinEmbeddingOptions options; + + public Builder() { + this.options = new WenxinEmbeddingOptions(); + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withUserId(String userId) { + this.options.setUserId(userId); + return this; + } + + public WenxinEmbeddingOptions build() { + return this.options; + } + + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinImageModel.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinImageModel.java new file mode 100644 index 00000000000..39a1ebe2274 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinImageModel.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:29 + * @description: + */ +public class WenxinImageModel { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinImageOptions.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinImageOptions.java new file mode 100644 index 00000000000..1c8f93cba9a --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinImageOptions.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:29 + * @description: + */ +public class WenxinImageOptions { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/aot/WenxinRuntimeHints.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/aot/WenxinRuntimeHints.java new file mode 100644 index 00000000000..570ed95e89b --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/aot/WenxinRuntimeHints.java @@ -0,0 +1,35 @@ +package org.springframework.ai.wenxin.aot; + +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.ai.wenxin.api.WenxinAudioApi; +import org.springframework.ai.wenxin.api.WenxinImageApi; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; + +/** + * @author lvchzh + * @date 2024年05月14日 下午4:50 + * @description: + */ +public class WenxinRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { + var mcs = MemberCategory.values(); + for (var tr : findJsonAnnotatedClassesInPackage(WenxinApi.class)) { + hints.reflection().registerType(tr, mcs); + } + for (var tr : findJsonAnnotatedClassesInPackage(WenxinAudioApi.class)) { + hints.reflection().registerType(tr, mcs); + } + for (var tr : findJsonAnnotatedClassesInPackage(WenxinImageApi.class)) { + hints.reflection().registerType(tr, mcs); + } + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/ApiUtils.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/ApiUtils.java new file mode 100644 index 00000000000..2ba4cfd6b3d --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/ApiUtils.java @@ -0,0 +1,102 @@ +package org.springframework.ai.wenxin.api; + +import org.apache.commons.codec.binary.Hex; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; + +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; +import java.nio.charset.StandardCharsets; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.time.Instant; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.util.Optional; +import java.util.function.Consumer; + +/** + * @author lvchzh + * @date 2024年05月22日 下午2:45 + * @description: ApiUtils + */ +public class ApiUtils { + + // @formatter:off + public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com"; + + public static final String DEFAULT_BASE_CHAT_URI = "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/"; + + public static final String DEFAULT_BASE_EMBEDDING_URI = "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/"; + + public static final String DEFAULT_HOST = "aip.baidubce.com"; + + private static final String EXPIRATION_PERIOD_IN_SECONDS = "1800"; + + private static final String HMAC_SHA256 = "HmacSHA256"; + + private static final DateTimeFormatter alternateIso8601DateFormat = DateTimeFormatter.ofPattern( + "yyyy-MM-dd'T'HH:mm:ss'Z'").withZone(ZoneOffset.UTC); + + public static Consumer getJsonContentHeaders() { + return (headers) -> headers.setContentType(MediaType.APPLICATION_JSON); + } + + public static String generationSignature(String accessKey, String secretKey, Instant timestamp, String modelName, String uri) { + var canonicalRequest = createCanonicalRequest(uri, modelName); + var authStringPrefix = createAuthStringPrefix(accessKey, timestamp); + var signingKey = hmacSha256Hex(secretKey, authStringPrefix); + var signature = hmacSha256Hex(signingKey, canonicalRequest.toString()); + return new StringBuilder() + .append(authStringPrefix) + .append("/host/") + .append(signature) + .toString(); + } + + private static String createAuthStringPrefix(String accessKey, Instant timestamp) { + return new StringBuilder() + .append("bce-auth-v1/").append(accessKey) + .append("/") + .append(formatDate(timestamp)) + .append("/") + .append(EXPIRATION_PERIOD_IN_SECONDS) + .toString(); + } + + private static StringBuilder createCanonicalRequest(String uri, String modelName) { + return new StringBuilder() + .append("POST") + .append("\n") + .append(uri).append(modelName) + .append("\n\n") + .append("host:").append(DEFAULT_HOST); + } + + private static String hmacSha256Hex(String secretKey, String authStringPrefix) { + try { + var mac = Mac.getInstance(HMAC_SHA256); + mac.init(new SecretKeySpec(secretKey.getBytes(StandardCharsets.UTF_8), HMAC_SHA256)); + return new String(Hex.encodeHex(mac.doFinal(authStringPrefix.getBytes(StandardCharsets.UTF_8)))); + } catch (NoSuchAlgorithmException | InvalidKeyException e) { + throw new RuntimeException("Failed to generate HMAC-SHA256 signature", e); + } + } + + private static Optional formatAlternateIso8601Date(Instant instant) { + if (instant == null) { + return Optional.empty(); + } + return Optional.of(alternateIso8601DateFormat.format(instant)); + } + + public static String formatDate(Instant instant) { + return formatAlternateIso8601Date(instant).orElseThrow(() -> new RuntimeException("Failed to format date")); + } + + public static String generationAuthorization(String accessKey, String secretKey, Instant timestamp, String model, String uri) { + return generationSignature(accessKey, secretKey, timestamp, model, uri); + } + // @formatter:on + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/CustomResponseErrorHandler.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/CustomResponseErrorHandler.java new file mode 100644 index 00000000000..f1c9e5efa73 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/CustomResponseErrorHandler.java @@ -0,0 +1,39 @@ +package org.springframework.ai.wenxin.api; + +import org.springframework.http.HttpStatus; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.StreamUtils; +import org.springframework.web.client.ResponseErrorHandler; + +import java.io.IOException; +import java.nio.charset.Charset; + +/** + * @author lvchzh + * @date 2024年05月24日 下午6:28 + * @description: + */ +public class CustomResponseErrorHandler implements ResponseErrorHandler { + + @Override + public boolean hasError(ClientHttpResponse httpResponse) throws IOException { + return (httpResponse.getStatusCode().value() == HttpStatus.Series.CLIENT_ERROR.value() + || httpResponse.getStatusCode().value() == HttpStatus.Series.SERVER_ERROR.value()); + } + + @Override + public void handleError(ClientHttpResponse httpResponse) throws IOException { + if (httpResponse.getStatusCode().value() == HttpStatus.Series.SERVER_ERROR.value()) { + // handle SERVER_ERROR + System.out.println("Server error: " + httpResponse.getStatusCode()); + } + else if (httpResponse.getStatusCode().value() == HttpStatus.Series.CLIENT_ERROR.value()) { + // handle CLIENT_ERROR + System.out.println("Client error: " + httpResponse.getStatusCode()); + } + // Print the response body + System.out + .println("Response body: " + StreamUtils.copyToString(httpResponse.getBody(), Charset.defaultCharset())); + } + +} \ No newline at end of file diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinApi.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinApi.java new file mode 100644 index 00000000000..48300b43751 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinApi.java @@ -0,0 +1,463 @@ +package org.springframework.ai.wenxin.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.boot.context.properties.bind.ConstructorBinding; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Predicate; + +/** + * @author lvchzh + * @date 2024年05月14日 下午4:51 + * @description: + */ +public class WenxinApi { + + // @formatter:off + public static final String DEFAULT_CHAT_MODEL = ChatModel.ERNIE_3_5_8K.getValue(); + + public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.Embedding_V1.getValue(); + + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; + + private final RestClient restClient; + + private final WebClient webClient; + + private final String accessKey; + + private final String secretKey; + + private WenxinStreamFunctionCallingHelper chunkMerger = new WenxinStreamFunctionCallingHelper(); + + public WenxinApi(String accessKey, String secretKey) { + this(ApiUtils.DEFAULT_BASE_URL, accessKey, secretKey); + } + + public WenxinApi(String baseUrl, String accessKey, String secretKey) { + this(baseUrl, RestClient.builder(), WebClient.builder(), accessKey, secretKey); + } + + public WenxinApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, + String accessKey, String secretKey) { + this(baseUrl, restClientBuilder, webClientBuilder, + RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER, + //new CustomResponseErrorHandler(), + accessKey, + secretKey); + } + + public WenxinApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, + ResponseErrorHandler responseErrorHandler, String accessKey, String secretKey) { + + this.restClient = restClientBuilder + .baseUrl(baseUrl) + .defaultHeaders(ApiUtils.getJsonContentHeaders()) + .defaultStatusHandler(responseErrorHandler) + .build(); + + this.webClient = webClientBuilder + .baseUrl(baseUrl) + .defaultHeaders(ApiUtils.getJsonContentHeaders()) + .build(); + + this.accessKey = accessKey; + this.secretKey = secretKey; + } + + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the steam property to false."); + + var timestamp = Instant.now(); + var authorization = ApiUtils.generationAuthorization(accessKey, secretKey, timestamp, chatRequest.model(), + ApiUtils.DEFAULT_BASE_CHAT_URI); + + return this.restClient.post() + .uri(ApiUtils.DEFAULT_BASE_CHAT_URI + chatRequest.model()) + .headers(headers -> { + headers.set("x-bce-date", ApiUtils.formatDate(timestamp)); + headers.set("Authorization", authorization); + }) + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + Instant timestamp = Instant.now(); + String authorization = ApiUtils.generationAuthorization(accessKey, secretKey, timestamp, chatRequest.model(), + ApiUtils.DEFAULT_BASE_CHAT_URI); + + return this.webClient.post() + .uri(ApiUtils.DEFAULT_BASE_CHAT_URI + chatRequest.model()) + .headers(headers -> { + headers.set("x-bce-date", ApiUtils.formatDate(timestamp)); + headers.set("Authorization", authorization); + }) + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + .concatMapIterable(window -> { + Mono monoChunk = window.reduce( + new ChatCompletionChunk(null, null, null, null, null, null, null, null, null, null, null, + null, null, null), + (previous, current) -> this.chunkMerger.merge(previous, current)); + return List.of(monoChunk); + }) + .flatMap(mono -> mono); + } + + public enum ChatModel { + + ERNIE_4_8K("completions_pro"), + ERNIE_4_8K_PREEMPTIVE("completions_pro_preemptive"), + ERNIE_4_8K_PREVIEW("ernie-4.0-8k-preview"), + ERNIE_4_8K_0329("ernie-4.0-8k-0329"), + ERNIE_4_8K_0104("ernie-4.0-8k-0104"), + ERNIE_3_5_8K("completions"), + ERNIE_3_5_8K_0205("ernie-3.5-8k-0205"), + ERNIE_3_5_8K_1222("ernie-3.5-8k-1222"), + ERNIE_3_5_4K_0205("ernie-3.5-4k-0205"), + ERNIE_3_5_8K_PREEMPTIVE("completions_preemptive"), + ERNIE_3_5_8K_Preview("ernie-3.5-8k-preview"), + ERNIE_3_5_8K_0329("ernie-3.5-8k-0329"); + + public final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + } + + public enum Role { + + @JsonProperty("user") USER, + @JsonProperty("assistant") ASSISTANT, + @JsonProperty("function") FUNCTION + + } + + public enum ChatCompletionFinishReason { + + @JsonProperty("normal") NORMAL, + @JsonProperty("stop") STOP, + @JsonProperty("length") LENGTH, + @JsonProperty("content_filter") CONTENT_FILTER, + @JsonProperty("function_call") FUNCTION_CALL + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record FunctionTool( + @JsonProperty("name") String name, + @JsonProperty("description") String description, + @JsonProperty("parameters") Map parameters, + @JsonProperty("responses") Map responses, + @JsonProperty("examples") List> examples) { + + @ConstructorBinding + public FunctionTool(String name, String description, String jsonSchemaForParameters, + List> examples) { + this(name, description, ModelOptionsUtils.jsonToMap(jsonSchemaForParameters), null, examples); + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Example( + @JsonProperty("role") Role role, + @JsonProperty("content") String content, + @JsonProperty("name") String name, + @JsonProperty("function_call") FunctionCall functionCall) { + } + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record FunctionCall( + @JsonProperty("name") String name, + @JsonProperty("arguments") String arguments, + @JsonProperty("thoughts") String thoughts) { + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionRequest( + @JsonProperty("messages") List messages, + @JsonProperty("model") String model, + @JsonProperty("penalty_score") Float penaltyScore, + @JsonProperty("max_output_tokens") Integer maxOutputTokens, + @JsonProperty("response_format") ResponseFormat responseFormat, + @JsonProperty("stop") List stop, + @JsonProperty("stream") Boolean stream, + @JsonProperty("temperature") Float temperature, + @JsonProperty("top_p") Float topP, + @JsonProperty("functions") List functions, + @JsonProperty("tool_choice") String toolChoice, + @JsonProperty("user_id") String userId, + @JsonProperty("system") String system, + @JsonProperty("disable_search") Boolean disableSearch, + @JsonProperty("enable_citation") Boolean enableCitation, + @JsonProperty("enable_trace") Boolean enableTrace) { + + public ChatCompletionRequest(List messages, String model, Float temperature) { + this(messages, model, null, null, null, null, false, temperature, null, null, null, null, null, false, + false, false); + } + + public ChatCompletionRequest(List messages, String model, Float temperature, + boolean stream) { + this(messages, model, null, null, null, null, stream, temperature, null, null, null, null, null, false, + false, false); + } + + public ChatCompletionRequest(List messages, String model, List tools, + String toolChoice, Boolean disableSearch) { + this(messages, model, null, null, null, null, false, 0.8f, null, tools, toolChoice, null, null, + disableSearch, false, false); + } + + public ChatCompletionRequest(List messages, Boolean stream) { + this(messages, DEFAULT_CHAT_MODEL, null, null, null, null, stream, null, null, null, null, null, null, + false, false, false); + } + + public enum ResponseFormat { + + @JsonProperty("text") TEXT, + @JsonProperty("json_object") JSON_OBJECT + + } + + public static class ToolChoiceBuilder { + + public static final String DEFAULT_TOOL_CHOICE = "auto"; + + public static final String NONE = "none"; + + public static String FUNCTION(String functionName) { + return ModelOptionsUtils.toJsonString( + Map.of( + "type", "function", + "function", + Map.of("name", functionName) + ) + ); + } + + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionMessage( + @JsonProperty("content") String content, + @JsonProperty("role") Role role, + @JsonProperty("name") String name, + @JsonProperty("function_call") FunctionCall functionCall) { + + public ChatCompletionMessage(String content, Role role) { + this(content, role, null, null); + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletion( + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("created") Long created, + @JsonProperty("sentence_id") String sentenceId, + @JsonProperty("is_end") Boolean isEnd, + @JsonProperty("is_truncated") Boolean isTruncated, + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("search_info") SearchInfo searchInfo, + @JsonProperty("result") String result, + @JsonProperty("need_clear_history") Boolean needClearHistory, + @JsonProperty("flag") Integer flag, + @JsonProperty("ban_round") Integer banRound, + @JsonProperty("usage") Usage usage, + @JsonProperty("function_call") FunctionCall functionCall) { + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record SearchInfo(@JsonProperty("search_results") List searchResults) { + + } + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record SearchResult( + @JsonProperty("index") Integer index, + @JsonProperty("url") String url, + @JsonProperty("title") String title) { + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Usage( + @JsonProperty("prompt_tokens") Integer promptTokens, + @JsonProperty("completion_tokens") Integer completionTokens, + @JsonProperty("total_tokens") Integer totalTokens, + @JsonProperty("plugins") List plugins) { + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record PluginUsage( + @JsonProperty("name") String name, + @JsonProperty("parse_tokens") Integer parseTokens, + @JsonProperty("abstract_tokens") Integer abstractTokens, + @JsonProperty("search_tokens") Integer searchTokens, + @JsonProperty("total_tokens") Integer totalTokens) { + + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionChunk( + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("created") Long created, + @JsonProperty("sentence_id") String sentenceId, + @JsonProperty("is_end") Boolean isEnd, + @JsonProperty("is_truncated") Boolean isTruncated, + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("search_info") ChatCompletion.SearchInfo searchInfo, + @JsonProperty("result") String result, + @JsonProperty("need_clear_history") Boolean needClearHistory, + @JsonProperty("flag") Integer flag, + @JsonProperty("ban_round") Integer banRound, + @JsonProperty("usage") Usage usage, + @JsonProperty("function_call") FunctionCall functionCall) { + + } + + // Embedding API + public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + Assert.notNull(embeddingRequest.input(), "The input can not be null."); + Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, + "The input must be either a String, or a List of Strings or List of List of integers."); + + if (embeddingRequest.input() instanceof List list) { + Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); + Assert.isTrue(list.size() <= 2048, "The list must be dimensions or less"); + Assert.isTrue(list.get(0) instanceof String || list.get(0) instanceof Integer + || list.get(0) instanceof List, + "The input must be either a String, or a list of Strings or list of integers."); + } + + Instant timestamp = Instant.now(); + String authorization = ApiUtils.generationAuthorization(accessKey, secretKey, timestamp, + embeddingRequest.model(), ApiUtils.DEFAULT_BASE_EMBEDDING_URI); + + return this.restClient.post() + .uri(ApiUtils.DEFAULT_BASE_EMBEDDING_URI + embeddingRequest.model()) + .headers(headers -> { + headers.set("x-bce-date", ApiUtils.formatDate(timestamp)); + headers.set("Authorization", authorization); + }) + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + }); + } + + public enum EmbeddingModel { + Embedding_V1("embedding-v1"), + BGE_LARGE_ZH("bge_large_zh"), + BGE_LARGE_EN("bge_large_en"), + TAO_8K("tao_8k"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Embedding( + @JsonProperty("index") Integer index, + @JsonProperty("embedding") List embedding, + @JsonProperty("object") String object) { + + public Embedding(Integer index, List embedding) { + this(index, embedding, "embedding"); + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record EmbeddingRequest( + @JsonProperty("input") T input, + @JsonProperty("model") String model, + @JsonProperty("user_id") String userId) { + + public EmbeddingRequest(T input, String model) { + this(input, model, null); + } + + public EmbeddingRequest(T input) { + this(input, DEFAULT_EMBEDDING_MODEL); + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record EmbeddingList( + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("created") Long created, + @JsonProperty("data") List data, + @JsonProperty("usage") Usage usage) { + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Usage( + @JsonProperty("prompt_tokens") Integer promptTokens, + @JsonProperty("total_tokens") Integer totalTokens) { + + } + } + // @formatter:on + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinAudioApi.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinAudioApi.java new file mode 100644 index 00000000000..45938d8e4cc --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinAudioApi.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.api; + +/** + * @author lvchzh + * @date 2024年05月14日 下午4:53 + * @description: + */ +public class WenxinAudioApi { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinImageApi.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinImageApi.java new file mode 100644 index 00000000000..1131192d524 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinImageApi.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.api; + +/** + * @author lvchzh + * @date 2024年05月14日 下午4:53 + * @description: + */ +public class WenxinImageApi { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinStreamFunctionCallingHelper.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinStreamFunctionCallingHelper.java new file mode 100644 index 00000000000..1e562333679 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinStreamFunctionCallingHelper.java @@ -0,0 +1,191 @@ +package org.springframework.ai.wenxin.api; + +import org.springframework.ai.wenxin.api.WenxinApi.ChatCompletionChunk; +import org.springframework.util.CollectionUtils; + +import java.util.ArrayList; +import java.util.List; + +/** + * @author lvchzh + * @date 2024年05月14日 下午4:52 + * @description: + */ +public class WenxinStreamFunctionCallingHelper { + + public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChunk current) { + if (previous == null) { + return current; + } + + String id = (current.id() != null ? current.id() : previous.id()); + String object = (current.object() != null ? current.object() : previous.object()); + Long created = (current.created() != null ? current.created() : previous.created()); + String sentenceId = (current.sentenceId() != null ? current.sentenceId() : previous.sentenceId()); + Boolean isEnd = (current.isEnd() != null ? current.isEnd() : previous.isEnd()); + Boolean isTruncated = (current.isTruncated() != null ? current.isTruncated() : previous.isTruncated()); + Boolean needClearHistory = (current.needClearHistory() != null ? current.needClearHistory() + : previous.needClearHistory()); + String result = (current.result() != null ? current.result() : previous.result()); + Integer flag = (current.flag() != null ? current.flag() : previous.flag()); + Integer banRound = (current.banRound() != null ? current.banRound() : previous.banRound()); + + WenxinApi.ChatCompletionFinishReason finishReason = (current.finishReason() != null ? current.finishReason() + : previous.finishReason()); + + WenxinApi.ChatCompletion.SearchInfo searchInfo = merge(previous.searchInfo(), current.searchInfo()); + + WenxinApi.FunctionCall functionCall = merge(previous.functionCall(), current.functionCall()); + + WenxinApi.Usage usage = merge(previous.usage(), current.usage()); + + return new ChatCompletionChunk(id, object, created, sentenceId, isEnd, isTruncated, finishReason, searchInfo, + result, needClearHistory, flag, banRound, usage, functionCall); + + } + + private WenxinApi.ChatCompletion.SearchInfo merge(WenxinApi.ChatCompletion.SearchInfo previous, + WenxinApi.ChatCompletion.SearchInfo current) { + if (previous == null) { + return current; + } + + List searchResults = new ArrayList<>(); + WenxinApi.SearchResult lastPreviousSearchResult = null; + if (previous.searchResults() != null) { + lastPreviousSearchResult = previous.searchResults().get(previous.searchResults().size() - 1); + if (previous.searchResults() != null) { + searchResults.addAll(previous.searchResults().subList(0, previous.searchResults().size() - 1)); + } + } + if (current.searchResults() != null) { + if (current.searchResults().size() > 1) { + throw new IllegalArgumentException("Currently only one tool call is supported per message!"); + } + var currentSearchResult = current.searchResults().iterator().next(); + if (currentSearchResult.index() != null) { + if (lastPreviousSearchResult != null) { + searchResults.add(lastPreviousSearchResult); + } + searchResults.add(currentSearchResult); + } + else { + searchResults.add(merge(lastPreviousSearchResult, currentSearchResult)); + } + } + else { + if (lastPreviousSearchResult != null) { + searchResults.add(lastPreviousSearchResult); + } + } + return new WenxinApi.ChatCompletion.SearchInfo(searchResults); + } + + private WenxinApi.SearchResult merge(WenxinApi.SearchResult previous, WenxinApi.SearchResult current) { + if (previous != null) { + return current; + } + + Integer id = current.index() != null ? current.index() : previous.index(); + String title = current.title() != null ? current.title() : previous.title(); + String url = current.url() != null ? current.url() : previous.url(); + + return new WenxinApi.SearchResult(id, title, url); + } + + private WenxinApi.FunctionCall merge(WenxinApi.FunctionCall previous, WenxinApi.FunctionCall current) { + if (previous == null) { + return current; + } + + String name = current.name() != null ? current.name() : previous.name(); + String thoughts = current.thoughts() != null ? current.thoughts() : previous.thoughts(); + StringBuilder arguments = new StringBuilder(); + if (previous.arguments() != null) { + arguments.append(previous.arguments()); + } + if (current.arguments() != null) { + arguments.append(current.arguments()); + } + + return new WenxinApi.FunctionCall(name, arguments.toString(), thoughts); + + } + + private WenxinApi.Usage merge(WenxinApi.Usage previous, WenxinApi.Usage current) { + if (previous == null) { + return current; + } + + Integer promptTokens = current.promptTokens() != null ? current.promptTokens() : previous.promptTokens(); + Integer completionTokens = current.completionTokens() != null ? current.completionTokens() + : previous.completionTokens(); + Integer totalTokens = current.totalTokens() != null ? current.totalTokens() : previous.totalTokens(); + + List plugins = new ArrayList<>(); + WenxinApi.Usage.PluginUsage lastPreviousPluginUsage = null; + if (previous.plugins() != null) { + lastPreviousPluginUsage = previous.plugins().get(previous.plugins().size() - 1); + if (previous.plugins().size() > 1) { + plugins.addAll(previous.plugins().subList(0, previous.plugins().size() - 1)); + } + } + if (current.plugins() != null) { + if (current.plugins().size() > 1) { + throw new IllegalArgumentException("Currently only one tool call is supported per message!"); + } + var currentPluginUsage = current.plugins().iterator().next(); + if (currentPluginUsage.name() != null) { + if (lastPreviousPluginUsage != null) { + plugins.add(lastPreviousPluginUsage); + } + plugins.add(currentPluginUsage); + } + else { + plugins.add(merge(lastPreviousPluginUsage, currentPluginUsage)); + } + } + else { + if (lastPreviousPluginUsage != null) { + plugins.add(lastPreviousPluginUsage); + } + } + return new WenxinApi.Usage(promptTokens, completionTokens, totalTokens, plugins); + } + + private WenxinApi.Usage.PluginUsage merge(WenxinApi.Usage.PluginUsage previous, + WenxinApi.Usage.PluginUsage current) { + if (previous == null) { + return current; + } + + String name = current.name() != null ? current.name() : previous.name(); + Integer parseTokens = current.parseTokens() != null ? current.parseTokens() : previous.parseTokens(); + Integer abstractTokens = current.abstractTokens() != null ? current.abstractTokens() + : previous.abstractTokens(); + Integer searchTokens = current.searchTokens() != null ? current.searchTokens() : previous.searchTokens(); + Integer totalTokens = current.totalTokens() != null ? current.totalTokens() : previous.totalTokens(); + + return new WenxinApi.Usage.PluginUsage(name, parseTokens, abstractTokens, searchTokens, totalTokens); + } + + public boolean isStreamingToolFunctionCall(ChatCompletionChunk chatCompletion) { + + if (chatCompletion == null || chatCompletion.functionCall() == null) { + return false; + } + + return true; + } + + public boolean isStreamingToolFunctionCallFinish(ChatCompletionChunk chatCompletion) { + + if (chatCompletion == null || chatCompletion.functionCall() == null) { + return false; + } + + return chatCompletion.finishReason() == WenxinApi.ChatCompletionFinishReason.FUNCTION_CALL; + + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiClientErrorException.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiClientErrorException.java new file mode 100644 index 00000000000..f5416bc008a --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiClientErrorException.java @@ -0,0 +1,18 @@ +package org.springframework.ai.wenxin.api.common; + +/** + * @author lvchzh + * @date 2024年05月14日 下午2:40 + * @description: + */ +public class WenxinApiClientErrorException extends RuntimeException { + + public WenxinApiClientErrorException(String message) { + super(message); + } + + public WenxinApiClientErrorException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiException.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiException.java new file mode 100644 index 00000000000..919f5a21583 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiException.java @@ -0,0 +1,18 @@ +package org.springframework.ai.wenxin.api.common; + +/** + * @author lvchzh + * @date 2024年05月14日 下午2:43 + * @description: + */ +public class WenxinApiException extends RuntimeException { + + public WenxinApiException(String message) { + super(message); + } + + public WenxinApiException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/Speech.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/Speech.java new file mode 100644 index 00000000000..95d857159e8 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/Speech.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.audio.speech; + +/** + * @author lvchzh + * @date 2024年05月27日 上午9:27 + * @description: + */ +public class Speech { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechMessage.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechMessage.java new file mode 100644 index 00000000000..bc30ed65338 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechMessage.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.audio.speech; + +/** + * @author lvchzh + * @date 2024年05月27日 上午9:27 + * @description: + */ +public class SpeechMessage { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechModel.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechModel.java new file mode 100644 index 00000000000..33284b0902b --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechModel.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.audio.speech; + +/** + * @author lvchzh + * @date 2024年05月27日 上午9:27 + * @description: + */ +public interface SpeechModel { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechPrompt.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechPrompt.java new file mode 100644 index 00000000000..f64be46a636 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechPrompt.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.audio.speech; + +/** + * @author lvchzh + * @date 2024年05月27日 上午9:28 + * @description: + */ +public class SpeechPrompt { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechResponse.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechResponse.java new file mode 100644 index 00000000000..9b713dc630c --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/SpeechResponse.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.audio.speech; + +/** + * @author lvchzh + * @date 2024年05月27日 上午9:30 + * @description: + */ +public class SpeechResponse { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/StreamingSpeechModel.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/StreamingSpeechModel.java new file mode 100644 index 00000000000..ac3505136a9 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/speech/StreamingSpeechModel.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.audio.speech; + +/** + * @author lvchzh + * @date 2024年05月27日 上午9:28 + * @description: + */ +public interface StreamingSpeechModel { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/transcription/AudioTranscription.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/transcription/AudioTranscription.java new file mode 100644 index 00000000000..df1e0b0d4b6 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/transcription/AudioTranscription.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.audio.transcription; + +/** + * @author lvchzh + * @date 2024年05月14日 下午4:54 + * @description: + */ +public class AudioTranscription { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/transcription/AudioTranscriptionPrompt.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/transcription/AudioTranscriptionPrompt.java new file mode 100644 index 00000000000..e4a5f4f8e8c --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/transcription/AudioTranscriptionPrompt.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.audio.transcription; + +/** + * @author lvchzh + * @date 2024年05月14日 下午4:54 + * @description: + */ +public class AudioTranscriptionPrompt { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/transcription/AudioTranscriptionResponse.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/transcription/AudioTranscriptionResponse.java new file mode 100644 index 00000000000..723565fdd5d --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/audio/transcription/AudioTranscriptionResponse.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.audio.transcription; + +/** + * @author lvchzh + * @date 2024年05月14日 下午4:55 + * @description: + */ +public class AudioTranscriptionResponse { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinChatResponseMetadata.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinChatResponseMetadata.java new file mode 100644 index 00000000000..59dc585b41f --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinChatResponseMetadata.java @@ -0,0 +1,74 @@ +package org.springframework.ai.wenxin.metadata; + +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.EmptyRateLimit; +import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import java.util.HashMap; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:02 + * @description: + */ +public class WenxinChatResponseMetadata extends HashMap implements ChatResponseMetadata { + + protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }"; + + private final String id; + + private final Usage usage; + + @Nullable + private RateLimit rateLimit; + + protected WenxinChatResponseMetadata(String id, WenxinUsage usage) { + this(id, usage, null); + } + + protected WenxinChatResponseMetadata(String id, WenxinUsage usage, @Nullable WenxinRateLimit rateLimit) { + this.id = id; + this.usage = usage; + this.rateLimit = rateLimit; + } + + public static WenxinChatResponseMetadata from(WenxinApi.ChatCompletion result) { + Assert.notNull(result, "Wenxin ChatCompletionResult must not be null"); + WenxinUsage usage = WenxinUsage.from(result.usage()); + WenxinChatResponseMetadata chatResponseMetadata = new WenxinChatResponseMetadata(result.id(), usage); + return chatResponseMetadata; + } + + public String getId() { + return this.id; + } + + @Override + @Nullable + public RateLimit getRateLimit() { + RateLimit rateLimit = this.rateLimit; + return rateLimit != null ? rateLimit : new EmptyRateLimit(); + } + + @Override + public Usage getUsage() { + Usage usage = this.usage; + return usage != null ? usage : new EmptyUsage(); + } + + public WenxinChatResponseMetadata withRateLimit(RateLimit rateLimit) { + this.rateLimit = rateLimit; + return this; + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getUsage(), getRateLimit()); + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinImageGenerationMetadata.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinImageGenerationMetadata.java new file mode 100644 index 00000000000..ec5d3f9223f --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinImageGenerationMetadata.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.metadata; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:03 + * @description: + */ +public class WenxinImageGenerationMetadata { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinImageResponseMetadata.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinImageResponseMetadata.java new file mode 100644 index 00000000000..f3fe886f9f6 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinImageResponseMetadata.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.metadata; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:04 + * @description: + */ +public class WenxinImageResponseMetadata { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinRateLimit.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinRateLimit.java new file mode 100644 index 00000000000..897a7978e98 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinRateLimit.java @@ -0,0 +1,69 @@ +package org.springframework.ai.wenxin.metadata; + +import org.springframework.ai.chat.metadata.RateLimit; + +import java.time.Duration; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:04 + * @description: + */ +public class WenxinRateLimit implements RateLimit { + + // @formatter:off + private static final String RATE_LIMIT_STRING = "{ @type: %1$s, requestsLimit: %2$s, requestsRemaining: %3$s, tokensLimit: %4$s, tokensRemaining: %5$s }"; + + private final Long requestsLimit; + + private final Long requestsRemaining; + + private final Long tokensLimit; + + private final Long tokensRemaining; + + public WenxinRateLimit(Long requestsLimit, Long requestsRemaining, Long tokensLimit, Long tokensRemaining) { + this.requestsLimit = requestsLimit; + this.requestsRemaining = requestsRemaining; + this.tokensLimit = tokensLimit; + this.tokensRemaining = tokensRemaining; + } + + @Override + public Long getRequestsLimit() { + return this.requestsLimit; + } + + @Override + public Long getRequestsRemaining() { + return this.requestsRemaining; + } + + @Override + public Duration getRequestsReset() { + throw new UnsupportedOperationException("unimplemented method 'getRequestsReset'"); + } + + @Override + public Long getTokensLimit() { + return this.tokensLimit; + } + + @Override + public Long getTokensRemaining() { + return this.tokensRemaining; + } + + @Override + public Duration getTokensReset() { + throw new UnsupportedOperationException("unimplemented method 'getTokensReset'"); + } + + @Override + public String toString() { + return RATE_LIMIT_STRING.formatted(getClass().getName(), getRequestsLimit(), getRequestsRemaining(), + getTokensLimit(), getTokensRemaining()); + } + // @formatter:on + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinUsage.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinUsage.java new file mode 100644 index 00000000000..def234ccdea --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinUsage.java @@ -0,0 +1,49 @@ +package org.springframework.ai.wenxin.metadata; + +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.util.Assert; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:05 + * @description: + */ +public class WenxinUsage implements Usage { + + private final WenxinApi.Usage usage; + + protected WenxinUsage(WenxinApi.Usage usage) { + Assert.notNull(usage, "Wenxin Usage must not be null"); + this.usage = usage; + } + + public static WenxinUsage from(WenxinApi.Usage usage) { + return new WenxinUsage(usage); + } + + protected WenxinApi.Usage getUsage() { + return this.usage; + } + + @Override + public Long getPromptTokens() { + return getUsage().promptTokens().longValue(); + } + + @Override + public Long getGenerationTokens() { + return getUsage().completionTokens().longValue(); + } + + @Override + public Long getTotalTokens() { + return getUsage().totalTokens().longValue(); + } + + @Override + public String toString() { + return getUsage().toString(); + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioSpeechMetadata.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioSpeechMetadata.java new file mode 100644 index 00000000000..ade38d648a6 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioSpeechMetadata.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.metadata.audio; + +/** + * @author lvchzh + * @date 2024年05月27日 上午9:29 + * @description: + */ +public interface WenxinAudioSpeechMetadata { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioSpeechResponseMetadata.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioSpeechResponseMetadata.java new file mode 100644 index 00000000000..71a0e960198 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioSpeechResponseMetadata.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.metadata.audio; + +/** + * @author lvchzh + * @date 2024年05月27日 上午9:31 + * @description: + */ +public class WenxinAudioSpeechResponseMetadata { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioTranscriptionMetadata.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioTranscriptionMetadata.java new file mode 100644 index 00000000000..87d10cb5a96 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioTranscriptionMetadata.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.metadata.audio; + +/** + * @author lvchzh + * @date 2024年05月14日 下午4:56 + * @description: + */ +public interface WenxinAudioTranscriptionMetadata { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioTransriptionResponseMetadata.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioTransriptionResponseMetadata.java new file mode 100644 index 00000000000..8d72029664a --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/audio/WenxinAudioTransriptionResponseMetadata.java @@ -0,0 +1,10 @@ +package org.springframework.ai.wenxin.metadata.audio; + +/** + * @author lvchzh + * @date 2024年05月14日 下午4:57 + * @description: + */ +public class WenxinAudioTransriptionResponseMetadata { + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinApiResponseHeaders.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinApiResponseHeaders.java new file mode 100644 index 00000000000..50aaa2dcf0c --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinApiResponseHeaders.java @@ -0,0 +1,34 @@ +package org.springframework.ai.wenxin.metadata.support; + +/** + * @author lvchzh + * @date 2024年05月14日 下午4:59 + * @description: + */ +public enum WenxinApiResponseHeaders { + + // @formatter:off + REQUESTS_LIMIT_HEADER("X-Ratelimit-Limit-Requests", "Total number of requests allowed within timeframe."), + TOKENS_LIMIT_HEADER("X-Ratelimit-Limit-Tokens", "Remaining number of tokens available in timeframe."), + REQUESTS_REMAINING_HEADER("X-Ratelimit-Remaining-Requests", "Remaining number of requests available in timeframe."), + TOKENS_REMAINING_HEADER("X-Ratelimit-Remaining-Tokens", "Duration of time until the number of tokens reset."); + // @formatter:on + + private String headerName; + + private String description; + + WenxinApiResponseHeaders(String headerName, String description) { + this.headerName = headerName; + this.description = description; + } + + public String getName() { + return this.headerName; + } + + public String getDescription() { + return this.description; + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinResponseHeaderExtractor.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinResponseHeaderExtractor.java new file mode 100644 index 00000000000..bc2e5458916 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinResponseHeaderExtractor.java @@ -0,0 +1,60 @@ +package org.springframework.ai.wenxin.metadata.support; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.wenxin.metadata.WenxinRateLimit; +import org.springframework.http.ResponseEntity; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +import static org.springframework.ai.wenxin.metadata.support.WenxinApiResponseHeaders.REQUESTS_LIMIT_HEADER; +import static org.springframework.ai.wenxin.metadata.support.WenxinApiResponseHeaders.REQUESTS_REMAINING_HEADER; +import static org.springframework.ai.wenxin.metadata.support.WenxinApiResponseHeaders.TOKENS_LIMIT_HEADER; +import static org.springframework.ai.wenxin.metadata.support.WenxinApiResponseHeaders.TOKENS_REMAINING_HEADER; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:00 + * @description: + */ +public class WenxinResponseHeaderExtractor { + + private static final Logger logger = LoggerFactory.getLogger(WenxinResponseHeaderExtractor.class); + + public static RateLimit extractAiResponseHeaders(ResponseEntity response) { + + Long requestsLimit = getHeaderAsLong(response, REQUESTS_LIMIT_HEADER.getName()); + Long requestRemaining = getHeaderAsLong(response, REQUESTS_REMAINING_HEADER.getName()); + Long tokensLimit = getHeaderAsLong(response, TOKENS_LIMIT_HEADER.getName()); + Long tokensRemaining = getHeaderAsLong(response, TOKENS_REMAINING_HEADER.getName()); + + return new WenxinRateLimit(requestsLimit, requestRemaining, tokensLimit, tokensRemaining); + } + + private static Long getHeaderAsLong(ResponseEntity response, String headerName) { + var headers = response.getHeaders(); + if (headers.containsKey(headerName)) { + var values = headers.get(headerName); + if (!CollectionUtils.isEmpty(values)) { + return parseLong(headerName, values.get(0)); + } + } + return null; + } + + private static Long parseLong(String headerName, String headerValue) { + + if (StringUtils.hasText(headerValue)) { + try { + return Long.valueOf(headerValue); + } + catch (NumberFormatException e) { + logger.warn("Value [{}] for HTTP header [{}] is not valid: {}", headerName, headerValue, + e.getMessage()); + } + } + return null; + } + +} diff --git a/models/spring-ai-wenxin/src/main/resources/META-INF/spring/aot.factories b/models/spring-ai-wenxin/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..f6a8c79b1ca --- /dev/null +++ b/models/spring-ai-wenxin/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ + org.springframework.ai.wenxin.aot.WenxinRuntimeHints \ No newline at end of file diff --git a/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/api/WenxinApiIT.java b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/api/WenxinApiIT.java new file mode 100644 index 00000000000..25dd85496a8 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/api/WenxinApiIT.java @@ -0,0 +1,45 @@ +package org.springframework.ai.wenxin.api; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.http.ResponseEntity; +import reactor.core.publisher.Flux; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author lvchzh + * @date 2024年05月22日 下午6:50 + * @description: + */ +@EnabledIfEnvironmentVariable(named = "ACCESS_KEY", matches = ".+") +@EnabledIfEnvironmentVariable(named = "SECRET_KEY", matches = ".+") +public class WenxinApiIT { + + WenxinApi wenxinApi = new WenxinApi(System.getenv("ACCESS_KEY"), System.getenv("SECRET_KEY")); + + @Test + void chatCompletionEntity() { + WenxinApi.ChatCompletionMessage chatCompletionMessage = new WenxinApi.ChatCompletionMessage("Tell me a joke", + WenxinApi.Role.USER); + ResponseEntity response = wenxinApi.chatCompletionEntity( + new WenxinApi.ChatCompletionRequest(List.of(chatCompletionMessage), "completions", 0.8f, false)); + System.out.println(response.getBody().result()); + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + } + + @Test + void chatCompletionStream() { + WenxinApi.ChatCompletionMessage chatCompletionMessage = new WenxinApi.ChatCompletionMessage("Tell me a joke", + WenxinApi.Role.USER); + Flux response = wenxinApi.chatCompletionStream( + new WenxinApi.ChatCompletionRequest(List.of(chatCompletionMessage), "completions", 0.8f, true)); + + assertThat(response).isNotNull(); + assertThat(response.collectList().block()).isNotNull(); + } + +} diff --git a/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/embedding/EmbeddingIT.java b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/embedding/EmbeddingIT.java new file mode 100644 index 00000000000..004e6c4e278 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/embedding/EmbeddingIT.java @@ -0,0 +1,30 @@ +package org.springframework.ai.wenxin.embedding; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.wenxin.WenxinEmbeddingModel; +import org.springframework.ai.wenxin.api.WenxinApi; + +import java.util.List; + +/** + * @author lvchzh + * @date 2024年05月30日 下午4:13 + * @description: + */ +@EnabledIfEnvironmentVariable(named = "ACCESS_KEY", matches = ".+") +@EnabledIfEnvironmentVariable(named = "SECRET_KEY", matches = ".+") +public class EmbeddingIT { + + WenxinEmbeddingModel wenxinEmbeddingModel = new WenxinEmbeddingModel( + new WenxinApi(System.getenv("ACCESS_KEY"), System.getenv("SECRET_KEY"))); + + @Test + void defaultEmbedding() { + EmbeddingResponse embeddingResponse = wenxinEmbeddingModel.embedForResponse(List.of("Hello World")); + System.out.println(embeddingResponse); + + } + +} diff --git a/pom.xml b/pom.xml index f9abde46034..b931fc811ac 100644 --- a/pom.xml +++ b/pom.xml @@ -64,6 +64,7 @@ models/spring-ai-ollama models/spring-ai-openai models/spring-ai-postgresml + models/spring-ai-wenxin models/spring-ai-stability-ai models/spring-ai-transformers models/spring-ai-vertex-ai-gemini @@ -85,6 +86,7 @@ spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2 spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai spring-ai-spring-boot-starters/spring-ai-starter-zhipuai + spring-ai-spring-boot-starters/spring-ai-starter-wenxin diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index 84615ecc2d4..b7dc8c2c1b4 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -130,6 +130,11 @@ ${project.version} + + org.springframework.ai + spring-ai-wenxin + ${project.version} + diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 84bd54a188a..fb7f0fdad9e 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -34,6 +34,13 @@ spring-boot-starter + + org.springframework.ai + spring-ai-wenxin + ${project.parent.version} + true + + org.springframework.ai spring-ai-openai diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAudioSpeechProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAudioSpeechProperties.java new file mode 100644 index 00000000000..8994219a2be --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAudioSpeechProperties.java @@ -0,0 +1,10 @@ +package org.springframework.ai.autoconfigure.wenxin; + +/** + * @author lvchzh + * @date 2024年05月27日 上午9:35 + * @description: + */ +public class WenxinAudioSpeechProperties { + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAudioTranscriptionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAudioTranscriptionProperties.java new file mode 100644 index 00000000000..d84679b0371 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAudioTranscriptionProperties.java @@ -0,0 +1,10 @@ +package org.springframework.ai.autoconfigure.wenxin; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:52 + * @description: + */ +public class WenxinAudioTranscriptionProperties { + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAutoConfiguration.java new file mode 100644 index 00000000000..81d4e85d80e --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAutoConfiguration.java @@ -0,0 +1,105 @@ +package org.springframework.ai.autoconfigure.wenxin; + +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.ai.wenxin.WenxinChatModel; +import org.springframework.ai.wenxin.WenxinEmbeddingModel; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.autoconfigure.web.reactive.function.client.WebClientAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +import java.util.List; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:54 + * @description: + */ +@AutoConfiguration(after = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class, + SpringAiRetryAutoConfiguration.class }) +@ConditionalOnClass(WenxinApi.class) +@EnableConfigurationProperties({ WenxinConnectionProperties.class, WenxinChatProperties.class, + WenxinEmbeddingProperties.class }) +@ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class, + WebClientAutoConfiguration.class }) +public class WenxinAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = WenxinChatProperties.CONFIG_PREFIX, name = "enable", havingValue = "true", + matchIfMissing = true) + public WenxinChatModel wenxinChatModel(WenxinConnectionProperties commonProperties, + WenxinChatProperties chatProperties, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, List toolFunctionCallbacks, + FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler) { + var wenxinApi = wenxinApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), + chatProperties.getAccessKey(), commonProperties.getAccessKey(), chatProperties.getSecretKey(), + commonProperties.getSecretKey(), restClientBuilder, webClientBuilder, responseErrorHandler); + + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { + chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); + } + + return new WenxinChatModel(wenxinApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate); + + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = WenxinEmbeddingProperties.CONFIG_PREFIX, name = "enable", havingValue = "true", + matchIfMissing = true) + public WenxinEmbeddingModel wenxinEmbeddingModel(WenxinConnectionProperties commonProperties, + WenxinEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler) { + var wenxinApi = wenxinApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), + embeddingProperties.getAccessKey(), commonProperties.getAccessKey(), embeddingProperties.getSecretKey(), + commonProperties.getSecretKey(), restClientBuilder, webClientBuilder, responseErrorHandler); + + return new WenxinEmbeddingModel(wenxinApi, embeddingProperties.getMetadataMode(), + embeddingProperties.getOptions(), retryTemplate); + } + + private WenxinApi wenxinApi(String chatBaseUrl, String commonBaseUrl, String accessKey, String commonAccessKey, + String secretKey, String commonSecretKey, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + + String resolvedChatBaseUrl = StringUtils.hasText(chatBaseUrl) ? chatBaseUrl : commonBaseUrl; + Assert.hasText(resolvedChatBaseUrl, "The Wenxin API base URL must be set!"); + + String resolvedAccessKey = StringUtils.hasText(secretKey) ? secretKey : commonAccessKey; + Assert.hasText(resolvedAccessKey, "The Wenxin API client ID must be set!"); + + String resolvedSecretKey = StringUtils.hasText(accessKey) ? accessKey : commonSecretKey; + Assert.hasText(resolvedSecretKey, "The Wenxin API client secret must be set!"); + + return new WenxinApi(resolvedChatBaseUrl, restClientBuilder, webClientBuilder, responseErrorHandler, + resolvedAccessKey, resolvedSecretKey); + } + + @Bean + @ConditionalOnMissingBean + public FunctionCallbackContext springAiFunctionCallbackContext(ApplicationContext context) { + FunctionCallbackContext manager = new FunctionCallbackContext(); + manager.setApplicationContext(context); + return manager; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinChatProperties.java new file mode 100644 index 00000000000..eeac97d2eb2 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinChatProperties.java @@ -0,0 +1,44 @@ +package org.springframework.ai.autoconfigure.wenxin; + +import org.springframework.ai.wenxin.WenxinChatOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:54 + * @description: + */ +@ConfigurationProperties(WenxinChatProperties.CONFIG_PREFIX) +public class WenxinChatProperties extends WenxinParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.wenxin.chat"; + + public static final String DEFAULT_CHAT_MODEL = "completions"; + + private static final Double DEFAULT_TEMPERATURE = 0.7; + + private boolean enable = true; + + //@formatter:off + private WenxinChatOptions options = WenxinChatOptions.builder() + .withModel(DEFAULT_CHAT_MODEL) + .withTemperature(DEFAULT_TEMPERATURE.floatValue()) + .build(); + //@formatter:on + public WenxinChatOptions getOptions() { + return options; + } + + public void setOptions(WenxinChatOptions options) { + this.options = options; + } + + public boolean isEnable() { + return enable; + } + + public void setEnable(boolean enable) { + this.enable = enable; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinConnectionProperties.java new file mode 100644 index 00000000000..65441067893 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinConnectionProperties.java @@ -0,0 +1,21 @@ +package org.springframework.ai.autoconfigure.wenxin; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:55 + * @description: + */ +@ConfigurationProperties(WenxinConnectionProperties.CONFIG_PREFIX) +public class WenxinConnectionProperties extends WenxinParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.wenxin"; + + public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com"; + + public WenxinConnectionProperties() { + super.setBaseUrl(DEFAULT_BASE_URL); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinEmbeddingProperties.java new file mode 100644 index 00000000000..769e25f1587 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinEmbeddingProperties.java @@ -0,0 +1,51 @@ +package org.springframework.ai.autoconfigure.wenxin; + +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.wenxin.WenxinEmbeddingOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:56 + * @description: + */ +@ConfigurationProperties(WenxinEmbeddingProperties.CONFIG_PREFIX) +public class WenxinEmbeddingProperties extends WenxinParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.wenxin.embedding"; + + public static final String DEFAULT_EMBEDDING_MODEL = "Embedding-V1"; + + private boolean enabled = true; + + private MetadataMode metadataMode = MetadataMode.EMBED; + + private WenxinEmbeddingOptions options = WenxinEmbeddingOptions.builder() + .withModel(DEFAULT_EMBEDDING_MODEL) + .build(); + + public WenxinEmbeddingOptions getOptions() { + return this.options; + } + + public void setOptions(WenxinEmbeddingOptions options) { + this.options = options; + } + + public MetadataMode getMetadataMode() { + return this.metadataMode; + } + + public void setMetadataMode(MetadataMode metadataMode) { + this.metadataMode = metadataMode; + } + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinImageProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinImageProperties.java new file mode 100644 index 00000000000..4d8d745b743 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinImageProperties.java @@ -0,0 +1,10 @@ +package org.springframework.ai.autoconfigure.wenxin; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:56 + * @description: + */ +public class WenxinImageProperties { + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinParentProperties.java new file mode 100644 index 00000000000..d644e1530f5 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinParentProperties.java @@ -0,0 +1,40 @@ +package org.springframework.ai.autoconfigure.wenxin; + +/** + * @author lvchzh + * @date 2024年05月14日 下午5:57 + * @description: + */ +public class WenxinParentProperties { + + private String baseUrl; + + private String accessKey; + + private String secretKey; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getAccessKey() { + return accessKey; + } + + public void setAccessKey(String accessKey) { + this.accessKey = accessKey; + } + + public String getSecretKey() { + return secretKey; + } + + public void setSecretKey(String secretKey) { + this.secretKey = secretKey; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index c744be669c2..4e6ec165713 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -1,4 +1,5 @@ org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration +org.springframework.ai.autoconfigure.wenxin.WenxinAutoConfiguration org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration org.springframework.ai.autoconfigure.stabilityai.StabilityAiImageAutoConfiguration org.springframework.ai.autoconfigure.transformers.TransformersEmbeddingModelAutoConfiguration diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-wenxin/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-wenxin/pom.xml new file mode 100644 index 00000000000..666d1545f91 --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-wenxin/pom.xml @@ -0,0 +1,42 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-wenxin-spring-boot-starter + jar + Spring AI Starter - Wenxin + Spring AI Wenxin Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-wenxin + ${project.parent.version} + + + + From 110dd67e8866ce3868e66e9b635f1c93699988c2 Mon Sep 17 00:00:00 2001 From: lvchzh Date: Thu, 6 Jun 2024 10:14:23 +0800 Subject: [PATCH 2/4] chore: modify the spring-ai-wenxin model name --- models/spring-ai-wenxin/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/spring-ai-wenxin/pom.xml b/models/spring-ai-wenxin/pom.xml index 0dd62468c9d..e5be0e750c5 100644 --- a/models/spring-ai-wenxin/pom.xml +++ b/models/spring-ai-wenxin/pom.xml @@ -10,7 +10,7 @@ spring-ai-wenxin jar - Spring AI Wenxin + Spring AI Model - Wenxin Wenxin support https://github.com/spring-projects/spring-ai From 71e8cee64ea7f2df7b703d8820eb40391b8cfdd2 Mon Sep 17 00:00:00 2001 From: lvchzh Date: Thu, 25 Jul 2024 22:48:32 +0800 Subject: [PATCH 3/4] fix bug #1118 & #1117 --- .../ai/chat/messages/AbstractMessage.java | 5 ++++- .../ai/chat/messages/AssistantMessage.java | 5 ++++- .../ai/chat/messages/SystemMessage.java | 5 ++++- .../ai/chat/messages/ToolResponseMessage.java | 5 ++++- .../ai/chat/messages/UserMessage.java | 5 ++++- .../springframework/ai/model/ModelOptionsUtils.java | 5 +++-- .../ai/model/function/AbstractFunctionCallback.java | 11 +++++++---- .../ai/model/function/FunctionCallbackWrapper.java | 12 ++++++++---- .../ai/model/function/TypeResolverHelper.java | 9 +++++---- .../ai/model/function/TypeResolverHelperTests.java | 7 ++++--- 10 files changed, 47 insertions(+), 22 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java index 05a89117c6b..7a24698c021 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.io.InputStream; +import java.io.Serializable; import java.nio.charset.Charset; import java.util.HashMap; import java.util.Map; @@ -33,7 +34,9 @@ * * @see Message */ -public abstract class AbstractMessage implements Message { +public abstract class AbstractMessage implements Message, Serializable { + + private static final long serialVersionUID = 1L; public static final String MESSAGE_TYPE = "messageType"; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java index beac0344c11..c0d946567c9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.chat.messages; +import java.io.Serializable; import java.util.List; import java.util.Map; import java.util.Objects; @@ -31,7 +32,9 @@ * @author Christian Tzolov * @since 1.0.0 */ -public class AssistantMessage extends AbstractMessage { +public class AssistantMessage extends AbstractMessage implements Serializable { + + private static final long serialVersionUID = 1L; public record ToolCall(String id, String type, String name, String arguments) { } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java index ddcff796678..936479b868e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.chat.messages; +import java.io.Serializable; import java.util.Map; import java.util.Objects; @@ -27,7 +28,9 @@ * generative to behave like a certain character or to provide answers in a specific * format. */ -public class SystemMessage extends AbstractMessage { +public class SystemMessage extends AbstractMessage implements Serializable { + + private static final long serialVersionUID = 1L; public SystemMessage(String textContent) { super(MessageType.SYSTEM, textContent, Map.of()); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java index 42f91f9df54..242fbb17d7f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.chat.messages; +import java.io.Serializable; import java.util.List; import java.util.Map; import java.util.Objects; @@ -26,7 +27,9 @@ * @author Christian Tzolov * @since 1.0.0 */ -public class ToolResponseMessage extends AbstractMessage { +public class ToolResponseMessage extends AbstractMessage implements Serializable { + + private static final long serialVersionUID = 1L; public record ToolResponse(String id, String name, String responseData) { }; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index 53c32425722..0a3a63cf16c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.chat.messages; +import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -31,7 +32,9 @@ * end-user or developer. They represent questions, prompts, or any input that you want * the generative to respond to. */ -public class UserMessage extends AbstractMessage implements MediaContent { +public class UserMessage extends AbstractMessage implements MediaContent, Serializable { + + private static final long serialVersionUID = 1L; protected final List media; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index d9b2bda508c..08268b919b1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -17,6 +17,7 @@ import java.beans.PropertyDescriptor; import java.lang.reflect.Field; +import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -331,7 +332,7 @@ private static String toGetName(String name) { * @param toUpperCaseTypeValues if true, the type values are converted to upper case. * @return the generated JSON Schema as a String. */ - public static String getJsonSchema(Class clazz, boolean toUpperCaseTypeValues) { + public static String getJsonSchema(Type type, boolean toUpperCaseTypeValues) { if (SCHEMA_GENERATOR_CACHE.get() == null) { @@ -350,7 +351,7 @@ public static String getJsonSchema(Class clazz, boolean toUpperCaseTypeValues SCHEMA_GENERATOR_CACHE.compareAndSet(null, generator); } - ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(clazz); + ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(type); if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI // version of it). toUpperCaseTypeValues(node); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java index 6bd639c883e..f9aaac5afc0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java @@ -15,9 +15,11 @@ */ package org.springframework.ai.model.function; +import java.lang.reflect.Type; import java.util.function.Function; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JavaType; import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.util.Assert; @@ -43,7 +45,7 @@ abstract class AbstractFunctionCallback implements Function, Functio private final String description; - private final Class inputType; + private final Type inputType; private final String inputTypeSchema; @@ -66,7 +68,7 @@ abstract class AbstractFunctionCallback implements Function, Functio * @param objectMapper Used to convert the function's input and output types to and * from JSON. */ - protected AbstractFunctionCallback(String name, String description, String inputTypeSchema, Class inputType, + protected AbstractFunctionCallback(String name, String description, String inputTypeSchema, Type inputType, Function responseConverter, ObjectMapper objectMapper) { Assert.notNull(name, "Name must not be null"); Assert.notNull(description, "Description must not be null"); @@ -107,9 +109,10 @@ public String call(String functionArguments) { return this.andThen(this.responseConverter).apply(request); } - private T fromJson(String json, Class targetClass) { + private T fromJson(String json, Type targetClass) { try { - return this.objectMapper.readValue(json, targetClass); + JavaType javaType = objectMapper.getTypeFactory().constructType(targetClass); + return this.objectMapper.readValue(json, javaType); } catch (JsonProcessingException e) { throw new RuntimeException(e); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java index bfa9c9c3c28..337ea13a639 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.model.function; +import java.lang.reflect.Type; import java.util.function.Function; import com.fasterxml.jackson.databind.DeserializationFeature; @@ -36,7 +37,7 @@ public class FunctionCallbackWrapper extends AbstractFunctionCallback function; - private FunctionCallbackWrapper(String name, String description, String inputTypeSchema, Class inputType, + private FunctionCallbackWrapper(String name, String description, String inputTypeSchema, Type inputType, Function responseConverter, ObjectMapper objectMapper, Function function) { super(name, description, inputTypeSchema, inputType, responseConverter, objectMapper); Assert.notNull(function, "Function must not be null"); @@ -44,8 +45,8 @@ private FunctionCallbackWrapper(String name, String description, String inputTyp } @SuppressWarnings("unchecked") - private static Class resolveInputType(Function function) { - return (Class) TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); + private static Type resolveInputType(Function function) { + return TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); } @Override @@ -69,7 +70,9 @@ public enum SchemaType { private String description; - private Class inputType; + // private Class inputType; + + private Type inputType; private final Function function; @@ -141,6 +144,7 @@ public FunctionCallbackWrapper build() { Assert.notNull(this.objectMapper, "ObjectMapper must not be null"); if (this.inputType == null) { + // this.inputType = resolveInputType(this.function); this.inputType = resolveInputType(this.function); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java index 1fa0736d3eb..62db288b748 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java @@ -37,7 +37,7 @@ public abstract class TypeResolverHelper { * @param functionClass The function class. * @return The input class of the function. */ - public static Class getFunctionInputClass(Class> functionClass) { + public static Type getFunctionInputClass(Class> functionClass) { return getFunctionArgumentClass(functionClass, 0); } @@ -46,7 +46,7 @@ public static Class getFunctionInputClass(Class> fun * @param functionClass The function class. * @return The output class of the function. */ - public static Class getFunctionOutputClass(Class> functionClass) { + public static Type getFunctionOutputClass(Class> functionClass) { return getFunctionArgumentClass(functionClass, 1); } @@ -56,13 +56,14 @@ public static Class getFunctionOutputClass(Class> fu * @param argumentIndex The index of the argument whose class should be retrieved. * @return The class of the specified function argument. */ - public static Class getFunctionArgumentClass(Class> functionClass, int argumentIndex) { + public static Type getFunctionArgumentClass(Class> functionClass, int argumentIndex) { Type type = TypeResolver.reify(Function.class, functionClass); var argumentType = type instanceof ParameterizedType ? ((ParameterizedType) type).getActualTypeArguments()[argumentIndex] : Object.class; - return toRawClass(argumentType); + // return toRawClass(argumentType); + return argumentType; } /** diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java index 76622a22281..8f81696953b 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.model.function; +import java.lang.reflect.Type; import java.util.function.Function; import com.fasterxml.jackson.annotation.JsonClassDescription; @@ -36,20 +37,20 @@ public class TypeResolverHelperTests { @Test public void testGetFunctionInputType() { - Class inputType = TypeResolverHelper.getFunctionInputClass(MockWeatherService.class); + Type inputType = TypeResolverHelper.getFunctionInputClass(MockWeatherService.class); assertThat(inputType).isEqualTo(Request.class); } @Test public void testGetFunctionOutputType() { - Class outputType = TypeResolverHelper.getFunctionOutputClass(MockWeatherService.class); + Type outputType = TypeResolverHelper.getFunctionOutputClass(MockWeatherService.class); assertThat(outputType).isEqualTo(Response.class); } @Test public void testGetFunctionInputTypeForInstance() { MockWeatherService service = new MockWeatherService(); - Class inputType = TypeResolverHelper.getFunctionInputClass(service.getClass()); + Type inputType = TypeResolverHelper.getFunctionInputClass(service.getClass()); assertThat(inputType).isEqualTo(Request.class); } From f8f2d00410ed145f5c3e8161f63d197ee9b994c8 Mon Sep 17 00:00:00 2001 From: lvchzh Date: Wed, 21 Aug 2024 13:42:40 +0800 Subject: [PATCH 4/4] fix compile error --- .../ai/wenxin/WenxinChatOptions.java | 20 +++++++++++++++++++ .../ai/wenxin/WenxinEmbeddingModel.java | 4 ++-- .../ai/wenxin/WenxinEmbeddingOptions.java | 5 +++++ .../ai/wenxin/api/WenxinApi.java | 4 ++-- 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatOptions.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatOptions.java index bb783adbe63..47fb3a8b5f7 100644 --- a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatOptions.java +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatOptions.java @@ -90,6 +90,26 @@ public String getModel() { return model; } + @Override + public Float getFrequencyPenalty() { + return 0f; + } + + @Override + public Integer getMaxTokens() { + return 0; + } + + @Override + public Float getPresencePenalty() { + return 0f; + } + + @Override + public List getStopSequences() { + return List.of(); + } + public void setModel(String model) { this.model = model; } diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingModel.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingModel.java index c5f4f761266..833af199a3a 100644 --- a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingModel.java +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingModel.java @@ -64,7 +64,7 @@ public WenxinEmbeddingModel(WenxinApi wenxinApi, MetadataMode metadataMode, Wenx } @Override - public List embed(Document document) { + public float[] embed(Document document) { Assert.notNull(document, "Document must not be null"); return this.embed(document.getFormattedContent(this.metadataMode)); } @@ -80,7 +80,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { this.defaultOptions.getUserId()) : new WenxinApi.EmbeddingRequest<>(request.getInstructions(), WenxinApi.DEFAULT_EMBEDDING_MODEL); - if (request.getOptions() != null && !EmbeddingOptions.EMPTY.equals(request.getOptions())) { + if (request.getOptions() != null) { apiRequest = ModelOptionsUtils.merge(request.getOptions(), apiRequest, WenxinApi.EmbeddingRequest.class); } diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingOptions.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingOptions.java index b68b41ddf4d..88700ece4e8 100644 --- a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingOptions.java +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingOptions.java @@ -21,6 +21,11 @@ public String getModel() { return this.model; } + @Override + public Integer getDimensions() { + return 0; + } + public void setModel(String model) { this.model = model; } diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinApi.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinApi.java index 0e18ba36948..c313beb4526 100644 --- a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinApi.java +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinApi.java @@ -418,10 +418,10 @@ public String getValue() { @JsonInclude(JsonInclude.Include.NON_NULL) public record Embedding( @JsonProperty("index") Integer index, - @JsonProperty("embedding") List embedding, + @JsonProperty("embedding")float[] embedding, @JsonProperty("object") String object) { - public Embedding(Integer index, List embedding) { + public Embedding(Integer index, float[] embedding) { this(index, embedding, "embedding"); } }