Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,13 @@ private ZhiPuAiEmbeddingOptions mergeOptions(@Nullable EmbeddingOptions runtimeO

return ZhiPuAiEmbeddingOptions.builder()
.model(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
.withDimensions(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getDimensions(),
defaultOptions.getDimensions()))
.build();
}

private ZhiPuAiApi.EmbeddingRequest<String> createEmbeddingRequest(String text, EmbeddingOptions requestOptions) {
return new ZhiPuAiApi.EmbeddingRequest<>(text, requestOptions.getModel());
return new ZhiPuAiApi.EmbeddingRequest<>(text, requestOptions.getModel(), requestOptions.getDimensions());
}

public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.springframework.ai.zhipuai;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
Expand All @@ -34,6 +33,12 @@
@JsonInclude(Include.NON_NULL)
public class ZhiPuAiEmbeddingOptions implements EmbeddingOptions {

/**
* The default vector dimension is 2048. The model supports custom vector dimensions,
* and it is recommended to choose dimensions of 256, 512, 1024, or 2048.
*/
private @JsonProperty("dimensions") Integer dimensions;

// @formatter:off
/**
* ID of the model to use.
Expand All @@ -55,9 +60,12 @@ public void setModel(String model) {
}

@Override
@JsonIgnore
public Integer getDimensions() {
return null;
return this.dimensions;
}

public void setDimensions(Integer dimensions) {
this.dimensions = dimensions;
}

public static class Builder {
Expand All @@ -73,6 +81,11 @@ public Builder model(String model) {
return this;
}

public Builder withDimensions(Integer dimensions) {
this.options.setDimensions(dimensions);
return this;
}

public ZhiPuAiEmbeddingOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,10 @@

package org.springframework.ai.zhipuai.api;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Predicate;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
Expand All @@ -43,6 +32,16 @@
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Predicate;

// @formatter:off
/**
Expand Down Expand Up @@ -283,7 +282,14 @@ public enum ChatCompletionFinishReason {
* Only for compatibility with Mistral AI API.
*/
@JsonProperty("tool_call")
TOOL_CALL
TOOL_CALL,
/**
* 'Sensitive' means that the content has been intercepted by the security
* audit interface (users should judge and decide whether to withdraw
* public content)
*/
@JsonProperty("sensitive")
SENSITIVE
}

/**
Expand All @@ -295,7 +301,12 @@ public enum EmbeddingModel {
/**
* DIMENSION: 1024
*/
Embedding_2("Embedding-2");
Embedding_2("Embedding-2"),

/**
* DIMENSION: 2048
*/
Embedding_3("Embedding-3");

public final String value;

Expand Down Expand Up @@ -956,15 +967,25 @@ public String toString() {
@JsonInclude(Include.NON_NULL)
public record EmbeddingRequest<T>(
@JsonProperty("input") T input,
@JsonProperty("model") String model) {

@JsonProperty("model") String model,

@JsonProperty("dimensions") Integer dimensions) {

/**
* Create an embedding request with the given input. Encoding model is set to 'embedding-2'.
* @param input Input text to embed.
*/
public EmbeddingRequest(T input) {
this(input, DEFAULT_EMBEDDING_MODEL);
this(input,DEFAULT_EMBEDDING_MODEL,null);
}

/**
* Create an embedding request with the given input. Encoding model is set to 'embedding-3'.
* @param input Input text to embed.
*/
public EmbeddingRequest(T input, Integer dimensions) {
this(input,EmbeddingModel.Embedding_3.getValue(),dimensions);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.zhipuai;

import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.util.StringUtils;

/**
* @author Yee Pan
*/
@SpringBootConfiguration
public class ZhiPuAiEmbeddingModleTestConfiguration {

@Bean
public ZhiPuAiApi zhiPuAiApi() {
return new ZhiPuAiApi(getApiKey());
}

@Bean
public ZhiPuAiImageApi zhiPuAiImageApi() {
return new ZhiPuAiImageApi(getApiKey());
}

private String getApiKey() {
String apiKey = System.getenv("ZHIPU_AI_API_KEY");
if (!StringUtils.hasText(apiKey)) {
throw new IllegalArgumentException(
"You must provide an API key. Put it in an environment variable under the name ZHIPU_AI_API_KEY");
}
return apiKey;
}

@Bean
public EmbeddingModel zhiPuAiEmbeddingModel(ZhiPuAiApi api) {
return new ZhiPuAiEmbeddingModel(api, MetadataMode.EMBED,
ZhiPuAiEmbeddingOptions.builder()
.model(ZhiPuAiApi.EmbeddingModel.Embedding_3.getValue())
.withDimensions(1024)
.build());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,14 @@ void embeddings() {
assertThat(response.getBody().data().get(0).embedding()).hasSize(1024);
}

@Test
void embeddings3() {
ResponseEntity<EmbeddingList<Embedding>> response = this.zhiPuAiApi
.embeddings(new ZhiPuAiApi.EmbeddingRequest<>("Hello world", 1024));

assertThat(response).isNotNull();
assertThat(Objects.requireNonNull(response.getBody()).data()).hasSize(1);
assertThat(response.getBody().data().get(0).embedding()).hasSize(1024);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.zhipuai.embedding3;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.zhipuai.ZhiPuAiEmbeddingModel;
import org.springframework.ai.zhipuai.ZhiPuAiEmbeddingModleTestConfiguration;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

/**
* <a href="https://bigmodel.cn/dev/api/vector/embedding-3">zhipuai vector model
* embedding-3 doc</a>
*
* @author Yee Pan
*/
@SpringBootTest(classes = ZhiPuAiEmbeddingModleTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "ZHIPU_AI_API_KEY", matches = ".+")
class Embedding3IT {

@Autowired
private ZhiPuAiEmbeddingModel embeddingModel;

@Test
void defaultEmbedding() {
assertThat(this.embeddingModel).isNotNull();

EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World"));

assertThat(embeddingResponse.getResults()).hasSize(1);

assertThat(embeddingResponse.getResults().get(0)).isNotNull();
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);

assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
}

@Test
void batchEmbedding() {
assertThat(this.embeddingModel).isNotNull();

EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World", "HI"));

assertThat(embeddingResponse.getResults()).hasSize(2);

assertThat(embeddingResponse.getResults().get(0)).isNotNull();
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);

assertThat(embeddingResponse.getResults().get(1)).isNotNull();
assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1024);

assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
}

}
Loading