diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/pom.xml b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/pom.xml index 3dde6f19d25..3ee1253e9fa 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/pom.xml +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/pom.xml @@ -35,12 +35,24 @@ + + org.springframework.ai + spring-ai-autoconfigure-model-tool + ${project.parent.version} + + org.springframework.ai spring-ai-autoconfigure-model-chat-observation ${project.parent.version} + + org.springframework.ai + spring-ai-autoconfigure-retry + ${project.parent.version} + + org.springframework.boot @@ -48,6 +60,12 @@ true + + org.springframework.boot + spring-boot-starter-restclient + true + + org.springframework.boot spring-boot-configuration-processor diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceApiAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceApiAutoConfiguration.java new file mode 100644 index 00000000000..c863c77f6c7 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceApiAutoConfiguration.java @@ -0,0 +1,61 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.huggingface.autoconfigure; + +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; +import org.springframework.context.annotation.Bean; + +/** + * {@link AutoConfiguration Auto-configuration} for HuggingFace API. + * + * @author Myeongdeok Kang + */ +@AutoConfiguration(after = RestClientAutoConfiguration.class) +@ConditionalOnClass(HuggingfaceApi.class) +@EnableConfigurationProperties(HuggingfaceConnectionProperties.class) +public class HuggingfaceApiAutoConfiguration { + + @Bean + @ConditionalOnMissingBean(HuggingfaceConnectionDetails.class) + PropertiesHuggingfaceConnectionDetails huggingfaceConnectionDetails(HuggingfaceConnectionProperties properties) { + return new PropertiesHuggingfaceConnectionDetails(properties); + } + + // This bean is no longer created here since Chat and Embedding + // need different base URLs. Each AutoConfiguration creates its own API instance. + + static class PropertiesHuggingfaceConnectionDetails implements HuggingfaceConnectionDetails { + + private final HuggingfaceConnectionProperties properties; + + PropertiesHuggingfaceConnectionDetails(HuggingfaceConnectionProperties properties) { + this.properties = properties; + } + + @Override + public String getApiKey() { + return this.properties.getApiKey(); + } + + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatAutoConfiguration.java index 8b8526d9c2a..b294f1ac773 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,27 +16,101 @@ package org.springframework.ai.model.huggingface.autoconfigure; +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.huggingface.HuggingfaceChatModel; +import org.springframework.ai.huggingface.api.HuggingfaceApi; import org.springframework.ai.model.SpringAIModelProperties; import org.springframework.ai.model.SpringAIModels; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration; +import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.context.annotation.Bean; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; -@AutoConfiguration +/** + * {@link AutoConfiguration Auto-configuration} for HuggingFace Chat Model. + * + * @author Mark Pollack + * @author Josh Long + * @author Soby Chacko + * @author Ilayaperumal Gopinathan + * @author Myeongdeok Kang + */ +@AutoConfiguration(after = { RestClientAutoConfiguration.class, HuggingfaceApiAutoConfiguration.class, + SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class }) @ConditionalOnClass(HuggingfaceChatModel.class) -@EnableConfigurationProperties(HuggingfaceChatProperties.class) @ConditionalOnProperty(name = SpringAIModelProperties.CHAT_MODEL, havingValue = SpringAIModels.HUGGINGFACE, matchIfMissing = true) +@EnableConfigurationProperties({ HuggingfaceConnectionProperties.class, HuggingfaceChatProperties.class }) public class HuggingfaceChatAutoConfiguration { + @Bean + @ConditionalOnMissingBean(name = "huggingfaceChatApi") + @ConditionalOnProperty(prefix = HuggingfaceChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public HuggingfaceApi huggingfaceChatApi(HuggingfaceConnectionDetails connectionDetails, + HuggingfaceChatProperties chatProperties, ObjectProvider restClientBuilderProvider, + ObjectProvider responseErrorHandlerProvider) { + + String apiKey = connectionDetails.getApiKey(); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "HuggingFace API key must be set. Please configure spring.ai.huggingface.api-key"); + } + + RestClient.Builder restClientBuilder = restClientBuilderProvider.getIfAvailable(RestClient::builder); + ResponseErrorHandler responseErrorHandler = responseErrorHandlerProvider.getIfAvailable(() -> null); + + HuggingfaceApi.Builder apiBuilder = HuggingfaceApi.builder() + .baseUrl(chatProperties.getUrl()) + .apiKey(apiKey) + .restClientBuilder(restClientBuilder); + + if (responseErrorHandler != null) { + apiBuilder.responseErrorHandler(responseErrorHandler); + } + + return apiBuilder.build(); + } + @Bean @ConditionalOnMissingBean - public HuggingfaceChatModel huggingfaceChatModel(HuggingfaceChatProperties huggingfaceChatProperties) { - return new HuggingfaceChatModel(huggingfaceChatProperties.getApiKey(), huggingfaceChatProperties.getUrl()); + @ConditionalOnProperty(prefix = HuggingfaceChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public HuggingfaceChatModel huggingfaceChatModel(@Qualifier("huggingfaceChatApi") HuggingfaceApi huggingfaceApi, + HuggingfaceChatProperties chatProperties, ToolCallingManager toolCallingManager, + ObjectProvider observationRegistry, + ObjectProvider observationConvention, RetryTemplate retryTemplate, + ObjectProvider huggingfaceToolExecutionEligibilityPredicate) { + + var chatModel = HuggingfaceChatModel.builder() + .huggingfaceApi(huggingfaceApi) + .defaultOptions(chatProperties.getOptions()) + .toolCallingManager(toolCallingManager) + .toolExecutionEligibilityPredicate(huggingfaceToolExecutionEligibilityPredicate + .getIfUnique(DefaultToolExecutionEligibilityPredicate::new)) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .retryTemplate(retryTemplate) + .build(); + + observationConvention.ifAvailable(chatModel::setObservationConvention); + + return chatModel; } } diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatProperties.java index 5b2004a88b8..0dc0189bc73 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatProperties.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatProperties.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,11 @@ package org.springframework.ai.model.huggingface.autoconfigure; +import org.springframework.ai.huggingface.HuggingfaceChatOptions; +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.huggingface.api.common.HuggingfaceApiConstants; import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; /** * Configuration properties for Hugging Face chat model. @@ -25,6 +29,7 @@ * @author Josh Long * @author Mark Pollack * @author Thomas Vitale + * @author Myeongdeok Kang */ @ConfigurationProperties(HuggingfaceChatProperties.CONFIG_PREFIX) public class HuggingfaceChatProperties { @@ -32,21 +37,31 @@ public class HuggingfaceChatProperties { public static final String CONFIG_PREFIX = "spring.ai.huggingface.chat"; /** - * API Key to authenticate with the Inference Endpoint. + * Enable HuggingFace chat model autoconfiguration. */ - private String apiKey; + private boolean enabled = true; /** - * URL of the Inference Endpoint. + * Base URL for the HuggingFace Chat API (OpenAI-compatible endpoint). */ - private String url; + private String url = HuggingfaceApiConstants.DEFAULT_CHAT_BASE_URL; - public String getApiKey() { - return this.apiKey; + /** + * Client-level HuggingFace chat options. Use this property to configure the model, + * temperature, max_tokens, and other parameters. Null values are ignored, defaulting + * to the API defaults. + */ + @NestedConfigurationProperty + private final HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model(HuggingfaceApi.DEFAULT_CHAT_MODEL) + .build(); + + public boolean isEnabled() { + return this.enabled; } - public void setApiKey(String apiKey) { - this.apiKey = apiKey; + public void setEnabled(boolean enabled) { + this.enabled = enabled; } public String getUrl() { @@ -57,4 +72,16 @@ public void setUrl(String url) { this.url = url; } + public String getModel() { + return this.options.getModel(); + } + + public void setModel(String model) { + this.options.setModel(model); + } + + public HuggingfaceChatOptions getOptions() { + return this.options; + } + } diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceConnectionDetails.java b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceConnectionDetails.java new file mode 100644 index 00000000000..9812c68e678 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceConnectionDetails.java @@ -0,0 +1,32 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.huggingface.autoconfigure; + +/** + * Details required to establish a connection to HuggingFace Inference API. + * + * @author Myeongdeok Kang + */ +public interface HuggingfaceConnectionDetails { + + /** + * The API key for authenticating with HuggingFace. + * @return the API key + */ + String getApiKey(); + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceConnectionProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceConnectionProperties.java new file mode 100644 index 00000000000..2fadbb5578f --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceConnectionProperties.java @@ -0,0 +1,45 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.huggingface.autoconfigure; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * HuggingFace connection configuration properties. These properties are shared across all + * HuggingFace model types (chat, embedding, etc.). + * + * @author Myeongdeok Kang + */ +@ConfigurationProperties(HuggingfaceConnectionProperties.CONFIG_PREFIX) +public class HuggingfaceConnectionProperties { + + public static final String CONFIG_PREFIX = "spring.ai.huggingface"; + + /** + * API Key to authenticate with HuggingFace Inference API. + */ + private String apiKey; + + public String getApiKey() { + return this.apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceEmbeddingAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceEmbeddingAutoConfiguration.java new file mode 100644 index 00000000000..94195dab083 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceEmbeddingAutoConfiguration.java @@ -0,0 +1,106 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.huggingface.autoconfigure; + +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; +import org.springframework.ai.huggingface.HuggingfaceEmbeddingModel; +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.model.SpringAIModelProperties; +import org.springframework.ai.model.SpringAIModels; +import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +/** + * {@link AutoConfiguration Auto-configuration} for HuggingFace Embedding Model. + * + * @author Myeongdeok Kang + */ +@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class, + HuggingfaceApiAutoConfiguration.class }) +@ConditionalOnClass(HuggingfaceApi.class) +@ConditionalOnProperty(name = SpringAIModelProperties.EMBEDDING_MODEL, havingValue = SpringAIModels.HUGGINGFACE, + matchIfMissing = true) +@EnableConfigurationProperties({ HuggingfaceConnectionProperties.class, HuggingfaceEmbeddingProperties.class }) +public class HuggingfaceEmbeddingAutoConfiguration { + + @Bean + @Qualifier("huggingfaceEmbeddingApi") + @ConditionalOnMissingBean(name = "huggingfaceEmbeddingApi") + @ConditionalOnProperty(prefix = HuggingfaceEmbeddingProperties.CONFIG_PREFIX, name = "enabled", + havingValue = "true", matchIfMissing = true) + public HuggingfaceApi huggingfaceEmbeddingApi(HuggingfaceConnectionDetails connectionDetails, + HuggingfaceEmbeddingProperties embeddingProperties, + ObjectProvider restClientBuilderProvider, + ObjectProvider responseErrorHandlerProvider) { + + String apiKey = connectionDetails.getApiKey(); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "HuggingFace API key must be set. Please configure spring.ai.huggingface.api-key"); + } + + RestClient.Builder restClientBuilder = restClientBuilderProvider.getIfAvailable(RestClient::builder); + ResponseErrorHandler responseErrorHandler = responseErrorHandlerProvider.getIfAvailable(() -> null); + + HuggingfaceApi.Builder apiBuilder = HuggingfaceApi.builder() + .baseUrl(embeddingProperties.getUrl()) + .apiKey(apiKey) + .restClientBuilder(restClientBuilder); + + if (responseErrorHandler != null) { + apiBuilder.responseErrorHandler(responseErrorHandler); + } + + return apiBuilder.build(); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = HuggingfaceEmbeddingProperties.CONFIG_PREFIX, name = "enabled", + havingValue = "true", matchIfMissing = true) + public HuggingfaceEmbeddingModel huggingfaceEmbeddingModel( + @Qualifier("huggingfaceEmbeddingApi") HuggingfaceApi huggingfaceApi, + HuggingfaceEmbeddingProperties embeddingProperties, ObjectProvider observationRegistry, + ObjectProvider observationConvention, RetryTemplate retryTemplate) { + + var embeddingModel = HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(huggingfaceApi) + .defaultOptions(embeddingProperties.getOptions()) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .retryTemplate(retryTemplate) + .build(); + + observationConvention.ifAvailable(embeddingModel::setObservationConvention); + + return embeddingModel; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceEmbeddingProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceEmbeddingProperties.java new file mode 100644 index 00000000000..e194ba4ab34 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceEmbeddingProperties.java @@ -0,0 +1,83 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.huggingface.autoconfigure; + +import org.springframework.ai.huggingface.HuggingfaceEmbeddingOptions; +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.huggingface.api.common.HuggingfaceApiConstants; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * HuggingFace Embedding autoconfiguration properties. + * + * @author Myeongdeok Kang + */ +@ConfigurationProperties(HuggingfaceEmbeddingProperties.CONFIG_PREFIX) +public class HuggingfaceEmbeddingProperties { + + public static final String CONFIG_PREFIX = "spring.ai.huggingface.embedding"; + + /** + * Enable HuggingFace embedding model autoconfiguration. + */ + private boolean enabled = true; + + /** + * Base URL for the HuggingFace Embedding API (Feature Extraction endpoint). + */ + private String url = HuggingfaceApiConstants.DEFAULT_EMBEDDING_BASE_URL; + + /** + * Client-level HuggingFace embedding options. Use this property to configure the + * model, dimensions, and other parameters. Null values are ignored, defaulting to the + * API defaults. + */ + @NestedConfigurationProperty + private final HuggingfaceEmbeddingOptions options = HuggingfaceEmbeddingOptions.builder() + .model(HuggingfaceApi.DEFAULT_EMBEDDING_MODEL) + .build(); + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public String getUrl() { + return this.url; + } + + public void setUrl(String url) { + this.url = url; + } + + public String getModel() { + return this.options.getModel(); + } + + public void setModel(String model) { + this.options.setModel(model); + } + + public HuggingfaceEmbeddingOptions getOptions() { + return this.options; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 092c087b87f..a51739c3e32 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -13,4 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # +org.springframework.ai.model.huggingface.autoconfigure.HuggingfaceApiAutoConfiguration org.springframework.ai.model.huggingface.autoconfigure.HuggingfaceChatAutoConfiguration +org.springframework.ai.model.huggingface.autoconfigure.HuggingfaceEmbeddingAutoConfiguration diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatAutoConfigurationIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatAutoConfigurationIT.java index db991fe088a..e08561fb29a 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatAutoConfigurationIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatAutoConfigurationIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,34 +32,97 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.huggingface.HuggingfaceChatModel; -import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.ai.huggingface.HuggingfaceChatOptions; +import org.springframework.ai.utils.SpringAiTestAutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import static org.assertj.core.api.Assertions.assertThat; +/** + * Integration tests for HuggingFace Chat Auto Configuration. + * + * @author Mark Pollack + * @author Myeongdeok Kang + */ @EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") -@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_CHAT_URL", matches = ".+") public class HuggingfaceChatAutoConfigurationIT { private static final Log logger = LogFactory.getLog(HuggingfaceChatAutoConfigurationIT.class); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( - // @formatter:off - "spring.ai.huggingface.chat.api-key=" + System.getenv("HUGGINGFACE_API_KEY"), - "spring.ai.huggingface.chat.url=" + System.getenv("HUGGINGFACE_CHAT_URL")) - // @formatter:on - .withConfiguration(AutoConfigurations.of(HuggingfaceChatAutoConfiguration.class)); + private static final String DEFAULT_CHAT_MODEL = "meta-llama/Llama-3.2-3B-Instruct"; + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.huggingface.api-key=" + System.getenv("HUGGINGFACE_API_KEY"), + "spring.ai.huggingface.chat.options.model=" + DEFAULT_CHAT_MODEL) + .withConfiguration(SpringAiTestAutoConfigurations.of(HuggingfaceChatAutoConfiguration.class)); @Test void generate() { this.contextRunner.run(context -> { HuggingfaceChatModel chatModel = context.getBean(HuggingfaceChatModel.class); - String response = chatModel.call("Hello"); + String response = chatModel.call("Say 'Hello World' and nothing else"); assertThat(response).isNotEmpty(); logger.info("Response: " + response); }); } + @Test + void chatActivation() { + // Default activation + this.contextRunner.run(context -> { + assertThat(context.getBeansOfType(HuggingfaceChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(HuggingfaceChatModel.class)).isNotEmpty(); + }); + + // Disabled via property + this.contextRunner.withPropertyValues("spring.ai.model.chat=none").run(context -> { + assertThat(context.getBeansOfType(HuggingfaceChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(HuggingfaceChatModel.class)).isEmpty(); + }); + + // Explicitly enabled + this.contextRunner.withPropertyValues("spring.ai.model.chat=huggingface").run(context -> { + assertThat(context.getBeansOfType(HuggingfaceChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(HuggingfaceChatModel.class)).isNotEmpty(); + }); + } + + @Test + void chatProperties() { + this.contextRunner + .withPropertyValues("spring.ai.huggingface.chat.options.model=" + DEFAULT_CHAT_MODEL, + "spring.ai.huggingface.chat.options.temperature=0.8", + "spring.ai.huggingface.chat.options.maxTokens=500") + .run(context -> { + var chatProperties = context.getBean(HuggingfaceChatProperties.class); + assertThat(chatProperties.getOptions().getModel()).isEqualTo(DEFAULT_CHAT_MODEL); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.8); + assertThat(chatProperties.getOptions().getMaxTokens()).isEqualTo(500); + }); + } + + @Test + void chatCallWithOptions() { + this.contextRunner.run(context -> { + HuggingfaceChatModel chatModel = context.getBean(HuggingfaceChatModel.class); + + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model(DEFAULT_CHAT_MODEL) + .temperature(0.7) + .maxTokens(100) + .build(); + + ChatResponse response = chatModel.call(new Prompt("Say 'Hello' and nothing else", options)); + + assertThat(response).isNotNull(); + assertThat(response.getResult()).isNotNull(); + assertThat(response.getResult().getOutput()).isNotNull(); + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); + + logger.info("Response with options: " + response.getResult().getOutput().getText()); + }); + } + @Disabled("Until streaming support is added") @Test void generateStreaming() { diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatAutoConfigurationTests.java b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatAutoConfigurationTests.java new file mode 100644 index 00000000000..fa4120cf128 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceChatAutoConfigurationTests.java @@ -0,0 +1,110 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.huggingface.autoconfigure; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.utils.SpringAiTestAutoConfigurations; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link HuggingfaceChatAutoConfiguration}. + * + * @author Myeongdeok Kang + */ +public class HuggingfaceChatAutoConfigurationTests { + + @Test + public void propertiesTest() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.huggingface.api-key=TEST_API_KEY", + "spring.ai.huggingface.chat.url=https://test.huggingface.co/v1", + "spring.ai.huggingface.chat.options.model=meta-llama/Llama-3.2-3B-Instruct", + "spring.ai.huggingface.chat.options.temperature=0.7", + "spring.ai.huggingface.chat.options.maxTokens=512", + "spring.ai.huggingface.chat.options.topP=0.9" + // @formatter:on + ) + .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class)) + .withConfiguration(SpringAiTestAutoConfigurations.of(HuggingfaceApiAutoConfiguration.class, + HuggingfaceChatAutoConfiguration.class)) + .run(context -> { + assertThat(context).hasSingleBean(HuggingfaceChatProperties.class); + assertThat(context).hasSingleBean(HuggingfaceConnectionProperties.class); + + var chatProperties = context.getBean(HuggingfaceChatProperties.class); + var connectionProperties = context.getBean(HuggingfaceConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("TEST_API_KEY"); + assertThat(chatProperties.getUrl()).isEqualTo("https://test.huggingface.co/v1"); + assertThat(chatProperties.getOptions().getModel()).isEqualTo("meta-llama/Llama-3.2-3B-Instruct"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.7); + assertThat(chatProperties.getOptions().getMaxTokens()).isEqualTo(512); + assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.9); + }); + } + + @Test + public void chatActivationTest() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.huggingface.api-key=TEST_API_KEY", + "spring.ai.huggingface.chat.enabled=false") + .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class)) + .withConfiguration(SpringAiTestAutoConfigurations.of(HuggingfaceApiAutoConfiguration.class, + HuggingfaceChatAutoConfiguration.class)) + .run(context -> assertThat(context).doesNotHaveBean("huggingfaceChatModel")); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.huggingface.api-key=TEST_API_KEY", "spring.ai.huggingface.chat.enabled=true") + .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class)) + .withConfiguration(SpringAiTestAutoConfigurations.of(HuggingfaceApiAutoConfiguration.class, + HuggingfaceChatAutoConfiguration.class)) + .run(context -> assertThat(context) + .hasSingleBean(org.springframework.ai.huggingface.HuggingfaceChatModel.class)); + } + + @Test + public void newParametersTest() { + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.huggingface.api-key=TEST_API_KEY", + "spring.ai.huggingface.chat.options.seed=42", + "spring.ai.huggingface.chat.options.tool-prompt=You have access to tools:", + "spring.ai.huggingface.chat.options.logprobs=true", + "spring.ai.huggingface.chat.options.top-logprobs=3" + // @formatter:on + ) + .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class)) + .withConfiguration(SpringAiTestAutoConfigurations.of(HuggingfaceApiAutoConfiguration.class, + HuggingfaceChatAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(HuggingfaceChatProperties.class); + + assertThat(chatProperties.getOptions().getSeed()).isEqualTo(42); + assertThat(chatProperties.getOptions().getToolPrompt()).isEqualTo("You have access to tools:"); + assertThat(chatProperties.getOptions().getLogprobs()).isTrue(); + assertThat(chatProperties.getOptions().getTopLogprobs()).isEqualTo(3); + }); + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceEmbeddingAutoConfigurationIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceEmbeddingAutoConfigurationIT.java new file mode 100644 index 00000000000..d1025fc912e --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceEmbeddingAutoConfigurationIT.java @@ -0,0 +1,121 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.huggingface.autoconfigure; + +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.huggingface.HuggingfaceEmbeddingModel; +import org.springframework.ai.utils.SpringAiTestAutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for HuggingFace Embedding Auto Configuration. + * + * @author Myeongdeok Kang + */ +@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") +public class HuggingfaceEmbeddingAutoConfigurationIT { + + private static final Log logger = LogFactory.getLog(HuggingfaceEmbeddingAutoConfigurationIT.class); + + private static final String MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"; + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.huggingface.api-key=" + System.getenv("HUGGINGFACE_API_KEY"), + "spring.ai.huggingface.embedding.options.model=" + MODEL_NAME) + // @formatter:on + .withConfiguration(SpringAiTestAutoConfigurations.of(HuggingfaceEmbeddingAutoConfiguration.class)); + + @Test + void singleTextEmbedding() { + this.contextRunner.run(context -> { + HuggingfaceEmbeddingModel embeddingModel = context.getBean(HuggingfaceEmbeddingModel.class); + assertThat(embeddingModel).isNotNull(); + + EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSizeGreaterThan(0); + + logger.info("Embedding dimensions: " + embeddingResponse.getResults().get(0).getOutput().length); + }); + } + + @Test + void batchTextEmbedding() { + this.contextRunner.run(context -> { + HuggingfaceEmbeddingModel embeddingModel = context.getBean(HuggingfaceEmbeddingModel.class); + assertThat(embeddingModel).isNotNull(); + + List texts = List.of("Hello World", "Spring AI", "HuggingFace Integration"); + EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(texts); + assertThat(embeddingResponse.getResults()).hasSize(3); + + embeddingResponse.getResults().forEach(result -> { + assertThat(result.getOutput()).isNotEmpty(); + assertThat(result.getOutput()).hasSizeGreaterThan(0); + }); + + logger.info("Batch embedding completed for " + texts.size() + " texts"); + }); + } + + @Test + void embeddingActivation() { + // Default activation + this.contextRunner.run(context -> { + assertThat(context.getBeansOfType(HuggingfaceEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(HuggingfaceEmbeddingModel.class)).isNotEmpty(); + }); + + // Disabled via property + this.contextRunner.withPropertyValues("spring.ai.model.embedding=none").run(context -> { + assertThat(context.getBeansOfType(HuggingfaceEmbeddingProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(HuggingfaceEmbeddingModel.class)).isEmpty(); + }); + + // Explicitly enabled + this.contextRunner.withPropertyValues("spring.ai.model.embedding=huggingface").run(context -> { + assertThat(context.getBeansOfType(HuggingfaceEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(HuggingfaceEmbeddingModel.class)).isNotEmpty(); + }); + } + + @Test + void embeddingProperties() { + this.contextRunner + .withPropertyValues("spring.ai.huggingface.embedding.options.model=" + MODEL_NAME, + "spring.ai.huggingface.embedding.options.normalize=true", + "spring.ai.huggingface.embedding.options.prompt-name=query") + .run(context -> { + var embeddingProperties = context.getBean(HuggingfaceEmbeddingProperties.class); + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo(MODEL_NAME); + assertThat(embeddingProperties.getOptions().getNormalize()).isTrue(); + assertThat(embeddingProperties.getOptions().getPromptName()).isEqualTo("query"); + }); + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceEmbeddingAutoConfigurationTests.java b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceEmbeddingAutoConfigurationTests.java new file mode 100644 index 00000000000..26d1c045bae --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceEmbeddingAutoConfigurationTests.java @@ -0,0 +1,75 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.huggingface.autoconfigure; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.utils.SpringAiTestAutoConfigurations; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link HuggingfaceEmbeddingAutoConfiguration}. + * + * @author Myeongdeok Kang + */ +public class HuggingfaceEmbeddingAutoConfigurationTests { + + @Test + public void propertiesTest() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.huggingface.api-key=TEST_API_KEY", + "spring.ai.huggingface.embedding.url=https://test.huggingface.co/hf-inference/models", + "spring.ai.huggingface.embedding.options.model=sentence-transformers/all-MiniLM-L6-v2", + "spring.ai.huggingface.embedding.options.normalize=true" + // @formatter:on + ) + .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class)) + .withConfiguration(SpringAiTestAutoConfigurations.of(HuggingfaceApiAutoConfiguration.class, + HuggingfaceEmbeddingAutoConfiguration.class)) + .run(context -> { + assertThat(context).hasSingleBean(HuggingfaceEmbeddingProperties.class); + assertThat(context).hasSingleBean(HuggingfaceConnectionProperties.class); + }); + } + + @Test + public void embeddingActivationTest() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.huggingface.api-key=TEST_API_KEY", + "spring.ai.huggingface.embedding.enabled=false") + .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class)) + .withConfiguration(SpringAiTestAutoConfigurations.of(HuggingfaceApiAutoConfiguration.class, + HuggingfaceEmbeddingAutoConfiguration.class)) + .run(context -> assertThat(context).doesNotHaveBean("huggingfaceEmbeddingModel")); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.huggingface.api-key=TEST_API_KEY", + "spring.ai.huggingface.embedding.enabled=true") + .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class)) + .withConfiguration(SpringAiTestAutoConfigurations.of(HuggingfaceApiAutoConfiguration.class, + HuggingfaceEmbeddingAutoConfiguration.class)) + .run(context -> assertThat(context) + .hasSingleBean(org.springframework.ai.huggingface.HuggingfaceEmbeddingModel.class)); + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceModelConfigurationTests.java b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceModelConfigurationTests.java index 2786ef96b3d..63cda60c961 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceModelConfigurationTests.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-huggingface/src/test/java/org/springframework/ai/model/huggingface/autoconfigure/HuggingfaceModelConfigurationTests.java @@ -19,7 +19,9 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.huggingface.HuggingfaceChatModel; +import org.springframework.ai.utils.SpringAiTestAutoConfigurations; import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import static org.assertj.core.api.Assertions.assertThat; @@ -33,7 +35,10 @@ public class HuggingfaceModelConfigurationTests { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(HuggingfaceChatAutoConfiguration.class)); + .withPropertyValues("spring.ai.huggingface.api-key=TEST_API_KEY") + .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class)) + .withConfiguration(SpringAiTestAutoConfigurations.of(HuggingfaceApiAutoConfiguration.class, + HuggingfaceChatAutoConfiguration.class)); @Test void chatModelActivation() { diff --git a/models/spring-ai-huggingface/pom.xml b/models/spring-ai-huggingface/pom.xml index 44533bfba90..2fa99c08a7e 100644 --- a/models/spring-ai-huggingface/pom.xml +++ b/models/spring-ai-huggingface/pom.xml @@ -47,20 +47,12 @@ ${project.parent.version} - - io.swagger.core.v3 - swagger-annotations-jakarta - ${swagger-annotations.version} - - - - javax.annotation - javax.annotation-api - 1.3.2 + org.springframework.ai + spring-ai-retry + ${project.parent.version} - org.springframework @@ -78,67 +70,25 @@ + + org.springframework.ai + spring-ai-client-chat + ${project.parent.version} + test + + org.springframework.boot spring-boot-starter-test test - - - - - - io.swagger.codegen.v3 - swagger-codegen-maven-plugin - 3.0.75 - - - - generate - - - ${project.basedir}/src/main/resources/openapi.json - java - resttemplate - src/main/resources/handlebars/Java - org.springframework.ai.huggingface.api - org.springframework.ai.huggingface.model - org.springframework.ai.huggingface.invoker - false - false - - src/main/java - java8 - - true - - - - - - - - org.codehaus.mojo - build-helper-maven-plugin - 3.4.0 - - - 01-add-test-sources - generate-sources - - add-source - - - - ${project.build.directory}/generated-sources/swagger/src/main/java - - - - - + + io.micrometer + micrometer-observation-test + test + + - - diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java index 2a11a12145f..acf6d174885 100644 --- a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,115 +16,431 @@ package org.springframework.ai.huggingface; -import java.util.ArrayList; import java.util.List; -import java.util.Map; +import java.util.stream.Collectors; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.observation.ObservationRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.observation.ChatModelObservationContext; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.huggingface.api.TextGenerationInferenceApi; -import org.springframework.ai.huggingface.invoker.ApiClient; -import org.springframework.ai.huggingface.model.AllOfGenerateResponseDetails; -import org.springframework.ai.huggingface.model.CompatGenerateRequest; -import org.springframework.ai.huggingface.model.GenerateParameters; -import org.springframework.ai.huggingface.model.GenerateResponse; +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.huggingface.api.common.HuggingfaceApiConstants; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; /** - * An implementation of {@link ChatModel} that interfaces with HuggingFace Inference - * Endpoints for text generation. + * {@link ChatModel} implementation for HuggingFace Inference API. HuggingFace provides + * access to thousands of pre-trained models for various NLP tasks including chat + * completions. * * @author Mark Pollack + * @author Josh Long + * @author Soby Chacko * @author Jihoon Kim + * @author Myeongdeok Kang */ public class HuggingfaceChatModel implements ChatModel { + private static final Logger logger = LoggerFactory.getLogger(HuggingfaceChatModel.class); + + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + + private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); + + private final HuggingfaceApi huggingfaceApi; + + private final HuggingfaceChatOptions defaultOptions; + + private final ObservationRegistry observationRegistry; + + private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + private final RetryTemplate retryTemplate; + + private final ToolCallingManager toolCallingManager; + + private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + /** - * Token required for authenticating with the HuggingFace Inference API. + * Constructor for HuggingfaceChatModel. + * @param huggingfaceApi The HuggingFace API client. + * @param defaultOptions Default chat options. + * @param toolCallingManager Tool calling manager for executing tools. + * @param observationRegistry Observation registry for metrics. + * @param retryTemplate Retry template for handling transient errors. + * @param toolExecutionEligibilityPredicate Predicate to determine if tool execution + * is required. */ - private final String apiToken; + public HuggingfaceChatModel(HuggingfaceApi huggingfaceApi, HuggingfaceChatOptions defaultOptions, + ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, RetryTemplate retryTemplate, + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + Assert.notNull(huggingfaceApi, "huggingfaceApi must not be null"); + Assert.notNull(defaultOptions, "defaultOptions must not be null"); + Assert.notNull(toolCallingManager, "toolCallingManager must not be null"); + Assert.notNull(observationRegistry, "observationRegistry must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); + Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate must not be null"); + + this.huggingfaceApi = huggingfaceApi; + this.defaultOptions = defaultOptions; + this.toolCallingManager = toolCallingManager; + this.observationRegistry = observationRegistry; + this.retryTemplate = retryTemplate; + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + } + + /** + * Create a new builder for HuggingfaceChatModel. + * @return A new builder instance. + */ + public static Builder builder() { + return new Builder(); + } + + @Override + public ChatResponse call(Prompt prompt) { + Assert.notEmpty(prompt.getInstructions(), "At least one message is required!"); + + // Build the final request, merging runtime and default options + Prompt requestPrompt = buildChatRequest(prompt); + return this.internalCall(requestPrompt, null); + } /** - * Client for making API calls. + * Internal method for making chat completion calls with tool execution support. + * @param prompt The prompt to send to the model. + * @param previousChatResponse Previous chat response for cumulative usage tracking. + * @return The chat response from the model. */ - private ApiClient apiClient = new ApiClient(); + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + HuggingfaceApi.ChatRequest apiRequest = createApiRequest(prompt); + + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(HuggingfaceApiConstants.PROVIDER_NAME) + .build(); + + ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + HuggingfaceApi.ChatResponse apiResponse = RetryUtils.execute(this.retryTemplate, + () -> this.huggingfaceApi.chat(apiRequest)); + + List generations = apiResponse.choices().stream().map(choice -> { + List toolCalls = extractToolCalls(choice.message().toolCalls()); + AssistantMessage.Builder messageBuilder = AssistantMessage.builder() + .content(choice.message().content()); + if (toolCalls != null) { + messageBuilder.toolCalls(toolCalls); + } + return new Generation(messageBuilder.build(), + ChatGenerationMetadata.builder().finishReason(choice.finishReason()).build()); + }).collect(Collectors.toList()); + + ChatResponseMetadata metadata = ChatResponseMetadata.builder() + .model(apiResponse.model()) + .usage(apiResponse.usage() != null ? new DefaultUsage(apiResponse.usage().promptTokens(), + apiResponse.usage().completionTokens()) : new DefaultUsage(0, 0)) + .build(); + + ChatResponse chatResponse = new ChatResponse(generations, metadata); + + observationContext.setResponse(chatResponse); + + return chatResponse; + }); + + // Tool execution handling + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return ChatResponse.builder() + .from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build(); + } + else { + // Send the tool execution result back to the model. + return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + response); + } + } + + return response; + } /** - * Mapper for converting between Java objects and JSON. + * Build the chat request by merging runtime and default options. + * @param chatRequest The original chat request. + * @return A new chat request with merged options. */ - private ObjectMapper objectMapper = new ObjectMapper(); + Prompt buildChatRequest(Prompt chatRequest) { + // Process runtime options + HuggingfaceChatOptions runtimeOptions = null; + if (chatRequest.getOptions() != null) { + if (chatRequest.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, + HuggingfaceChatOptions.class); + } + else { + runtimeOptions = ModelOptionsUtils.copyToTarget(chatRequest.getOptions(), ChatOptions.class, + HuggingfaceChatOptions.class); + } + } + + // Merge runtime and default options + HuggingfaceChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + HuggingfaceChatOptions.class); + + // Merge @JsonIgnore-annotated options explicitly since they are ignored by + // Jackson, used by ModelOptionsUtils. + if (runtimeOptions != null) { + requestOptions.setInternalToolExecutionEnabled( + ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), + this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), + this.defaultOptions.getToolNames())); + requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), + this.defaultOptions.getToolCallbacks())); + requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), + this.defaultOptions.getToolContext())); + } + + // Validate + if (!StringUtils.hasText(requestOptions.getModel())) { + throw new IllegalArgumentException("model cannot be null or empty"); + } + + return new Prompt(chatRequest.getInstructions(), requestOptions); + } /** - * API for text generation inferences. + * Create the API request from the chat request. + * @param prompt The chat request. + * @return The API request. */ - private TextGenerationInferenceApi textGenApi = new TextGenerationInferenceApi(); + private HuggingfaceApi.ChatRequest createApiRequest(Prompt prompt) { + HuggingfaceChatOptions options = (HuggingfaceChatOptions) prompt.getOptions(); + + List messages = prompt.getInstructions() + .stream() + .flatMap(message -> toHuggingfaceMessage(message).stream()) + .toList(); + + // Add tool definitions to the request if present + List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(options); + + if (!CollectionUtils.isEmpty(toolDefinitions)) { + List tools = getFunctionTools(toolDefinitions); + return new HuggingfaceApi.ChatRequest(options.getModel(), messages, tools, "auto", options.toMap()); + } + + return new HuggingfaceApi.ChatRequest(options.getModel(), messages, options.toMap()); + } /** - * The maximum number of new tokens to be generated. Note: The total token size for - * the Mistral7b instruct model should be less than 1500. + * Convert Spring AI message to HuggingFace API message(s). Tool response messages may + * produce multiple API messages (one per tool response). + * @param message The Spring AI message. + * @return The list of HuggingFace API messages. */ - private int maxNewTokens = 1000; + private List toHuggingfaceMessage(Message message) { + if (message.getMessageType() == MessageType.TOOL) { + // Tool response messages need special handling + ToolResponseMessage toolMessage = (ToolResponseMessage) message; + return toolMessage.getResponses() + .stream() + .map(response -> new HuggingfaceApi.Message(response.responseData(), MessageType.TOOL.getValue(), + response.name(), response.id())) + .toList(); + } + else if (message instanceof AssistantMessage assistantMessage && assistantMessage.getToolCalls() != null + && !assistantMessage.getToolCalls().isEmpty()) { + // Assistant message with tool calls + List toolCalls = assistantMessage.getToolCalls() + .stream() + .map(toolCall -> new HuggingfaceApi.ToolCall(toolCall.id(), toolCall.type(), + new HuggingfaceApi.ChatCompletionFunction(toolCall.name(), toolCall.arguments()))) + .toList(); + return List + .of(new HuggingfaceApi.Message(message.getMessageType().getValue(), message.getText(), toolCalls)); + } + else { + // Regular user/system/assistant message + return List.of(new HuggingfaceApi.Message(message.getMessageType().getValue(), message.getText())); + } + } /** - * Constructs a new HuggingfaceChatModel with the specified API token and base path. - * @param apiToken The API token for HuggingFace. - * @param basePath The base path for API requests. + * Convert tool definitions to HuggingFace API function tools. + * @param toolDefinitions The tool definitions. + * @return The list of function tools. */ - public HuggingfaceChatModel(final String apiToken, String basePath) { - this.apiToken = apiToken; - this.apiClient.setBasePath(basePath); - this.apiClient.addDefaultHeader("Authorization", "Bearer " + this.apiToken); - this.textGenApi.setApiClient(this.apiClient); + private List getFunctionTools(List toolDefinitions) { + return toolDefinitions.stream().map(toolDefinition -> { + var function = new HuggingfaceApi.FunctionTool.Function(toolDefinition.description(), toolDefinition.name(), + ModelOptionsUtils.jsonToMap(toolDefinition.inputSchema())); + return new HuggingfaceApi.FunctionTool(function); + }).toList(); } /** - * Generate text based on the provided prompt. - * @param prompt The input prompt based on which text is to be generated. - * @return ChatResponse containing the generated text and other related details. + * Extract tool calls from HuggingFace API response and convert to Spring AI format. + * @param apiToolCalls The tool calls from the API response. + * @return The list of tool calls in Spring AI format, or null if no tool calls. */ - @Override - public ChatResponse call(Prompt prompt) { - CompatGenerateRequest compatGenerateRequest = new CompatGenerateRequest(); - compatGenerateRequest.setInputs(prompt.getContents()); - GenerateParameters generateParameters = new GenerateParameters(); - // TODO - need to expose API to set parameters per call. - generateParameters.setMaxNewTokens(this.maxNewTokens); - compatGenerateRequest.setParameters(generateParameters); - List generateResponses = this.textGenApi.compatGenerate(compatGenerateRequest); - List generations = new ArrayList<>(); - for (GenerateResponse generateResponse : generateResponses) { - String generatedText = generateResponse.getGeneratedText(); - AllOfGenerateResponseDetails allOfGenerateResponseDetails = generateResponse.getDetails(); - Map detailsMap = this.objectMapper.convertValue(allOfGenerateResponseDetails, - new TypeReference<>() { - - }); - Generation generation = new Generation( - AssistantMessage.builder().content(generatedText).properties(detailsMap).build()); - generations.add(generation); - } - return new ChatResponse(generations); + private List extractToolCalls(List apiToolCalls) { + if (apiToolCalls == null || apiToolCalls.isEmpty()) { + return null; + } + + return apiToolCalls.stream() + .map(apiToolCall -> new AssistantMessage.ToolCall(apiToolCall.id(), apiToolCall.type(), + apiToolCall.function().name(), apiToolCall.function().arguments())) + .toList(); } /** - * Gets the maximum number of new tokens to be generated. - * @return The maximum number of new tokens. + * Set the observation convention for reporting metrics. + * @param observationConvention The observation convention. */ - public int getMaxNewTokens() { - return this.maxNewTokens; + public void setObservationConvention(ChatModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + + @Override + public ChatOptions getDefaultOptions() { + return this.defaultOptions.copy(); } /** - * Sets the maximum number of new tokens to be generated. - * @param maxNewTokens The maximum number of new tokens. + * Builder for creating HuggingfaceChatModel instances. */ - public void setMaxNewTokens(int maxNewTokens) { - this.maxNewTokens = maxNewTokens; + public static final class Builder { + + private HuggingfaceApi huggingfaceApi; + + private HuggingfaceChatOptions defaultOptions = HuggingfaceChatOptions.builder() + .model(HuggingfaceApi.DEFAULT_CHAT_MODEL) + .build(); + + private ToolCallingManager toolCallingManager; + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + + private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); + + private Builder() { + } + + /** + * Set the HuggingFace API client. + * @param huggingfaceApi The API client. + * @return This builder. + */ + public Builder huggingfaceApi(HuggingfaceApi huggingfaceApi) { + this.huggingfaceApi = huggingfaceApi; + return this; + } + + /** + * Set the default chat options. + * @param defaultOptions The default options. + * @return This builder. + */ + public Builder defaultOptions(HuggingfaceChatOptions defaultOptions) { + this.defaultOptions = defaultOptions; + return this; + } + + /** + * Set the tool calling manager. + * @param toolCallingManager The tool calling manager. + * @return This builder. + */ + public Builder toolCallingManager(ToolCallingManager toolCallingManager) { + this.toolCallingManager = toolCallingManager; + return this; + } + + /** + * Set the observation registry. + * @param observationRegistry The observation registry. + * @return This builder. + */ + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + /** + * Set the retry template. + * @param retryTemplate The retry template. + * @return This builder. + */ + public Builder retryTemplate(RetryTemplate retryTemplate) { + this.retryTemplate = retryTemplate; + return this; + } + + /** + * Set the tool execution eligibility predicate. + * @param toolExecutionEligibilityPredicate The predicate. + * @return This builder. + */ + public Builder toolExecutionEligibilityPredicate( + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + return this; + } + + /** + * Build the HuggingfaceChatModel instance. + * @return A new HuggingfaceChatModel. + */ + public HuggingfaceChatModel build() { + Assert.notNull(this.huggingfaceApi, "huggingfaceApi must not be null"); + Assert.notNull(this.toolExecutionEligibilityPredicate, + "toolExecutionEligibilityPredicate must not be null"); + if (this.toolCallingManager != null) { + return new HuggingfaceChatModel(this.huggingfaceApi, this.defaultOptions, this.toolCallingManager, + this.observationRegistry, this.retryTemplate, this.toolExecutionEligibilityPredicate); + } + return new HuggingfaceChatModel(this.huggingfaceApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, + this.observationRegistry, this.retryTemplate, this.toolExecutionEligibilityPredicate); + } + } } diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatOptions.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatOptions.java new file mode 100644 index 00000000000..c255774be20 --- /dev/null +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatOptions.java @@ -0,0 +1,468 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; + +/** + * Chat options for HuggingFace chat model. + * + * @author Myeongdeok Kang + */ +@JsonInclude(Include.NON_NULL) +public class HuggingfaceChatOptions implements ToolCallingChatOptions { + + /** + * The model name to use for chat. + */ + @JsonProperty("model") + private String model; + + /** + * Controls the randomness of the output. Higher values (e.g., 1.0) make the output + * more random, while lower values (e.g., 0.1) make it more focused and deterministic. + */ + @JsonProperty("temperature") + private Double temperature; + + /** + * The maximum number of tokens to generate in the chat completion. + */ + @JsonProperty("max_tokens") + private Integer maxTokens; + + /** + * Nucleus sampling parameter. The model considers the results of the tokens with + * top_p probability mass. + */ + @JsonProperty("top_p") + private Double topP; + + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their + * existing frequency in the text so far. + */ + @JsonProperty("frequency_penalty") + private Double frequencyPenalty; + + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether + * they appear in the text so far. + */ + @JsonProperty("presence_penalty") + private Double presencePenalty; + + /** + * Up to 4 sequences where the API will stop generating further tokens. + */ + @JsonProperty("stop") + private List stop; + + /** + * Integer seed for reproducibility. This makes repeated requests with the same seed + * and parameters return the same result. + */ + @JsonProperty("seed") + private Integer seed; + + /** + * An object specifying the format that the model must output. Setting to {"type": + * "json_object"} enables JSON mode, which guarantees the message the model generates + * is valid JSON. Setting to {"type": "json_schema", "json_schema": {...}} enables + * Structured Outputs which ensures the model will match your supplied JSON schema. + */ + @JsonProperty("response_format") + private Map responseFormat; + + /** + * A prompt to be appended before the tools. + */ + @JsonProperty("tool_prompt") + private String toolPrompt; + + /** + * Whether to return log probabilities of the output tokens or not. If true, returns + * the log probabilities of each output token returned in the content of message. + */ + @JsonProperty("logprobs") + private Boolean logprobs; + + /** + * An integer between 0 and 5 specifying the number of most likely tokens to return at + * each token position, each with an associated log probability. logprobs must be set + * to true if this parameter is used. + */ + @JsonProperty("top_logprobs") + private Integer topLogprobs; + + /** + * Tool callbacks to be registered with the ChatModel. + */ + @JsonIgnore + private List toolCallbacks = new ArrayList<>(); + + /** + * Names of the tools to register with the ChatModel. + */ + @JsonIgnore + private Set toolNames = new HashSet<>(); + + /** + * Whether the ChatModel is responsible for executing the tools requested by the model + * or if the tools should be executed directly by the caller. + */ + @JsonIgnore + private Boolean internalToolExecutionEnabled; + + /** + * Tool context values as map. + */ + @JsonIgnore + private Map toolContext = new HashMap<>(); + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + @Override + public Double getTopP() { + return this.topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + @Override + public Integer getTopK() { + return null; + } + + @Override + public List getStopSequences() { + return this.stop; + } + + public void setStopSequences(List stopSequences) { + this.stop = stopSequences; + } + + public Integer getSeed() { + return this.seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + public Map getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(Map responseFormat) { + this.responseFormat = responseFormat; + } + + public String getToolPrompt() { + return this.toolPrompt; + } + + public void setToolPrompt(String toolPrompt) { + this.toolPrompt = toolPrompt; + } + + public Boolean getLogprobs() { + return this.logprobs; + } + + public void setLogprobs(Boolean logprobs) { + this.logprobs = logprobs; + } + + public Integer getTopLogprobs() { + return this.topLogprobs; + } + + public void setTopLogprobs(Integer topLogprobs) { + this.topLogprobs = topLogprobs; + } + + @Override + public List getToolCallbacks() { + return this.toolCallbacks; + } + + @Override + public void setToolCallbacks(List toolCallbacks) { + this.toolCallbacks = toolCallbacks != null ? toolCallbacks : new ArrayList<>(); + } + + @Override + public Set getToolNames() { + return this.toolNames; + } + + @Override + public void setToolNames(Set toolNames) { + this.toolNames = toolNames != null ? toolNames : new HashSet<>(); + } + + @Override + @Nullable + public Boolean getInternalToolExecutionEnabled() { + return this.internalToolExecutionEnabled; + } + + @Override + public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.internalToolExecutionEnabled = internalToolExecutionEnabled; + } + + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext != null ? toolContext : new HashMap<>(); + } + + @Override + public HuggingfaceChatOptions copy() { + return fromOptions(this); + } + + /** + * Create a new {@link HuggingfaceChatOptions} instance from the given options. + * @param fromOptions the options to copy from + * @return a new {@link HuggingfaceChatOptions} instance + */ + public static HuggingfaceChatOptions fromOptions(HuggingfaceChatOptions fromOptions) { + return builder().model(fromOptions.getModel()) + .temperature(fromOptions.getTemperature()) + .maxTokens(fromOptions.getMaxTokens()) + .topP(fromOptions.getTopP()) + .frequencyPenalty(fromOptions.getFrequencyPenalty()) + .presencePenalty(fromOptions.getPresencePenalty()) + .stopSequences(fromOptions.getStopSequences()) + .seed(fromOptions.getSeed()) + .responseFormat(fromOptions.getResponseFormat()) + .toolPrompt(fromOptions.getToolPrompt()) + .logprobs(fromOptions.getLogprobs()) + .topLogprobs(fromOptions.getTopLogprobs()) + .toolCallbacks(fromOptions.getToolCallbacks()) + .toolNames(fromOptions.getToolNames()) + .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .toolContext(fromOptions.getToolContext()) + .build(); + } + + /** + * Convert the {@link HuggingfaceChatOptions} object to a {@link Map} of key/value + * pairs. + * @return the {@link Map} of key/value pairs + */ + public Map toMap() { + return ModelOptionsUtils.objectToMap(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + HuggingfaceChatOptions that = (HuggingfaceChatOptions) o; + return Objects.equals(this.model, that.model) && Objects.equals(this.temperature, that.temperature) + && Objects.equals(this.maxTokens, that.maxTokens) && Objects.equals(this.topP, that.topP) + && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.presencePenalty, that.presencePenalty) && Objects.equals(this.stop, that.stop) + && Objects.equals(this.seed, that.seed) && Objects.equals(this.responseFormat, that.responseFormat) + && Objects.equals(this.toolPrompt, that.toolPrompt) && Objects.equals(this.logprobs, that.logprobs) + && Objects.equals(this.topLogprobs, that.topLogprobs) + && Objects.equals(this.toolCallbacks, that.toolCallbacks) + && Objects.equals(this.toolNames, that.toolNames) + && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.toolContext, that.toolContext); + } + + @Override + public int hashCode() { + return Objects.hash(this.model, this.temperature, this.maxTokens, this.topP, this.frequencyPenalty, + this.presencePenalty, this.stop, this.seed, this.responseFormat, this.toolPrompt, this.logprobs, + this.topLogprobs, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, + this.toolContext); + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private final HuggingfaceChatOptions options = new HuggingfaceChatOptions(); + + private Builder() { + } + + public Builder model(String model) { + this.options.setModel(model); + return this; + } + + public Builder temperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder maxTokens(Integer maxTokens) { + this.options.setMaxTokens(maxTokens); + return this; + } + + public Builder topP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public Builder frequencyPenalty(Double frequencyPenalty) { + this.options.setFrequencyPenalty(frequencyPenalty); + return this; + } + + public Builder presencePenalty(Double presencePenalty) { + this.options.setPresencePenalty(presencePenalty); + return this; + } + + public Builder stopSequences(List stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + public Builder seed(Integer seed) { + this.options.setSeed(seed); + return this; + } + + public Builder responseFormat(Map responseFormat) { + this.options.setResponseFormat(responseFormat); + return this; + } + + public Builder toolPrompt(String toolPrompt) { + this.options.setToolPrompt(toolPrompt); + return this; + } + + public Builder logprobs(Boolean logprobs) { + this.options.setLogprobs(logprobs); + return this; + } + + public Builder topLogprobs(Integer topLogprobs) { + this.options.setTopLogprobs(topLogprobs); + return this; + } + + public Builder toolCallbacks(List toolCallbacks) { + this.options.setToolCallbacks(toolCallbacks); + return this; + } + + public Builder toolNames(Set toolNames) { + this.options.setToolNames(toolNames); + return this; + } + + public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled); + return this; + } + + public Builder toolContext(Map toolContext) { + this.options.setToolContext(toolContext); + return this; + } + + public HuggingfaceChatOptions build() { + return this.options; + } + + } + +} diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingModel.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingModel.java new file mode 100644 index 00000000000..aa4223ae176 --- /dev/null +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingModel.java @@ -0,0 +1,264 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.AbstractEmbeddingModel; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.huggingface.api.HuggingfaceApi.EmbeddingsResponse; +import org.springframework.ai.huggingface.api.common.HuggingfaceApiConstants; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * {@link EmbeddingModel} implementation for HuggingFace Inference API. HuggingFace + * provides access to thousands of pre-trained models for various NLP tasks including text + * embeddings. + * + * @author Myeongdeok Kang + */ +public class HuggingfaceEmbeddingModel extends AbstractEmbeddingModel { + + private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + + private final HuggingfaceApi huggingfaceApi; + + private final HuggingfaceEmbeddingOptions defaultOptions; + + private final ObservationRegistry observationRegistry; + + private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + private final RetryTemplate retryTemplate; + + /** + * Constructor for HuggingfaceEmbeddingModel. + * @param huggingfaceApi The HuggingFace API client. + * @param defaultOptions Default embedding options. + * @param observationRegistry Observation registry for metrics. + * @param retryTemplate Retry template for handling transient errors. + */ + public HuggingfaceEmbeddingModel(HuggingfaceApi huggingfaceApi, HuggingfaceEmbeddingOptions defaultOptions, + ObservationRegistry observationRegistry, RetryTemplate retryTemplate) { + Assert.notNull(huggingfaceApi, "huggingfaceApi must not be null"); + Assert.notNull(defaultOptions, "defaultOptions must not be null"); + Assert.notNull(observationRegistry, "observationRegistry must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); + + this.huggingfaceApi = huggingfaceApi; + this.defaultOptions = defaultOptions; + this.observationRegistry = observationRegistry; + this.retryTemplate = retryTemplate; + } + + /** + * Create a new builder for HuggingfaceEmbeddingModel. + * @return A new builder instance. + */ + public static Builder builder() { + return new Builder(); + } + + @Override + public float[] embed(Document document) { + Assert.notNull(document, "Document must not be null"); + return embed(document.getText()); + } + + @Override + public EmbeddingResponse call(EmbeddingRequest request) { + Assert.notEmpty(request.getInstructions(), "At least one text is required!"); + + // Build the final request, merging runtime and default options + EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request); + + HuggingfaceApi.EmbeddingsRequest huggingfaceEmbeddingRequest = createApiRequest(embeddingRequest); + + var observationContext = EmbeddingModelObservationContext.builder() + .embeddingRequest(embeddingRequest) + .provider(HuggingfaceApiConstants.PROVIDER_NAME) + .build(); + + return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + EmbeddingsResponse response = RetryUtils.execute(this.retryTemplate, + () -> this.huggingfaceApi.embeddings(huggingfaceEmbeddingRequest)); + + AtomicInteger indexCounter = new AtomicInteger(0); + + List embeddings = convertToEmbeddings(response.embeddings(), indexCounter); + + // HuggingFace Inference API doesn't provide usage information + EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(response.model(), new EmptyUsage()); + + EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, metadata); + + observationContext.setResponse(embeddingResponse); + + return embeddingResponse; + }); + } + + /** + * Convert list of float arrays to list of Embedding objects. + * @param embeddingsList The list of embedding arrays from the API. + * @param indexCounter Counter for tracking embedding indices. + * @return List of Embedding objects. + */ + private List convertToEmbeddings(List embeddingsList, AtomicInteger indexCounter) { + return embeddingsList.stream() + .map(embeddingVector -> new Embedding(embeddingVector, indexCounter.getAndIncrement())) + .toList(); + } + + /** + * Build the embedding request by merging runtime and default options. + * @param embeddingRequest The original embedding request. + * @return A new embedding request with merged options. + */ + EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) { + // Process runtime options + HuggingfaceEmbeddingOptions runtimeOptions = null; + if (embeddingRequest.getOptions() != null) { + runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class, + HuggingfaceEmbeddingOptions.class); + } + + // Merge runtime and default options + HuggingfaceEmbeddingOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + HuggingfaceEmbeddingOptions.class); + + // Validate + if (!StringUtils.hasText(requestOptions.getModel())) { + throw new IllegalArgumentException("model cannot be null or empty"); + } + + return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions); + } + + /** + * Create the API request from the embedding request. + * @param embeddingRequest The embedding request. + * @return The API request. + */ + private HuggingfaceApi.EmbeddingsRequest createApiRequest(EmbeddingRequest embeddingRequest) { + HuggingfaceEmbeddingOptions options = (HuggingfaceEmbeddingOptions) embeddingRequest.getOptions(); + return new HuggingfaceApi.EmbeddingsRequest(options.getModel(), embeddingRequest.getInstructions(), + options.toMap()); + } + + /** + * Set the observation convention for reporting metrics. + * @param observationConvention The observation convention. + */ + public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + + /** + * Builder for creating HuggingfaceEmbeddingModel instances. + */ + public static final class Builder { + + private HuggingfaceApi huggingfaceApi; + + private HuggingfaceEmbeddingOptions defaultOptions = HuggingfaceEmbeddingOptions.builder() + .model(HuggingfaceApi.DEFAULT_EMBEDDING_MODEL) + .build(); + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + + private Builder() { + } + + /** + * Set the HuggingFace API client. + * @param huggingfaceApi The API client. + * @return This builder. + */ + public Builder huggingfaceApi(HuggingfaceApi huggingfaceApi) { + this.huggingfaceApi = huggingfaceApi; + return this; + } + + /** + * Set the default embedding options. + * @param defaultOptions The default options. + * @return This builder. + */ + public Builder defaultOptions(HuggingfaceEmbeddingOptions defaultOptions) { + this.defaultOptions = defaultOptions; + return this; + } + + /** + * Set the observation registry. + * @param observationRegistry The observation registry. + * @return This builder. + */ + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + /** + * Set the retry template. + * @param retryTemplate The retry template. + * @return This builder. + */ + public Builder retryTemplate(RetryTemplate retryTemplate) { + this.retryTemplate = retryTemplate; + return this; + } + + /** + * Build the HuggingfaceEmbeddingModel instance. + * @return A new HuggingfaceEmbeddingModel. + */ + public HuggingfaceEmbeddingModel build() { + Assert.notNull(this.huggingfaceApi, "huggingfaceApi must not be null"); + return new HuggingfaceEmbeddingModel(this.huggingfaceApi, this.defaultOptions, this.observationRegistry, + this.retryTemplate); + } + + } + +} diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingOptions.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingOptions.java new file mode 100644 index 00000000000..fe97eec576d --- /dev/null +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingOptions.java @@ -0,0 +1,257 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import java.util.Map; +import java.util.Objects; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.model.ModelOptionsUtils; + +/** + * Options for HuggingFace embedding model. + * + * @author Myeongdeok Kang + */ +@JsonInclude(Include.NON_NULL) +public class HuggingfaceEmbeddingOptions implements EmbeddingOptions { + + /** + * The name of the model to use for embeddings. + */ + @JsonProperty("model") + private String model; + + /** + * Whether to normalize the embedding vectors. + */ + @JsonProperty("normalize") + private Boolean normalize; + + /** + * The name of a predefined prompt from the model configuration to apply to the input + * text. + *

+ * For example, setting this to "query" might prepend "query: " to your input text, + * which can improve retrieval performance for query-document matching tasks. + */ + @JsonProperty("prompt_name") + private String promptName; + + /** + * Whether to truncate input text that exceeds the model's maximum sequence length. + */ + @JsonProperty("truncate") + private Boolean truncate; + + /** + * Which side of the text to truncate when it exceeds the maximum length. Must be + * either "left" or "right". + *

+ * Only meaningful when truncate is set to true. + */ + @JsonProperty("truncation_direction") + private String truncationDirection; + + /** + * Create a new builder for HuggingfaceEmbeddingOptions. + * @return A new builder instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Create a copy from existing options. + * @param fromOptions The options to copy from. + * @return A new HuggingfaceEmbeddingOptions instance with copied values. + */ + public static HuggingfaceEmbeddingOptions fromOptions(HuggingfaceEmbeddingOptions fromOptions) { + return builder().model(fromOptions.getModel()) + .normalize(fromOptions.getNormalize()) + .promptName(fromOptions.getPromptName()) + .truncate(fromOptions.getTruncate()) + .truncationDirection(fromOptions.getTruncationDirection()) + .build(); + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + @JsonIgnore + public Integer getDimensions() { + return null; + } + + public Boolean getNormalize() { + return this.normalize; + } + + public void setNormalize(Boolean normalize) { + this.normalize = normalize; + } + + public String getPromptName() { + return this.promptName; + } + + public void setPromptName(String promptName) { + this.promptName = promptName; + } + + public Boolean getTruncate() { + return this.truncate; + } + + public void setTruncate(Boolean truncate) { + this.truncate = truncate; + } + + public String getTruncationDirection() { + return this.truncationDirection; + } + + public void setTruncationDirection(String truncationDirection) { + this.truncationDirection = truncationDirection; + } + + /** + * Create a copy of this options instance. + * @return A new copy of this options. + */ + public HuggingfaceEmbeddingOptions copy() { + return fromOptions(this); + } + + /** + * Convert the {@link HuggingfaceEmbeddingOptions} object to a {@link Map} of + * key/value pairs. + * @return the {@link Map} of key/value pairs + */ + public Map toMap() { + return ModelOptionsUtils.objectToMap(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + HuggingfaceEmbeddingOptions that = (HuggingfaceEmbeddingOptions) o; + return Objects.equals(this.model, that.model) && Objects.equals(this.normalize, that.normalize) + && Objects.equals(this.promptName, that.promptName) && Objects.equals(this.truncate, that.truncate) + && Objects.equals(this.truncationDirection, that.truncationDirection); + } + + @Override + public int hashCode() { + return Objects.hash(this.model, this.normalize, this.promptName, this.truncate, this.truncationDirection); + } + + @Override + public String toString() { + return "HuggingfaceEmbeddingOptions{" + "model='" + this.model + '\'' + ", normalize=" + this.normalize + + ", promptName='" + this.promptName + '\'' + ", truncate=" + this.truncate + ", truncationDirection='" + + this.truncationDirection + '\'' + '}'; + } + + /** + * Builder for HuggingfaceEmbeddingOptions. + */ + public static final class Builder { + + private final HuggingfaceEmbeddingOptions options = new HuggingfaceEmbeddingOptions(); + + private Builder() { + } + + /** + * Set the model name. + * @param model The model name. + * @return This builder. + */ + public Builder model(String model) { + this.options.model = model; + return this; + } + + /** + * Set whether to normalize the embedding vectors. + * @param normalize True to normalize, false otherwise. + * @return This builder. + */ + public Builder normalize(Boolean normalize) { + this.options.normalize = normalize; + return this; + } + + /** + * Set the name of a predefined prompt to apply to the input text. + * @param promptName The prompt name from the model configuration. + * @return This builder. + */ + public Builder promptName(String promptName) { + this.options.promptName = promptName; + return this; + } + + /** + * Set whether to truncate input text that exceeds the model's maximum length. + * @param truncate True to truncate, false otherwise. + * @return This builder. + */ + public Builder truncate(Boolean truncate) { + this.options.truncate = truncate; + return this; + } + + /** + * Set which side of the text to truncate when it exceeds the maximum length. + * @param truncationDirection Either "left" or "right" (case-sensitive). + * @return This builder. + */ + public Builder truncationDirection(String truncationDirection) { + this.options.truncationDirection = truncationDirection; + return this; + } + + /** + * Build the HuggingfaceEmbeddingOptions instance. + * @return A new HuggingfaceEmbeddingOptions instance. + */ + public HuggingfaceEmbeddingOptions build() { + return this.options; + } + + } + +} diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/aot/HuggingfaceRuntimeHints.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/aot/HuggingfaceRuntimeHints.java new file mode 100644 index 00000000000..77b62cef310 --- /dev/null +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/aot/HuggingfaceRuntimeHints.java @@ -0,0 +1,43 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface.aot; + +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; + +/** + * Runtime hints for HuggingFace integration to support GraalVM Native Image compilation. + * Registers reflection hints for all JSON-annotated classes in the HuggingFace package. + * + * @author Myeongdeok Kang + */ +public class HuggingfaceRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { + var mcs = MemberCategory.values(); + for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.huggingface")) { + hints.reflection().registerType(tr, mcs); + } + } + +} diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/api/HuggingfaceApi.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/api/HuggingfaceApi.java new file mode 100644 index 00000000000..1170424bcd6 --- /dev/null +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/api/HuggingfaceApi.java @@ -0,0 +1,606 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface.api; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import com.fasterxml.jackson.annotation.JsonFormat; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.huggingface.api.common.HuggingfaceApiConstants; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +/** + * Java Client for the HuggingFace Inference API. Supports both Chat and Embedding + * endpoints. HuggingFace + * Inference API + * + * @author Myeongdeok Kang + */ +public final class HuggingfaceApi { + + // API Endpoint Paths + public static final String CHAT_COMPLETIONS_PATH = "/chat/completions"; + + public static final String EMBEDDING_PATH_TEMPLATE = "/%s/pipeline/feature-extraction"; + + // Default Models + public static final String DEFAULT_CHAT_MODEL = "meta-llama/Llama-3.2-3B-Instruct"; + + public static final String DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"; + + private static final String REQUEST_BODY_NULL_ERROR = "The request body cannot be null."; + + private final RestClient restClient; + + /** + * Create a new HuggingfaceApi instance. + * @param baseUrl The base URL of the HuggingFace Inference API. + * @param apiKey The HuggingFace API key for authentication. + * @param restClientBuilder The {@link RestClient.Builder} to use. + * @param responseErrorHandler Response error handler. + */ + private HuggingfaceApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder, + ResponseErrorHandler responseErrorHandler) { + Assert.hasText(baseUrl, "baseUrl must not be empty"); + Assert.hasText(apiKey, "apiKey must not be empty"); + + Consumer defaultHeaders = headers -> { + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setAccept(List.of(MediaType.APPLICATION_JSON)); + headers.setBearerAuth(apiKey); + }; + + RestClient.Builder builder = restClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(defaultHeaders); + + if (responseErrorHandler != null) { + builder.defaultStatusHandler(responseErrorHandler); + } + + this.restClient = builder.build(); + } + + /** + * Generate chat completion using the specified model (OpenAI-compatible endpoint). + * Supports parameters from the Chat Completion API specification: + * https://huggingface.co/docs/inference-providers/tasks/chat-completion + * @param chatRequest Chat request containing the model, messages, and optional + * parameters (temperature, max_tokens, top_p, frequency_penalty, presence_penalty, + * stop, seed, response_format, tool_prompt, logprobs, top_logprobs, etc.) + * @return Chat response containing the generated text. + */ + public ChatResponse chat(ChatRequest chatRequest) { + Assert.notNull(chatRequest, REQUEST_BODY_NULL_ERROR); + Assert.hasText(chatRequest.model(), "Model must not be empty"); + Assert.notEmpty(chatRequest.messages(), "Messages must not be empty"); + + // OpenAI-compatible chat completions endpoint + ResponseEntity responseEntity = this.restClient.post() + .uri(CHAT_COMPLETIONS_PATH) + .body(chatRequest) + .retrieve() + .toEntity(ChatResponse.class); + + ChatResponse response = responseEntity.getBody(); + if (response == null) { + throw new IllegalStateException("No response returned from HuggingFace API"); + } + + return response; + } + + /** + * Generate embeddings from a model using the Feature Extraction pipeline. + * @param embeddingsRequest Embedding request containing the model and inputs. + * @return Embeddings response containing the generated embeddings. + */ + public EmbeddingsResponse embeddings(EmbeddingsRequest embeddingsRequest) { + Assert.notNull(embeddingsRequest, REQUEST_BODY_NULL_ERROR); + Assert.hasText(embeddingsRequest.model(), "Model must not be empty"); + Assert.notEmpty(embeddingsRequest.inputs(), "Inputs must not be empty"); + + // HuggingFace Inference API endpoint for feature extraction + String uri = String.format(EMBEDDING_PATH_TEMPLATE, embeddingsRequest.model()); + + ResponseEntity responseEntity = this.restClient.post() + .uri(uri) + .body(new EmbeddingsRequestBody(embeddingsRequest.inputs(), embeddingsRequest.options())) + .retrieve() + .toEntity(float[][].class); + + float[][] embeddings = responseEntity.getBody(); + if (embeddings == null || embeddings.length == 0) { + throw new IllegalStateException("No embeddings returned from HuggingFace API"); + } + + // Convert float[][] to List for consistency with other implementations + return new EmbeddingsResponse(embeddingsRequest.model(), Arrays.asList(embeddings)); + } + + /** + * Create a new builder for HuggingfaceApi. + * @return A new builder instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating HuggingfaceApi instances. + */ + public static final class Builder { + + private String baseUrl = HuggingfaceApiConstants.DEFAULT_CHAT_BASE_URL; + + private String apiKey; + + private RestClient.Builder restClientBuilder = RestClient.builder(); + + private ResponseErrorHandler responseErrorHandler; + + private Builder() { + } + + /** + * Set the base URL for the HuggingFace Inference API. + * @param baseUrl The base URL. + * @return This builder. + */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** + * Set the API key for authentication. + * @param apiKey The HuggingFace API key. + * @return This builder. + */ + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + /** + * Set the RestClient.Builder to use. + * @param restClientBuilder The RestClient.Builder. + * @return This builder. + */ + public Builder restClientBuilder(RestClient.Builder restClientBuilder) { + this.restClientBuilder = restClientBuilder; + return this; + } + + /** + * Set the response error handler. + * @param responseErrorHandler The error handler. + * @return This builder. + */ + public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) { + this.responseErrorHandler = responseErrorHandler; + return this; + } + + /** + * Build the HuggingfaceApi instance. + * @return A new HuggingfaceApi instance. + */ + public HuggingfaceApi build() { + return new HuggingfaceApi(this.baseUrl, this.apiKey, this.restClientBuilder, this.responseErrorHandler); + } + + } + + /** + * Chat request for HuggingFace Inference API. + * + * @param model The name of the model to use for chat. + * @param messages The list of messages in the conversation. + * @param tools A list of tools the model may call. Currently, only functions are + * supported as a tool. + * @param toolChoice Controls which (if any) function is called by the model. + * @param options Additional options for the chat request (optional). + */ + @JsonInclude(Include.NON_NULL) + public record ChatRequest(@JsonProperty("model") String model, @JsonProperty("messages") List messages, + @JsonProperty("tools") List tools, @JsonProperty("tool_choice") Object toolChoice, + @JsonProperty("options") Map options) { + + /** + * Shortcut constructor without options. + * @param model The model name. + * @param messages The messages. + */ + public ChatRequest(String model, List messages) { + this(model, messages, null, null, null); + } + + /** + * Constructor with options but no tools. + * @param model The model name. + * @param messages The messages. + * @param options Additional options. + */ + public ChatRequest(String model, List messages, Map options) { + this(model, messages, null, null, options); + } + + /** + * Constructor with tools and tool choice. + * @param model The model name. + * @param messages The messages. + * @param tools The list of tools. + * @param toolChoice Controls which function is called. + */ + public ChatRequest(String model, List messages, List tools, Object toolChoice) { + this(model, messages, tools, toolChoice, null); + } + + /** + * Constructor with tools, tool choice, and additional options. + * @param model The model name. + * @param messages The messages. + * @param tools The list of tools. + * @param toolChoice Controls which function is called. + * @param options Additional options. + */ + public ChatRequest(String model, List messages, List tools, Object toolChoice, + Map options) { + this.model = model; + this.messages = messages; + this.tools = tools; + this.toolChoice = toolChoice; + this.options = options; + } + + } + + /** + * Chat message. + * + * @param role The role of the message sender (system, user, assistant, tool). + * @param content The content of the message. + * @param name An optional name for the participant. In case of function calling, the + * name is the function name that the message is responding to. + * @param toolCallId Tool call that this message is responding to. Only applicable for + * the tool role. + * @param toolCalls The tool calls generated by the model, such as function calls. + * Applicable only for assistant role. + */ + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Message(@JsonProperty("role") String role, @JsonProperty("content") String content, + @JsonProperty("name") String name, @JsonProperty("tool_call_id") String toolCallId, + @JsonProperty("tool_calls") @JsonFormat( + with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List toolCalls) { + + /** + * Create a simple message with role and content. + * @param role The role of the message sender. + * @param content The content of the message. + */ + public Message(String role, String content) { + this(role, content, null, null, null); + } + + /** + * Create a tool response message. + * @param content The content of the tool response. + * @param role The role (should be "tool"). + * @param name The function name. + * @param toolCallId The tool call ID this message responds to. + */ + public Message(String content, String role, String name, String toolCallId) { + this(role, content, name, toolCallId, null); + } + + /** + * Create an assistant message with tool calls. + * @param role The role (should be "assistant"). + * @param content The content of the message. + * @param toolCalls The tool calls generated by the model. + */ + public Message(String role, String content, List toolCalls) { + this(role, content, null, null, toolCalls); + } + + } + + /** + * Chat response from HuggingFace Inference API (OpenAI-compatible). + * + * @param id Unique identifier for the chat completion. + * @param object Object type, always "chat.completion". + * @param created Unix timestamp of when the chat completion was created. + * @param model The model used for generating the response. + * @param choices The list of generated choices. + * @param usage Token usage information (optional). + * @param systemFingerprint System fingerprint (optional). + */ + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ChatResponse(@JsonProperty("id") String id, @JsonProperty("object") String object, + @JsonProperty("created") Long created, @JsonProperty("model") String model, + @JsonProperty("choices") List choices, @JsonProperty("usage") Usage usage, + @JsonProperty("system_fingerprint") String systemFingerprint) { + } + + /** + * A chat completion choice. + * + * @param index The index of the choice. + * @param message The generated message. + * @param finishReason The reason the generation stopped. + */ + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Choice(@JsonProperty("index") Integer index, @JsonProperty("message") Message message, + @JsonProperty("finish_reason") String finishReason) { + } + + /** + * Token usage information. + * + * @param promptTokens Number of tokens in the prompt. + * @param completionTokens Number of tokens in the completion. + * @param totalTokens Total number of tokens. + */ + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Usage(@JsonProperty("prompt_tokens") Integer promptTokens, + @JsonProperty("completion_tokens") Integer completionTokens, + @JsonProperty("total_tokens") Integer totalTokens) { + } + + /** + * Embedding request for HuggingFace Inference API. + * + * @param model The name of the model to use for embeddings (e.g., + * "sentence-transformers/all-MiniLM-L6-v2"). + * @param inputs The list of text inputs to generate embeddings for. + * @param options Additional options for the embedding request (optional). + */ + @JsonInclude(Include.NON_NULL) + public record EmbeddingsRequest(@JsonProperty("model") String model, @JsonProperty("inputs") List inputs, + @JsonProperty("options") Map options) { + + /** + * Shortcut constructor without options. + * @param model The model name. + * @param inputs The text inputs. + */ + public EmbeddingsRequest(String model, List inputs) { + this(model, inputs, null); + } + + } + + /** + * Internal request body sent to the HuggingFace API for embeddings. The API doesn't + * expect a "model" field in the body since it's in the URL path. + *

+ * Options are flattened at the top level alongside inputs, as per the HuggingFace API + * specification. Example request body:

+	 * {
+	 *   "inputs": ["text1", "text2"],
+	 *   "normalize": true,
+	 *   "dimensions": 256,
+	 *   "prompt_name": "query"
+	 * }
+	 * 
+ * + * @param inputs The text inputs. + * @param options Additional options (dimensions, normalize, prompt_name, etc.) that + * get flattened at the top level. + */ + @JsonInclude(Include.NON_NULL) + record EmbeddingsRequestBody(@JsonProperty("inputs") List inputs, + @com.fasterxml.jackson.annotation.JsonUnwrapped Map options) { + } + + /** + * Embedding response from HuggingFace Inference API. + * + * @param model The model used for generating embeddings. + * @param embeddings The generated embeddings as a list of float arrays. Each array + * represents one input's embedding vector. + */ + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record EmbeddingsResponse(@JsonProperty("model") String model, + @JsonProperty("embeddings") List embeddings) { + } + + /** + * Represents a tool the model may call. Currently, only functions are supported as a + * tool. + */ + @JsonInclude(Include.NON_NULL) + public static class FunctionTool { + + /** + * The type of the tool. Currently, only 'function' is supported. + */ + @JsonProperty("type") + private Type type = Type.FUNCTION; + + /** + * The function definition. + */ + @JsonProperty("function") + private Function function; + + public FunctionTool() { + + } + + /** + * Create a tool of type 'function' and the given function definition. + * @param type the tool type + * @param function function definition + */ + public FunctionTool(Type type, Function function) { + this.type = type; + this.function = function; + } + + /** + * Create a tool of type 'function' and the given function definition. + * @param function function definition. + */ + public FunctionTool(Function function) { + this(Type.FUNCTION, function); + } + + public Type getType() { + return this.type; + } + + public Function getFunction() { + return this.function; + } + + public void setType(Type type) { + this.type = type; + } + + public void setFunction(Function function) { + this.function = function; + } + + /** + * Create a tool of type 'function' and the given function definition. + */ + public enum Type { + + /** + * Function tool type. + */ + @JsonProperty("function") + FUNCTION + + } + + /** + * Function definition. + */ + @JsonInclude(Include.NON_NULL) + public static class Function { + + @JsonProperty("description") + private String description; + + @JsonProperty("name") + private String name; + + @JsonProperty("parameters") + private Map parameters; + + /** + * NOTE: Required by Jackson, JSON deserialization! + */ + @SuppressWarnings("unused") + private Function() { + } + + /** + * Create tool function definition. + * @param description A description of what the function does, used by the + * model to choose when and how to call the function. + * @param name The name of the function to be called. Must be a-z, A-Z, 0-9, + * or contain underscores and dashes, with a maximum length of 64. + * @param parameters The parameters the functions accepts, described as a JSON + * Schema object. + */ + public Function(String description, String name, Map parameters) { + this.description = description; + this.name = name; + this.parameters = parameters; + } + + public String getDescription() { + return this.description; + } + + public String getName() { + return this.name; + } + + public Map getParameters() { + return this.parameters; + } + + public void setDescription(String description) { + this.description = description; + } + + public void setName(String name) { + this.name = name; + } + + public void setParameters(Map parameters) { + this.parameters = parameters; + } + + } + + } + + /** + * The relevant tool call. + * + * @param index The index of the tool call. + * @param id The ID of the tool call. + * @param type The type of tool call the output is required for. For now, this is + * always function. + * @param function The function definition. + */ + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ToolCall(@JsonProperty("index") Integer index, @JsonProperty("id") String id, + @JsonProperty("type") String type, @JsonProperty("function") ChatCompletionFunction function) { + + public ToolCall(String id, String type, ChatCompletionFunction function) { + this(null, id, type, function); + } + + } + + /** + * The function definition. + * + * @param name The name of the function. + * @param arguments The arguments that the model expects you to pass to the function. + */ + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ChatCompletionFunction(@JsonProperty("name") String name, + @JsonProperty("arguments") String arguments) { + } + +} diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/api/common/HuggingfaceApiConstants.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/api/common/HuggingfaceApiConstants.java new file mode 100644 index 00000000000..4de499dc158 --- /dev/null +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/api/common/HuggingfaceApiConstants.java @@ -0,0 +1,47 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface.api.common; + +import org.springframework.ai.observation.conventions.AiProvider; + +/** + * Common value constants for HuggingFace API. + * + * @author Myeongdeok Kang + */ +public final class HuggingfaceApiConstants { + + /** + * Default base URL for HuggingFace Chat API (OpenAI-compatible endpoint). + */ + public static final String DEFAULT_CHAT_BASE_URL = "https://router.huggingface.co/v1"; + + /** + * Default base URL for HuggingFace Embedding API (Feature Extraction endpoint). + */ + public static final String DEFAULT_EMBEDDING_BASE_URL = "https://router.huggingface.co/hf-inference/models"; + + /** + * Provider name for observation and metrics. + */ + public static final String PROVIDER_NAME = AiProvider.HUGGINGFACE.value(); + + private HuggingfaceApiConstants() { + // Prevent instantiation + } + +} diff --git a/models/spring-ai-huggingface/src/main/resources/META-INF/spring/aot.factories b/models/spring-ai-huggingface/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..81e54fd59ce --- /dev/null +++ b/models/spring-ai-huggingface/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ + org.springframework.ai.huggingface.aot.HuggingfaceRuntimeHints diff --git a/models/spring-ai-huggingface/src/main/resources/openapi.json b/models/spring-ai-huggingface/src/main/resources/openapi.json deleted file mode 100644 index 82a9e59eb41..00000000000 --- a/models/spring-ai-huggingface/src/main/resources/openapi.json +++ /dev/null @@ -1,852 +0,0 @@ -{ - "openapi": "3.0.3", - "info": { - "title": "Text Generation Inference", - "description": "Text Generation Webserver", - "contact": { - "name": "Olivier Dehaene" - }, - "license": { - "name": "Apache 2.0", - "url": "https://www.apache.org/licenses/LICENSE-2.0" - }, - "version": "1.0.2" - }, - "paths": { - "/": { - "post": { - "tags": [ - "Text Generation Inference" - ], - "summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`", - "description": "Generate tokens if `stream == false` or a stream of token if `stream == true`", - "operationId": "compat_generate", - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/CompatGenerateRequest" - } - } - }, - "required": true - }, - "responses": { - "200": { - "description": "Generated Text", - "content": { - "application/json": { - "schema": { - "type": "array", - "items": { - "$ref": "#/components/schemas/GenerateResponse" - } - } - }, - "text/event-stream": { - "schema": { - "$ref": "#/components/schemas/StreamResponse" - } - } - } - }, - "422": { - "description": "Input validation error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ErrorResponse" - }, - "example": { - "error": "Input validation error" - } - } - } - }, - "424": { - "description": "Generation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ErrorResponse" - }, - "example": { - "error": "Request failed during generation" - } - } - } - }, - "429": { - "description": "Model is overloaded", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ErrorResponse" - }, - "example": { - "error": "Model is overloaded" - } - } - } - }, - "500": { - "description": "Incomplete generation", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ErrorResponse" - }, - "example": { - "error": "Incomplete generation" - } - } - } - } - } - } - }, - "/generate": { - "post": { - "tags": [ - "Text Generation Inference" - ], - "summary": "Generate tokens", - "description": "Generate tokens", - "operationId": "generate", - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/GenerateRequest" - } - } - }, - "required": true - }, - "responses": { - "200": { - "description": "Generated Text", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/GenerateResponse" - } - } - } - }, - "422": { - "description": "Input validation error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ErrorResponse" - }, - "example": { - "error": "Input validation error" - } - } - } - }, - "424": { - "description": "Generation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ErrorResponse" - }, - "example": { - "error": "Request failed during generation" - } - } - } - }, - "429": { - "description": "Model is overloaded", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ErrorResponse" - }, - "example": { - "error": "Model is overloaded" - } - } - } - }, - "500": { - "description": "Incomplete generation", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ErrorResponse" - }, - "example": { - "error": "Incomplete generation" - } - } - } - } - } - } - }, - "/generate_stream": { - "post": { - "tags": [ - "Text Generation Inference" - ], - "summary": "Generate a stream of token using Server-Sent Events", - "description": "Generate a stream of token using Server-Sent Events", - "operationId": "generate_stream", - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/GenerateRequest" - } - } - }, - "required": true - }, - "responses": { - "200": { - "description": "Generated Text", - "content": { - "text/event-stream": { - "schema": { - "$ref": "#/components/schemas/StreamResponse" - } - } - } - }, - "422": { - "description": "Input validation error", - "content": { - "text/event-stream": { - "schema": { - "$ref": "#/components/schemas/ErrorResponse" - }, - "example": { - "error": "Input validation error" - } - } - } - }, - "424": { - "description": "Generation Error", - "content": { - "text/event-stream": { - "schema": { - "$ref": "#/components/schemas/ErrorResponse" - }, - "example": { - "error": "Request failed during generation" - } - } - } - }, - "429": { - "description": "Model is overloaded", - "content": { - "text/event-stream": { - "schema": { - "$ref": "#/components/schemas/ErrorResponse" - }, - "example": { - "error": "Model is overloaded" - } - } - } - }, - "500": { - "description": "Incomplete generation", - "content": { - "text/event-stream": { - "schema": { - "$ref": "#/components/schemas/ErrorResponse" - }, - "example": { - "error": "Incomplete generation" - } - } - } - } - } - } - }, - "/health": { - "get": { - "tags": [ - "Text Generation Inference" - ], - "summary": "Health check method", - "description": "Health check method", - "operationId": "health", - "responses": { - "200": { - "description": "Everything is working fine" - }, - "503": { - "description": "Text generation inference is down", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ErrorResponse" - }, - "example": { - "error": "unhealthy", - "error_type": "healthcheck" - } - } - } - } - } - } - }, - "/info": { - "get": { - "tags": [ - "Text Generation Inference" - ], - "summary": "Text Generation Inference endpoint info", - "description": "Text Generation Inference endpoint info", - "operationId": "get_model_info", - "responses": { - "200": { - "description": "Served model info", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/Info" - } - } - } - } - } - } - }, - "/metrics": { - "get": { - "tags": [ - "Text Generation Inference" - ], - "summary": "Prometheus metrics scrape endpoint", - "description": "Prometheus metrics scrape endpoint", - "operationId": "metrics", - "responses": { - "200": { - "description": "Prometheus Metrics", - "content": { - "text/plain": { - "schema": { - "type": "string" - } - } - } - } - } - } - } - }, - "components": { - "schemas": { - "BestOfSequence": { - "type": "object", - "required": [ - "generated_text", - "finish_reason", - "generated_tokens", - "prefill", - "tokens" - ], - "properties": { - "finish_reason": { - "$ref": "#/components/schemas/FinishReason" - }, - "generated_text": { - "type": "string", - "example": "test" - }, - "generated_tokens": { - "type": "integer", - "format": "int32", - "example": 1, - "minimum": 0.0 - }, - "prefill": { - "type": "array", - "items": { - "$ref": "#/components/schemas/PrefillToken" - } - }, - "seed": { - "type": "integer", - "format": "int64", - "example": 42, - "nullable": true, - "minimum": 0.0 - }, - "tokens": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Token" - } - } - } - }, - "CompatGenerateRequest": { - "type": "object", - "required": [ - "inputs" - ], - "properties": { - "inputs": { - "type": "string", - "example": "My name is Olivier and I" - }, - "parameters": { - "$ref": "#/components/schemas/GenerateParameters" - }, - "stream": { - "type": "boolean", - "default": "false" - } - } - }, - "Details": { - "type": "object", - "required": [ - "finish_reason", - "generated_tokens", - "prefill", - "tokens" - ], - "properties": { - "best_of_sequences": { - "type": "array", - "items": { - "$ref": "#/components/schemas/BestOfSequence" - }, - "nullable": true - }, - "finish_reason": { - "$ref": "#/components/schemas/FinishReason" - }, - "generated_tokens": { - "type": "integer", - "format": "int32", - "example": 1, - "minimum": 0.0 - }, - "prefill": { - "type": "array", - "items": { - "$ref": "#/components/schemas/PrefillToken" - } - }, - "seed": { - "type": "integer", - "format": "int64", - "example": 42, - "nullable": true, - "minimum": 0.0 - }, - "tokens": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Token" - } - } - } - }, - "ErrorResponse": { - "type": "object", - "required": [ - "error", - "error_type" - ], - "properties": { - "error": { - "type": "string" - }, - "error_type": { - "type": "string" - } - } - }, - "FinishReason": { - "type": "string", - "enum": [ - "length", - "eos_token", - "stop_sequence" - ] - }, - "GenerateParameters": { - "type": "object", - "properties": { - "best_of": { - "type": "integer", - "default": "null", - "example": 1, - "nullable": true, - "minimum": 0.0, - "exclusiveMinimum": 0.0 - }, - "decoder_input_details": { - "type": "boolean", - "default": "true" - }, - "details": { - "type": "boolean", - "default": "true" - }, - "do_sample": { - "type": "boolean", - "default": "false", - "example": true - }, - "max_new_tokens": { - "type": "integer", - "format": "int32", - "default": "20", - "minimum": 0.0, - "exclusiveMaximum": 512.0, - "exclusiveMinimum": 0.0 - }, - "repetition_penalty": { - "type": "number", - "format": "float", - "default": "null", - "example": 1.03, - "nullable": true, - "exclusiveMinimum": 0.0 - }, - "return_full_text": { - "type": "boolean", - "default": "null", - "example": false, - "nullable": true - }, - "seed": { - "type": "integer", - "format": "int64", - "default": "null", - "example": "null", - "nullable": true, - "minimum": 0.0, - "exclusiveMinimum": 0.0 - }, - "stop": { - "type": "array", - "items": { - "type": "string" - }, - "example": [ - "photographer" - ], - "maxItems": 4 - }, - "temperature": { - "type": "number", - "format": "float", - "default": "null", - "example": 0.5, - "nullable": true, - "exclusiveMinimum": 0.0 - }, - "top_k": { - "type": "integer", - "format": "int32", - "default": "null", - "example": 10, - "nullable": true, - "exclusiveMinimum": 0.0 - }, - "top_p": { - "type": "number", - "format": "float", - "default": "null", - "example": 0.95, - "nullable": true, - "maximum": 1.0, - "exclusiveMinimum": 0.0 - }, - "truncate": { - "type": "integer", - "default": "null", - "example": "null", - "nullable": true, - "minimum": 0.0 - }, - "typical_p": { - "type": "number", - "format": "float", - "default": "null", - "example": 0.95, - "nullable": true, - "maximum": 1.0, - "exclusiveMinimum": 0.0 - }, - "watermark": { - "type": "boolean", - "default": "false", - "example": true - } - } - }, - "GenerateRequest": { - "type": "object", - "required": [ - "inputs" - ], - "properties": { - "inputs": { - "type": "string", - "example": "My name is Olivier and I" - }, - "parameters": { - "$ref": "#/components/schemas/GenerateParameters" - } - } - }, - "GenerateResponse": { - "type": "object", - "required": [ - "generated_text" - ], - "properties": { - "details": { - "allOf": [ - { - "$ref": "#/components/schemas/Details" - } - ], - "nullable": true - }, - "generated_text": { - "type": "string", - "example": "test" - } - } - }, - "Info": { - "type": "object", - "required": [ - "model_id", - "model_dtype", - "model_device_type", - "max_concurrent_requests", - "max_best_of", - "max_stop_sequences", - "max_input_length", - "max_total_tokens", - "waiting_served_ratio", - "max_batch_total_tokens", - "max_waiting_tokens", - "validation_workers", - "version" - ], - "properties": { - "docker_label": { - "type": "string", - "example": "null", - "nullable": true - }, - "max_batch_total_tokens": { - "type": "integer", - "format": "int32", - "example": "32000", - "minimum": 0.0 - }, - "max_best_of": { - "type": "integer", - "example": "2", - "minimum": 0.0 - }, - "max_concurrent_requests": { - "type": "integer", - "description": "Router Parameters", - "example": "128", - "minimum": 0.0 - }, - "max_input_length": { - "type": "integer", - "example": "1024", - "minimum": 0.0 - }, - "max_stop_sequences": { - "type": "integer", - "example": "4", - "minimum": 0.0 - }, - "max_total_tokens": { - "type": "integer", - "example": "2048", - "minimum": 0.0 - }, - "max_waiting_tokens": { - "type": "integer", - "example": "20", - "minimum": 0.0 - }, - "model_device_type": { - "type": "string", - "example": "cuda" - }, - "model_dtype": { - "type": "string", - "example": "torch.float16" - }, - "model_id": { - "type": "string", - "description": "Model info", - "example": "bigscience/blomm-560m" - }, - "model_pipeline_tag": { - "type": "string", - "example": "text-generation", - "nullable": true - }, - "model_sha": { - "type": "string", - "example": "e985a63cdc139290c5f700ff1929f0b5942cced2", - "nullable": true - }, - "sha": { - "type": "string", - "example": "null", - "nullable": true - }, - "validation_workers": { - "type": "integer", - "example": "2", - "minimum": 0.0 - }, - "version": { - "type": "string", - "description": "Router Info", - "example": "0.5.0" - }, - "waiting_served_ratio": { - "type": "number", - "format": "float", - "example": "1.2" - } - } - }, - "PrefillToken": { - "type": "object", - "required": [ - "id", - "text", - "logprob" - ], - "properties": { - "id": { - "type": "integer", - "format": "int32", - "example": 0, - "minimum": 0.0 - }, - "logprob": { - "type": "number", - "format": "float", - "example": -0.34, - "nullable": true - }, - "text": { - "type": "string", - "example": "test" - } - } - }, - "StreamDetails": { - "type": "object", - "required": [ - "finish_reason", - "generated_tokens" - ], - "properties": { - "finish_reason": { - "$ref": "#/components/schemas/FinishReason" - }, - "generated_tokens": { - "type": "integer", - "format": "int32", - "example": 1, - "minimum": 0.0 - }, - "seed": { - "type": "integer", - "format": "int64", - "example": 42, - "nullable": true, - "minimum": 0.0 - } - } - }, - "StreamResponse": { - "type": "object", - "required": [ - "token" - ], - "properties": { - "details": { - "allOf": [ - { - "$ref": "#/components/schemas/StreamDetails" - } - ], - "nullable": true - }, - "generated_text": { - "type": "string", - "default": "null", - "example": "test", - "nullable": true - }, - "token": { - "$ref": "#/components/schemas/Token" - } - } - }, - "Token": { - "type": "object", - "required": [ - "id", - "text", - "logprob", - "special" - ], - "properties": { - "id": { - "type": "integer", - "format": "int32", - "example": 0, - "minimum": 0.0 - }, - "logprob": { - "type": "number", - "format": "float", - "example": -0.34, - "nullable": true - }, - "special": { - "type": "boolean", - "example": "false" - }, - "text": { - "type": "string", - "example": "test" - } - } - } - } - }, - "tags": [ - { - "name": "Text Generation Inference", - "description": "Hugging Face Text Generation Inference API" - } - ] -} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/BaseHuggingfaceIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/BaseHuggingfaceIT.java new file mode 100644 index 00000000000..9cd99f219a4 --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/BaseHuggingfaceIT.java @@ -0,0 +1,103 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.huggingface.api.common.HuggingfaceApiConstants; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +/** + * Base class for Huggingface integration tests. Provides shared configuration and common + * utilities for testing against the Huggingface API. + *

+ * Integration tests require a valid HuggingFace API key to be set in the + * HUGGINGFACE_API_KEY environment variable. + * + * @author Myeongdeok Kang + */ +@SpringBootTest(classes = BaseHuggingfaceIT.Config.class) +public abstract class BaseHuggingfaceIT { + + protected static final String DEFAULT_CHAT_MODEL = "meta-llama/Llama-3.2-3B-Instruct"; + + protected static final String DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"; + + /** + * Get the Huggingface API key from environment variable. + * @return the API key + * @throws IllegalStateException if the API key is not set + */ + protected static String getApiKey() { + String apiKey = System.getenv("HUGGINGFACE_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalStateException( + "HUGGINGFACE_API_KEY environment variable must be set for integration tests"); + } + return apiKey; + } + + /** + * Spring Boot configuration for Huggingface integration tests. + */ + @SpringBootConfiguration + static class Config { + + @Bean + public HuggingfaceApi huggingfaceChatApi() { + return HuggingfaceApi.builder() + .baseUrl(HuggingfaceApiConstants.DEFAULT_CHAT_BASE_URL) + .apiKey(getApiKey()) + .build(); + } + + @Bean + public HuggingfaceApi huggingfaceEmbeddingApi() { + return HuggingfaceApi.builder() + .baseUrl(HuggingfaceApiConstants.DEFAULT_EMBEDDING_BASE_URL) + .apiKey(getApiKey()) + .build(); + } + + @Bean + public HuggingfaceChatModel huggingfaceChatModel(HuggingfaceApi huggingfaceChatApi) { + return HuggingfaceChatModel.builder() + .huggingfaceApi(huggingfaceChatApi) + .defaultOptions(HuggingfaceChatOptions.builder().model(DEFAULT_CHAT_MODEL).temperature(0.7).build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .observationRegistry(ObservationRegistry.NOOP) + .build(); + } + + @Bean + public HuggingfaceEmbeddingModel huggingfaceEmbeddingModel(HuggingfaceApi huggingfaceEmbeddingApi) { + return HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(huggingfaceEmbeddingApi) + .defaultOptions(HuggingfaceEmbeddingOptions.builder().model(DEFAULT_EMBEDDING_MODEL).build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .observationRegistry(ObservationRegistry.NOOP) + .build(); + } + + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatModelFunctionCallingIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatModelFunctionCallingIT.java new file mode 100644 index 00000000000..1af77cfd331 --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatModelFunctionCallingIT.java @@ -0,0 +1,165 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.huggingface.api.tool.MockWeatherService; +import org.springframework.ai.huggingface.api.tool.MockWeatherService.Request; +import org.springframework.ai.huggingface.api.tool.MockWeatherService.Response; +import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.beans.factory.annotation.Autowired; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for HuggingFace ChatModel function calling capabilities. These tests + * verify that the high-level ChatModel API correctly handles automatic tool/function + * execution. + * + *

+ * Note: Function calling requires specific models and providers. This test uses + * meta-llama/Llama-3.2-3B-Instruct with the 'together' provider which supports function + * calling through the HuggingFace Inference API. + *

+ * + *

+ * Streaming Support: HuggingfaceChatModel currently does NOT implement + * StreamingChatModel, so streaming function calling tests are not included. Streaming + * support will be added in a future PR when WebClient integration is implemented. + *

+ * + * @author Myeongdeok Kang + * @see HuggingFace + * Function Calling Guide + * @see Function + * Calling Models Collection + */ +@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") +class HuggingfaceChatModelFunctionCallingIT extends BaseHuggingfaceIT { + + private static final Logger logger = LoggerFactory.getLogger(HuggingfaceChatModelFunctionCallingIT.class); + + // Use function-calling compatible model with provider specification + // Provider suffix notation (":together") is required for function calling support + private static final String FUNCTION_CALLING_MODEL = "meta-llama/Llama-3.2-3B-Instruct:together"; + + @Autowired + ChatModel chatModel; + + /** + * Test basic function calling with automatic tool execution. Verifies that: + *
    + *
  • Function callbacks are properly registered via HuggingfaceChatOptions
  • + *
  • The model correctly identifies when to call functions
  • + *
  • Functions are automatically executed by the framework
  • + *
  • Tool results are integrated into the final response
  • + *
+ */ + @Test + void functionCallTest() { + functionCallTest(HuggingfaceChatOptions.builder() + .model(FUNCTION_CALLING_MODEL) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location. Return temperature in Celsius.") + .inputType(MockWeatherService.Request.class) + .build())) + .build()); + } + + /** + * Test function calling with ToolContext support. Verifies that: + *
    + *
  • ToolContext can be passed to function callbacks
  • + *
  • BiFunction<Request, ToolContext, Response> signature works correctly
  • + *
  • Context values are accessible during function execution
  • + *
  • Context propagates correctly through the tool execution flow
  • + *
+ */ + @Test + void functionCallWithToolContextTest() { + + var biFunction = new BiFunction() { + + @Override + public Response apply(Request request, ToolContext toolContext) { + + // Verify ToolContext contains expected values + assertThat(toolContext.getContext()).containsEntry("sessionId", "123"); + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new MockWeatherService.Response(temperature, 15, 20, 2, 53, 45, MockWeatherService.Unit.C); + } + + }; + + functionCallTest(HuggingfaceChatOptions.builder() + .model(FUNCTION_CALLING_MODEL) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", biFunction) + .description("Get the weather in location. Return temperature in Celsius.") + .inputType(MockWeatherService.Request.class) + .build())) + .toolContext(Map.of("sessionId", "123")) + .build()); + } + + /** + * Common test logic for function calling scenarios. + * @param promptOptions The chat options including tool callbacks and context + */ + void functionCallTest(HuggingfaceChatOptions promptOptions) { + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + // Verify the response contains the expected temperatures from MockWeatherService + // San Francisco: 30C, Tokyo: 10C, Paris: 15C + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatModelIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatModelIT.java new file mode 100644 index 00000000000..80235b4193e --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatModelIT.java @@ -0,0 +1,325 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.ai.converter.ListOutputConverter; +import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.core.convert.support.DefaultConversionService; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link HuggingfaceChatModel}. These tests require a valid + * HuggingFace API key set in the HUGGINGFACE_API_KEY environment variable. + * + * @author Myeongdeok Kang + */ +@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") +class HuggingfaceChatModelIT extends BaseHuggingfaceIT { + + @Autowired + private HuggingfaceChatModel chatModel; + + @Test + void roleTest() { + Message systemMessage = new SystemPromptTemplate(""" + You are a helpful AI assistant. Your name is {name}. + You are an AI assistant that helps people find information. + Your name is {name} + You should reply to the user's request with your name and also in the style of a {voice}. + """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + + UserMessage userMessage = new UserMessage("Tell me about 3 famous pirates from the Golden Age of Piracy."); + + // portable/generic options + var portableOptions = ChatOptions.builder().temperature(0.7).build(); + + Prompt prompt = new Prompt(List.of(systemMessage, userMessage), portableOptions); + + ChatResponse response = this.chatModel.call(prompt); + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); + + // huggingface specific options + var huggingfaceOptions = HuggingfaceChatOptions.builder().temperature(0.8).maxTokens(200).build(); + + response = this.chatModel.call(new Prompt(List.of(systemMessage, userMessage), huggingfaceOptions)); + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); + } + + @Test + void testMessageHistory() { + Message systemMessage = new SystemPromptTemplate(""" + You are a helpful AI assistant. Your name is {name}. + You are an AI assistant that helps people find information. + Your name is {name} + You should reply to the user's request with your name and also in the style of a {voice}. + """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + + UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and why they were famous."); + + Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); + + ChatResponse response = this.chatModel.call(prompt); + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); + + var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Hello"), response.getResult().getOutput(), + new UserMessage("Tell me just the names of those pirates."))); + response = this.chatModel.call(promptWithMessageHistory); + + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); + } + + @Test + void simplePromptTest() { + Prompt prompt = new Prompt("Tell me a short joke about programming"); + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getResult()).isNotNull(); + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); + } + + @Test + void usageTest() { + Prompt prompt = new Prompt("Tell me a short joke"); + ChatResponse response = this.chatModel.call(prompt); + Usage usage = response.getMetadata().getUsage(); + + assertThat(usage).isNotNull(); + assertThat(usage.getPromptTokens()).isPositive(); + assertThat(usage.getCompletionTokens()).isPositive(); + assertThat(usage.getTotalTokens()).isPositive(); + assertThat(usage.getTotalTokens()).isEqualTo(usage.getPromptTokens() + usage.getCompletionTokens()); + } + + @Test + void listOutputConverter() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputConverter outputConverter = new ListOutputConverter(conversionService); + + String format = outputConverter.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = PromptTemplate.builder() + .template(template) + .variables(Map.of("subject", "ice cream flavors.", "format", format)) + .build(); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + List list = outputConverter.convert(generation.getOutput().getText()); + assertThat(list).hasSizeGreaterThanOrEqualTo(3); // At least 3 items + } + + @Test + void mapOutputConverter() { + MapOutputConverter outputConverter = new MapOutputConverter(); + + String format = outputConverter.getFormat(); + String template = """ + For each letter in the RGB color scheme, tell me what it stands for. + Example: R -> Red. + {format} + """; + PromptTemplate promptTemplate = PromptTemplate.builder() + .template(template) + .variables(Map.of("format", format)) + .build(); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + Generation generation = this.chatModel.call(prompt).getResult(); + + Map result = outputConverter.convert(generation.getOutput().getText()); + assertThat(result).isNotNull(); + assertThat(result).containsKeys("R", "G", "B"); + } + + @Test + void beanOutputConverterRecords() { + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 3 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = PromptTemplate.builder() + .template(template) + .variables(Map.of("format", format)) + .build(); + + // Set higher maxTokens and lower temperature to ensure complete JSON response + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder().maxTokens(1000).temperature(0.1).build(); + + Prompt prompt = new Prompt(promptTemplate.createMessage(), options); + Generation generation = this.chatModel.call(prompt).getResult(); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getText()); + assertThat(actorsFilms.actor()).containsIgnoringCase("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSizeGreaterThanOrEqualTo(3); + } + + @Test + void chatMemory() { + ChatMemory memory = MessageWindowChatMemory.builder().build(); + String conversationId = UUID.randomUUID().toString(); + + UserMessage userMessage1 = new UserMessage("My name is James Bond"); + memory.add(conversationId, userMessage1); + ChatResponse response1 = this.chatModel.call(new Prompt(memory.get(conversationId))); + + assertThat(response1).isNotNull(); + memory.add(conversationId, response1.getResult().getOutput()); + + UserMessage userMessage2 = new UserMessage("What is my name?"); + memory.add(conversationId, userMessage2); + ChatResponse response2 = this.chatModel.call(new Prompt(memory.get(conversationId))); + + assertThat(response2).isNotNull(); + memory.add(conversationId, response2.getResult().getOutput()); + + assertThat(response2.getResults()).hasSize(1); + assertThat(response2.getResult().getOutput().getText()).containsIgnoringCase("James Bond"); + } + + @Test + void chatClientSimplePrompt() { + String joke = ChatClient.create(this.chatModel).prompt("Tell me a joke about developers").call().content(); + + assertThat(joke).isNotEmpty(); + } + + @Test + void customOptionsTest() { + HuggingfaceChatOptions customOptions = HuggingfaceChatOptions.builder() + .model(DEFAULT_CHAT_MODEL) + .temperature(0.3) + .maxTokens(50) + .build(); + + Prompt prompt = new Prompt("Say 'Hello'", customOptions); + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); + } + + @Test + void multipleGenerationsTest() { + Prompt prompt = new Prompt("What is 2 + 2?"); + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getResults()).isNotEmpty(); + assertThat(response.getResult().getOutput().getText()).contains("4"); + } + + @Test + void testStopSequences() { + List stopSequences = Arrays.asList("STOP", "END"); + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model(DEFAULT_CHAT_MODEL) + .temperature(0.7) + .maxTokens(100) + .stopSequences(stopSequences) + .build(); + + Prompt prompt = new Prompt("Count from 1 to 10. When you see STOP, stop counting.", options); + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); + // The response should be limited by stop sequences + } + + @Test + void testSeedForReproducibility() { + int seed = 42; + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model(DEFAULT_CHAT_MODEL) + .temperature(0.7) + .seed(seed) + .build(); + + Prompt prompt = new Prompt("Tell me a random number between 1 and 100", options); + + // Call twice with the same seed + ChatResponse response1 = this.chatModel.call(prompt); + ChatResponse response2 = this.chatModel.call(prompt); + + assertThat(response1).isNotNull(); + assertThat(response2).isNotNull(); + assertThat(response1.getResult().getOutput().getText()).isNotEmpty(); + assertThat(response2.getResult().getOutput().getText()).isNotEmpty(); + // With the same seed, responses should be deterministic (same or very similar) + } + + @Test + void testResponseFormatJsonObject() { + Map responseFormat = new HashMap<>(); + responseFormat.put("type", "json_object"); + + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model(DEFAULT_CHAT_MODEL) + .temperature(0.7) + .maxTokens(200) + .responseFormat(responseFormat) + .build(); + + Prompt prompt = new Prompt( + "Generate a JSON object with fields: name (string), age (number), city (string). Make up the values.", + options); + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response).isNotNull(); + String output = response.getResult().getOutput().getText(); + assertThat(output).isNotEmpty(); + // The output should be valid JSON when response_format is json_object + assertThat(output).contains("{"); + assertThat(output).contains("}"); + } + + record ActorsFilmsRecord(String actor, List movies) { + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatModelObservationIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatModelObservationIT.java new file mode 100644 index 00000000000..47f6a2b9323 --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatModelObservationIT.java @@ -0,0 +1,138 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; +import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames; + +/** + * Integration tests for observation instrumentation in {@link HuggingfaceChatModel}. + * + * @author Myeongdeok Kang + */ +@SpringBootTest(classes = HuggingfaceChatModelObservationIT.Config.class) +@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") +class HuggingfaceChatModelObservationIT { + + private static final String MODEL = "meta-llama/Llama-3.2-3B-Instruct"; + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + HuggingfaceChatModel chatModel; + + @BeforeEach + void beforeEach() { + this.observationRegistry.clear(); + } + + @Test + void observationForChatOperation() { + var options = HuggingfaceChatOptions.builder() + .model(MODEL) + .frequencyPenalty(0.0) + .maxTokens(2048) + .presencePenalty(0.0) + .temperature(0.7) + .topP(1.0) + .build(); + + Prompt prompt = new Prompt("Why does a raven look like a desk?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + validate(responseMetadata); + } + + private void validate(ChatResponseMetadata responseMetadata) { + TestObservationRegistryAssert.assertThat(this.observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("chat " + MODEL) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.CHAT.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.HUGGINGFACE.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), MODEL) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), "0.0") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), "0.0") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TOP_P.asString(), "1.0") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getPromptTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getTotalTokens())) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public HuggingfaceApi huggingfaceApi() { + String apiKey = System.getenv("HUGGINGFACE_API_KEY"); + return HuggingfaceApi.builder().apiKey(apiKey).build(); + } + + @Bean + public HuggingfaceChatModel huggingfaceChatModel(HuggingfaceApi huggingfaceApi, + TestObservationRegistry observationRegistry) { + return HuggingfaceChatModel.builder() + .huggingfaceApi(huggingfaceApi) + .observationRegistry(observationRegistry) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .build(); + } + + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatModelTests.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatModelTests.java new file mode 100644 index 00000000000..b3358469322 --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatModelTests.java @@ -0,0 +1,222 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import io.micrometer.observation.ObservationRegistry; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; +import org.springframework.ai.retry.RetryUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link HuggingfaceChatModel}. + * + * @author Myeongdeok Kang + */ +@ExtendWith(MockitoExtension.class) +class HuggingfaceChatModelTests { + + @Mock + HuggingfaceApi huggingfaceApi; + + @Mock + ToolCallingManager toolCallingManager; + + @Mock + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + + @Test + void buildHuggingfaceChatModelWithConstructor() { + ChatModel chatModel = new HuggingfaceChatModel(this.huggingfaceApi, + HuggingfaceChatOptions.builder().model("meta-llama/Llama-3.2-3B-Instruct").build(), + this.toolCallingManager, ObservationRegistry.NOOP, RetryUtils.DEFAULT_RETRY_TEMPLATE, + this.toolExecutionEligibilityPredicate); + assertThat(chatModel).isNotNull(); + } + + @Test + void buildHuggingfaceChatModelWithBuilder() { + ChatModel chatModel = HuggingfaceChatModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .toolCallingManager(this.toolCallingManager) + .toolExecutionEligibilityPredicate(this.toolExecutionEligibilityPredicate) + .build(); + assertThat(chatModel).isNotNull(); + } + + @Test + void buildHuggingfaceChatModelWithNullApi() { + assertThatThrownBy(() -> HuggingfaceChatModel.builder() + .huggingfaceApi(null) + .toolCallingManager(this.toolCallingManager) + .toolExecutionEligibilityPredicate(this.toolExecutionEligibilityPredicate) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("huggingfaceApi must not be null"); + } + + @Test + void buildHuggingfaceChatModelWithAllBuilderOptions() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .temperature(0.7) + .maxTokens(100) + .topP(0.9) + .build(); + + ChatModel chatModel = HuggingfaceChatModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(options) + .toolCallingManager(this.toolCallingManager) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .observationRegistry(ObservationRegistry.NOOP) + .toolExecutionEligibilityPredicate(this.toolExecutionEligibilityPredicate) + .build(); + + assertThat(chatModel).isNotNull(); + assertThat(chatModel).isInstanceOf(HuggingfaceChatModel.class); + } + + @Test + void buildHuggingfaceChatModelWithCustomObservationRegistry() { + ObservationRegistry customRegistry = ObservationRegistry.create(); + + ChatModel chatModel = HuggingfaceChatModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .toolCallingManager(this.toolCallingManager) + .toolExecutionEligibilityPredicate(this.toolExecutionEligibilityPredicate) + .observationRegistry(customRegistry) + .build(); + + assertThat(chatModel).isNotNull(); + } + + @Test + void buildHuggingfaceChatModelImmutability() { + // Test that the builder creates immutable instances + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .temperature(0.5) + .build(); + + ChatModel chatModel1 = HuggingfaceChatModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(options) + .toolCallingManager(this.toolCallingManager) + .toolExecutionEligibilityPredicate(this.toolExecutionEligibilityPredicate) + .build(); + + ChatModel chatModel2 = HuggingfaceChatModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(options) + .toolCallingManager(this.toolCallingManager) + .toolExecutionEligibilityPredicate(this.toolExecutionEligibilityPredicate) + .build(); + + // Should create different instances + assertThat(chatModel1).isNotSameAs(chatModel2); + assertThat(chatModel1).isNotNull(); + assertThat(chatModel2).isNotNull(); + } + + @Test + void buildHuggingfaceChatModelWithMinimalConfiguration() { + // Test building with only required parameters + ChatModel chatModel = HuggingfaceChatModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .toolCallingManager(this.toolCallingManager) + .toolExecutionEligibilityPredicate(this.toolExecutionEligibilityPredicate) + .build(); + + assertThat(chatModel).isNotNull(); + assertThat(chatModel).isInstanceOf(HuggingfaceChatModel.class); + } + + @Test + void getDefaultOptionsReturnsCopy() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .temperature(0.7) + .build(); + + HuggingfaceChatModel chatModel = HuggingfaceChatModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(options) + .toolCallingManager(this.toolCallingManager) + .toolExecutionEligibilityPredicate(this.toolExecutionEligibilityPredicate) + .build(); + + HuggingfaceChatOptions retrievedOptions = (HuggingfaceChatOptions) chatModel.getDefaultOptions(); + assertThat(retrievedOptions).isNotNull(); + assertThat(retrievedOptions).isNotSameAs(options); + assertThat(retrievedOptions.getModel()).isEqualTo(options.getModel()); + assertThat(retrievedOptions.getTemperature()).isEqualTo(options.getTemperature()); + } + + @Test + void setObservationConventionValidation() { + HuggingfaceChatModel chatModel = HuggingfaceChatModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .toolCallingManager(this.toolCallingManager) + .toolExecutionEligibilityPredicate(this.toolExecutionEligibilityPredicate) + .build(); + + assertThatThrownBy(() -> chatModel.setObservationConvention(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("observationConvention cannot be null"); + } + + @Test + void buildHuggingfaceChatModelWithNullDefaultOptions() { + assertThatThrownBy(() -> new HuggingfaceChatModel(this.huggingfaceApi, null, this.toolCallingManager, + ObservationRegistry.NOOP, RetryUtils.DEFAULT_RETRY_TEMPLATE, this.toolExecutionEligibilityPredicate)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("defaultOptions must not be null"); + } + + @Test + void buildHuggingfaceChatModelWithNullObservationRegistry() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .build(); + + assertThatThrownBy(() -> new HuggingfaceChatModel(this.huggingfaceApi, options, this.toolCallingManager, null, + RetryUtils.DEFAULT_RETRY_TEMPLATE, this.toolExecutionEligibilityPredicate)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("observationRegistry must not be null"); + } + + @Test + void buildHuggingfaceChatModelWithNullRetryTemplate() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .build(); + + assertThatThrownBy(() -> new HuggingfaceChatModel(this.huggingfaceApi, options, this.toolCallingManager, + ObservationRegistry.NOOP, null, this.toolExecutionEligibilityPredicate)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("retryTemplate must not be null"); + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatOptionsTests.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatOptionsTests.java new file mode 100644 index 00000000000..495b5810f39 --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatOptionsTests.java @@ -0,0 +1,372 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link HuggingfaceChatOptions}. + * + * @author Myeongdeok Kang + */ +class HuggingfaceChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .temperature(0.7) + .maxTokens(100) + .topP(0.9) + .frequencyPenalty(0.5) + .presencePenalty(0.8) + .build(); + + assertThat(options) + .extracting("model", "temperature", "maxTokens", "topP", "frequencyPenalty", "presencePenalty") + .containsExactly("meta-llama/Llama-3.2-3B-Instruct", 0.7, 100, 0.9, 0.5, 0.8); + } + + @Test + void testCopy() { + HuggingfaceChatOptions originalOptions = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .temperature(0.7) + .maxTokens(100) + .build(); + + HuggingfaceChatOptions copiedOptions = originalOptions.copy(); + + assertThat(copiedOptions).isNotSameAs(originalOptions).isEqualTo(originalOptions); + } + + @Test + void testSetters() { + HuggingfaceChatOptions options = new HuggingfaceChatOptions(); + + options.setModel("test-model"); + options.setTemperature(0.5); + options.setMaxTokens(50); + options.setTopP(0.8); + options.setFrequencyPenalty(0.3); + options.setPresencePenalty(0.6); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getTemperature()).isEqualTo(0.5); + assertThat(options.getMaxTokens()).isEqualTo(50); + assertThat(options.getTopP()).isEqualTo(0.8); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.3); + assertThat(options.getPresencePenalty()).isEqualTo(0.6); + } + + @Test + void testDefaultValues() { + HuggingfaceChatOptions options = new HuggingfaceChatOptions(); + + assertThat(options.getModel()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + } + + @Test + void testFromOptions() { + HuggingfaceChatOptions originalOptions = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .temperature(0.7) + .build(); + + HuggingfaceChatOptions copiedOptions = HuggingfaceChatOptions.fromOptions(originalOptions); + + assertThat(copiedOptions).isNotSameAs(originalOptions).isEqualTo(originalOptions); + } + + @Test + void testToMap() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("test-model") + .temperature(0.7) + .maxTokens(100) + .topP(0.9) + .build(); + + Map map = options.toMap(); + + assertThat(map).containsEntry("model", "test-model") + .containsEntry("temperature", 0.7) + .containsEntry("max_tokens", 100) + .containsEntry("top_p", 0.9); + } + + @Test + void testEqualsAndHashCode() { + HuggingfaceChatOptions options1 = HuggingfaceChatOptions.builder() + .model("test-model") + .temperature(0.7) + .maxTokens(100) + .build(); + + HuggingfaceChatOptions options2 = HuggingfaceChatOptions.builder() + .model("test-model") + .temperature(0.7) + .maxTokens(100) + .build(); + + assertThat(options1).isEqualTo(options2); + assertThat(options1.hashCode()).isEqualTo(options2.hashCode()); + } + + @Test + void testBuilderWithNullValues() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model(null) + .temperature(null) + .maxTokens(null) + .build(); + + assertThat(options.getModel()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + } + + @Test + void testBuilderChaining() { + HuggingfaceChatOptions.Builder builder = HuggingfaceChatOptions.builder(); + + HuggingfaceChatOptions.Builder result = builder.model("test-model").temperature(0.7).maxTokens(100); + + assertThat(result).isSameAs(builder); + } + + @Test + void testCopyChangeIndependence() { + HuggingfaceChatOptions originalOptions = HuggingfaceChatOptions.builder() + .model("original-model") + .temperature(0.5) + .build(); + + HuggingfaceChatOptions copiedOptions = originalOptions.copy(); + + // Modify original + originalOptions.setTemperature(0.9); + + // Copy should retain original values + assertThat(copiedOptions.getTemperature()).isEqualTo(0.5); + assertThat(originalOptions.getTemperature()).isEqualTo(0.9); + } + + @Test + void testStopSequences() { + List stopSequences = Arrays.asList("STOP", "END"); + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("test-model") + .stopSequences(stopSequences) + .build(); + + assertThat(options.getStopSequences()).isEqualTo(stopSequences); + } + + @Test + void testSeed() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder().model("test-model").seed(12345).build(); + + assertThat(options.getSeed()).isEqualTo(12345); + } + + @Test + void testResponseFormat() { + Map responseFormat = new HashMap<>(); + responseFormat.put("type", "json_object"); + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("test-model") + .responseFormat(responseFormat) + .build(); + + assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + } + + @Test + void testToolPrompt() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("test-model") + .toolPrompt("You have access to the following tools:") + .build(); + + assertThat(options.getToolPrompt()).isEqualTo("You have access to the following tools:"); + } + + @Test + void testLogprobs() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder().model("test-model").logprobs(true).build(); + + assertThat(options.getLogprobs()).isTrue(); + } + + @Test + void testTopLogprobs() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("test-model") + .logprobs(true) + .topLogprobs(5) + .build(); + + assertThat(options.getTopLogprobs()).isEqualTo(5); + } + + @Test + void testBuilderWithAllNewParameters() { + List stopSequences = Arrays.asList("STOP", "END"); + Map responseFormat = new HashMap<>(); + responseFormat.put("type", "json_object"); + + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("test-model") + .temperature(0.7) + .maxTokens(100) + .topP(0.9) + .frequencyPenalty(0.5) + .presencePenalty(0.8) + .stopSequences(stopSequences) + .seed(12345) + .responseFormat(responseFormat) + .toolPrompt("You have access to tools:") + .logprobs(true) + .topLogprobs(3) + .build(); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getStopSequences()).isEqualTo(stopSequences); + assertThat(options.getSeed()).isEqualTo(12345); + assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + assertThat(options.getToolPrompt()).isEqualTo("You have access to tools:"); + assertThat(options.getLogprobs()).isTrue(); + assertThat(options.getTopLogprobs()).isEqualTo(3); + } + + @Test + void testFromOptionsWithNewParameters() { + List stopSequences = Arrays.asList("STOP"); + Map responseFormat = new HashMap<>(); + responseFormat.put("type", "json_object"); + + HuggingfaceChatOptions originalOptions = HuggingfaceChatOptions.builder() + .model("test-model") + .stopSequences(stopSequences) + .seed(999) + .responseFormat(responseFormat) + .toolPrompt("Tools:") + .logprobs(true) + .topLogprobs(2) + .build(); + + HuggingfaceChatOptions copiedOptions = HuggingfaceChatOptions.fromOptions(originalOptions); + + assertThat(copiedOptions.getStopSequences()).isEqualTo(stopSequences); + assertThat(copiedOptions.getSeed()).isEqualTo(999); + assertThat(copiedOptions.getResponseFormat()).isEqualTo(responseFormat); + assertThat(copiedOptions.getToolPrompt()).isEqualTo("Tools:"); + assertThat(copiedOptions.getLogprobs()).isTrue(); + assertThat(copiedOptions.getTopLogprobs()).isEqualTo(2); + } + + @Test + void testToMapWithNewParameters() { + List stopSequences = Arrays.asList("STOP", "END"); + Map responseFormat = new HashMap<>(); + responseFormat.put("type", "json_object"); + + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("test-model") + .temperature(0.7) + .stopSequences(stopSequences) + .seed(12345) + .responseFormat(responseFormat) + .toolPrompt("Tools:") + .logprobs(true) + .topLogprobs(3) + .build(); + + Map map = options.toMap(); + + assertThat(map).containsEntry("model", "test-model") + .containsEntry("temperature", 0.7) + .containsEntry("stop", stopSequences) + .containsEntry("seed", 12345) + .containsEntry("response_format", responseFormat) + .containsEntry("tool_prompt", "Tools:") + .containsEntry("logprobs", true) + .containsEntry("top_logprobs", 3); + } + + @Test + void testEqualsAndHashCodeWithNewParameters() { + List stopSequences = Arrays.asList("STOP"); + Map responseFormat = new HashMap<>(); + responseFormat.put("type", "json_object"); + + HuggingfaceChatOptions options1 = HuggingfaceChatOptions.builder() + .model("test-model") + .stopSequences(stopSequences) + .seed(999) + .responseFormat(responseFormat) + .logprobs(true) + .build(); + + HuggingfaceChatOptions options2 = HuggingfaceChatOptions.builder() + .model("test-model") + .stopSequences(stopSequences) + .seed(999) + .responseFormat(responseFormat) + .logprobs(true) + .build(); + + assertThat(options1).isEqualTo(options2); + assertThat(options1.hashCode()).isEqualTo(options2.hashCode()); + } + + @Test + void testSettersForNewParameters() { + HuggingfaceChatOptions options = new HuggingfaceChatOptions(); + List stopSequences = Arrays.asList("STOP"); + Map responseFormat = new HashMap<>(); + responseFormat.put("type", "json_object"); + + options.setStopSequences(stopSequences); + options.setSeed(777); + options.setResponseFormat(responseFormat); + options.setToolPrompt("Tools available:"); + options.setLogprobs(true); + options.setTopLogprobs(4); + + assertThat(options.getStopSequences()).isEqualTo(stopSequences); + assertThat(options.getSeed()).isEqualTo(777); + assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + assertThat(options.getToolPrompt()).isEqualTo("Tools available:"); + assertThat(options.getLogprobs()).isTrue(); + assertThat(options.getTopLogprobs()).isEqualTo(4); + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatRequestTests.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatRequestTests.java new file mode 100644 index 00000000000..518fb465e2a --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceChatRequestTests.java @@ -0,0 +1,334 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.retry.RetryUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for request building in {@link HuggingfaceChatModel}. + * + * @author Myeongdeok Kang + */ +class HuggingfaceChatRequestTests { + + private final HuggingfaceChatModel chatModel = HuggingfaceChatModel.builder() + .huggingfaceApi(HuggingfaceApi.builder().apiKey("test-key").build()) + .defaultOptions(HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .temperature(0.7) + .maxTokens(100) + .topP(0.9) + .build()) + .toolCallingManager(ToolCallingManager.builder().build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .build(); + + @Test + void createRequestWithDefaultOptions() { + var prompt = this.chatModel.buildChatRequest(new Prompt("Test message content")); + + assertThat(prompt.getInstructions()).hasSize(1); + assertThat(prompt.getOptions()).isNotNull(); + + HuggingfaceChatOptions options = (HuggingfaceChatOptions) prompt.getOptions(); + assertThat(options.getModel()).isEqualTo("meta-llama/Llama-3.2-3B-Instruct"); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getTopP()).isEqualTo(0.9); + } + + @Test + void createRequestWithPromptHuggingfaceOptions() { + // Runtime options should override the default options. + HuggingfaceChatOptions promptOptions = HuggingfaceChatOptions.builder() + .temperature(0.8) + .topP(0.5) + .maxTokens(200) + .build(); + var prompt = this.chatModel.buildChatRequest(new Prompt("Test message content", promptOptions)); + + assertThat(prompt.getInstructions()).hasSize(1); + + HuggingfaceChatOptions options = (HuggingfaceChatOptions) prompt.getOptions(); + assertThat(options.getModel()).isEqualTo("meta-llama/Llama-3.2-3B-Instruct"); + assertThat(options.getTemperature()).isEqualTo(0.8); + assertThat(options.getMaxTokens()).isEqualTo(200); // overridden + assertThat(options.getTopP()).isEqualTo(0.5); // overridden + } + + @Test + void createRequestWithPromptPortableChatOptions() { + // Portable runtime options. + ChatOptions portablePromptOptions = ChatOptions.builder().temperature(0.9).topP(0.6).build(); + var prompt = this.chatModel.buildChatRequest(new Prompt("Test message content", portablePromptOptions)); + + assertThat(prompt.getInstructions()).hasSize(1); + + HuggingfaceChatOptions options = (HuggingfaceChatOptions) prompt.getOptions(); + assertThat(options.getModel()).isEqualTo("meta-llama/Llama-3.2-3B-Instruct"); + assertThat(options.getTemperature()).isEqualTo(0.9); + assertThat(options.getTopP()).isEqualTo(0.6); + assertThat(options.getMaxTokens()).isEqualTo(100); // default value maintained + } + + @Test + void createRequestWithPromptOptionsModelOverride() { + // Runtime options override model + HuggingfaceChatOptions promptOptions = HuggingfaceChatOptions.builder() + .model("mistralai/Mistral-7B-Instruct-v0.3") + .build(); + var prompt = this.chatModel.buildChatRequest(new Prompt("Test message content", promptOptions)); + + HuggingfaceChatOptions options = (HuggingfaceChatOptions) prompt.getOptions(); + assertThat(options.getModel()).isEqualTo("mistralai/Mistral-7B-Instruct-v0.3"); + } + + @Test + void createRequestWithDefaultOptionsModelOverride() { + HuggingfaceChatModel chatModel = HuggingfaceChatModel.builder() + .huggingfaceApi(HuggingfaceApi.builder().apiKey("test-key").build()) + .defaultOptions(HuggingfaceChatOptions.builder().model("google/gemma-2-2b-it").temperature(0.5).build()) + .toolCallingManager(ToolCallingManager.builder().build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .build(); + + var prompt1 = chatModel.buildChatRequest(new Prompt("Test message content")); + + HuggingfaceChatOptions options1 = (HuggingfaceChatOptions) prompt1.getOptions(); + assertThat(options1.getModel()).isEqualTo("google/gemma-2-2b-it"); + + // Prompt options should override the default options. + HuggingfaceChatOptions promptOptions = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .build(); + var prompt2 = chatModel.buildChatRequest(new Prompt("Test message content", promptOptions)); + + HuggingfaceChatOptions options2 = (HuggingfaceChatOptions) prompt2.getOptions(); + assertThat(options2.getModel()).isEqualTo("meta-llama/Llama-3.2-3B-Instruct"); + } + + @Test + void createRequestWithAllMessageTypes() { + var prompt = this.chatModel.buildChatRequest(new Prompt(createMessagesWithAllMessageTypes())); + + assertThat(prompt.getInstructions()).hasSize(3); + + var systemMessage = prompt.getInstructions().get(0); + assertThat(systemMessage).isInstanceOf(SystemMessage.class); + assertThat(systemMessage.getText()).isEqualTo("Test system message"); + + var userMessage = prompt.getInstructions().get(1); + assertThat(userMessage).isInstanceOf(UserMessage.class); + assertThat(userMessage.getText()).isEqualTo("Test user message"); + + var assistantMessage = prompt.getInstructions().get(2); + assertThat(assistantMessage).isInstanceOf(AssistantMessage.class); + assertThat(assistantMessage.getText()).isEqualTo("Test assistant message"); + } + + @Test + void createRequestWithMultipleUserMessages() { + List messages = List.of(new UserMessage("First question"), new UserMessage("Second question"), + new UserMessage("Third question")); + + var prompt = this.chatModel.buildChatRequest(new Prompt(messages)); + + assertThat(prompt.getInstructions()).hasSize(3); + assertThat(prompt.getInstructions().get(0).getText()).isEqualTo("First question"); + assertThat(prompt.getInstructions().get(1).getText()).isEqualTo("Second question"); + assertThat(prompt.getInstructions().get(2).getText()).isEqualTo("Third question"); + } + + @Test + void createRequestPreservesMessageOrder() { + List messages = List.of(new SystemMessage("System"), new UserMessage("User 1"), + new AssistantMessage("Assistant 1"), new UserMessage("User 2")); + + var prompt = this.chatModel.buildChatRequest(new Prompt(messages)); + + assertThat(prompt.getInstructions()).hasSize(4); + assertThat(prompt.getInstructions().get(0)).isInstanceOf(SystemMessage.class); + assertThat(prompt.getInstructions().get(1)).isInstanceOf(UserMessage.class); + assertThat(prompt.getInstructions().get(2)).isInstanceOf(AssistantMessage.class); + assertThat(prompt.getInstructions().get(3)).isInstanceOf(UserMessage.class); + } + + @Test + void createRequestWithMinimalOptions() { + HuggingfaceChatModel minimalChatModel = HuggingfaceChatModel.builder() + .huggingfaceApi(HuggingfaceApi.builder().apiKey("test-key").build()) + .defaultOptions(HuggingfaceChatOptions.builder().model("meta-llama/Llama-3.2-3B-Instruct").build()) + .toolCallingManager(ToolCallingManager.builder().build()) + .build(); + + var prompt = minimalChatModel.buildChatRequest(new Prompt("Test message")); + + HuggingfaceChatOptions options = (HuggingfaceChatOptions) prompt.getOptions(); + assertThat(options.getModel()).isEqualTo("meta-llama/Llama-3.2-3B-Instruct"); + // Other options should be null or default values + } + + @Test + void createRequestWithMaximumOptions() { + HuggingfaceChatOptions maxOptions = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .temperature(0.8) + .maxTokens(500) + .topP(0.95) + .frequencyPenalty(0.5) + .presencePenalty(0.3) + .build(); + + var prompt = this.chatModel.buildChatRequest(new Prompt("Test message", maxOptions)); + + HuggingfaceChatOptions options = (HuggingfaceChatOptions) prompt.getOptions(); + assertThat(options.getModel()).isEqualTo("meta-llama/Llama-3.2-3B-Instruct"); + assertThat(options.getTemperature()).isEqualTo(0.8); + assertThat(options.getMaxTokens()).isEqualTo(500); + assertThat(options.getTopP()).isEqualTo(0.95); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getPresencePenalty()).isEqualTo(0.3); + } + + @Test + void createRequestWithStopSequences() { + List stopSequences = Arrays.asList("STOP", "END", "DONE"); + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .stopSequences(stopSequences) + .build(); + + var prompt = this.chatModel.buildChatRequest(new Prompt("Test message", options)); + + HuggingfaceChatOptions resultOptions = (HuggingfaceChatOptions) prompt.getOptions(); + assertThat(resultOptions.getStopSequences()).isEqualTo(stopSequences); + } + + @Test + void createRequestWithSeed() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .seed(42) + .build(); + + var prompt = this.chatModel.buildChatRequest(new Prompt("Test message", options)); + + HuggingfaceChatOptions resultOptions = (HuggingfaceChatOptions) prompt.getOptions(); + assertThat(resultOptions.getSeed()).isEqualTo(42); + } + + @Test + void createRequestWithResponseFormat() { + Map responseFormat = new HashMap<>(); + responseFormat.put("type", "json_object"); + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .responseFormat(responseFormat) + .build(); + + var prompt = this.chatModel.buildChatRequest(new Prompt("Test message", options)); + + HuggingfaceChatOptions resultOptions = (HuggingfaceChatOptions) prompt.getOptions(); + assertThat(resultOptions.getResponseFormat()).isEqualTo(responseFormat); + } + + @Test + void createRequestWithToolPrompt() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .toolPrompt("You have access to these tools:") + .build(); + + var prompt = this.chatModel.buildChatRequest(new Prompt("Test message", options)); + + HuggingfaceChatOptions resultOptions = (HuggingfaceChatOptions) prompt.getOptions(); + assertThat(resultOptions.getToolPrompt()).isEqualTo("You have access to these tools:"); + } + + @Test + void createRequestWithLogprobs() { + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .logprobs(true) + .topLogprobs(3) + .build(); + + var prompt = this.chatModel.buildChatRequest(new Prompt("Test message", options)); + + HuggingfaceChatOptions resultOptions = (HuggingfaceChatOptions) prompt.getOptions(); + assertThat(resultOptions.getLogprobs()).isTrue(); + assertThat(resultOptions.getTopLogprobs()).isEqualTo(3); + } + + @Test + void createRequestWithAllNewParameters() { + List stopSequences = Arrays.asList("STOP"); + Map responseFormat = new HashMap<>(); + responseFormat.put("type", "json_object"); + + HuggingfaceChatOptions options = HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .temperature(0.7) + .maxTokens(200) + .stopSequences(stopSequences) + .seed(12345) + .responseFormat(responseFormat) + .toolPrompt("Tools available:") + .logprobs(true) + .topLogprobs(5) + .build(); + + var prompt = this.chatModel.buildChatRequest(new Prompt("Test message", options)); + + HuggingfaceChatOptions resultOptions = (HuggingfaceChatOptions) prompt.getOptions(); + assertThat(resultOptions.getModel()).isEqualTo("meta-llama/Llama-3.2-3B-Instruct"); + assertThat(resultOptions.getTemperature()).isEqualTo(0.7); + assertThat(resultOptions.getMaxTokens()).isEqualTo(200); + assertThat(resultOptions.getStopSequences()).isEqualTo(stopSequences); + assertThat(resultOptions.getSeed()).isEqualTo(12345); + assertThat(resultOptions.getResponseFormat()).isEqualTo(responseFormat); + assertThat(resultOptions.getToolPrompt()).isEqualTo("Tools available:"); + assertThat(resultOptions.getLogprobs()).isTrue(); + assertThat(resultOptions.getTopLogprobs()).isEqualTo(5); + } + + private static List createMessagesWithAllMessageTypes() { + var systemMessage = new SystemMessage("Test system message"); + var userMessage = new UserMessage("Test user message"); + var assistantMessage = new AssistantMessage("Test assistant message"); + + return List.of(systemMessage, userMessage, assistantMessage); + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingModelIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingModelIT.java new file mode 100644 index 00000000000..43eddd90111 --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingModelIT.java @@ -0,0 +1,248 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import java.util.List; + +import org.assertj.core.data.Offset; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.beans.factory.annotation.Autowired; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link HuggingfaceEmbeddingModel}. These tests require a valid + * HuggingFace API key set in the HUGGINGFACE_API_KEY environment variable. + * + * @author Myeongdeok Kang + */ +@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") +class HuggingfaceEmbeddingModelIT extends BaseHuggingfaceIT { + + @Autowired + private HuggingfaceEmbeddingModel embeddingModel; + + @Test + void defaultEmbedding() { + assertThat(this.embeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = this.embeddingModel + .call(new EmbeddingRequest(List.of("Hello World"), HuggingfaceEmbeddingOptions.builder().build())); + + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getMetadata()).isNotNull(); + } + + @Test + void embeddingBatchDocuments() { + assertThat(this.embeddingModel).isNotNull(); + List texts = List.of("Hello World", "Spring AI is awesome", "Huggingface provides great models"); + + EmbeddingResponse embeddingResponse = this.embeddingModel + .call(new EmbeddingRequest(texts, HuggingfaceEmbeddingOptions.builder().build())); + + assertThat(embeddingResponse.getResults()).hasSize(3); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(2).getIndex()).isEqualTo(2); + assertThat(embeddingResponse.getResults().get(2).getOutput()).isNotEmpty(); + + assertThat(embeddingResponse.getMetadata()).isNotNull(); + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo(DEFAULT_EMBEDDING_MODEL); + } + + @Test + void embeddingWithDocuments() { + List documents = List.of(new Document("Spring Framework is great"), + new Document("AI is transforming technology"), new Document("Integration tests are important")); + + List texts = documents.stream().map(Document::getText).toList(); + + EmbeddingResponse embeddingResponse = this.embeddingModel + .call(new EmbeddingRequest(texts, HuggingfaceEmbeddingOptions.builder().build())); + + assertThat(embeddingResponse.getResults()).hasSize(3); + for (int i = 0; i < 3; i++) { + assertThat(embeddingResponse.getResults().get(i).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(i).getIndex()).isEqualTo(i); + } + } + + @Test + void embeddingDimensions() { + assertThat(this.embeddingModel).isNotNull(); + + // For sentence-transformers/all-MiniLM-L6-v2, the dimension should be 384 + // Note: The dimensions() method returns the model's native dimensions, + // not a configurable parameter + Integer dimensions = this.embeddingModel.dimensions(); + assertThat(dimensions).isNotNull(); + assertThat(dimensions).isEqualTo(384); + } + + @Test + void embeddingWithCustomModel() { + HuggingfaceEmbeddingOptions customOptions = HuggingfaceEmbeddingOptions.builder() + .model("sentence-transformers/all-MiniLM-L6-v2") + .build(); + + EmbeddingResponse embeddingResponse = this.embeddingModel + .call(new EmbeddingRequest(List.of("Custom model test"), customOptions)); + + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("sentence-transformers/all-MiniLM-L6-v2"); + } + + @Test + void embeddingWithEmptyString() { + EmbeddingResponse embeddingResponse = this.embeddingModel + .call(new EmbeddingRequest(List.of(""), HuggingfaceEmbeddingOptions.builder().build())); + + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + } + + @Test + void embeddingWithLongText() { + String longText = "This is a longer text that contains multiple sentences. " + + "It is used to test how the embedding model handles longer inputs. " + + "The model should be able to process this text and return meaningful embeddings. " + + "These embeddings can then be used for various NLP tasks such as similarity search or classification."; + + EmbeddingResponse embeddingResponse = this.embeddingModel + .call(new EmbeddingRequest(List.of(longText), HuggingfaceEmbeddingOptions.builder().build())); + + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSizeGreaterThan(100); + } + + @Test + void embeddingVectorSimilarity() { + // Test that similar texts produce similar embeddings + List similarTexts = List.of("The cat sat on the mat", "A cat is sitting on a mat"); + + EmbeddingResponse embeddingResponse = this.embeddingModel + .call(new EmbeddingRequest(similarTexts, HuggingfaceEmbeddingOptions.builder().build())); + + assertThat(embeddingResponse.getResults()).hasSize(2); + float[] embedding1 = embeddingResponse.getResults().get(0).getOutput(); + float[] embedding2 = embeddingResponse.getResults().get(1).getOutput(); + + // Both embeddings should have the same dimensions + assertThat(embedding1).hasSameSizeAs(embedding2); + + // Calculate cosine similarity (should be high for similar texts) + double similarity = cosineSimilarity(embedding1, embedding2); + assertThat(similarity).isGreaterThan(0.7); // Similar texts should have high + // similarity + } + + @Test + void embeddingWithNormalizeOption() { + // Note: The normalize, prompt_name, truncate, and truncation_direction parameters + // are part of the HuggingFace Inference API Feature Extraction specification: + // https://huggingface.co/docs/inference-providers/tasks/feature-extraction + // + // This test verifies that: + // 1. The normalize option can be set and sent to the API (via toMap()) + // 2. The API accepts the parameter without throwing errors + // 3. The resulting embeddings are normalized (magnitude ≈ 1.0) + + HuggingfaceEmbeddingOptions optionsWithNormalize = HuggingfaceEmbeddingOptions.builder() + .model("sentence-transformers/all-MiniLM-L6-v2") + .normalize(true) + .build(); + + // Verify the option is included in the request + assertThat(optionsWithNormalize.getNormalize()).isTrue(); + assertThat(optionsWithNormalize.toMap()).containsEntry("normalize", true); + + EmbeddingResponse response = this.embeddingModel + .call(new EmbeddingRequest(List.of("Test normalize option"), optionsWithNormalize)); + + assertThat(response.getResults()).hasSize(1); + float[] embedding = response.getResults().get(0).getOutput(); + assertThat(embedding).isNotEmpty(); + + // The standard HuggingFace Inference API normalizes embeddings by default, + // so the magnitude should be close to 1.0 regardless of the normalize parameter + double magnitude = calculateMagnitude(embedding); + assertThat(magnitude).isCloseTo(1.0, Offset.offset(0.01)); + } + + @Test + void embeddingWithWrongBaseUrl() { + HuggingfaceApi wrongApi = HuggingfaceApi.builder() + .baseUrl("https://router.huggingface.co/v1") + .apiKey(getApiKey()) + .build(); + + HuggingfaceEmbeddingModel wrongEmbeddingModel = HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(wrongApi) + .defaultOptions(HuggingfaceEmbeddingOptions.builder().model(DEFAULT_EMBEDDING_MODEL).build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .observationRegistry(io.micrometer.observation.ObservationRegistry.NOOP) + .build(); + + org.junit.jupiter.api.Assertions.assertThrows(org.springframework.web.client.HttpClientErrorException.class, + () -> wrongEmbeddingModel.call(new EmbeddingRequest(List.of("Test with wrong URL"), + HuggingfaceEmbeddingOptions.builder().model(DEFAULT_EMBEDDING_MODEL).build()))); + } + + @Test + void embeddingWithCorrectBaseUrl() { + EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest( + List.of("Verify correct baseURL usage"), HuggingfaceEmbeddingOptions.builder().build())); + + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + } + + private double calculateMagnitude(float[] vector) { + double sum = 0.0; + for (float v : vector) { + sum += v * v; + } + return Math.sqrt(sum); + } + + private double cosineSimilarity(float[] vectorA, float[] vectorB) { + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (int i = 0; i < vectorA.length; i++) { + dotProduct += vectorA[i] * vectorB[i]; + normA += Math.pow(vectorA[i], 2); + normB += Math.pow(vectorB[i], 2); + } + + return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingModelObservationIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingModelObservationIT.java new file mode 100644 index 00000000000..cd8310fd518 --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingModelObservationIT.java @@ -0,0 +1,112 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import java.util.List; + +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for observation instrumentation in {@link HuggingfaceEmbeddingModel}. + * + * @author Myeongdeok Kang + */ +@SpringBootTest(classes = HuggingfaceEmbeddingModelObservationIT.Config.class) +@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") +class HuggingfaceEmbeddingModelObservationIT { + + private static final String MODEL = "sentence-transformers/all-MiniLM-L6-v2"; + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + HuggingfaceEmbeddingModel embeddingModel; + + @Test + void observationForEmbeddingOperation() { + var options = HuggingfaceEmbeddingOptions.builder().model(MODEL).build(); + + EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); + + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).isNotEmpty(); + + EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + TestObservationRegistryAssert.assertThat(this.observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("embedding " + MODEL) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.EMBEDDING.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.HUGGINGFACE.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), MODEL) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel()) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public HuggingfaceApi huggingfaceApi() { + String apiKey = System.getenv("HUGGINGFACE_API_KEY"); + return HuggingfaceApi.builder() + .baseUrl("https://router.huggingface.co/hf-inference/models") + .apiKey(apiKey) + .build(); + } + + @Bean + public HuggingfaceEmbeddingModel huggingfaceEmbeddingModel(HuggingfaceApi huggingfaceApi, + TestObservationRegistry observationRegistry) { + return HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(huggingfaceApi) + .observationRegistry(observationRegistry) + .build(); + } + + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingModelTests.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingModelTests.java new file mode 100644 index 00000000000..968772c0131 --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingModelTests.java @@ -0,0 +1,298 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import java.util.List; + +import io.micrometer.observation.ObservationRegistry; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResultMetadata; +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.huggingface.api.HuggingfaceApi.EmbeddingsRequest; +import org.springframework.ai.huggingface.api.HuggingfaceApi.EmbeddingsResponse; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.BDDMockito.given; + +/** + * Unit tests for {@link HuggingfaceEmbeddingModel}. + * + * @author Myeongdeok Kang + */ +@ExtendWith(MockitoExtension.class) +class HuggingfaceEmbeddingModelTests { + + @Mock + HuggingfaceApi huggingfaceApi; + + @Captor + ArgumentCaptor embeddingsRequestCaptor; + + @Test + void options() { + given(this.huggingfaceApi.embeddings(this.embeddingsRequestCaptor.capture())) + .willReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME", + List.of(new float[] { 1f, 2f, 3f }, new float[] { 4f, 5f, 6f }))) + .willReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME2", + List.of(new float[] { 7f, 8f, 9f }, new float[] { 10f, 11f, 12f }))); + + // Tests default options + var defaultOptions = HuggingfaceEmbeddingOptions.builder().model("DEFAULT_MODEL").build(); + + var embeddingModel = HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(defaultOptions) + .build(); + + EmbeddingResponse response = embeddingModel + .call(new EmbeddingRequest(List.of("Input1", "Input2", "Input3"), EmbeddingOptions.builder().build())); + + assertThat(response.getResults()).hasSize(2); + assertThat(response.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[] { 1f, 2f, 3f }); + assertThat(response.getResults().get(0).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY); + assertThat(response.getResults().get(1).getIndex()).isEqualTo(1); + assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[] { 4f, 5f, 6f }); + assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY); + assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME"); + + assertThat(this.embeddingsRequestCaptor.getValue().inputs()).isEqualTo(List.of("Input1", "Input2", "Input3")); + assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("DEFAULT_MODEL"); + + // Tests runtime options + var runtimeOptions = HuggingfaceEmbeddingOptions.builder().model("RUNTIME_MODEL").build(); + + response = embeddingModel.call(new EmbeddingRequest(List.of("Input4", "Input5", "Input6"), runtimeOptions)); + + assertThat(response.getResults()).hasSize(2); + assertThat(response.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[] { 7f, 8f, 9f }); + assertThat(response.getResults().get(0).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY); + assertThat(response.getResults().get(1).getIndex()).isEqualTo(1); + assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[] { 10f, 11f, 12f }); + assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY); + assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME2"); + + assertThat(this.embeddingsRequestCaptor.getValue().inputs()).isEqualTo(List.of("Input4", "Input5", "Input6")); + assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("RUNTIME_MODEL"); + } + + @Test + void singleInputEmbedding() { + given(this.huggingfaceApi.embeddings(this.embeddingsRequestCaptor.capture())) + .willReturn(new EmbeddingsResponse("TEST_MODEL", List.of(new float[] { 0.1f, 0.2f, 0.3f }))); + + var embeddingModel = HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(HuggingfaceEmbeddingOptions.builder().model("TEST_MODEL").build()) + .build(); + + EmbeddingResponse response = embeddingModel + .call(new EmbeddingRequest(List.of("Single input text"), EmbeddingOptions.builder().build())); + + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[] { 0.1f, 0.2f, 0.3f }); + assertThat(response.getMetadata().getModel()).isEqualTo("TEST_MODEL"); + + assertThat(this.embeddingsRequestCaptor.getValue().inputs()).isEqualTo(List.of("Single input text")); + assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("TEST_MODEL"); + } + + @Test + void embeddingWithNullOptions() { + given(this.huggingfaceApi.embeddings(this.embeddingsRequestCaptor.capture())) + .willReturn(new EmbeddingsResponse("NULL_OPTIONS_MODEL", List.of(new float[] { 0.5f }))); + + var embeddingModel = HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(HuggingfaceEmbeddingOptions.builder().model("NULL_OPTIONS_MODEL").build()) + .build(); + + EmbeddingResponse response = embeddingModel.call(new EmbeddingRequest(List.of("Null options test"), null)); + + assertThat(response.getResults()).hasSize(1); + assertThat(response.getMetadata().getModel()).isEqualTo("NULL_OPTIONS_MODEL"); + + assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("NULL_OPTIONS_MODEL"); + } + + @Test + void embeddingWithMultipleLargeInputs() { + List largeInputs = List.of( + "This is a very long text input that might be used for document embedding scenarios", + "Another substantial piece of text content that could represent a paragraph or section", + "A third lengthy input to test batch processing capabilities of the embedding model"); + + given(this.huggingfaceApi.embeddings(this.embeddingsRequestCaptor.capture())) + .willReturn(new EmbeddingsResponse("BATCH_MODEL", List.of(new float[] { 0.1f, 0.2f, 0.3f, 0.4f }, + new float[] { 0.5f, 0.6f, 0.7f, 0.8f }, new float[] { 0.9f, 1.0f, 1.1f, 1.2f }))); + + var embeddingModel = HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(HuggingfaceEmbeddingOptions.builder().model("BATCH_MODEL").build()) + .build(); + + EmbeddingResponse response = embeddingModel + .call(new EmbeddingRequest(largeInputs, EmbeddingOptions.builder().build())); + + assertThat(response.getResults()).hasSize(3); + assertThat(response.getResults().get(0).getOutput()).hasSize(4); + assertThat(response.getResults().get(1).getOutput()).hasSize(4); + assertThat(response.getResults().get(2).getOutput()).hasSize(4); + + assertThat(this.embeddingsRequestCaptor.getValue().inputs()).isEqualTo(largeInputs); + } + + @Test + void embeddingResponseMetadata() { + given(this.huggingfaceApi.embeddings(this.embeddingsRequestCaptor.capture())) + .willReturn(new EmbeddingsResponse("METADATA_MODEL", List.of(new float[] { 0.1f, 0.2f }))); + + var embeddingModel = HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(HuggingfaceEmbeddingOptions.builder().model("METADATA_MODEL").build()) + .build(); + + EmbeddingResponse response = embeddingModel + .call(new EmbeddingRequest(List.of("Metadata test"), EmbeddingOptions.builder().build())); + + assertThat(response.getMetadata().getModel()).isEqualTo("METADATA_MODEL"); + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY); + } + + @Test + void embeddingWithZeroLengthVectors() { + given(this.huggingfaceApi.embeddings(this.embeddingsRequestCaptor.capture())) + .willReturn(new EmbeddingsResponse("ZERO_MODEL", List.of(new float[] {}))); + + var embeddingModel = HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(HuggingfaceEmbeddingOptions.builder().model("ZERO_MODEL").build()) + .build(); + + EmbeddingResponse response = embeddingModel + .call(new EmbeddingRequest(List.of("Zero length test"), EmbeddingOptions.builder().build())); + + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput()).isEmpty(); + } + + @Test + void builderValidation() { + // Test that builder requires huggingfaceApi + assertThatThrownBy(() -> HuggingfaceEmbeddingModel.builder().build()) + .isInstanceOf(IllegalArgumentException.class); + + // Test successful builder with minimal required parameters + var model = HuggingfaceEmbeddingModel.builder().huggingfaceApi(this.huggingfaceApi).build(); + + assertThat(model).isNotNull(); + } + + @Test + void builderWithAllOptions() { + HuggingfaceEmbeddingOptions options = HuggingfaceEmbeddingOptions.builder() + .model("test-model") + .normalize(true) + .promptName("query") + .truncate(true) + .truncationDirection("Right") + .build(); + + var embeddingModel = HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(options) + .observationRegistry(ObservationRegistry.NOOP) + .build(); + + assertThat(embeddingModel).isNotNull(); + } + + @Test + void embedDocument() { + given(this.huggingfaceApi.embeddings(this.embeddingsRequestCaptor.capture())) + .willReturn(new EmbeddingsResponse("DOC_MODEL", List.of(new float[] { 0.1f, 0.2f, 0.3f }))); + + var embeddingModel = HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(HuggingfaceEmbeddingOptions.builder().model("DOC_MODEL").build()) + .build(); + + Document document = new Document("Document content for embedding"); + float[] embedding = embeddingModel.embed(document); + + assertThat(embedding).isEqualTo(new float[] { 0.1f, 0.2f, 0.3f }); + } + + @Test + void buildEmbeddingRequestMergesOptions() { + var defaultOptions = HuggingfaceEmbeddingOptions.builder().model("DEFAULT_MODEL").build(); + + var embeddingModel = HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(defaultOptions) + .build(); + + var runtimeOptions = HuggingfaceEmbeddingOptions.builder().model("RUNTIME_MODEL").build(); + + var originalRequest = new EmbeddingRequest(List.of("Test text"), runtimeOptions); + var builtRequest = embeddingModel.buildEmbeddingRequest(originalRequest); + + assertThat(builtRequest.getOptions()).isInstanceOf(HuggingfaceEmbeddingOptions.class); + HuggingfaceEmbeddingOptions mergedOptions = (HuggingfaceEmbeddingOptions) builtRequest.getOptions(); + assertThat(mergedOptions.getModel()).isEqualTo("RUNTIME_MODEL"); + } + + @Test + void buildEmbeddingRequestValidatesModel() { + var defaultOptions = HuggingfaceEmbeddingOptions.builder().build(); + + var embeddingModel = HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(defaultOptions) + .build(); + + var request = new EmbeddingRequest(List.of("Test text"), null); + + assertThatThrownBy(() -> embeddingModel.buildEmbeddingRequest(request)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("model cannot be null or empty"); + } + + @Test + void setObservationConventionValidation() { + var embeddingModel = HuggingfaceEmbeddingModel.builder().huggingfaceApi(this.huggingfaceApi).build(); + + assertThatThrownBy(() -> embeddingModel.setObservationConvention(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("observationConvention cannot be null"); + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingOptionsTests.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingOptionsTests.java new file mode 100644 index 00000000000..818935d4181 --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingOptionsTests.java @@ -0,0 +1,211 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link HuggingfaceEmbeddingOptions}. + * + * @author Myeongdeok Kang + */ +class HuggingfaceEmbeddingOptionsTests { + + @Test + void testBuilderWithAllFields() { + HuggingfaceEmbeddingOptions options = HuggingfaceEmbeddingOptions.builder() + .model("sentence-transformers/all-MiniLM-L6-v2") + .normalize(true) + .promptName("query") + .truncate(true) + .truncationDirection("Right") + .build(); + + assertThat(options).extracting("model", "normalize", "promptName", "truncate", "truncationDirection") + .containsExactly("sentence-transformers/all-MiniLM-L6-v2", true, "query", true, "Right"); + } + + @Test + void testCopy() { + HuggingfaceEmbeddingOptions originalOptions = HuggingfaceEmbeddingOptions.builder() + .model("test-model") + .normalize(false) + .promptName("passage") + .truncate(false) + .truncationDirection("Left") + .build(); + + HuggingfaceEmbeddingOptions copiedOptions = originalOptions.copy(); + + assertThat(copiedOptions).isNotSameAs(originalOptions).isEqualTo(originalOptions); + } + + @Test + void testSetters() { + HuggingfaceEmbeddingOptions options = new HuggingfaceEmbeddingOptions(); + + options.setModel("test-model"); + options.setNormalize(true); + options.setPromptName("query"); + options.setTruncate(true); + options.setTruncationDirection("Right"); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getNormalize()).isTrue(); + assertThat(options.getPromptName()).isEqualTo("query"); + assertThat(options.getTruncate()).isTrue(); + assertThat(options.getTruncationDirection()).isEqualTo("Right"); + } + + @Test + void testDefaultValues() { + HuggingfaceEmbeddingOptions options = new HuggingfaceEmbeddingOptions(); + + assertThat(options.getModel()).isNull(); + assertThat(options.getNormalize()).isNull(); + assertThat(options.getPromptName()).isNull(); + assertThat(options.getTruncate()).isNull(); + assertThat(options.getTruncationDirection()).isNull(); + } + + @Test + void testFromOptions() { + HuggingfaceEmbeddingOptions originalOptions = HuggingfaceEmbeddingOptions.builder() + .model("original-model") + .normalize(true) + .promptName("document") + .truncate(true) + .truncationDirection("Left") + .build(); + + HuggingfaceEmbeddingOptions copiedOptions = HuggingfaceEmbeddingOptions.fromOptions(originalOptions); + + assertThat(copiedOptions).isNotSameAs(originalOptions).isEqualTo(originalOptions); + } + + @Test + void testToMap() { + HuggingfaceEmbeddingOptions options = HuggingfaceEmbeddingOptions.builder() + .model("test-model") + .normalize(true) + .promptName("query") + .truncate(false) + .truncationDirection("Right") + .build(); + + Map map = options.toMap(); + + assertThat(map).containsEntry("model", "test-model") + .containsEntry("normalize", true) + .containsEntry("prompt_name", "query") + .containsEntry("truncate", false) + .containsEntry("truncation_direction", "Right"); + } + + @Test + void testEqualsAndHashCode() { + HuggingfaceEmbeddingOptions options1 = HuggingfaceEmbeddingOptions.builder() + .model("test-model") + .normalize(false) + .promptName("passage") + .truncate(true) + .truncationDirection("Left") + .build(); + + HuggingfaceEmbeddingOptions options2 = HuggingfaceEmbeddingOptions.builder() + .model("test-model") + .normalize(false) + .promptName("passage") + .truncate(true) + .truncationDirection("Left") + .build(); + + assertThat(options1).isEqualTo(options2); + assertThat(options1.hashCode()).isEqualTo(options2.hashCode()); + } + + @Test + void testToString() { + HuggingfaceEmbeddingOptions options = HuggingfaceEmbeddingOptions.builder() + .model("test-model") + .normalize(true) + .promptName("query") + .truncate(false) + .truncationDirection("Right") + .build(); + + String result = options.toString(); + + assertThat(result).contains("test-model", "true", "query", "false", "Right"); + } + + @Test + void testBuilderWithNullValues() { + HuggingfaceEmbeddingOptions options = HuggingfaceEmbeddingOptions.builder() + .model(null) + .normalize(null) + .promptName(null) + .truncate(null) + .truncationDirection(null) + .build(); + + assertThat(options.getModel()).isNull(); + assertThat(options.getNormalize()).isNull(); + assertThat(options.getPromptName()).isNull(); + assertThat(options.getTruncate()).isNull(); + assertThat(options.getTruncationDirection()).isNull(); + } + + @Test + void testCopyChangeIndependence() { + HuggingfaceEmbeddingOptions originalOptions = HuggingfaceEmbeddingOptions.builder() + .model("original-model") + .normalize(true) + .promptName("query") + .build(); + + HuggingfaceEmbeddingOptions copiedOptions = originalOptions.copy(); + + // Modify original + originalOptions.setModel("modified-model"); + originalOptions.setNormalize(false); + + // Copy should retain original values + assertThat(copiedOptions.getModel()).isEqualTo("original-model"); + assertThat(copiedOptions.getNormalize()).isTrue(); + assertThat(originalOptions.getModel()).isEqualTo("modified-model"); + assertThat(originalOptions.getNormalize()).isFalse(); + } + + @Test + void testBuilderChaining() { + HuggingfaceEmbeddingOptions.Builder builder = HuggingfaceEmbeddingOptions.builder(); + + HuggingfaceEmbeddingOptions.Builder result = builder.model("test-model") + .normalize(true) + .promptName("query") + .truncate(true) + .truncationDirection("Right"); + + assertThat(result).isSameAs(builder); + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceRetryTests.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceRetryTests.java new file mode 100644 index 00000000000..7ada6e7f04e --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceRetryTests.java @@ -0,0 +1,220 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface; + +import java.util.List; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.retry.NonTransientAiException; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; +import org.springframework.web.client.ResourceAccessException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Unit tests for retry logic in {@link HuggingfaceChatModel}. + * + * @author Myeongdeok Kang + */ +@ExtendWith(MockitoExtension.class) +class HuggingfaceRetryTests { + + private static final String MODEL = "meta-llama/Llama-3.2-3B-Instruct"; + + private TestRetryListener retryListener; + + private RetryTemplate retryTemplate; + + @Mock + private HuggingfaceApi huggingfaceApi; + + private HuggingfaceChatModel chatModel; + + @BeforeEach + public void beforeEach() { + this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.setRetryListener(this.retryListener); + + this.chatModel = HuggingfaceChatModel.builder() + .huggingfaceApi(this.huggingfaceApi) + .defaultOptions(HuggingfaceChatOptions.builder().model(MODEL).temperature(0.9).build()) + .toolCallingManager(ToolCallingManager.builder().build()) + .retryTemplate(this.retryTemplate) + .build(); + } + + @Test + void huggingfaceChatTransientError() { + String promptText = "What is the capital of Bulgaria and what is the size? What it the national anthem?"; + var expectedChatResponse = new HuggingfaceApi.ChatResponse("id-123", "chat.completion", + System.currentTimeMillis(), MODEL, + List.of(new HuggingfaceApi.Choice(0, new HuggingfaceApi.Message("assistant", "Response"), "stop")), + new HuggingfaceApi.Usage(10, 20, 30), null); + + when(this.huggingfaceApi.chat(isA(HuggingfaceApi.ChatRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(expectedChatResponse); + + var result = this.chatModel.call(new Prompt(promptText)); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getText()).isSameAs("Response"); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); + assertThat(this.retryListener.retryCount).isEqualTo(2); + } + + @Test + void huggingfaceChatSuccessOnFirstAttempt() { + String promptText = "Simple question"; + var expectedChatResponse = new HuggingfaceApi.ChatResponse("id-123", "chat.completion", + System.currentTimeMillis(), MODEL, List.of(new HuggingfaceApi.Choice(0, + new HuggingfaceApi.Message("assistant", "Quick response"), "stop")), + new HuggingfaceApi.Usage(5, 10, 15), null); + + when(this.huggingfaceApi.chat(isA(HuggingfaceApi.ChatRequest.class))).thenReturn(expectedChatResponse); + + var result = this.chatModel.call(new Prompt(promptText)); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getText()).isEqualTo("Quick response"); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(0); + assertThat(this.retryListener.retryCount).isEqualTo(0); + verify(this.huggingfaceApi, times(1)).chat(isA(HuggingfaceApi.ChatRequest.class)); + } + + @Test + void huggingfaceChatNonTransientErrorShouldNotRetry() { + String promptText = "Invalid request"; + + when(this.huggingfaceApi.chat(isA(HuggingfaceApi.ChatRequest.class))) + .thenThrow(new NonTransientAiException("Model not found")); + + assertThatThrownBy(() -> this.chatModel.call(new Prompt(promptText))) + .isInstanceOf(NonTransientAiException.class) + .hasMessage("Model not found"); + + verify(this.huggingfaceApi, times(1)).chat(isA(HuggingfaceApi.ChatRequest.class)); + } + + @Test + void huggingfaceChatWithMultipleMessages() { + List messages = List.of(new UserMessage("What is AI?"), new UserMessage("Explain machine learning")); + Prompt prompt = new Prompt(messages); + + var expectedChatResponse = new HuggingfaceApi.ChatResponse("id-123", "chat.completion", + System.currentTimeMillis(), MODEL, + List.of(new HuggingfaceApi.Choice(0, + new HuggingfaceApi.Message("assistant", "AI is artificial intelligence..."), "stop")), + new HuggingfaceApi.Usage(15, 30, 45), null); + + when(this.huggingfaceApi.chat(isA(HuggingfaceApi.ChatRequest.class))) + .thenThrow(new TransientAiException("Temporary overload")) + .thenReturn(expectedChatResponse); + + var result = this.chatModel.call(prompt); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getText()).isEqualTo("AI is artificial intelligence..."); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); + assertThat(this.retryListener.retryCount).isEqualTo(1); + } + + @Test + void huggingfaceChatWithCustomOptions() { + String promptText = "Custom temperature request"; + HuggingfaceChatOptions customOptions = HuggingfaceChatOptions.builder() + .model(MODEL) + .temperature(0.1) + .topP(0.9) + .build(); + + var expectedChatResponse = new HuggingfaceApi.ChatResponse("id-123", "chat.completion", + System.currentTimeMillis(), MODEL, List.of(new HuggingfaceApi.Choice(0, + new HuggingfaceApi.Message("assistant", "Deterministic response"), "stop")), + new HuggingfaceApi.Usage(8, 12, 20), null); + + when(this.huggingfaceApi.chat(isA(HuggingfaceApi.ChatRequest.class))) + .thenThrow(new ResourceAccessException("Connection timeout")) + .thenReturn(expectedChatResponse); + + var result = this.chatModel.call(new Prompt(promptText, customOptions)); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getText()).isEqualTo("Deterministic response"); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); + } + + @Test + void huggingfaceChatWithEmptyResponse() { + String promptText = "Edge case request"; + var expectedChatResponse = new HuggingfaceApi.ChatResponse("id-123", "chat.completion", + System.currentTimeMillis(), MODEL, + List.of(new HuggingfaceApi.Choice(0, new HuggingfaceApi.Message("assistant", ""), "stop")), + new HuggingfaceApi.Usage(5, 0, 5), null); + + when(this.huggingfaceApi.chat(isA(HuggingfaceApi.ChatRequest.class))) + .thenThrow(new TransientAiException("Rate limit exceeded")) + .thenReturn(expectedChatResponse); + + var result = this.chatModel.call(new Prompt(promptText)); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getText()).isEmpty(); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); + } + + private static class TestRetryListener implements RetryListener { + + int retryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onRetrySuccess(final RetryPolicy retryPolicy, final Retryable retryable, final Object result) { + this.onSuccessRetryCount++; + } + + @Override + public void beforeRetry(RetryPolicy retryPolicy, Retryable retryable) { + this.retryCount++; + } + + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java index 5f933a09c8c..05dc94c786e 100644 --- a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.ai.huggingface; +import org.springframework.ai.huggingface.api.HuggingfaceApi; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; @@ -24,16 +25,32 @@ public class HuggingfaceTestConfiguration { @Bean - public HuggingfaceChatModel huggingfaceChatModel() { + public HuggingfaceApi huggingfaceApi() { String apiKey = System.getenv("HUGGINGFACE_API_KEY"); if (!StringUtils.hasText(apiKey)) { throw new IllegalArgumentException( "You must provide an API key. Put it in an environment variable under the name HUGGINGFACE_API_KEY"); } + + // Create builder with API key + HuggingfaceApi.Builder builder = HuggingfaceApi.builder().apiKey(apiKey); + + // Optional: use custom base URL if provided // Created aws-mistral-7b-instruct and update the HUGGINGFACE_CHAT_URL - HuggingfaceChatModel huggingfaceChatModel = new HuggingfaceChatModel(apiKey, - System.getenv("HUGGINGFACE_CHAT_URL")); - return huggingfaceChatModel; + String chatUrl = System.getenv("HUGGINGFACE_CHAT_URL"); + if (StringUtils.hasText(chatUrl)) { + builder.baseUrl(chatUrl); + } + + return builder.build(); + } + + @Bean + public HuggingfaceChatModel huggingfaceChatModel(HuggingfaceApi huggingfaceApi) { + return HuggingfaceChatModel.builder() + .huggingfaceApi(huggingfaceApi) + .defaultOptions(HuggingfaceChatOptions.builder().model(HuggingfaceApi.DEFAULT_CHAT_MODEL).build()) + .build(); } } diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/aot/HuggingfaceRuntimeHintsTests.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/aot/HuggingfaceRuntimeHintsTests.java new file mode 100644 index 00000000000..0bf0f48817a --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/aot/HuggingfaceRuntimeHintsTests.java @@ -0,0 +1,299 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface.aot; + +import java.util.HashSet; +import java.util.Set; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.huggingface.HuggingfaceChatOptions; +import org.springframework.ai.huggingface.HuggingfaceEmbeddingOptions; +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.TypeReference; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; + +/** + * Unit tests for {@link HuggingfaceRuntimeHints}. + * + * @author Myeongdeok Kang + */ +class HuggingfaceRuntimeHintsTests { + + @Test + void registerHints() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage( + "org.springframework.ai.huggingface"); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { + assertThat(registeredTypes.contains(jsonAnnotatedClass)).isTrue(); + } + + // Check specific HuggingFace API classes + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.ChatRequest.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.Message.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.ChatResponse.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.FunctionTool.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceChatOptions.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceEmbeddingOptions.class))).isTrue(); + } + + @Test + void registerHintsWithNullClassLoader() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + + // Should not throw exception with null ClassLoader + org.assertj.core.api.Assertions.assertThatCode(() -> huggingfaceRuntimeHints.registerHints(runtimeHints, null)) + .doesNotThrowAnyException(); + } + + @Test + void ensureReflectionHintsAreRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + + // Ensure reflection hints are properly registered + assertThat(runtimeHints.reflection().typeHints().spliterator().estimateSize()).isGreaterThan(0); + } + + @Test + void verifyMultipleRegistrationCallsAreIdempotent() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + + // Register hints multiple times + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + long firstCount = runtimeHints.reflection().typeHints().spliterator().estimateSize(); + + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + long secondCount = runtimeHints.reflection().typeHints().spliterator().estimateSize(); + + // Should not register duplicate hints + assertThat(firstCount).isEqualTo(secondCount); + } + + @Test + void verifyMainApiClassesRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + // Verify that the main classes we know exist are registered + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.ChatRequest.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.Message.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceChatOptions.class))).isTrue(); + } + + @Test + void verifyJsonAnnotatedClassesFromCorrectPackage() { + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage( + "org.springframework.ai.huggingface"); + + // Ensure we found some JSON annotated classes in the expected package + assertThat(jsonAnnotatedClasses.spliterator().estimateSize()).isGreaterThan(0); + + // Verify all found classes are from the expected package + for (TypeReference classRef : jsonAnnotatedClasses) { + assertThat(classRef.getName()).startsWith("org.springframework.ai.huggingface"); + } + } + + @Test + void verifyNoUnnecessaryHintsRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage( + "org.springframework.ai.huggingface"); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + // Ensure we don't register significantly more types than needed + // Allow for some additional utility types but prevent hint bloat + assertThat(registeredTypes.size()).isLessThanOrEqualTo(jsonAnnotatedClasses.size() + 15); + } + + @Test + void verifyNestedClassHintsAreRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + // Verify nested classes that we know exist + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.FunctionTool.Function.class))).isTrue(); + + // Count nested classes to ensure comprehensive registration + long nestedClassCount = registeredTypes.stream().filter(typeRef -> typeRef.getName().contains("$")).count(); + assertThat(nestedClassCount).isGreaterThan(0); + } + + @Test + void verifyEmbeddingRelatedClassesAreRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + // Verify embedding-related classes are registered for reflection + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.EmbeddingsRequest.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.EmbeddingsResponse.class))).isTrue(); + + // Count classes related to embedding functionality + long embeddingClassCount = registeredTypes.stream() + .filter(typeRef -> typeRef.getName().toLowerCase().contains("embedding")) + .count(); + assertThat(embeddingClassCount).isGreaterThan(0); + } + + @Test + void verifyHintsRegistrationWithCustomClassLoader() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + + // Create a custom class loader + ClassLoader customClassLoader = Thread.currentThread().getContextClassLoader(); + + // Should work with custom class loader + org.assertj.core.api.Assertions + .assertThatCode(() -> huggingfaceRuntimeHints.registerHints(runtimeHints, customClassLoader)) + .doesNotThrowAnyException(); + + // Verify hints are still registered properly + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + assertThat(registeredTypes.size()).isGreaterThan(0); + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.ChatRequest.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceChatOptions.class))).isTrue(); + } + + @Test + void verifyNoProxyHintsAreRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + + // HuggingFace should only register reflection hints, not proxy hints + assertThat(runtimeHints.proxies().jdkProxyHints().count()).isEqualTo(0); + } + + @Test + void verifyNoSerializationHintsAreRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + + // HuggingFace should only register reflection hints, not serialization hints + assertThat(runtimeHints.serialization().javaSerializationHints().count()).isEqualTo(0); + } + + @Test + void verifyConstructorHintsAreRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + + // Verify that reflection hints include constructor access for JSON + // deserialization + boolean hasConstructorHints = runtimeHints.reflection() + .typeHints() + .anyMatch(typeHint -> typeHint.constructors().findAny().isPresent() || typeHint.getMemberCategories() + .contains(org.springframework.aot.hint.MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)); + + assertThat(hasConstructorHints).as("Should register constructor hints for JSON deserialization").isTrue(); + } + + @Test + void verifyEnumTypesAreRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + // Verify enum types are registered (critical for JSON deserialization) + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.FunctionTool.Type.class))).isTrue(); + + boolean hasEnumTypes = registeredTypes.stream() + .anyMatch(tr -> tr.getName().contains("$") || tr.getName().toLowerCase().contains("type")); + + assertThat(hasEnumTypes).as("Enum types should be registered for native image compatibility").isTrue(); + } + + @Test + void verifyResponseTypesAreRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + // Verify response wrapper types are registered + assertThat(registeredTypes.stream().anyMatch(tr -> tr.getName().contains("Response"))) + .as("Response types should be registered") + .isTrue(); + + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.ChatResponse.class))) + .as("ChatResponse type should be registered") + .isTrue(); + } + + @Test + void verifyToolRelatedClassesAreRegistered() { + RuntimeHints runtimeHints = new RuntimeHints(); + HuggingfaceRuntimeHints huggingfaceRuntimeHints = new HuggingfaceRuntimeHints(); + huggingfaceRuntimeHints.registerHints(runtimeHints, null); + + Set registeredTypes = new HashSet<>(); + runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType())); + + // Verify tool-related classes are registered + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.FunctionTool.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(HuggingfaceApi.ToolCall.class))).isTrue(); + + // Count tool-related classes + long toolClassCount = registeredTypes.stream() + .filter(typeRef -> typeRef.getName().toLowerCase().contains("tool")) + .count(); + assertThat(toolClassCount).isGreaterThan(0); + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/api/tool/HuggingfaceApiToolFunctionCallIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/api/tool/HuggingfaceApiToolFunctionCallIT.java new file mode 100644 index 00000000000..c8a9103e122 --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/api/tool/HuggingfaceApiToolFunctionCallIT.java @@ -0,0 +1,180 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface.api.tool; + +import java.util.ArrayList; +import java.util.List; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.huggingface.api.HuggingfaceApi; +import org.springframework.ai.huggingface.api.HuggingfaceApi.ChatRequest; +import org.springframework.ai.huggingface.api.HuggingfaceApi.ChatResponse; +import org.springframework.ai.huggingface.api.HuggingfaceApi.Message; +import org.springframework.ai.huggingface.api.HuggingfaceApi.ToolCall; +import org.springframework.ai.model.ModelOptionsUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Low-level API integration tests for HuggingFace tool/function calling. + * + *

+ * Note: Function calling is only supported by specific models and providers on + * HuggingFace Inference API. This test uses meta-llama/Llama-3.2-3B-Instruct (3B + * parameters) with the 'together' provider specified using the :provider suffix notation. + * The model supports function calling through multiple providers (novita, hyperbolic, + * together). To specify a provider, append :provider-name to the model ID (e.g., + * "model:together", "model:fastest", "model:cheapest"). + *

+ * + * @author Myeongdeok Kang + * @see HuggingFace + * Function Calling Guide + */ +@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") +public class HuggingfaceApiToolFunctionCallIT { + + private final Logger logger = LoggerFactory.getLogger(HuggingfaceApiToolFunctionCallIT.class); + + MockWeatherService weatherService = new MockWeatherService(); + + HuggingfaceApi huggingfaceApi = HuggingfaceApi.builder().apiKey(System.getenv("HUGGINGFACE_API_KEY")).build(); + + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @SuppressWarnings("null") + @Test + public void toolFunctionCall() { + // Reset the weather service call history before the test + this.weatherService.reset(); + + // Step 1: send the conversation and available functions to the model + var message = new Message("user", "What's the weather like in San Francisco, Tokyo, and Paris?"); + + var functionTool = new HuggingfaceApi.FunctionTool(HuggingfaceApi.FunctionTool.Type.FUNCTION, + new HuggingfaceApi.FunctionTool.Function("Get the weather in location. Return temperature in Celsius.", + "getCurrentWeather", ModelOptionsUtils.jsonToMap(""" + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "lat", "lon", "unit"] + } + """))); + + List messages = new ArrayList<>(List.of(message)); + + // Specify the 'together' provider using :provider suffix notation + ChatRequest chatRequest = new ChatRequest("meta-llama/Llama-3.2-3B-Instruct:together", messages, + List.of(functionTool), "auto"); + + ChatResponse chatResponse = this.huggingfaceApi.chat(chatRequest); + + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.choices()).isNotEmpty(); + + Message responseMessage = chatResponse.choices().get(0).message(); + + // Check if the model wanted to call a function + assertThat(responseMessage.role()).isEqualTo("assistant"); + assertThat(responseMessage.toolCalls()).isNotNull(); + + // extend conversation with assistant's reply. + messages.add(responseMessage); + + // Send the info for each function call and function response to the model. + for (ToolCall toolCall : responseMessage.toolCalls()) { + var functionName = toolCall.function().name(); + if ("getCurrentWeather".equals(functionName)) { + MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), + MockWeatherService.Request.class); + + MockWeatherService.Response weatherResponse = this.weatherService.apply(weatherRequest); + + // extend conversation with function response. + messages.add(new Message("" + weatherResponse.temp() + weatherRequest.unit(), "tool", functionName, + toolCall.id())); + } + } + + // Use the same provider for the follow-up request + var functionResponseRequest = new ChatRequest("meta-llama/Llama-3.2-3B-Instruct:together", messages); + + ChatResponse chatResponse2 = this.huggingfaceApi.chat(functionResponseRequest); + + logger.info("Final response: " + chatResponse2); + + assertThat(chatResponse2.choices()).isNotEmpty(); + assertThat(chatResponse2.choices().get(0).message().role()).isEqualTo("assistant"); + assertThat(chatResponse2.choices().get(0).message().content()).isNotEmpty(); + + // Verify that all three cities are mentioned in the response + String finalContent = chatResponse2.choices().get(0).message().content(); + assertThat(finalContent).containsIgnoringCase("San Francisco"); + assertThat(finalContent).containsIgnoringCase("Tokyo"); + assertThat(finalContent).containsIgnoringCase("Paris"); + + // Verify that the function was actually called 3 times (once for each city) + assertThat(this.weatherService.getCallCount()).isEqualTo(3); + + // Verify the function was called with correct locations + List callHistory = this.weatherService.getCallHistory(); + assertThat(callHistory).hasSize(3); + + List locations = callHistory.stream().map(MockWeatherService.Request::location).toList(); + assertThat(locations).anyMatch(loc -> loc.contains("San Francisco")); + assertThat(locations).anyMatch(loc -> loc.contains("Tokyo")); + assertThat(locations).anyMatch(loc -> loc.contains("Paris")); + + // Verify all calls used Celsius unit + assertThat(callHistory).allMatch(req -> req.unit() == MockWeatherService.Unit.C); + + // Verify lat/lon were provided for all calls + assertThat(callHistory).allMatch(req -> req.lat() != null && req.lon() != null); + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/api/tool/MockWeatherService.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/api/tool/MockWeatherService.java new file mode 100644 index 00000000000..c10f8a04b1f --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/api/tool/MockWeatherService.java @@ -0,0 +1,124 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.huggingface.api.tool; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** + * @author Christian Tzolov + */ +public class MockWeatherService implements Function { + + private final List callHistory = new ArrayList<>(); + + @Override + public Response apply(Request request) { + // Track function calls + this.callHistory.add(request); + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); + } + + /** + * Get the number of times the function was called. + * @return The call count. + */ + public int getCallCount() { + return this.callHistory.size(); + } + + /** + * Get the history of all function calls. + * @return List of all requests received. + */ + public List getCallHistory() { + return new ArrayList<>(this.callHistory); + } + + /** + * Reset the call history. + */ + public void reset() { + this.callHistory.clear(); + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") Double lat, + @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") Double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + + /** + * Weather Function response. + */ + public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, + Unit unit) { + + } + +} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java deleted file mode 100644 index ce84e2a3d36..00000000000 --- a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.huggingface.client; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.huggingface.HuggingfaceChatModel; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.test.context.SpringBootTest; - -import static org.assertj.core.api.Assertions.assertThat; - -@SpringBootTest -@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") -@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_CHAT_URL", matches = ".+") -public class ClientIT { - - @Autowired - protected HuggingfaceChatModel huggingfaceChatModel; - - @Test - void helloWorldCompletion() { - String mistral7bInstruct = """ - [INST] You are a helpful code assistant. Your task is to generate a valid JSON object based on the given information: - name: John - lastname: Smith - address: #1 Samuel St. - Just generate the JSON object without explanations: - Your response should be in JSON format. - Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation. - Do not include markdown code blocks in your response. - Remove the ```json markdown from the output. - [/INST] - """; - Prompt prompt = new Prompt(mistral7bInstruct); - ChatResponse chatResponse = this.huggingfaceChatModel.call(prompt); - assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); - String expectedResponse = """ - { - "name": "John", - "lastname": "Smith", - "address": "#1 Samuel St." - }"""; - assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo(expectedResponse); - assertThat(chatResponse.getResult().getOutput().getMetadata()).containsKey("generated_tokens"); - assertThat(chatResponse.getResult().getOutput().getMetadata()).containsEntry("generated_tokens", 32); - - } - -} diff --git a/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java index 88105725a69..6bf89ed0d9f 100644 --- a/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java +++ b/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java @@ -55,6 +55,11 @@ public enum AiProvider { */ GOOGLE_GENAI_AI("google_genai"), + /** + * AI system provided by HuggingFace. + */ + HUGGINGFACE("huggingface"), + /** * AI system provided by Minimax. */ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/huggingface.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/huggingface.adoc index 7ee36094338..1061c50e168 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/huggingface.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/huggingface.adoc @@ -1,32 +1,28 @@ -= Hugging Face Chat += HuggingFace Chat -Hugging Face Text Generation Inference (TGI) is a specialized deployment solution for serving Large Language Models (LLMs) in the cloud, making them accessible via an API. TGI provides optimized performance for text generation tasks through features like continuous batching, token streaming, and efficient memory management. +Spring AI supports HuggingFace's language models through the HuggingFace Inference API. +HuggingFace provides access to thousands of pre-trained language models, from small efficient models to large state-of-the-art models, making advanced AI capabilities accessible through a simple API. -IMPORTANT: Text Generation Inference requires models to be compatible with its architecture-specific optimizations. While many popular LLMs are supported, not all models on Hugging Face Hub can be deployed using TGI. If you need to deploy other types of models, consider using standard Hugging Face Inference Endpoints instead. +IMPORTANT: The HuggingFace Chat implementation uses OpenAI-compatible endpoints (`/v1/chat/completions`). This provides broad compatibility with various HuggingFace deployment options including Inference Endpoints, Dedicated Endpoints, and Serverless Inference API. -TIP: For a complete and up-to-date list of supported models and architectures, see the link:https://huggingface.co/docs/text-generation-inference/en/supported_models[Text Generation Inference supported models documentation]. +TIP: For the most up-to-date list of supported models and deployment options, see the link:https://huggingface.co/docs/api-inference[HuggingFace Inference API documentation]. == Prerequisites -You will need to create an Inference Endpoint on Hugging Face and create an API token to access the endpoint. -Further details can be found link:https://huggingface.co/docs/inference-endpoints/index[here]. +You will need to create an API token with HuggingFace to access the Inference API. -The Spring AI project defines two configuration properties: +Create an account at https://huggingface.co/join[HuggingFace signup page] and generate a token on the https://huggingface.co/settings/tokens[Access Tokens page]. -1. `spring.ai.huggingface.chat.api-key`: Set this to the value of the API token obtained from Hugging Face. -2. `spring.ai.huggingface.chat.url`: Set this to the inference endpoint URL obtained when provisioning your model in Hugging Face. +The Spring AI project defines a configuration property named `spring.ai.huggingface.api-key` that you should set to the value of the API token obtained from huggingface.co. -You can find your inference endpoint URL on the Inference Endpoint's UI link:https://ui.endpoints.huggingface.co/[here]. - -You can set these configuration properties in your `application.properties` file: +You can set this configuration property in your `application.properties` file: [source,properties] ---- -spring.ai.huggingface.chat.api-key= -spring.ai.huggingface.chat.url= +spring.ai.huggingface.api-key= ---- -For enhanced security when handling sensitive information like API keys, you can use Spring Expression Language (SpEL) to reference custom environment variables: +For enhanced security when handling sensitive information like API keys, you can use Spring Expression Language (SpEL) to reference an environment variable: [source,yaml] ---- @@ -34,25 +30,21 @@ For enhanced security when handling sensitive information like API keys, you can spring: ai: huggingface: - chat: - api-key: ${HUGGINGFACE_API_KEY} - url: ${HUGGINGFACE_ENDPOINT_URL} + api-key: ${HUGGINGFACE_API_KEY} ---- [source,bash] ---- # In your environment or .env file -export HUGGINGFACE_API_KEY= -export HUGGINGFACE_ENDPOINT_URL= +export HUGGINGFACE_API_KEY= ---- -You can also set these configurations programmatically in your application code: +You can also set this configuration programmatically in your application code: [source,java] ---- -// Retrieve API key and endpoint URL from secure sources or environment variables +// Retrieve API key from a secure source or environment variable String apiKey = System.getenv("HUGGINGFACE_API_KEY"); -String endpointUrl = System.getenv("HUGGINGFACE_ENDPOINT_URL"); ---- === Add Repositories and BOM @@ -70,9 +62,13 @@ There has been a significant change in the Spring AI auto-configuration, starter Please refer to the https://docs.spring.io/spring-ai/reference/upgrade-notes.html[upgrade notes] for more information. ==== -Spring AI provides Spring Boot auto-configuration for the Hugging Face Chat Client. -To enable it add the following dependency to your project's Maven `pom.xml` file: +Spring AI provides Spring Boot auto-configuration for the HuggingFace Chat Model. +To enable it add the following dependency to your project's Maven `pom.xml` or Gradle `build.gradle` build files: +[tabs] +====== +Maven:: ++ [source, xml] ---- @@ -81,19 +77,52 @@ To enable it add the following dependency to your project's Maven `pom.xml` file ---- -or to your Gradle `build.gradle` build file. - +Gradle:: ++ [source,groovy] ---- dependencies { implementation 'org.springframework.ai:spring-ai-starter-model-huggingface' } ---- +====== TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. === Chat Properties +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the HuggingFace Chat model. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + +==== Connection Properties + +The prefix `spring.ai.huggingface` is used as the property prefix that lets you connect to HuggingFace Inference API. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.huggingface.api-key | The API Key (token) | - +|==== + +NOTE: The API key is shared between the Chat and Embedding models. You only need to configure it once. + +==== Configuration Properties + [NOTE] ==== Enabling and disabling of the chat auto-configurations are now configured via top level properties with the prefix `spring.ai.model.chat`. @@ -105,30 +134,196 @@ To disable, spring.ai.model.chat=none (or any value which doesn't match huggingf This change is done to allow configuration of multiple models. ==== -The prefix `spring.ai.huggingface` is the property prefix that lets you configure the chat model implementation for Hugging Face. +The prefix `spring.ai.huggingface.chat` is the property prefix that lets you configure the chat model implementation for HuggingFace. + +NOTE: Default values shown below for numeric options (e.g., temperature, frequency-penalty, presence-penalty) are HuggingFace Inference API defaults when these parameters are not specified in the request. Spring AI itself does not set these values - they are `null` by default and only applied by the API if omitted. To explicitly configure these values, set them in your application properties or at runtime. [cols="3,5,1", stripes=even] |==== | Property | Description | Default -| spring.ai.huggingface.chat.api-key | API Key to authenticate with the Inference Endpoint. | - -| spring.ai.huggingface.chat.url | URL of the Inference Endpoint to connect to | - -| spring.ai.huggingface.chat.enabled (Removed and no longer valid) | Enable Hugging Face chat model. | true -| spring.ai.model.chat | Enable Hugging Face chat model. | huggingface + +| spring.ai.model.chat | Enable HuggingFace chat model. | huggingface +| spring.ai.huggingface.chat.url | Base URL for the HuggingFace Inference API Chat endpoint | +https://router.huggingface.co/v1+ +| spring.ai.huggingface.chat.options.model | The model to use. Examples: `meta-llama/Llama-3.2-3B-Instruct`, `mistralai/Mistral-7B-Instruct-v0.3`, `google/gemma-2-9b-it` | meta-llama/Llama-3.2-3B-Instruct +| spring.ai.huggingface.chat.options.temperature | The sampling temperature to use that controls the apparent creativity of generated completions. Higher values will make output more random while lower values will make results more focused and deterministic. | - +| spring.ai.huggingface.chat.options.max-tokens | The maximum number of tokens to generate in the chat completion. | - +| spring.ai.huggingface.chat.options.top-p | An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. | - +| spring.ai.huggingface.chat.options.frequency-penalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. | - +| spring.ai.huggingface.chat.options.presence-penalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. | - +| spring.ai.huggingface.chat.options.stop | Up to 4 sequences where the API will stop generating further tokens. | - +| spring.ai.huggingface.chat.options.seed | Integer seed for reproducibility. Makes repeated requests with the same seed and parameters return the same result. | - +| spring.ai.huggingface.chat.options.response-format | An object specifying the format that the model must output. Setting to `{"type": "json_object"}` enables JSON mode. Can be configured as a Map with type and optional schema fields. | - +| spring.ai.huggingface.chat.options.tool-prompt | A prompt to be appended before the tools when using function calling. | - +| spring.ai.huggingface.chat.options.logprobs | Whether to return log probabilities of the output tokens. If true, returns the log probabilities of each output token. | - +| spring.ai.huggingface.chat.options.top-logprobs | An integer between 0 and 5 specifying the number of most likely tokens to return at each token position. Requires logprobs to be set to true. | - +| spring.ai.huggingface.chat.options.tool-names | List of tools, identified by their names, to enable for function calling in a single prompt request. Tools with those names must exist in the ToolCallback registry. | - +| spring.ai.huggingface.chat.options.tool-callbacks | Tool Callbacks to register with the ChatModel for function calling. | - +| spring.ai.huggingface.chat.options.internal-tool-execution-enabled | If false, Spring AI will not handle the tool calls internally, but will proxy them to the client. If true (the default), Spring AI will handle the function calls internally. | true |==== -== Sample Controller (Auto-configuration) +NOTE: You can override the common `spring.ai.huggingface.api-key` for the `ChatModel` and `EmbeddingModel` implementations if needed. The `spring.ai.huggingface.chat.api-key` property (if set) takes precedence over the common property. + +TIP: All properties prefixed with `spring.ai.huggingface.chat.options` can be overridden at runtime by adding request-specific <> to the `Prompt` call. + +== Runtime Options [[chat-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatOptions.java[HuggingfaceChatOptions.java] provides the HuggingFace configurations, such as the model to use, the temperature, max tokens, etc. + +The default options can be configured using the `spring.ai.huggingface.chat.options` properties as well. + +At start-time, use the `HuggingfaceChatModel` constructor to set the default options used for all chat requests. +At run-time, you can override the default options by adding a `HuggingfaceChatOptions` instance as part of your `Prompt`. + +For example, to override the default model name for a specific request: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "Generate the names of 5 famous pirates.", + HuggingfaceChatOptions.builder() + .model("mistralai/Mistral-7B-Instruct-v0.3") + .temperature(0.4) + .build() + )); +---- + +=== Advanced Options + +You can use additional parameters for more control over the model's behavior: + +[source,java] +---- +// Using stop sequences to limit generation +ChatResponse response = chatModel.call( + new Prompt( + "Count from 1 to 10", + HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .stopSequences(Arrays.asList("5", "STOP")) + .build() + )); + +// Using seed for reproducible outputs +ChatResponse response = chatModel.call( + new Prompt( + "Generate a random story", + HuggingfaceChatOptions.builder() + .seed(42) // Same seed produces same results + .temperature(0.7) + .build() + )); + +// Using JSON response format +Map responseFormat = new HashMap<>(); +responseFormat.put("type", "json_object"); +ChatResponse response = chatModel.call( + new Prompt( + "Generate a JSON object with fields: name, age, city", + HuggingfaceChatOptions.builder() + .responseFormat(responseFormat) + .build() + )); +---- + +TIP: In addition to model-specific `HuggingfaceChatOptions`, you can use portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java[ChatOptions] instance. This enables you to switch between different chat model providers with minimal code changes. + +== Function Calling + +You can register custom Java functions with the `HuggingfaceChatModel` and have the HuggingFace model intelligently choose to call them when appropriate. + +This is a powerful technique to connect the LLM capabilities with external tools and APIs. +Read more about link:https://docs.spring.io/spring-ai/reference/api/tools.html[Tool/Function Calling] in Spring AI. + +=== Example: Weather Service Function + +Here's a complete example demonstrating function calling with HuggingFace: + +[source,java] +---- +@Configuration +public class FunctionConfiguration { + + @Bean + @Description("Get the current weather conditions for a specific location") + public Function weatherFunction() { + return new MockWeatherService(); + } + + public record WeatherRequest(String location, String unit) {} + public record WeatherResponse(double temperature, double windSpeed, String forecast) {} +} + +@RestController +public class ChatController { + + private final ChatModel chatModel; + + public ChatController(ChatModel chatModel) { + this.chatModel = chatModel; + } + + @GetMapping("/ai/weather") + public String getWeather(@RequestParam String location) { + UserMessage userMessage = new UserMessage( + "What's the weather like in " + location + "?"); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + HuggingfaceChatOptions.builder() + .functionCallbacks(List.of(FunctionCallback.builder() + .function("weatherFunction", new MockWeatherService()) + .description("Get the current weather conditions") + .build())) + .build())); + + return response.getResult().getOutput().getText(); + } +} +---- + +IMPORTANT: Function calling support in HuggingFace requires both a compatible model AND provider. Not all models or providers support this feature. + +**Model and Provider Requirements:** + +* **Provider Suffix Required:** Function-calling models typically require a provider suffix in the model name (e.g., `meta-llama/Llama-3.2-3B-Instruct:together`) +* **Supported Providers:** Common providers include `together`, `fastest`, and others depending on the model +* **Compatible Models:** See the link:https://huggingface.co/collections/MarketAgents/function-calling-models-tool-use[HuggingFace Function Calling Models Collection] for a curated list + +**Configuration Example:** + +[source,yaml] +---- +spring: + ai: + huggingface: + chat: + options: + model: meta-llama/Llama-3.2-3B-Instruct:together # Note the :together provider suffix +---- + +For more details about function calling with HuggingFace, see the link:https://huggingface.co/docs/inference-providers/guides/function-calling[HuggingFace Function Calling Guide]. + +NOTE: Streaming function calling is not yet supported in this release. Non-streaming function calling (using `ChatModel.call()`) works as expected. + +== Sample Controller https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-starter-model-huggingface` to your pom (or gradle) dependencies. -Add an `application.properties` file, under the `src/main/resources` directory, to enable and configure the Hugging Face chat model: +Add an `application.yml` file, under the `src/main/resources` directory, to enable and configure the HuggingFace chat model: -[source,application.properties] +[source,yaml] ---- -spring.ai.huggingface.chat.api-key=YOUR_API_KEY -spring.ai.huggingface.chat.url=YOUR_INFERENCE_ENDPOINT_URL +spring: + ai: + huggingface: + api-key: ${HUGGINGFACE_API_KEY} + chat: + options: + model: meta-llama/Llama-3.2-3B-Instruct + temperature: 0.7 ---- -TIP: replace the `api-key` and `url` with your Hugging Face values. +TIP: Replace the `api-key` with your HuggingFace API token value. This will create a `HuggingfaceChatModel` implementation that you can inject into your class. Here is an example of a simple `@Controller` class that uses the chat model for text generations. @@ -138,25 +333,26 @@ Here is an example of a simple `@Controller` class that uses the chat model for @RestController public class ChatController { - private final HuggingfaceChatModel chatModel; + private final ChatModel chatModel; @Autowired - public ChatController(HuggingfaceChatModel chatModel) { + public ChatController(ChatModel chatModel) { this.chatModel = chatModel; } @GetMapping("/ai/generate") public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { - return Map.of("generation", this.chatModel.call(message)); + return Map.of("generation", chatModel.call(message)); } } ---- -== Manual Configuration +NOTE: Streaming is not currently supported by `HuggingfaceChatModel`. This feature is planned for a future release using WebClient integration. -The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java[HuggingfaceChatModel] implements the `ChatModel` interface and uses the <> to connect to the Hugging Face inference endpoints. +== Manual Configuration -Add the `spring-ai-huggingface` dependency to your project's Maven `pom.xml` file: +If you are not using Spring Boot, you can manually configure the HuggingFace Chat Model. +For this add the `spring-ai-huggingface` dependency to your project's Maven `pom.xml` file: [source, xml] ---- @@ -177,14 +373,113 @@ dependencies { TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. -Next, create a `HuggingfaceChatModel` and use it for text generations: +NOTE: The `spring-ai-huggingface` dependency provides access also to the `HuggingfaceEmbeddingModel`. +For more information about the `HuggingfaceEmbeddingModel` refer to the link:../embeddings/huggingface-embeddings.html[HuggingFace Embeddings] section. + +Next, create a `HuggingfaceChatModel` instance and use it for text generations: [source,java] ---- -HuggingfaceChatModel chatModel = new HuggingfaceChatModel(apiKey, url); - -ChatResponse response = this.chatModel.call( +var huggingfaceApi = HuggingfaceApi.builder() + .baseUrl("https://router.huggingface.co/v1") + .apiKey(System.getenv("HUGGINGFACE_API_KEY")) + .build(); + +var chatModel = HuggingfaceChatModel.builder() + .huggingfaceApi(huggingfaceApi) + .defaultOptions(HuggingfaceChatOptions.builder() + .model("meta-llama/Llama-3.2-3B-Instruct") + .temperature(0.7) + .build()) + .build(); + +ChatResponse response = chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); System.out.println(response.getResult().getOutput().getText()); ---- + +The `HuggingfaceChatOptions` provides the configuration information for the chat requests. +Both the API and options classes offer a `builder()` for easy instance creation. + +== Low-level HuggingfaceApi Client [[low-level-api]] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/api/HuggingfaceApi.java[HuggingfaceApi] provides a lightweight Java client for link:https://huggingface.co/docs/api-inference[HuggingFace Inference API]. + +The `HuggingfaceApi` supports: + +* **OpenAI-compatible Chat Completions**: Accessible via `/v1/chat/completions` endpoint (relative to chat base URL) +* **Feature Extraction for Embeddings**: Accessible via `/{model}/pipeline/feature-extraction` endpoint (relative to embedding base URL: `https://router.huggingface.co/hf-inference/models`) + +Here's a simple example of how to use the `HuggingfaceApi` directly for chat completions: + +[source,java] +---- +HuggingfaceApi huggingfaceApi = HuggingfaceApi.builder() + .baseUrl("https://router.huggingface.co/v1") + .apiKey(System.getenv("HUGGINGFACE_API_KEY")) + .build(); + +// Create a user message +HuggingfaceApi.Message userMessage = new HuggingfaceApi.Message( + "user", + "Explain quantum computing in simple terms"); + +// Create chat request with options +Map options = new HashMap<>(); +options.put("temperature", 0.7); + +HuggingfaceApi.ChatRequest chatRequest = new HuggingfaceApi.ChatRequest( + "meta-llama/Llama-3.2-3B-Instruct", + List.of(userMessage), + options); + +// Call the API +HuggingfaceApi.ChatResponse response = huggingfaceApi.chat(chatRequest); +String assistantReply = response.choices().get(0).message().content(); +---- + +== Supported Models + +HuggingFace Inference API provides access to thousands of models. Popular chat models include: + +* **Llama Models**: `meta-llama/Llama-3.2-3B-Instruct`, `meta-llama/Llama-3.1-8B-Instruct` +* **Mistral Models**: `mistralai/Mistral-7B-Instruct-v0.3`, `mistralai/Mixtral-8x7B-Instruct-v0.1` +* **Gemma Models**: `google/gemma-2-9b-it`, `google/gemma-2-27b-it` +* **Qwen Models**: `Qwen/Qwen2.5-7B-Instruct`, `Qwen/Qwen2.5-72B-Instruct` +* **DeepSeek Models**: `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1` + +You can browse all available models at https://huggingface.co/models?pipeline_tag=text-generation&sort=trending[HuggingFace Model Hub]. + +IMPORTANT: Ensure the model you choose supports the OpenAI-compatible chat completions endpoint. Most instruction-tuned models work well, but always check the model card for API compatibility information. + +== Observability + +Spring AI provides built-in observability for HuggingFace Chat models through Micrometer and Spring Boot actuators. + +To enable observability: + +1. Add the Spring Boot Actuator dependency to your project +2. Enable metrics in your `application.yml`: + +[source,yaml] +---- +management: + endpoints: + web: + exposure: + include: "*" + metrics: + export: + simple: + enabled: true +---- + +The HuggingFace Chat model will automatically export metrics including: + +* Request count +* Request duration +* Token usage (when provided by the model) +* Error rates + +These metrics are tagged with the model name and provider for easy filtering and aggregation. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/huggingface-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/huggingface-embeddings.adoc new file mode 100644 index 00000000000..8dc081a18a9 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/huggingface-embeddings.adoc @@ -0,0 +1,293 @@ += HuggingFace Embeddings + +Spring AI supports HuggingFace's text embedding models through the HuggingFace Inference API. +HuggingFace's text embeddings measure the relatedness of text strings using various transformer-based models. +An embedding is a vector (list) of floating point numbers. The distance between two vectors measures their relatedness. Small distances suggest high relatedness and large distances suggest low relatedness. + +IMPORTANT: The HuggingFace Embedding implementation uses the Feature Extraction pipeline endpoint. Make sure the model you select supports feature extraction for text embeddings. + +== Prerequisites + +You will need to create an API token with HuggingFace to access HuggingFace Inference API embedding models. + +Create an account at https://huggingface.co/join[HuggingFace signup page] and generate a token on the https://huggingface.co/settings/tokens[Access Tokens page]. + +The Spring AI project defines a configuration property named `spring.ai.huggingface.api-key` that you should set to the value of the API token obtained from huggingface.co. + +You can set this configuration property in your `application.properties` file: + +[source,properties] +---- +spring.ai.huggingface.api-key= +---- + +For enhanced security when handling sensitive information like API keys, you can use Spring Expression Language (SpEL) to reference an environment variable: + +[source,yaml] +---- +# In application.yml +spring: + ai: + huggingface: + api-key: ${HUGGINGFACE_API_KEY} +---- + +[source,bash] +---- +# In your environment or .env file +export HUGGINGFACE_API_KEY= +---- + +You can also set this configuration programmatically in your application code: + +[source,java] +---- +// Retrieve API key from a secure source or environment variable +String apiKey = System.getenv("HUGGINGFACE_API_KEY"); +---- + +=== Add Repositories and BOM + +Spring AI artifacts are published in Maven Central and Spring Snapshot repositories. +Refer to the xref:getting-started.adoc#artifact-repositories[Artifact Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + +== Auto-configuration + +[NOTE] +==== +There has been a significant change in the Spring AI auto-configuration, starter modules' artifact names. +Please refer to the https://docs.spring.io/spring-ai/reference/upgrade-notes.html[upgrade notes] for more information. +==== + +Spring AI provides Spring Boot auto-configuration for the HuggingFace Embedding Model. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-starter-model-huggingface + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-starter-model-huggingface' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Embedding Properties + +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the HuggingFace Embedding model. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + +==== Connection Properties + +The prefix `spring.ai.huggingface` is used as the property prefix that lets you connect to HuggingFace Inference API. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.huggingface.api-key | The API Key (token) | - +|==== + +NOTE: The API key is shared between the Chat and Embedding models. You only need to configure it once. + +==== Configuration Properties + +[NOTE] +==== +Enabling and disabling of the embedding auto-configurations are now configured via top level properties with the prefix `spring.ai.model.embedding`. + +To enable, spring.ai.model.embedding=huggingface (It is enabled by default) + +To disable, spring.ai.model.embedding=none (or any value which doesn't match huggingface) + +This change is done to allow configuration of multiple models. +==== + +The prefix `spring.ai.huggingface.embedding` is property prefix that configures the `EmbeddingModel` implementation for HuggingFace. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.model.embedding | Enable HuggingFace embedding model. | huggingface +| spring.ai.huggingface.embedding.enabled | Enable HuggingFace embedding model (deprecated, use spring.ai.model.embedding instead) | true +| spring.ai.huggingface.embedding.url | Base URL for the HuggingFace Inference API Feature Extraction endpoint | +https://router.huggingface.co/hf-inference/models+ +| spring.ai.huggingface.embedding.options.model | The model to use for embeddings | sentence-transformers/all-MiniLM-L6-v2 +| spring.ai.huggingface.embedding.options.normalize | Whether to normalize embeddings to unit length | - +| spring.ai.huggingface.embedding.options.prompt-name | Name of a predefined prompt from model config to apply | - +| spring.ai.huggingface.embedding.options.truncate | Whether to truncate text exceeding model's max length | - +| spring.ai.huggingface.embedding.options.truncation-direction | Which side to truncate: "left" or "right" | - +|==== + +NOTE: HuggingFace Embedding uses the Feature Extraction API. The options `normalize`, `prompt_name`, `truncate`, and `truncation_direction` are part of the standard Feature Extraction API specification. + +TIP: All properties prefixed with `spring.ai.huggingface.embedding.options` can be overridden at runtime by adding a request specific <> to the `EmbeddingRequest` call. + +== Runtime Options [[embedding-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceEmbeddingOptions.java[HuggingfaceEmbeddingOptions.java] provides the HuggingFace configurations, such as the model to use and etc. + +The default options can be configured using the `spring.ai.huggingface.embedding.options` properties as well. + +At start-time use the `HuggingfaceEmbeddingModel` constructor to set the default options used for all embedding requests. +At run-time you can override the default options, using a `HuggingfaceEmbeddingOptions` instance as part of your `EmbeddingRequest`. + +For example to override the default model name for a specific request: + +[source,java] +---- +EmbeddingResponse embeddingResponse = embeddingModel.call( + new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), + HuggingfaceEmbeddingOptions.builder() + .model("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") + .build())); +---- + +=== Using Advanced Options + +You can use advanced options to customize the embedding behavior: + +[source,java] +---- +EmbeddingResponse queryEmbedding = embeddingModel.call( + new EmbeddingRequest(List.of("What is machine learning?"), + HuggingfaceEmbeddingOptions.builder() + .promptName("query") // Apply "query" prompt from model config + .truncate(true) // Truncate long text + .truncationDirection("right") // Truncate from the right + .normalize(true) // Normalize embeddings to unit length + .build())); + +EmbeddingResponse documentEmbedding = embeddingModel.call( + new EmbeddingRequest(List.of("Machine learning is a subset of AI..."), + HuggingfaceEmbeddingOptions.builder() + .promptName("passage") // Apply "passage" prompt for documents + .normalize(true) + .build())); +---- + +TIP: You can use portable `EmbeddingOptions` implementation for runtime configuration, enabling you to switch between different embedding model providers with minimal code changes. + +== Sample Controller + +This will create a `EmbeddingModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the `EmbeddingModel` implementation. + +[source,application.properties] +---- +spring.ai.huggingface.api-key=YOUR_API_KEY +spring.ai.huggingface.embedding.options.model=sentence-transformers/all-MiniLM-L6-v2 +---- + +[source,java] +---- +@RestController +public class EmbeddingController { + + private final EmbeddingModel embeddingModel; + + @Autowired + public EmbeddingController(EmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + } + + @GetMapping("/ai/embedding") + public Map embed(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of(message)); + return Map.of("embedding", embeddingResponse); + } +} +---- + +== Manual Configuration + +If you are not using Spring Boot, you can manually configure the HuggingFace Embedding Model. +For this add the `spring-ai-huggingface` dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-huggingface + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-huggingface' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +NOTE: The `spring-ai-huggingface` dependency provides access also to the `HuggingfaceChatModel`. +For more information about the `HuggingfaceChatModel` refer to the link:../chat/huggingface-chat.html[HuggingFace Chat Client] section. + +Next, create a `HuggingfaceEmbeddingModel` instance and use it to compute the similarity between two input texts: + +[source,java] +---- +var huggingfaceApi = HuggingfaceApi.builder() + .baseUrl("https://router.huggingface.co/hf-inference/models") + .apiKey(System.getenv("HUGGINGFACE_API_KEY")) + .build(); + +var embeddingModel = HuggingfaceEmbeddingModel.builder() + .huggingfaceApi(huggingfaceApi) + .defaultOptions(HuggingfaceEmbeddingOptions.builder() + .model("sentence-transformers/all-MiniLM-L6-v2") + .build()) + .build(); + +EmbeddingResponse embeddingResponse = embeddingModel + .embedForResponse(List.of("Hello World", "World is big and salvation is near")); +---- + +The `HuggingfaceEmbeddingOptions` provides the configuration information for the embedding requests. +Both the API and options classes offer a `builder()` for easy instance creation. + +NOTE: The HuggingFace Embedding implementation returns `EmptyUsage` for usage metadata since the HuggingFace Inference API does not provide token usage information for embedding requests. + +== Supported Models + +HuggingFace Inference API supports a wide range of embedding models through the Feature Extraction pipeline. +Popular choices include: + +- `sentence-transformers/all-MiniLM-L6-v2` - Fast, efficient, general-purpose embeddings (default) +- `sentence-transformers/all-mpnet-base-v2` - High-quality general-purpose embeddings +- `sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2` - Multilingual support (50+ languages) +- `BAAI/bge-small-en-v1.5` - High-quality English embeddings +- `intfloat/e5-large-v2` - State-of-the-art embeddings for various tasks + +You can find more embedding models on the https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending[HuggingFace Model Hub]. + +IMPORTANT: Ensure the model you choose supports the Feature Extraction pipeline and is compatible with the HuggingFace Inference API. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/providers/huggingface/index.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/providers/huggingface/index.adoc index 06860b9863a..af7c77f1858 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/providers/huggingface/index.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/providers/huggingface/index.adoc @@ -1,14 +1,136 @@ [[hugging-face]] -= Hugging Face += HuggingFace -One of the easiest ways you can get access to many machine learning and artificial intelligence models is by using the https://en.wikipedia.org/wiki/Hugging_Face[Hugging Face's] https://huggingface.co/inference-endpoints[Inference Endpoints]. +Spring AI provides comprehensive integration with link:https://huggingface.co[HuggingFace], one of the most popular platforms for machine learning and artificial intelligence models. +HuggingFace offers access to thousands of pre-trained models, datasets, and deployment options, making cutting-edge AI capabilities accessible to developers. -Hugging Face Hub is a platform that provides a collaborative environment for creating and sharing tens of thousands of Open Source ML/AI models, data sets, and demo applications. +== Overview -Inference Endpoints let you deploy AI Models on dedicated infrastructure with a pay-as-you-go billing model. +HuggingFace Hub is a collaborative platform providing: + +* **Tens of thousands of open-source AI models** - From small efficient models to large state-of-the-art language models +* **Model hosting and deployment** - Inference Endpoints, Dedicated Endpoints, and Serverless API +* **Diverse model types** - Text generation, embeddings, image generation, speech recognition, and more +* **Community-driven** - Active open-source community continuously contributing new models + +link:https://huggingface.co/inference-endpoints[Inference Endpoints] enable you to deploy AI models on dedicated infrastructure with a pay-as-you-go billing model. You can use infrastructure provided by Amazon Web Services, Microsoft Azure, and Google Cloud Platform. -Hugging Face lets you run the models on your own machine, but it is quite common to not have enough CPU/GPU resources to run the larger, more AI-focused models. -It provides access to Meta's recent (August 2023) Llama 2 and CodeLlama 2 models and provides the https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard[Open LLM Leaderboard], where you can quickly discover high quality models. +== Spring AI Support + +Spring AI provides native support for HuggingFace models through two main integrations: + +=== Chat Models +The xref:api/chat/huggingface.adoc[HuggingFace Chat integration] enables you to use any text generation model from HuggingFace Hub for conversational AI applications. + +**Key Features:** +* OpenAI-compatible API endpoint (`/v1/chat/completions`) +* Support for thousands of instruction-tuned models (Llama, Mistral, Gemma, Qwen, etc.) +* Function/tool calling capabilities (model-dependent) +* Full observability and metrics support + +**Example Configuration:** +[source,yaml] +---- +spring: + ai: + huggingface: + api-key: ${HUGGINGFACE_API_KEY} + chat: + options: + model: meta-llama/Llama-3.2-3B-Instruct + temperature: 0.7 +---- + +link:api/chat/huggingface.adoc[Learn more about HuggingFace Chat integration →] + +=== Embedding Models +The xref:api/embeddings/huggingface-embeddings.adoc[HuggingFace Embedding integration] enables you to generate text embeddings using HuggingFace's Feature Extraction pipeline. + +**Key Features:** +* Access to specialized embedding models (sentence-transformers, BAAI, intfloat, etc.) +* Multilingual embedding support +* Semantic search and similarity calculations +* Full observability and metrics support + +**Example Configuration:** +[source,yaml] +---- +spring: + ai: + huggingface: + api-key: ${HUGGINGFACE_API_KEY} + embedding: + options: + model: sentence-transformers/all-MiniLM-L6-v2 +---- + +link:api/embeddings/huggingface-embeddings.adoc[Learn more about HuggingFace Embeddings integration →] + +== Popular Models + +HuggingFace provides access to a vast collection of models. Here are some popular choices: + +=== Chat/Text Generation Models +* **Llama Series**: `meta-llama/Llama-3.2-3B-Instruct`, `meta-llama/Llama-3.1-8B-Instruct` +* **Mistral Series**: `mistralai/Mistral-7B-Instruct-v0.3`, `mistralai/Mixtral-8x7B-Instruct-v0.1` +* **Gemma Series**: `google/gemma-2-9b-it`, `google/gemma-2-27b-it` +* **Qwen Series**: `Qwen/Qwen2.5-7B-Instruct`, `Qwen/Qwen2.5-72B-Instruct` + +Browse more at link:https://huggingface.co/models?pipeline_tag=text-generation&sort=trending[HuggingFace Model Hub (Text Generation)] + +=== Embedding Models +* **Sentence Transformers**: `sentence-transformers/all-MiniLM-L6-v2`, `sentence-transformers/all-mpnet-base-v2` +* **BGE Series**: `BAAI/bge-small-en-v1.5`, `BAAI/bge-large-en-v1.5` +* **E5 Series**: `intfloat/e5-large-v2`, `intfloat/e5-base-v2` +* **Multilingual**: `sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2` + +Browse more at link:https://huggingface.co/models?pipeline_tag=feature-extraction&sort=trending[HuggingFace Model Hub (Feature Extraction)] + +== Deployment Options + +HuggingFace offers several deployment options that work seamlessly with Spring AI: + +=== Inference Endpoints (Recommended) +Dedicated infrastructure for production workloads: + +* **Flexible scaling**: Choose CPU or GPU instances based on your needs +* **Pay-as-you-go**: $0.06 per CPU core/hr, $0.6 per GPU/hr (pricing may vary) +* **Multiple cloud providers**: AWS, Azure, Google Cloud +* **Automatic scaling**: Scale up or down based on demand + +link:https://huggingface.co/inference-endpoints[Learn more about Inference Endpoints] + +=== Serverless Inference API +Free tier for development and testing: + +* **No infrastructure management**: Fully managed by HuggingFace +* **Quick experimentation**: Test models without setup +* **Rate-limited**: Suitable for development, not production + +link:https://huggingface.co/docs/api-inference/index[Learn more about Serverless Inference API] + +=== Dedicated Endpoints +Enterprise-grade deployment for mission-critical applications: + +* **Reserved capacity**: Guaranteed availability and performance +* **SLA guarantees**: Production-ready reliability +* **Custom configurations**: Tailored to your specific requirements + +== Getting Started + +To get started with HuggingFace in Spring AI: + +1. **Create a HuggingFace account** at link:https://huggingface.co/join[https://huggingface.co/join] +2. **Generate an API token** on the link:https://huggingface.co/settings/tokens[Access Tokens page] +3. **Add the Spring AI HuggingFace starter** to your project +4. **Configure your application** with the API key and desired model + +Complete setup instructions are available in the Chat and Embedding documentation pages linked above. + +== Additional Resources -While Hugging Face has a free hosting tier, which is very useful for quickly evaluating if a specific ML/AI Model fits your needs, they do not let you access many of those models on the free tier by using the https://huggingface.co/docs/text-generation-inference/main/en/index[Text Generation Interface API]. If you want to end up on production anyway, with a stable API, pay a few cents to try out a reliable solution. Prices are as low as $0.06 per CPU core/hr and $0.6 per GPU/hr. +* link:https://huggingface.co/docs[HuggingFace Documentation] +* link:https://huggingface.co/docs/api-inference[Inference API Documentation] +* link:https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard[Open LLM Leaderboard] +* link:https://huggingface.co/pricing[HuggingFace Pricing]