Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
*
* @author Christian Tzolov
*/
public final class ToolCallAdvisor implements CallAdvisor, StreamAdvisor {
public class ToolCallAdvisor implements CallAdvisor, StreamAdvisor {

private final ToolCallingManager toolCallingManager;

Expand All @@ -57,7 +57,7 @@ public final class ToolCallAdvisor implements CallAdvisor, StreamAdvisor {
*/
private final int advisorOrder;

private ToolCallAdvisor(ToolCallingManager toolCallingManager, int advisorOrder) {
protected ToolCallAdvisor(ToolCallingManager toolCallingManager, int advisorOrder) {
Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
Assert.isTrue(advisorOrder > BaseAdvisor.HIGHEST_PRECEDENCE && advisorOrder < BaseAdvisor.LOWEST_PRECEDENCE,
"advisorOrder must be between HIGHEST_PRECEDENCE and LOWEST_PRECEDENCE");
Expand All @@ -76,7 +76,6 @@ public int getOrder() {
return this.advisorOrder;
}

@SuppressWarnings("null")
@Override
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
Assert.notNull(callAdvisorChain, "callAdvisorChain must not be null");
Expand All @@ -88,6 +87,8 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd
"ToolCall Advisor requires ToolCallingChatOptions to be set in the ChatClientRequest options.");
}

chatClientRequest = this.doInitializeLoop(chatClientRequest, callAdvisorChain);

// Overwrite the ToolCallingChatOptions to disable internal tool execution.
var optionsCopy = (ToolCallingChatOptions) chatClientRequest.prompt().getOptions().copy();

Expand All @@ -109,8 +110,12 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd
.build();

// Next Call
processedChatClientRequest = this.doBeforeCall(processedChatClientRequest, callAdvisorChain);

chatClientResponse = callAdvisorChain.copy(this).nextCall(processedChatClientRequest);

chatClientResponse = this.doAfterCall(chatClientResponse, callAdvisorChain);

// After Call

// TODO: check that this is tool call is sufficiant for all chat models
Expand Down Expand Up @@ -148,6 +153,19 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd
return chatClientResponse;
}

protected ChatClientRequest doInitializeLoop(ChatClientRequest chatClientRequest,
CallAdvisorChain callAdvisorChain) {
return chatClientRequest;
}

protected ChatClientRequest doBeforeCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
return chatClientRequest;
}

protected ChatClientResponse doAfterCall(ChatClientResponse chatClientResponse, CallAdvisorChain callAdvisorChain) {
return chatClientResponse;
}

@Override
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
StreamAdvisorChain streamAdvisorChain) {
Expand All @@ -158,30 +176,45 @@ public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest
* Creates a new Builder instance for constructing a ToolCallAdvisor.
* @return a new Builder instance
*/
public static Builder builder() {
return new Builder();
public static Builder<?> builder() {
return new Builder<>();
}

/**
* Builder for creating instances of ToolCallAdvisor.
* <p>
* This builder uses the self-referential generic pattern to support extensibility.
*
* @param <T> the builder type, used for self-referential generics to support method
* chaining in subclasses
*/
public final static class Builder {
public static class Builder<T extends Builder<T>> {

private ToolCallingManager toolCallingManager = ToolCallingManager.builder().build();

private int advisorOrder = BaseAdvisor.HIGHEST_PRECEDENCE + 300;

private Builder() {
protected Builder() {
}

/**
* Returns this builder cast to the appropriate type for method chaining.
* Subclasses should override this method to return the correct type.
* @return this builder instance
*/
@SuppressWarnings("unchecked")
protected T self() {
return (T) this;
}

/**
* Sets the ToolCallingManager to be used by the advisor.
* @param toolCallingManager the ToolCallingManager instance
* @return this Builder instance for method chaining
*/
public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
public T toolCallingManager(ToolCallingManager toolCallingManager) {
this.toolCallingManager = toolCallingManager;
return this;
return self();
}

/**
Expand All @@ -190,9 +223,25 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
* LOWEST_PRECEDENCE
* @return this Builder instance for method chaining
*/
public Builder advisorOrder(int advisorOrder) {
public T advisorOrder(int advisorOrder) {
this.advisorOrder = advisorOrder;
return this;
return self();
}

/**
* Returns the configured ToolCallingManager.
* @return the ToolCallingManager instance
*/
protected ToolCallingManager getToolCallingManager() {
return this.toolCallingManager;
}

/**
* Returns the configured advisor order.
* @return the advisor order value
*/
protected int getAdvisorOrder() {
return this.advisorOrder;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,99 @@ void testGetOrder() {
assertThat(advisor.getOrder()).isEqualTo(customOrder);
}

@Test
void testBuilderGetters() {
ToolCallingManager customManager = mock(ToolCallingManager.class);
int customOrder = BaseAdvisor.HIGHEST_PRECEDENCE + 500;

ToolCallAdvisor.Builder<?> builder = ToolCallAdvisor.builder()
.toolCallingManager(customManager)
.advisorOrder(customOrder);

assertThat(builder.getToolCallingManager()).isEqualTo(customManager);
assertThat(builder.getAdvisorOrder()).isEqualTo(customOrder);
}

@Test
void testExtendedAdvisorWithCustomHooks() {
int[] hookCallCounts = { 0, 0, 0 }; // initializeLoop, beforeCall, afterCall

// Create extended advisor to verify hooks are called
TestableToolCallAdvisor advisor = new TestableToolCallAdvisor(this.toolCallingManager,
BaseAdvisor.HIGHEST_PRECEDENCE + 300, hookCallCounts);

ChatClientRequest request = createMockRequest(true);
ChatClientResponse response = createMockResponse(false);

CallAdvisor terminalAdvisor = new TerminalCallAdvisor((req, chain) -> response);

CallAdvisorChain realChain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
.pushAll(List.of(advisor, terminalAdvisor))
.build();

advisor.adviseCall(request, realChain);

// Verify hooks were called
assertThat(hookCallCounts[0]).isEqualTo(1); // doInitializeLoop called once
assertThat(hookCallCounts[1]).isEqualTo(1); // doBeforeCall called once
assertThat(hookCallCounts[2]).isEqualTo(1); // doAfterCall called once
}

@Test
void testExtendedAdvisorHooksCalledMultipleTimesWithToolCalls() {
int[] hookCallCounts = { 0, 0, 0 }; // initializeLoop, beforeCall, afterCall

TestableToolCallAdvisor advisor = new TestableToolCallAdvisor(this.toolCallingManager,
BaseAdvisor.HIGHEST_PRECEDENCE + 300, hookCallCounts);

ChatClientRequest request = createMockRequest(true);
ChatClientResponse responseWithToolCall = createMockResponse(true);
ChatClientResponse finalResponse = createMockResponse(false);

int[] callCount = { 0 };
CallAdvisor terminalAdvisor = new TerminalCallAdvisor((req, chain) -> {
callCount[0]++;
return callCount[0] == 1 ? responseWithToolCall : finalResponse;
});

CallAdvisorChain realChain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
.pushAll(List.of(advisor, terminalAdvisor))
.build();

// Mock tool execution result
List<Message> conversationHistory = List.of(new UserMessage("test"),
AssistantMessage.builder().content("").build(), ToolResponseMessage.builder().build());
ToolExecutionResult toolExecutionResult = ToolExecutionResult.builder()
.conversationHistory(conversationHistory)
.build();
when(this.toolCallingManager.executeToolCalls(any(Prompt.class), any(ChatResponse.class)))
.thenReturn(toolExecutionResult);

advisor.adviseCall(request, realChain);

// Verify hooks were called correct number of times
assertThat(hookCallCounts[0]).isEqualTo(1); // doInitializeLoop called once
// (before loop)
assertThat(hookCallCounts[1]).isEqualTo(2); // doBeforeCall called twice (each
// iteration)
assertThat(hookCallCounts[2]).isEqualTo(2); // doAfterCall called twice (each
// iteration)
}

@Test
void testExtendedBuilderWithCustomBuilder() {
ToolCallingManager customManager = mock(ToolCallingManager.class);
int customOrder = BaseAdvisor.HIGHEST_PRECEDENCE + 450;

TestableToolCallAdvisor advisor = TestableToolCallAdvisor.testBuilder()
.toolCallingManager(customManager)
.advisorOrder(customOrder)
.build();

assertThat(advisor).isNotNull();
assertThat(advisor.getOrder()).isEqualTo(customOrder);
}

// Helper methods

private ChatClientRequest createMockRequest(boolean withToolCallingOptions) {
Expand Down Expand Up @@ -472,6 +565,65 @@ public ChatClientResponse adviseCall(ChatClientRequest req, CallAdvisorChain cha
return this.responseFunction.apply(req, chain);
}

};
}

/**
* Test subclass of ToolCallAdvisor to verify extensibility and hook methods.
*/
private static class TestableToolCallAdvisor extends ToolCallAdvisor {

private final int[] hookCallCounts;

TestableToolCallAdvisor(ToolCallingManager toolCallingManager, int advisorOrder, int[] hookCallCounts) {
super(toolCallingManager, advisorOrder);
this.hookCallCounts = hookCallCounts;
}

@Override
protected ChatClientRequest doInitializeLoop(ChatClientRequest chatClientRequest,
CallAdvisorChain callAdvisorChain) {
if (this.hookCallCounts != null) {
this.hookCallCounts[0]++;
}
return super.doInitializeLoop(chatClientRequest, callAdvisorChain);
}

@Override
protected ChatClientRequest doBeforeCall(ChatClientRequest chatClientRequest,
CallAdvisorChain callAdvisorChain) {
if (this.hookCallCounts != null) {
this.hookCallCounts[1]++;
}
return super.doBeforeCall(chatClientRequest, callAdvisorChain);
}

@Override
protected ChatClientResponse doAfterCall(ChatClientResponse chatClientResponse,
CallAdvisorChain callAdvisorChain) {
if (this.hookCallCounts != null) {
this.hookCallCounts[2]++;
}
return super.doAfterCall(chatClientResponse, callAdvisorChain);
}

static TestableBuilder testBuilder() {
return new TestableBuilder();
}

static class TestableBuilder extends ToolCallAdvisor.Builder<TestableBuilder> {

@Override
protected TestableBuilder self() {
return this;
}

@Override
public TestableToolCallAdvisor build() {
return new TestableToolCallAdvisor(getToolCallingManager(), getAdvisorOrder(), null);
}

}

}

}