Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions services/openai-dialog/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ use openai_api_rs::realtime::{
api::RealtimeClient,
client_event::{self, ClientEvent},
server_event::{self, ServerEvent},
types::{self, ItemContentType, ItemRole, ItemStatus, ItemType, RealtimeVoice, ResponseStatus},
types::{
self, ItemContentType, ItemRole, ItemStatus, ItemType, RealtimeVoice, ResponseStatus,
ToolChoice,
},
};
use serde::{Deserialize, Serialize};
use tokio::{net::TcpStream, select};
Expand All @@ -42,6 +45,7 @@ pub struct Params {
pub temperature: Option<f32>,
#[serde(default)]
pub tools: Vec<types::ToolDefinition>,
Comment thread
pragmatrix marked this conversation as resolved.
tool_choice: Option<ToolChoice>,
}

impl Params {
Expand All @@ -54,6 +58,7 @@ impl Params {
voice: None,
temperature: None,
tools: vec![],
tool_choice: None,
}
}
}
Expand Down Expand Up @@ -113,9 +118,18 @@ pub enum ServiceInputEvent {
Prompt {
text: String,
},
#[serde(rename_all = "camelCase")]
SessionUpdate {
#[serde(skip_serializing_if = "Option::is_none")]
instructions: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
voice: Option<RealtimeVoice>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<types::ToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<ToolChoice>,
},
}

Expand Down Expand Up @@ -251,11 +265,6 @@ impl Client {
send_update = true;
};

if !params.tools.is_empty() {
session.tools = Some(params.tools);
send_update = true;
}

if let Some(voice) = params.voice {
session.voice = Some(voice);
send_update = true;
Expand All @@ -266,6 +275,16 @@ impl Client {
send_update = true;
}

if !params.tools.is_empty() {
session.tools = Some(params.tools);
send_update = true;
}

if let Some(tool_choice) = params.tool_choice {
session.tool_choice = Some(tool_choice);
send_update = true;
}

if send_update {
self.send_client_event(ClientEvent::SessionUpdate(client_event::SessionUpdate {
event_id: None,
Expand Down Expand Up @@ -415,10 +434,20 @@ impl Client {
info!("Received prompt");
self.push_prompt(PromptRequest(text)).await?;
}
ServiceInputEvent::SessionUpdate { tools } => {
ServiceInputEvent::SessionUpdate {
instructions,
voice,
temperature,
tools,
tool_choice,
} => {
let event = ClientEvent::SessionUpdate(client_event::SessionUpdate {
session: types::Session {
instructions,
voice,
temperature,
tools,
tool_choice,
..Default::default()
},
..Default::default()
Expand Down
Loading