diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiAssistantMessage.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiAssistantMessage.java new file mode 100644 index 00000000000..c18265c33ac --- /dev/null +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiAssistantMessage.java @@ -0,0 +1,181 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mistralai; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.content.Media; + +/** + * A Mistral AI specific implementation of {@link AssistantMessage} that supports + * additional fields returned by Magistral reasoning models. + * + *

+ * Magistral models (like magistral-medium-latest and magistral-small-latest) return + * thinking/reasoning content alongside the regular response content. This class captures + * both the final response text and the intermediate reasoning process. + *

+ * + * @author Kyle Kreuter + * @since 1.1.0 + */ +public class MistralAiAssistantMessage extends AssistantMessage { + + /** + * The thinking/reasoning content from Magistral models. This contains the model's + * intermediate reasoning steps before producing the final response. + */ + private String thinkingContent; + + /** + * Constructs a new MistralAiAssistantMessage with all fields. + * @param content the main text content of the message + * @param thinkingContent the thinking/reasoning content from Magistral models + * @param properties additional metadata properties + * @param toolCalls list of tool calls requested by the model + * @param media list of media attachments + */ + protected MistralAiAssistantMessage(String content, String thinkingContent, Map properties, + List toolCalls, List media) { + super(content, properties, toolCalls, media); + this.thinkingContent = thinkingContent; + } + + /** + * Returns the thinking/reasoning content from Magistral models. + * @return the thinking content, or null if not available + */ + public String getThinkingContent() { + return this.thinkingContent; + } + + /** + * Sets the thinking/reasoning content. + * @param thinkingContent the thinking content to set + * @return this instance for method chaining + */ + public MistralAiAssistantMessage setThinkingContent(String thinkingContent) { + this.thinkingContent = thinkingContent; + return this; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof MistralAiAssistantMessage that)) { + return false; + } + if (!super.equals(o)) { + return false; + } + return Objects.equals(this.thinkingContent, that.thinkingContent); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.thinkingContent); + } + + @Override + public String toString() { + return "MistralAiAssistantMessage{" + "media=" + this.media + ", messageType=" + this.messageType + + ", metadata=" + this.metadata + ", thinkingContent='" + this.thinkingContent + '\'' + + ", textContent='" + this.textContent + '\'' + '}'; + } + + /** + * Builder for creating MistralAiAssistantMessage instances. + */ + public static final class Builder { + + private String content; + + private Map properties = Map.of(); + + private List toolCalls = List.of(); + + private List media = List.of(); + + private String thinkingContent; + + /** + * Sets the main text content. + * @param content the content to set + * @return this builder + */ + public Builder content(String content) { + this.content = content; + return this; + } + + /** + * Sets the metadata properties. + * @param properties the properties to set + * @return this builder + */ + public Builder properties(Map properties) { + this.properties = properties; + return this; + } + + /** + * Sets the tool calls. + * @param toolCalls the tool calls to set + * @return this builder + */ + public Builder toolCalls(List toolCalls) { + this.toolCalls = toolCalls; + return this; + } + + /** + * Sets the media attachments. + * @param media the media to set + * @return this builder + */ + public Builder media(List media) { + this.media = media; + return this; + } + + /** + * Sets the thinking/reasoning content from Magistral models. + * @param thinkingContent the thinking content to set + * @return this builder + */ + public Builder thinkingContent(String thinkingContent) { + this.thinkingContent = thinkingContent; + return this; + } + + /** + * Builds the MistralAiAssistantMessage instance. + * @return a new MistralAiAssistantMessage + */ + public MistralAiAssistantMessage build() { + return new MistralAiAssistantMessage(this.content, this.thinkingContent, this.properties, this.toolCalls, + this.media); + } + + } + +} diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index 431c25afc19..02eff86e661 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -358,8 +358,8 @@ private Generation buildGeneration(Choice choice, Map metadata) toolCall.function().name(), toolCall.function().arguments())) .toList(); - var assistantMessage = AssistantMessage.builder() - .content(choice.message().content()) + var assistantMessage = new MistralAiAssistantMessage.Builder().content(choice.message().content()) + .thinkingContent(choice.message().thinkingContent()) .properties(metadata) .toolCalls(toolCalls) .build(); diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java index e55987e91f2..0b9510600fb 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java @@ -63,6 +63,7 @@ * @author Thomas Vitale * @author Jason Smith * @author Nicolas Krier + * @author Kyle Kreuter * @since 1.0.0 */ public class MistralAiApi { @@ -209,6 +210,51 @@ public Flux chatCompletionStream(ChatCompletionRequest chat .flatMap(mono -> mono); } + /** + * Sealed interface for content chunks returned by Magistral reasoning models. + * Magistral models can return content as an array of typed blocks instead of a simple + * string. + * + * @since 1.0.0 + */ + public sealed interface ContentChunk permits TextChunk, ThinkChunk, ReferenceChunk { + + } + + /** + * A text content chunk containing the main response text. + * + * @param text the text content + */ + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record TextChunk(@JsonProperty("text") String text) implements ContentChunk { + + } + + /** + * A thinking/reasoning content chunk from Magistral models. Contains the model's + * intermediate reasoning process. + * + * @param thinking the thinking/reasoning content + */ + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ThinkChunk(@JsonProperty("thinking") String thinking) implements ContentChunk { + + } + + /** + * A reference content chunk containing citation reference IDs. + * + * @param referenceIds list of reference IDs for citations + */ + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ReferenceChunk(@JsonProperty("reference_ids") List referenceIds) implements ContentChunk { + + } + /** * The reason the model stopped generating tokens. */ @@ -806,7 +852,9 @@ public ChatCompletionMessage(Object content, Role role) { } /** - * Get message content as String. + * Returns the text content of the message. For reasoning models (Magistral), + * extracts the text block from the content array. + * @return the text content or null if not available */ public String content() { if (this.rawContent == null) { @@ -815,7 +863,132 @@ public String content() { if (this.rawContent instanceof String text) { return text; } - throw new IllegalStateException("The content is not a string!"); + if (this.rawContent instanceof List blocks) { + StringBuilder textBuilder = new StringBuilder(); + for (Object block : blocks) { + if (block instanceof Map map && "text".equals(map.get("type"))) { + Object text = map.get("text"); + if (text instanceof String s) { + if (!textBuilder.isEmpty()) { + textBuilder.append("\n"); + } + textBuilder.append(s); + } + } + } + return textBuilder.isEmpty() ? null : textBuilder.toString(); + } + throw new IllegalStateException("Unexpected content type: " + rawContent.getClass()); + } + + /** + * Returns the thinking/reasoning content from Magistral models. For non-Magistral + * models or when no thinking content is present, returns null. + * @return the thinking content or null if not available + */ + public String thinkingContent() { + if (this.rawContent == null) { + return null; + } + if (this.rawContent instanceof String) { + return null; + } + if (this.rawContent instanceof List blocks) { + StringBuilder thinkingBuilder = new StringBuilder(); + for (Object block : blocks) { + if (block instanceof Map map && "thinking".equals(map.get("type"))) { + Object thinking = map.get("thinking"); + if (thinking instanceof List thinkingBlocks) { + for (Object thinkingBlock : thinkingBlocks) { + if (thinkingBlock instanceof Map thinkingMap + && "text".equals(thinkingMap.get("type"))) { + Object text = thinkingMap.get("text"); + if (text instanceof String s) { + if (!thinkingBuilder.isEmpty()) { + thinkingBuilder.append("\n"); + } + thinkingBuilder.append(s); + } + } + } + } + else if (thinking instanceof String s) { + if (!thinkingBuilder.isEmpty()) { + thinkingBuilder.append("\n"); + } + thinkingBuilder.append(s); + } + } + } + return thinkingBuilder.isEmpty() ? null : thinkingBuilder.toString(); + } + return null; + } + + /** + * Parses the raw content into a list of typed ContentChunk objects. For string + * content, returns a single TextChunk. For array content from Magistral models, + * parses each block into its appropriate type. + * @return list of ContentChunk objects, or empty list if content is null + */ + @SuppressWarnings("unchecked") + public List contentChunks() { + if (this.rawContent == null) { + return List.of(); + } + if (this.rawContent instanceof String text) { + return List.of(new TextChunk(text)); + } + if (this.rawContent instanceof List blocks) { + List chunks = new java.util.ArrayList<>(); + for (Object block : blocks) { + if (block instanceof Map map) { + String type = (String) map.get("type"); + if ("text".equals(type)) { + String text = (String) map.get("text"); + if (text != null) { + chunks.add(new TextChunk(text)); + } + } + else if ("thinking".equals(type)) { + Object thinking = map.get("thinking"); + if (thinking instanceof List thinkingBlocks) { + StringBuilder thinkingBuilder = new StringBuilder(); + for (Object thinkingBlock : thinkingBlocks) { + if (thinkingBlock instanceof Map thinkingMap + && "text".equals(thinkingMap.get("type"))) { + Object text = thinkingMap.get("text"); + if (text instanceof String s) { + if (!thinkingBuilder.isEmpty()) { + thinkingBuilder.append("\n"); + } + thinkingBuilder.append(s); + } + } + } + if (!thinkingBuilder.isEmpty()) { + chunks.add(new ThinkChunk(thinkingBuilder.toString())); + } + } + else if (thinking instanceof String s) { + chunks.add(new ThinkChunk(s)); + } + } + else if ("reference".equals(type)) { + Object refIds = map.get("reference_ids"); + if (refIds instanceof List ids) { + List referenceIds = ((List) ids).stream() + .filter(id -> id instanceof Number) + .map(id -> ((Number) id).intValue()) + .toList(); + chunks.add(new ReferenceChunk(referenceIds)); + } + } + } + } + return chunks; + } + return List.of(); } /** diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiAssistantMessageTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiAssistantMessageTests.java new file mode 100644 index 00000000000..65a584137b2 --- /dev/null +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiAssistantMessageTests.java @@ -0,0 +1,326 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mistralai; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.messages.AssistantMessage.ToolCall; +import org.springframework.ai.content.Media; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; + +/** + * Unit tests for {@link MistralAiAssistantMessage}. Tests the builder pattern, + * equals/hashCode contract, and proper handling of the thinkingContent field for + * Magistral reasoning models. + * + * @author Kyle Kreuter + */ +class MistralAiAssistantMessageTests { + + // Builder Tests + + @Test + void testBuildMessageWithContentOnly() { + String content = "Hello, world!"; + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content(content).build(); + + assertThat(message.getText()).isEqualTo(content); + assertThat(message.getThinkingContent()).isNull(); + assertThat(message.getToolCalls()).isEmpty(); + assertThat(message.getMedia()).isEmpty(); + } + + @Test + void testBuildMessageWithContentAndThinkingContent() { + String content = "The answer is 42."; + String thinkingContent = "Let me calculate this step by step..."; + + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content(content) + .thinkingContent(thinkingContent) + .build(); + + assertThat(message.getText()).isEqualTo(content); + assertThat(message.getThinkingContent()).isEqualTo(thinkingContent); + } + + @Test + void testBuildMessageWithAllProperties() { + String content = "Response content"; + String thinkingContent = "Thinking process"; + Map properties = new HashMap<>(); + properties.put("key1", "value1"); + properties.put("key2", 123); + + List toolCalls = List.of(new ToolCall("1", "function", "testFunction", "{}")); + + List media = List.of(); + + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content(content) + .thinkingContent(thinkingContent) + .properties(properties) + .toolCalls(toolCalls) + .media(media) + .build(); + + assertThat(message.getText()).isEqualTo(content); + assertThat(message.getThinkingContent()).isEqualTo(thinkingContent); + assertThat(message.getMetadata()).containsAllEntriesOf(properties); + assertThat(message.getToolCalls()).isEqualTo(toolCalls); + assertThat(message.getMedia()).isEqualTo(media); + } + + @Test + void testBuildMessageWithNullContent() { + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content(null).build(); + + assertThat(message.getText()).isNull(); + assertThat(message.getThinkingContent()).isNull(); + } + + @Test + void testBuildMessageWithEmptyContent() { + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("").build(); + + assertThat(message.getText()).isEmpty(); + } + + @Test + void testDefaultEmptyCollectionsWhenNotSpecified() { + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("content").build(); + + assertThat(message.getToolCalls()).isNotNull().isEmpty(); + assertThat(message.getMedia()).isNotNull().isEmpty(); + assertThat(message.getMetadata()).isNotNull(); + } + + // Setter Tests + + @Test + void testSetAndGetThinkingContent() { + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("content").build(); + + String thinkingContent = "New thinking content"; + message.setThinkingContent(thinkingContent); + + assertThat(message.getThinkingContent()).isEqualTo(thinkingContent); + } + + @Test + void testMethodChainingOnSetThinkingContent() { + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("content").build(); + + MistralAiAssistantMessage result = message.setThinkingContent("thinking"); + + assertThat(result).isSameAs(message); + assertThat(message.getThinkingContent()).isEqualTo("thinking"); + } + + @Test + void testSettingThinkingContentToNull() { + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("content") + .thinkingContent("initial thinking") + .build(); + + message.setThinkingContent(null); + + assertThat(message.getThinkingContent()).isNull(); + } + + // Equals and HashCode Tests + + @Test + void testEqualityForSameContentAndThinkingContent() { + MistralAiAssistantMessage message1 = new MistralAiAssistantMessage.Builder().content("content") + .thinkingContent("thinking") + .build(); + + MistralAiAssistantMessage message2 = new MistralAiAssistantMessage.Builder().content("content") + .thinkingContent("thinking") + .build(); + + assertThat(message1).isEqualTo(message2); + assertThat(message1.hashCode()).isEqualTo(message2.hashCode()); + } + + @Test + void testInequalityForDifferentThinkingContent() { + MistralAiAssistantMessage message1 = new MistralAiAssistantMessage.Builder().content("content") + .thinkingContent("thinking1") + .build(); + + MistralAiAssistantMessage message2 = new MistralAiAssistantMessage.Builder().content("content") + .thinkingContent("thinking2") + .build(); + + assertThat(message1).isNotEqualTo(message2); + } + + @Test + void testInequalityForDifferentContent() { + MistralAiAssistantMessage message1 = new MistralAiAssistantMessage.Builder().content("content1") + .thinkingContent("thinking") + .build(); + + MistralAiAssistantMessage message2 = new MistralAiAssistantMessage.Builder().content("content2") + .thinkingContent("thinking") + .build(); + + assertThat(message1).isNotEqualTo(message2); + } + + @Test + void testNullThinkingContentEquality() { + MistralAiAssistantMessage message1 = new MistralAiAssistantMessage.Builder().content("content") + .thinkingContent(null) + .build(); + + MistralAiAssistantMessage message2 = new MistralAiAssistantMessage.Builder().content("content") + .thinkingContent(null) + .build(); + + MistralAiAssistantMessage message3 = new MistralAiAssistantMessage.Builder().content("content") + .thinkingContent("thinking") + .build(); + + assertThat(message1).isEqualTo(message2); + assertThat(message1).isNotEqualTo(message3); + } + + @Test + void testEqualityToItself() { + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("content") + .thinkingContent("thinking") + .build(); + + assertThat(message).isEqualTo(message); + } + + @Test + void testInequalityToNull() { + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("content").build(); + + assertThat(message).isNotEqualTo(null); + } + + @Test + void testInequalityToDifferentType() { + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("content").build(); + + assertThat(message).isNotEqualTo("content"); + } + + // ToString Tests + + @Test + void testToStringDoesNotThrowException() { + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("content") + .thinkingContent("thinking") + .build(); + + assertThatNoException().isThrownBy(message::toString); + } + + @Test + void testToStringIncludesRelevantFields() { + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("test content") + .thinkingContent("test thinking") + .build(); + + String toString = message.toString(); + + assertThat(toString).contains("test content"); + assertThat(toString).contains("test thinking"); + assertThat(toString).contains("MistralAiAssistantMessage"); + } + + @Test + void testToStringHandlesNullThinkingContent() { + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("content").build(); + + assertThatNoException().isThrownBy(message::toString); + assertThat(message.toString()).contains("null"); + } + + // Tool Calls Tests + + @Test + void testBuildMessageWithToolCalls() { + List toolCalls = List.of(new ToolCall("call-1", "function", "getWeather", "{\"city\":\"Paris\"}"), + new ToolCall("call-2", "function", "getTime", "{\"timezone\":\"UTC\"}")); + + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("content") + .toolCalls(toolCalls) + .build(); + + assertThat(message.getToolCalls()).hasSize(2); + assertThat(message.getToolCalls().get(0).id()).isEqualTo("call-1"); + assertThat(message.getToolCalls().get(0).name()).isEqualTo("getWeather"); + assertThat(message.getToolCalls().get(1).id()).isEqualTo("call-2"); + assertThat(message.getToolCalls().get(1).name()).isEqualTo("getTime"); + } + + @Test + void testBuildMessageWithThinkingContentAndToolCalls() { + List toolCalls = List.of(new ToolCall("call-1", "function", "calculator", "{\"op\":\"add\"}")); + + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("Let me calculate...") + .thinkingContent("I need to use the calculator tool") + .toolCalls(toolCalls) + .build(); + + assertThat(message.getText()).isEqualTo("Let me calculate..."); + assertThat(message.getThinkingContent()).isEqualTo("I need to use the calculator tool"); + assertThat(message.getToolCalls()).hasSize(1); + } + + // Properties/Metadata Tests + + @Test + void testBuildMessageWithProperties() { + Map properties = Map.of("id", "msg-123", "role", "assistant", "finishReason", "stop"); + + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("content") + .properties(properties) + .build(); + + assertThat(message.getMetadata()).containsEntry("id", "msg-123") + .containsEntry("role", "assistant") + .containsEntry("finishReason", "stop"); + } + + @Test + void testMutablePropertiesMapInBuilder() { + Map properties = new HashMap<>(); + properties.put("key", "value"); + + MistralAiAssistantMessage message = new MistralAiAssistantMessage.Builder().content("content") + .properties(properties) + .build(); + + // Modifying original map should not affect the message + properties.put("newKey", "newValue"); + + assertThat(message.getMetadata()).containsKey("key"); + } + +} diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiContentParsingTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiContentParsingTests.java new file mode 100644 index 00000000000..0b17f823fe6 --- /dev/null +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiContentParsingTests.java @@ -0,0 +1,332 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mistralai; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage; +import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.Role; +import org.springframework.ai.mistralai.api.MistralAiApi.ContentChunk; +import org.springframework.ai.mistralai.api.MistralAiApi.ReferenceChunk; +import org.springframework.ai.mistralai.api.MistralAiApi.TextChunk; +import org.springframework.ai.mistralai.api.MistralAiApi.ThinkChunk; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for content parsing in {@link ChatCompletionMessage}. Tests the parsing of + * content returned by Magistral reasoning models which can return content as either a + * simple string or an array of typed blocks (text, thinking, reference). + * + * @author Kyle Kreuter + */ +class MistralAiContentParsingTests { + + // String Content Parsing Tests (Backward Compatibility) + + @Test + void testParseSimpleStringContent() { + String textContent = "Hello, I am a response from Mistral AI."; + ChatCompletionMessage message = new ChatCompletionMessage(textContent, Role.ASSISTANT); + + assertThat(message.content()).isEqualTo(textContent); + assertThat(message.thinkingContent()).isNull(); + assertThat(message.rawContent()).isEqualTo(textContent); + } + + @Test + void testReturnNullContentForNullRawContent() { + ChatCompletionMessage message = new ChatCompletionMessage(null, Role.ASSISTANT); + + assertThat(message.content()).isNull(); + assertThat(message.thinkingContent()).isNull(); + } + + @Test + void testReturnEmptyStringContentAsIs() { + ChatCompletionMessage message = new ChatCompletionMessage("", Role.ASSISTANT); + + assertThat(message.content()).isEmpty(); + assertThat(message.thinkingContent()).isNull(); + } + + @Test + void testParseStringContentWithSpecialCharacters() { + String textContent = "Response with special chars: <>&\"' and unicode: \u00e9\u00e8\u00ea"; + ChatCompletionMessage message = new ChatCompletionMessage(textContent, Role.ASSISTANT); + + assertThat(message.content()).isEqualTo(textContent); + } + + // Array Content Parsing Tests (Magistral Models) + + @Test + void testParseArrayContentWithTextChunkOnly() { + List> content = List.of(Map.of("type", "text", "text", "This is the response text.")); + + ChatCompletionMessage message = new ChatCompletionMessage(content, Role.ASSISTANT); + + assertThat(message.content()).isEqualTo("This is the response text."); + assertThat(message.thinkingContent()).isNull(); + } + + @Test + void testParseArrayContentWithThinkChunkOnly() { + List> content = List + .of(Map.of("type", "thinking", "thinking", "Let me reason through this problem...")); + + ChatCompletionMessage message = new ChatCompletionMessage(content, Role.ASSISTANT); + + assertThat(message.content()).isNull(); + assertThat(message.thinkingContent()).isEqualTo("Let me reason through this problem..."); + } + + @Test + void testParseArrayContentWithBothTextAndThinkChunks() { + List> content = List.of( + Map.of("type", "thinking", "thinking", "First, I need to analyze the question..."), + Map.of("type", "text", "text", "The answer is 42.")); + + ChatCompletionMessage message = new ChatCompletionMessage(content, Role.ASSISTANT); + + assertThat(message.content()).isEqualTo("The answer is 42."); + assertThat(message.thinkingContent()).isEqualTo("First, I need to analyze the question..."); + } + + @Test + void testParseArrayContentWithReferenceChunk() { + List> content = List.of(Map.of("type", "text", "text", "According to the sources..."), + Map.of("type", "reference", "reference_ids", List.of(1, 2, 3))); + + ChatCompletionMessage message = new ChatCompletionMessage(content, Role.ASSISTANT); + + assertThat(message.content()).isEqualTo("According to the sources..."); + + List chunks = message.contentChunks(); + assertThat(chunks).hasSize(2); + assertThat(chunks.get(0)).isInstanceOf(TextChunk.class); + assertThat(chunks.get(1)).isInstanceOf(ReferenceChunk.class); + + ReferenceChunk refChunk = (ReferenceChunk) chunks.get(1); + assertThat(refChunk.referenceIds()).containsExactly(1, 2, 3); + } + + @Test + void testConcatenateMultipleTextChunksWithNewlines() { + List> content = List.of(Map.of("type", "text", "text", "First paragraph."), + Map.of("type", "text", "text", "Second paragraph.")); + + ChatCompletionMessage message = new ChatCompletionMessage(content, Role.ASSISTANT); + + assertThat(message.content()).isEqualTo("First paragraph.\nSecond paragraph."); + } + + @Test + void testConcatenateMultipleThinkingChunksWithNewlines() { + List> content = List.of(Map.of("type", "thinking", "thinking", "Step 1: Analyze..."), + Map.of("type", "thinking", "thinking", "Step 2: Evaluate...")); + + ChatCompletionMessage message = new ChatCompletionMessage(content, Role.ASSISTANT); + + assertThat(message.thinkingContent()).isEqualTo("Step 1: Analyze...\nStep 2: Evaluate..."); + } + + @Test + void testHandleEmptyArrayContent() { + List> content = List.of(); + + ChatCompletionMessage message = new ChatCompletionMessage(content, Role.ASSISTANT); + + assertThat(message.content()).isNull(); + assertThat(message.thinkingContent()).isNull(); + assertThat(message.contentChunks()).isEmpty(); + } + + @Test + void testHandleArrayWithUnknownChunkTypesGracefully() { + List> content = List.of(Map.of("type", "unknown", "data", "some data"), + Map.of("type", "text", "text", "Valid text content.")); + + ChatCompletionMessage message = new ChatCompletionMessage(content, Role.ASSISTANT); + + assertThat(message.content()).isEqualTo("Valid text content."); + assertThat(message.contentChunks()).hasSize(1); + } + + @Test + void testHandleArrayWithNullTextInTextChunk() { + // Use HashMap to allow null values (Map.of() doesn't support nulls) + java.util.HashMap map = new java.util.HashMap<>(); + map.put("type", "text"); + map.put("text", null); + List> content = List.of(map); + + ChatCompletionMessage message = new ChatCompletionMessage(content, Role.ASSISTANT); + + assertThat(message.content()).isNull(); + assertThat(message.contentChunks()).isEmpty(); + } + + // ContentChunks Parsing Tests + + @Test + void testReturnSingleTextChunkForStringContent() { + String textContent = "Simple text response"; + ChatCompletionMessage message = new ChatCompletionMessage(textContent, Role.ASSISTANT); + + List chunks = message.contentChunks(); + + assertThat(chunks).hasSize(1); + assertThat(chunks.get(0)).isInstanceOf(TextChunk.class); + assertThat(((TextChunk) chunks.get(0)).text()).isEqualTo(textContent); + } + + @Test + void testReturnEmptyListForNullContent() { + ChatCompletionMessage message = new ChatCompletionMessage(null, Role.ASSISTANT); + + assertThat(message.contentChunks()).isEmpty(); + } + + @Test + void testParseAllChunkTypesCorrectly() { + List> content = List.of(Map.of("type", "thinking", "thinking", "Reasoning..."), + Map.of("type", "text", "text", "Answer text"), + Map.of("type", "reference", "reference_ids", List.of(1, 2))); + + ChatCompletionMessage message = new ChatCompletionMessage(content, Role.ASSISTANT); + + List chunks = message.contentChunks(); + assertThat(chunks).hasSize(3); + assertThat(chunks.get(0)).isInstanceOf(ThinkChunk.class); + assertThat(chunks.get(1)).isInstanceOf(TextChunk.class); + assertThat(chunks.get(2)).isInstanceOf(ReferenceChunk.class); + + assertThat(((ThinkChunk) chunks.get(0)).thinking()).isEqualTo("Reasoning..."); + assertThat(((TextChunk) chunks.get(1)).text()).isEqualTo("Answer text"); + assertThat(((ReferenceChunk) chunks.get(2)).referenceIds()).containsExactly(1, 2); + } + + @Test + void testHandleReferenceChunkWithNumericValues() { + List> content = List.of(Map.of("type", "reference", "reference_ids", List.of(1L, 2L, 3L))); + + ChatCompletionMessage message = new ChatCompletionMessage(content, Role.ASSISTANT); + + List chunks = message.contentChunks(); + assertThat(chunks).hasSize(1); + assertThat(chunks.get(0)).isInstanceOf(ReferenceChunk.class); + assertThat(((ReferenceChunk) chunks.get(0)).referenceIds()).containsExactly(1, 2, 3); + } + + // Edge Cases and Error Handling Tests + + @Test + void testThrowExceptionForUnexpectedContentType() { + // Using an Integer as content which is not a supported type + ChatCompletionMessage message = new ChatCompletionMessage(12345, Role.ASSISTANT); + + assertThatThrownBy(message::content).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Unexpected content type"); + } + + @Test + void testHandleMultilineTextContent() { + String multilineText = """ + Line 1 + Line 2 + Line 3 + """; + ChatCompletionMessage message = new ChatCompletionMessage(multilineText, Role.ASSISTANT); + + assertThat(message.content()).isEqualTo(multilineText); + } + + @Test + void testHandleVeryLongContent() { + String longContent = "A".repeat(10000); + ChatCompletionMessage message = new ChatCompletionMessage(longContent, Role.ASSISTANT); + + assertThat(message.content()).isEqualTo(longContent); + assertThat(message.content()).hasSize(10000); + } + + @Test + void testHandleMixedContentWithEmptyStrings() { + List> content = List.of(Map.of("type", "thinking", "thinking", ""), + Map.of("type", "text", "text", "Valid text")); + + ChatCompletionMessage message = new ChatCompletionMessage(content, Role.ASSISTANT); + + // Empty thinking should result in empty chunk being added (but still parsed) + assertThat(message.content()).isEqualTo("Valid text"); + } + + @Test + void testPreserveMessageRole() { + ChatCompletionMessage assistantMessage = new ChatCompletionMessage("content", Role.ASSISTANT); + ChatCompletionMessage userMessage = new ChatCompletionMessage("content", Role.USER); + ChatCompletionMessage systemMessage = new ChatCompletionMessage("content", Role.SYSTEM); + + assertThat(assistantMessage.role()).isEqualTo(Role.ASSISTANT); + assertThat(userMessage.role()).isEqualTo(Role.USER); + assertThat(systemMessage.role()).isEqualTo(Role.SYSTEM); + } + + // Record Components Tests + + @Test + void testTextChunkRecordEquality() { + TextChunk chunk1 = new TextChunk("Hello"); + TextChunk chunk2 = new TextChunk("Hello"); + TextChunk chunk3 = new TextChunk("World"); + + assertThat(chunk1.text()).isEqualTo("Hello"); + assertThat(chunk1).isEqualTo(chunk2); + assertThat(chunk1).isNotEqualTo(chunk3); + assertThat(chunk1.hashCode()).isEqualTo(chunk2.hashCode()); + } + + @Test + void testThinkChunkRecordEquality() { + ThinkChunk chunk1 = new ThinkChunk("Thinking..."); + ThinkChunk chunk2 = new ThinkChunk("Thinking..."); + ThinkChunk chunk3 = new ThinkChunk("Different thinking"); + + assertThat(chunk1.thinking()).isEqualTo("Thinking..."); + assertThat(chunk1).isEqualTo(chunk2); + assertThat(chunk1).isNotEqualTo(chunk3); + assertThat(chunk1.hashCode()).isEqualTo(chunk2.hashCode()); + } + + @Test + void testReferenceChunkRecordEquality() { + ReferenceChunk chunk1 = new ReferenceChunk(List.of(1, 2, 3)); + ReferenceChunk chunk2 = new ReferenceChunk(List.of(1, 2, 3)); + ReferenceChunk chunk3 = new ReferenceChunk(List.of(4, 5)); + + assertThat(chunk1.referenceIds()).containsExactly(1, 2, 3); + assertThat(chunk1).isEqualTo(chunk2); + assertThat(chunk1).isNotEqualTo(chunk3); + assertThat(chunk1.hashCode()).isEqualTo(chunk2.hashCode()); + } + +} diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiMagistralIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiMagistralIT.java new file mode 100644 index 00000000000..6a043a106c6 --- /dev/null +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiMagistralIT.java @@ -0,0 +1,239 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mistralai; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.mistralai.api.MistralAiApi; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for Mistral AI Magistral reasoning models. These tests verify that + * the Magistral models (magistral-small-latest, magistral-medium-latest) properly return + * thinking/reasoning content alongside the regular response. + * + *

+ * Magistral models are reasoning models that show their thought process before providing + * an answer. The thinking content is returned in a separate field from the main response + * content. + *

+ * + * @author Kyle Kreuter + */ +@SpringBootTest(classes = MistralAiTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+") +class MistralAiMagistralIT { + + private static final Logger logger = LoggerFactory.getLogger(MistralAiMagistralIT.class); + + @Autowired + private ChatModel chatModel; + + @Autowired + private StreamingChatModel streamingChatModel; + + @Test + void testMagistralModelReturnsThinkingContent() { + // Magistral models excel at reasoning tasks - use a question that requires + // step-by-step thinking + var promptOptions = MistralAiChatOptions.builder() + .model(MistralAiApi.ChatModel.MAGISTRAL_SMALL.getValue()) + .build(); + + Prompt prompt = new Prompt("9.11 and 9.8, which is greater?", promptOptions); + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getResult()).isNotNull(); + assertThat(response.getResult().getOutput()).isInstanceOf(MistralAiAssistantMessage.class); + + MistralAiAssistantMessage assistantMessage = (MistralAiAssistantMessage) response.getResult().getOutput(); + + // Magistral models should provide thinking content for reasoning questions + assertThat(assistantMessage.getThinkingContent()).isNotNull().isNotEmpty(); + assertThat(assistantMessage.getText()).isNotNull().isNotEmpty(); + + logger.info("Thinking content: {}", assistantMessage.getThinkingContent()); + logger.info("Response text: {}", assistantMessage.getText()); + } + + @Test + void testMagistralModelHandlesMathProblemsWithReasoning() { + var promptOptions = MistralAiChatOptions.builder() + .model(MistralAiApi.ChatModel.MAGISTRAL_SMALL.getValue()) + .build(); + + Prompt prompt = new Prompt( + "If a train travels at 60 mph for 2.5 hours, how far does it travel? Show your reasoning.", + promptOptions); + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getResult()).isNotNull(); + + MistralAiAssistantMessage assistantMessage = (MistralAiAssistantMessage) response.getResult().getOutput(); + + assertThat(assistantMessage.getThinkingContent()).isNotNull().isNotEmpty(); + assertThat(assistantMessage.getText()).isNotNull(); + + // The answer should contain 150 (60 * 2.5 = 150 miles) + assertThat(assistantMessage.getText()).containsAnyOf("150", "one hundred fifty"); + + logger.info("Math problem thinking: {}", assistantMessage.getThinkingContent()); + logger.info("Math problem answer: {}", assistantMessage.getText()); + } + + @Test + void testMagistralModelStreamingWorks() { + var promptOptions = MistralAiChatOptions.builder() + .model(MistralAiApi.ChatModel.MAGISTRAL_SMALL.getValue()) + .build(); + + Prompt prompt = new Prompt("What is 25 * 4? Think step by step.", promptOptions); + + String aggregatedContent = this.streamingChatModel.stream(prompt) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .filter(text -> text != null) + .collect(Collectors.joining()); + + assertThat(aggregatedContent).isNotEmpty(); + // The answer should contain 100 (25 * 4 = 100) + assertThat(aggregatedContent).containsAnyOf("100", "one hundred"); + + logger.info("Streamed response: {}", aggregatedContent); + } + + @Test + void testMagistralModelMultiRoundConversationPreservesContext() { + List messages = new ArrayList<>(); + messages.add(new UserMessage("What is 5 + 3?")); + + var promptOptions = MistralAiChatOptions.builder() + .model(MistralAiApi.ChatModel.MAGISTRAL_SMALL.getValue()) + .build(); + + // First round + Prompt prompt1 = new Prompt(messages, promptOptions); + ChatResponse response1 = this.chatModel.call(prompt1); + + assertThat(response1).isNotNull(); + MistralAiAssistantMessage message1 = (MistralAiAssistantMessage) response1.getResult().getOutput(); + assertThat(message1.getText()).containsAnyOf("8", "eight"); + + logger.info("First response thinking: {}", message1.getThinkingContent()); + logger.info("First response: {}", message1.getText()); + + // Add assistant response to conversation + messages.add(new AssistantMessage(message1.getText())); + messages.add(new UserMessage("Now multiply that result by 2")); + + // Second round + Prompt prompt2 = new Prompt(messages, promptOptions); + ChatResponse response2 = this.chatModel.call(prompt2); + + assertThat(response2).isNotNull(); + MistralAiAssistantMessage message2 = (MistralAiAssistantMessage) response2.getResult().getOutput(); + // 8 * 2 = 16 + assertThat(message2.getText()).containsAnyOf("16", "sixteen"); + + logger.info("Second response thinking: {}", message2.getThinkingContent()); + logger.info("Second response: {}", message2.getText()); + } + + @Test + void testMagistralModelHandlesLogicPuzzles() { + var promptOptions = MistralAiChatOptions.builder() + .model(MistralAiApi.ChatModel.MAGISTRAL_SMALL.getValue()) + .build(); + + String puzzle = """ + There are three boxes. One contains only apples, one contains only oranges, + and one contains both apples and oranges. The boxes have been incorrectly labeled + such that no label identifies the actual contents of the box it labels. + Opening just one box, and without looking in the box, you take out one piece of fruit. + By looking at the fruit, how can you immediately label all of the boxes correctly? + Which box should you open? + """; + + Prompt prompt = new Prompt(puzzle, promptOptions); + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response).isNotNull(); + MistralAiAssistantMessage assistantMessage = (MistralAiAssistantMessage) response.getResult().getOutput(); + + // For reasoning puzzles, thinking content should be substantial + assertThat(assistantMessage.getThinkingContent()).isNotNull().isNotEmpty(); + assertThat(assistantMessage.getText()).isNotNull().isNotEmpty(); + + // The answer should mention the "both" or "mixed" box + assertThat(assistantMessage.getText().toLowerCase()).containsAnyOf("both", "mixed", "apples and oranges"); + + logger.info("Logic puzzle thinking (length: {}): {}", + assistantMessage.getThinkingContent() != null ? assistantMessage.getThinkingContent().length() : 0, + assistantMessage.getThinkingContent()); + logger.info("Logic puzzle answer: {}", assistantMessage.getText()); + } + + @Test + void testResponseMetadataPopulatedCorrectly() { + var promptOptions = MistralAiChatOptions.builder() + .model(MistralAiApi.ChatModel.MAGISTRAL_SMALL.getValue()) + .build(); + + Prompt prompt = new Prompt("What is 2 + 2?", promptOptions); + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getMetadata()).isNotNull(); + assertThat(response.getMetadata().getModel()).containsIgnoringCase("magistral"); + assertThat(response.getMetadata().getUsage()).isNotNull(); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isGreaterThan(0); + assertThat(response.getMetadata().getUsage().getTotalTokens()).isGreaterThan(0); + + logger.info("Model used: {}", response.getMetadata().getModel()); + logger.info("Token usage - Prompt: {}, Completion: {}, Total: {}", + response.getMetadata().getUsage().getPromptTokens(), + response.getMetadata().getUsage().getCompletionTokens(), + response.getMetadata().getUsage().getTotalTokens()); + } + +}