From 75e121d4b656f663305981296a27b5b422949dbc Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 2 Dec 2025 14:06:15 +0100 Subject: [PATCH] Make ToolCallAdvisor extensible with hook methods - Add protected doInitializeLoop, doBeforeCall, and doAfterCall hooks to allow subclasses to customize the tool calling loop behavior. - Update Builder to support inheritance via self-referential generics. Signed-off-by: Christian Tzolov --- .../chat/client/advisor/ToolCallAdvisor.java | 71 ++++++-- .../client/advisor/ToolCallAdvisorTests.java | 154 +++++++++++++++++- 2 files changed, 213 insertions(+), 12 deletions(-) diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisor.java index ad2e6530dc9..54c0e395063 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisor.java @@ -44,7 +44,7 @@ * * @author Christian Tzolov */ -public final class ToolCallAdvisor implements CallAdvisor, StreamAdvisor { +public class ToolCallAdvisor implements CallAdvisor, StreamAdvisor { private final ToolCallingManager toolCallingManager; @@ -57,7 +57,7 @@ public final class ToolCallAdvisor implements CallAdvisor, StreamAdvisor { */ private final int advisorOrder; - private ToolCallAdvisor(ToolCallingManager toolCallingManager, int advisorOrder) { + protected ToolCallAdvisor(ToolCallingManager toolCallingManager, int advisorOrder) { Assert.notNull(toolCallingManager, "toolCallingManager must not be null"); Assert.isTrue(advisorOrder > BaseAdvisor.HIGHEST_PRECEDENCE && advisorOrder < BaseAdvisor.LOWEST_PRECEDENCE, "advisorOrder must be between HIGHEST_PRECEDENCE and LOWEST_PRECEDENCE"); @@ -76,7 +76,6 @@ public int getOrder() { return this.advisorOrder; } - @SuppressWarnings("null") @Override public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { Assert.notNull(callAdvisorChain, "callAdvisorChain must not be null"); @@ -88,6 +87,8 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd "ToolCall Advisor requires ToolCallingChatOptions to be set in the ChatClientRequest options."); } + chatClientRequest = this.doInitializeLoop(chatClientRequest, callAdvisorChain); + // Overwrite the ToolCallingChatOptions to disable internal tool execution. var optionsCopy = (ToolCallingChatOptions) chatClientRequest.prompt().getOptions().copy(); @@ -109,8 +110,12 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd .build(); // Next Call + processedChatClientRequest = this.doBeforeCall(processedChatClientRequest, callAdvisorChain); + chatClientResponse = callAdvisorChain.copy(this).nextCall(processedChatClientRequest); + chatClientResponse = this.doAfterCall(chatClientResponse, callAdvisorChain); + // After Call // TODO: check that this is tool call is sufficiant for all chat models @@ -148,6 +153,19 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd return chatClientResponse; } + protected ChatClientRequest doInitializeLoop(ChatClientRequest chatClientRequest, + CallAdvisorChain callAdvisorChain) { + return chatClientRequest; + } + + protected ChatClientRequest doBeforeCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { + return chatClientRequest; + } + + protected ChatClientResponse doAfterCall(ChatClientResponse chatClientResponse, CallAdvisorChain callAdvisorChain) { + return chatClientResponse; + } + @Override public Flux adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { @@ -158,20 +176,35 @@ public Flux adviseStream(ChatClientRequest chatClientRequest * Creates a new Builder instance for constructing a ToolCallAdvisor. * @return a new Builder instance */ - public static Builder builder() { - return new Builder(); + public static Builder builder() { + return new Builder<>(); } /** * Builder for creating instances of ToolCallAdvisor. + *

+ * This builder uses the self-referential generic pattern to support extensibility. + * + * @param the builder type, used for self-referential generics to support method + * chaining in subclasses */ - public final static class Builder { + public static class Builder> { private ToolCallingManager toolCallingManager = ToolCallingManager.builder().build(); private int advisorOrder = BaseAdvisor.HIGHEST_PRECEDENCE + 300; - private Builder() { + protected Builder() { + } + + /** + * Returns this builder cast to the appropriate type for method chaining. + * Subclasses should override this method to return the correct type. + * @return this builder instance + */ + @SuppressWarnings("unchecked") + protected T self() { + return (T) this; } /** @@ -179,9 +212,9 @@ private Builder() { * @param toolCallingManager the ToolCallingManager instance * @return this Builder instance for method chaining */ - public Builder toolCallingManager(ToolCallingManager toolCallingManager) { + public T toolCallingManager(ToolCallingManager toolCallingManager) { this.toolCallingManager = toolCallingManager; - return this; + return self(); } /** @@ -190,9 +223,25 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) { * LOWEST_PRECEDENCE * @return this Builder instance for method chaining */ - public Builder advisorOrder(int advisorOrder) { + public T advisorOrder(int advisorOrder) { this.advisorOrder = advisorOrder; - return this; + return self(); + } + + /** + * Returns the configured ToolCallingManager. + * @return the ToolCallingManager instance + */ + protected ToolCallingManager getToolCallingManager() { + return this.toolCallingManager; + } + + /** + * Returns the configured advisor order. + * @return the advisor order value + */ + protected int getAdvisorOrder() { + return this.advisorOrder; } /** diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisorTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisorTests.java index a56b54f0810..7328ff6e317 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisorTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisorTests.java @@ -377,6 +377,99 @@ void testGetOrder() { assertThat(advisor.getOrder()).isEqualTo(customOrder); } + @Test + void testBuilderGetters() { + ToolCallingManager customManager = mock(ToolCallingManager.class); + int customOrder = BaseAdvisor.HIGHEST_PRECEDENCE + 500; + + ToolCallAdvisor.Builder builder = ToolCallAdvisor.builder() + .toolCallingManager(customManager) + .advisorOrder(customOrder); + + assertThat(builder.getToolCallingManager()).isEqualTo(customManager); + assertThat(builder.getAdvisorOrder()).isEqualTo(customOrder); + } + + @Test + void testExtendedAdvisorWithCustomHooks() { + int[] hookCallCounts = { 0, 0, 0 }; // initializeLoop, beforeCall, afterCall + + // Create extended advisor to verify hooks are called + TestableToolCallAdvisor advisor = new TestableToolCallAdvisor(this.toolCallingManager, + BaseAdvisor.HIGHEST_PRECEDENCE + 300, hookCallCounts); + + ChatClientRequest request = createMockRequest(true); + ChatClientResponse response = createMockResponse(false); + + CallAdvisor terminalAdvisor = new TerminalCallAdvisor((req, chain) -> response); + + CallAdvisorChain realChain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP) + .pushAll(List.of(advisor, terminalAdvisor)) + .build(); + + advisor.adviseCall(request, realChain); + + // Verify hooks were called + assertThat(hookCallCounts[0]).isEqualTo(1); // doInitializeLoop called once + assertThat(hookCallCounts[1]).isEqualTo(1); // doBeforeCall called once + assertThat(hookCallCounts[2]).isEqualTo(1); // doAfterCall called once + } + + @Test + void testExtendedAdvisorHooksCalledMultipleTimesWithToolCalls() { + int[] hookCallCounts = { 0, 0, 0 }; // initializeLoop, beforeCall, afterCall + + TestableToolCallAdvisor advisor = new TestableToolCallAdvisor(this.toolCallingManager, + BaseAdvisor.HIGHEST_PRECEDENCE + 300, hookCallCounts); + + ChatClientRequest request = createMockRequest(true); + ChatClientResponse responseWithToolCall = createMockResponse(true); + ChatClientResponse finalResponse = createMockResponse(false); + + int[] callCount = { 0 }; + CallAdvisor terminalAdvisor = new TerminalCallAdvisor((req, chain) -> { + callCount[0]++; + return callCount[0] == 1 ? responseWithToolCall : finalResponse; + }); + + CallAdvisorChain realChain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP) + .pushAll(List.of(advisor, terminalAdvisor)) + .build(); + + // Mock tool execution result + List conversationHistory = List.of(new UserMessage("test"), + AssistantMessage.builder().content("").build(), ToolResponseMessage.builder().build()); + ToolExecutionResult toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory(conversationHistory) + .build(); + when(this.toolCallingManager.executeToolCalls(any(Prompt.class), any(ChatResponse.class))) + .thenReturn(toolExecutionResult); + + advisor.adviseCall(request, realChain); + + // Verify hooks were called correct number of times + assertThat(hookCallCounts[0]).isEqualTo(1); // doInitializeLoop called once + // (before loop) + assertThat(hookCallCounts[1]).isEqualTo(2); // doBeforeCall called twice (each + // iteration) + assertThat(hookCallCounts[2]).isEqualTo(2); // doAfterCall called twice (each + // iteration) + } + + @Test + void testExtendedBuilderWithCustomBuilder() { + ToolCallingManager customManager = mock(ToolCallingManager.class); + int customOrder = BaseAdvisor.HIGHEST_PRECEDENCE + 450; + + TestableToolCallAdvisor advisor = TestableToolCallAdvisor.testBuilder() + .toolCallingManager(customManager) + .advisorOrder(customOrder) + .build(); + + assertThat(advisor).isNotNull(); + assertThat(advisor.getOrder()).isEqualTo(customOrder); + } + // Helper methods private ChatClientRequest createMockRequest(boolean withToolCallingOptions) { @@ -472,6 +565,65 @@ public ChatClientResponse adviseCall(ChatClientRequest req, CallAdvisorChain cha return this.responseFunction.apply(req, chain); } - }; + } + + /** + * Test subclass of ToolCallAdvisor to verify extensibility and hook methods. + */ + private static class TestableToolCallAdvisor extends ToolCallAdvisor { + + private final int[] hookCallCounts; + + TestableToolCallAdvisor(ToolCallingManager toolCallingManager, int advisorOrder, int[] hookCallCounts) { + super(toolCallingManager, advisorOrder); + this.hookCallCounts = hookCallCounts; + } + + @Override + protected ChatClientRequest doInitializeLoop(ChatClientRequest chatClientRequest, + CallAdvisorChain callAdvisorChain) { + if (this.hookCallCounts != null) { + this.hookCallCounts[0]++; + } + return super.doInitializeLoop(chatClientRequest, callAdvisorChain); + } + + @Override + protected ChatClientRequest doBeforeCall(ChatClientRequest chatClientRequest, + CallAdvisorChain callAdvisorChain) { + if (this.hookCallCounts != null) { + this.hookCallCounts[1]++; + } + return super.doBeforeCall(chatClientRequest, callAdvisorChain); + } + + @Override + protected ChatClientResponse doAfterCall(ChatClientResponse chatClientResponse, + CallAdvisorChain callAdvisorChain) { + if (this.hookCallCounts != null) { + this.hookCallCounts[2]++; + } + return super.doAfterCall(chatClientResponse, callAdvisorChain); + } + + static TestableBuilder testBuilder() { + return new TestableBuilder(); + } + + static class TestableBuilder extends ToolCallAdvisor.Builder { + + @Override + protected TestableBuilder self() { + return this; + } + + @Override + public TestableToolCallAdvisor build() { + return new TestableToolCallAdvisor(getToolCallingManager(), getAdvisorOrder(), null); + } + + } + + } }