diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientRequest.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientRequest.java index 5a04db360c7..094a75acd86 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientRequest.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientRequest.java @@ -37,6 +37,7 @@ public record ChatClientRequest(Prompt prompt, Map context) { Assert.notNull(prompt, "prompt cannot be null"); Assert.notNull(context, "context cannot be null"); Assert.noNullElements(context.keySet(), "context keys cannot be null"); + Assert.noNullElements(context.values(), "context values cannot be null"); } public ChatClientRequest copy() { @@ -68,12 +69,15 @@ public Builder prompt(Prompt prompt) { public Builder context(Map context) { Assert.notNull(context, "context cannot be null"); + Assert.noNullElements(context.keySet(), "context keys cannot be null"); + Assert.noNullElements(context.values(), "context values cannot be null"); this.context.putAll(context); return this; } public Builder context(String key, Object value) { Assert.notNull(key, "key cannot be null"); + Assert.notNull(value, "value cannot be null"); this.context.put(key, value); return this; } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java index a069702356b..36998c743c4 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java @@ -36,6 +36,7 @@ public record ChatClientResponse(@Nullable ChatResponse chatResponse, Map context) { Assert.notNull(context, "context cannot be null"); + Assert.noNullElements(context.keySet(), "context keys cannot be null"); + Assert.noNullElements(context.values(), "context values cannot be null"); this.context.putAll(context); return this; } public Builder context(String key, Object value) { Assert.notNull(key, "key cannot be null"); + Assert.notNull(value, "value cannot be null"); this.context.put(key, value); return this; } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientRequestTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientRequestTests.java index 17178cd2b31..a2ea633d92f 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientRequestTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientRequestTests.java @@ -62,6 +62,30 @@ void whenContextHasNullKeysThenThrow() { .hasMessage("context keys cannot be null"); } + @Test + void whenContextHasNullValuesThenThrow() { + Map context = new HashMap<>(); + context.put("key", null); + + assertThatThrownBy(() -> new ChatClientRequest(new Prompt(), context)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context values cannot be null"); + + assertThatThrownBy(() -> ChatClientRequest.builder().prompt(new Prompt()).context(context).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context values cannot be null"); + } + + @Test + void whenBuilderContextMapHasNullKeyThenThrow() { + Map context = new HashMap<>(); + context.put(null, "value"); + + assertThatThrownBy(() -> ChatClientRequest.builder().prompt(new Prompt()).context(context).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context keys cannot be null"); + } + @Test void whenCopyThenImmutableContext() { Map context = new HashMap<>(); @@ -86,6 +110,13 @@ void whenMutateThenImmutableContext() { assertThat(copy.context()).isEqualTo(Map.of("key", "newValue")); } + @Test + void whenBuilderAddsNullValueThenThrow() { + assertThatThrownBy(() -> ChatClientRequest.builder().prompt(new Prompt()).context("key", null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("value cannot be null"); + } + @Test void whenBuilderWithMultipleContextEntriesThenSuccess() { Prompt prompt = new Prompt("test message"); diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientResponseTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientResponseTests.java index 1e26a6c334f..579b27b548d 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientResponseTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientResponseTests.java @@ -53,6 +53,19 @@ void whenContextHasNullKeysThenThrow() { .hasMessage("context keys cannot be null"); } + @Test + void whenContextHasNullValuesThenThrow() { + Map context = new HashMap<>(); + context.put("key", null); + + assertThatThrownBy(() -> new ChatClientResponse(null, context)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("context values cannot be null"); + + assertThatThrownBy(() -> ChatClientResponse.builder().chatResponse(null).context(context).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context values cannot be null"); + } + @Test void whenCopyThenImmutableContext() { Map context = new HashMap<>(); @@ -120,16 +133,15 @@ void whenEmptyContextThenCreateSuccessfully() { } @Test - void whenContextWithNullValuesThenCreateSuccessfully() { + void whenContextWithNullValuesThenThrow() { ChatResponse chatResponse = mock(ChatResponse.class); Map context = new HashMap<>(); context.put("key1", "value1"); context.put("key2", null); - ChatClientResponse response = new ChatClientResponse(chatResponse, context); - - assertThat(response.context()).containsEntry("key1", "value1"); - assertThat(response.context()).containsEntry("key2", null); + assertThatThrownBy(() -> new ChatClientResponse(chatResponse, context)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context values cannot be null"); } @Test @@ -166,6 +178,23 @@ void whenBuilderWithoutChatResponseThenCreateWithNull() { assertThat(response.chatResponse()).isNull(); } + @Test + void whenBuilderContextMapHasNullKeyThenThrow() { + Map context = new HashMap<>(); + context.put(null, "value"); + + assertThatThrownBy(() -> ChatClientResponse.builder().context(context).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context keys cannot be null"); + } + + @Test + void whenBuilderAddsNullValueThenThrow() { + assertThatThrownBy(() -> ChatClientResponse.builder().context("key", null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("value cannot be null"); + } + @Test void whenComplexObjectsInContextThenPreserveCorrectly() { ChatResponse chatResponse = mock(ChatResponse.class);