diff --git a/services/openai-dialog/src/lib.rs b/services/openai-dialog/src/lib.rs index c638360..27042a9 100644 --- a/services/openai-dialog/src/lib.rs +++ b/services/openai-dialog/src/lib.rs @@ -15,7 +15,10 @@ use openai_api_rs::realtime::{ api::RealtimeClient, client_event::{self, ClientEvent}, server_event::{self, ServerEvent}, - types::{self, ItemContentType, ItemRole, ItemStatus, ItemType, RealtimeVoice, ResponseStatus}, + types::{ + self, ItemContentType, ItemRole, ItemStatus, ItemType, RealtimeVoice, ResponseStatus, + ToolChoice, + }, }; use serde::{Deserialize, Serialize}; use tokio::{net::TcpStream, select}; @@ -42,6 +45,7 @@ pub struct Params { pub temperature: Option, #[serde(default)] pub tools: Vec, + tool_choice: Option, } impl Params { @@ -54,6 +58,7 @@ impl Params { voice: None, temperature: None, tools: vec![], + tool_choice: None, } } } @@ -113,9 +118,18 @@ pub enum ServiceInputEvent { Prompt { text: String, }, + #[serde(rename_all = "camelCase")] SessionUpdate { + #[serde(skip_serializing_if = "Option::is_none")] + instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + voice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, }, } @@ -251,11 +265,6 @@ impl Client { send_update = true; }; - if !params.tools.is_empty() { - session.tools = Some(params.tools); - send_update = true; - } - if let Some(voice) = params.voice { session.voice = Some(voice); send_update = true; @@ -266,6 +275,16 @@ impl Client { send_update = true; } + if !params.tools.is_empty() { + session.tools = Some(params.tools); + send_update = true; + } + + if let Some(tool_choice) = params.tool_choice { + session.tool_choice = Some(tool_choice); + send_update = true; + } + if send_update { self.send_client_event(ClientEvent::SessionUpdate(client_event::SessionUpdate { event_id: None, @@ -415,10 +434,20 @@ impl Client { info!("Received prompt"); self.push_prompt(PromptRequest(text)).await?; } - ServiceInputEvent::SessionUpdate { tools } => { + ServiceInputEvent::SessionUpdate { + instructions, + voice, + temperature, + tools, + tool_choice, + } => { let event = ClientEvent::SessionUpdate(client_event::SessionUpdate { session: types::Session { + instructions, + voice, + temperature, tools, + tool_choice, ..Default::default() }, ..Default::default()