Skip to content

Commit

Permalink
refactor(azure/openai): upgrade to @azure/openai#1.0.0-beta.8
Browse files Browse the repository at this point in the history
  • Loading branch information
paztek committed Dec 11, 2023
1 parent 98d0f02 commit d316306
Show file tree
Hide file tree
Showing 34 changed files with 767 additions and 560 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ pnpmfile.js
# env files
.env
!.env.example

# Sensible example
examples/insurance
12 changes: 12 additions & 0 deletions .idea/runConfigurations/build_watch.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions .idea/runConfigurations/tests.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

128 changes: 82 additions & 46 deletions dist/index.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Assistant {
constructor(params) {
this.client = params.client;
this.instructions = params.instructions;
this.functions = params.functions;
this.tools = params.tools;
this.deployment = params.deployment;
}
listChatCompletions(messages) {
Expand All @@ -24,7 +24,7 @@ class Assistant {
};
messages = [systemMessage, ...messages];
const options = {
functions: this.functions,
tools: this.tools,
};
const completions = this.client.listChatCompletions(this.deployment, messages, options);
return stream.Readable.from(completions, {
Expand All @@ -36,8 +36,9 @@ class Assistant {
class Thread extends EventEmitter {
constructor(messages = []) {
super();
this.messages = messages;
this.messages = [];
this._stream = null;
this.messages = messages;
}
get stream() {
if (!this._stream) {
Expand All @@ -46,8 +47,7 @@ class Thread extends EventEmitter {
return this._stream;
}
addMessage(message) {
this.messages.push(message);
this.emit('message', message);
this.doAddMessage(message);
}
run(assistant) {
this._stream = new stream.Readable({
Expand All @@ -57,7 +57,8 @@ class Thread extends EventEmitter {
}
doRun(assistant) {
this.emit('in_progress');
const stream = assistant.listChatCompletions(this.messages);
const messages = this.getRequestMessages();
const stream = assistant.listChatCompletions(messages);
/**
* When the LLM responds with a function call, the first completion's first choice looks like this:
* {
Expand Down Expand Up @@ -90,12 +91,32 @@ class Thread extends EventEmitter {
if (!delta) {
throw new Error('No delta returned');
}
if (delta.functionCall) {
const name = delta.functionCall.name;
this.handleStreamAsFunctionCall(name, stream, assistant);
if (delta.toolCalls.length > 0) {
this.handleStreamAsToolCalls(delta.toolCalls, stream, assistant);
}
else {
this.handleStreamAsChatMessage(stream);
this.handleStreamAsChatResponseMessage(stream);
}
});
}
/**
* Convert the mix of ChatRequestMessages and ChatResponseMessages to ChatRequestMessages only
* so they can be sent again to the LLM.
*/
getRequestMessages() {
return this.messages.map((m) => {
if (m.role === 'system' || m.role === 'user' || m.role === 'tool') {
// These are messages from the application (a.k.a request messages)
return m;
}
else {
// These are messages from the assistant (a.k.a response messages)
const responseMessage = m;
return {
role: 'assistant',
content: responseMessage.content,
toolCalls: responseMessage.toolCalls,
};
}
});
}
Expand All @@ -117,8 +138,8 @@ class Thread extends EventEmitter {
* }
* { index: 0, finishReason: 'function_call', delta: {} } <---- end of the function call
*/
handleStreamAsFunctionCall(name, stream, assistant) {
let args = '';
handleStreamAsToolCalls(toolCalls, stream, assistant) {
const argsList = Array(toolCalls.length).fill('');
stream.on('data', (completions) => {
const choice = completions.choices[0];
if (!choice) {
Expand All @@ -128,36 +149,35 @@ class Thread extends EventEmitter {
if (!delta) {
throw new Error('No delta returned');
}
if (delta.functionCall) {
const functionCall = delta.functionCall;
if (functionCall.arguments) {
args += functionCall.arguments;
}
}
if (choice.finishReason === 'function_call') {
const functionCall = {
name,
arguments: args,
};
delta.toolCalls.forEach((toolCall, index) => {
argsList[index] += toolCall.function.arguments;
});
if (choice.finishReason === 'tool_calls') {
const finalToolCalls = toolCalls.map((toolCall, index) => ({
...toolCall,
function: {
...toolCall.function,
arguments: argsList[index],
},
}));
// Adds the assistant's response to the messages
const message = {
role: 'assistant',
content: null,
functionCall,
toolCalls: finalToolCalls,
};
this.addMessage(message);
const requiredAction = new RequiredAction({
name,
arguments: args,
});
requiredAction.on('submitting', (toolOutput) => {
// Adds the tool output to the messages
const message = {
role: 'function',
name: functionCall.name,
content: JSON.stringify(toolOutput),
};
this.addMessage(message);
this.doAddMessage(message);
const requiredAction = new RequiredAction(finalToolCalls);
requiredAction.on('submitting', (toolOutputs) => {
// Adds the tool outputs to the messages
for (const toolOutput of toolOutputs) {
const message = {
role: 'tool',
content: JSON.stringify(toolOutput.value),
toolCallId: toolOutput.callId,
};
this.doAddMessage(message);
}
this.doRun(assistant);
});
this.emit('requires_action', requiredAction);
Expand All @@ -182,7 +202,7 @@ class Thread extends EventEmitter {
* }
* { index: 0, finishReason: 'stop', delta: {} } <---- end of the message
*/
handleStreamAsChatMessage(stream) {
handleStreamAsChatResponseMessage(stream) {
let content = '';
stream.on('data', (completions) => {
const choice = completions.choices[0];
Expand All @@ -206,28 +226,44 @@ class Thread extends EventEmitter {
const message = {
role: 'assistant',
content,
toolCalls: [],
};
this.addMessage(message);
this.doAddMessage(message);
this.emit('completed');
this._stream?.push(null);
}
});
}
doAddMessage(message) {
this.messages.push(message);
this.emit('message', message);
if (isChatRequestMessage(message)) {
this.emit('message:request', message);
}
else {
this.emit('message:response', message);
}
}
}
class RequiredAction extends EventEmitter {
constructor(functionCall) {
constructor(toolCalls) {
super();
this.toolCall = {
name: functionCall.name,
arguments: JSON.parse(functionCall.arguments),
};
this.toolCalls = toolCalls;
}
submitToolOutput(toolOutput) {
this.emit('submitting', toolOutput);
submitToolOutputs(toolOutputs) {
this.emit('submitting', toolOutputs);
}
}
function isChatResponseMessage(m) {
return 'toolCalls' in m;
}
function isChatRequestMessage(m) {
return !isChatResponseMessage(m);
}

exports.Assistant = Assistant;
exports.RequiredAction = RequiredAction;
exports.Thread = Thread;
exports.isChatRequestMessage = isChatRequestMessage;
exports.isChatResponseMessage = isChatResponseMessage;
//# sourceMappingURL=index.cjs.map
Loading

0 comments on commit d316306

Please sign in to comment.