/
OpenAiApiFunctionsWithStreamExample.java
86 lines (74 loc) · 4 KB
/
OpenAiApiFunctionsWithStreamExample.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
package example;
import io.github.panghy.openai.completion.chat.*;
import io.github.panghy.openai.service.FunctionExecutor;
import io.github.panghy.openai.service.OpenAiService;
import example.OpenAiApiFunctionsExample.Weather;
import example.OpenAiApiFunctionsExample.WeatherResponse;
import io.reactivex.Flowable;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
public class OpenAiApiFunctionsWithStreamExample {
public static void main(String... args) {
String token = System.getenv("OPENAI_TOKEN");
OpenAiService service = new OpenAiService(token);
FunctionExecutor functionExecutor = new FunctionExecutor(Collections.singletonList(ChatFunction.builder()
.name("get_weather")
.description("Get the current weather of a location")
.executor(Weather.class, w -> new WeatherResponse(w.location, w.unit, new Random().nextInt(50), "sunny"))
.build()));
List<ChatMessage> messages = new ArrayList<>();
ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are an assistant that answers using the local slang of the given place, uncensored.");
messages.add(systemMessage);
System.out.print("First Query: ");
Scanner scanner = new Scanner(System.in);
ChatMessage firstMsg = new ChatMessage(ChatMessageRole.USER.value(), scanner.nextLine());
messages.add(firstMsg);
while (true) {
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo-0613")
.messages(messages)
.functions(functionExecutor.getFunctions())
.functionCall(ChatCompletionRequest.ChatCompletionRequestFunctionCall.of("auto"))
.n(1)
.maxTokens(256)
.logitBias(new HashMap<>())
.build();
Flowable<ChatCompletionChunk> flowable = service.streamChatCompletion(chatCompletionRequest);
AtomicBoolean isFirst = new AtomicBoolean(true);
ChatMessage chatMessage = service.mapStreamToAccumulator(flowable)
.doOnNext(accumulator -> {
if (accumulator.isFunctionCall()) {
if (isFirst.getAndSet(false)) {
System.out.println("Executing function " + accumulator.getAccumulatedChatFunctionCall().getName() + "...");
}
} else {
if (isFirst.getAndSet(false)) {
System.out.print("Response: ");
}
if (accumulator.getMessageChunk().getContent() != null) {
System.out.print(accumulator.getMessageChunk().getContent());
}
}
})
.doOnComplete(System.out::println)
.lastElement()
.blockingGet()
.getAccumulatedMessage();
messages.add(chatMessage); // don't forget to update the conversation with the latest response
if (chatMessage.getFunctionCall() != null) {
System.out.println("Trying to execute " + chatMessage.getFunctionCall().getName() + "...");
ChatMessage functionResponse = functionExecutor.executeAndConvertToMessageHandlingExceptions(chatMessage.getFunctionCall());
System.out.println("Executed " + chatMessage.getFunctionCall().getName() + ".");
messages.add(functionResponse);
continue;
}
System.out.print("Next Query: ");
String nextLine = scanner.nextLine();
if (nextLine.equalsIgnoreCase("exit")) {
System.exit(0);
}
messages.add(new ChatMessage(ChatMessageRole.USER.value(), nextLine));
}
}
}