Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support customizing top_p parameter #434

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ Feel free to adjust the configuration according to your needs.
> Get `config.yaml` path with command `aichat --info` or repl command `.info`.

```yaml
model: openai:gpt-3.5-turbo # The Large Language Model (LLM) to use
temperature: 1.0 # Controls the randomness and creativity of the LLM's responses
model: openai:gpt-3.5-turbo # Specify the language model to use
temperature: null # Set default temperature parameter
top_p: null # Set default top-p parameter
save: true # Indicates whether to persist the message
save_session: null # Controls the persistence of the session, if null, asking the user
highlight: true # Controls syntax highlighting
Expand Down
5 changes: 3 additions & 2 deletions config.example.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model: openai:gpt-3.5-turbo # The Large Language Model (LLM) to use
temperature: 1.0 # Controls the randomness and creativity of the LLM's responses
model: openai:gpt-3.5-turbo # Specify the language model to use
temperature: null # Set default temperature parameter
top_p: null # Set default top-p parameter
save: true # Indicates whether to persist the message
save_session: null # Controls the persistence of the session, if null, asking the user
highlight: true # Controls syntax highlighting
Expand Down
4 changes: 4 additions & 0 deletions src/client/claude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
mut messages,
temperature,
top_p,
stream,
} = data;

Expand Down Expand Up @@ -205,6 +206,9 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
if let Some(v) = temperature {
body["temperature"] = v.into();
}
if let Some(v) = top_p {
body["top_p"] = v.into();
}
if stream {
body["stream"] = true.into();
}
Expand Down
4 changes: 4 additions & 0 deletions src/client/cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
mut messages,
temperature,
top_p,
stream,
} = data;

Expand Down Expand Up @@ -173,6 +174,9 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
if let Some(temperature) = temperature {
body["temperature"] = temperature.into();
}
if let Some(top_p) = top_p {
body["p"] = top_p.into();
}
if stream {
body["stream"] = true.into();
}
Expand Down
1 change: 1 addition & 0 deletions src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ pub struct ExtraConfig {
pub struct SendData {
pub messages: Vec<Message>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub stream: bool,
}

Expand Down
4 changes: 4 additions & 0 deletions src/client/ernie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ fn build_body(data: SendData, model: &Model) -> Value {
let SendData {
mut messages,
temperature,
top_p,
stream,
} = data;

Expand All @@ -242,6 +243,9 @@ fn build_body(data: SendData, model: &Model) -> Value {
if let Some(temperature) = temperature {
body["temperature"] = temperature.into();
}
if let Some(top_p) = top_p {
body["top_p"] = top_p.into();
}

if let Some(max_output_tokens) = model.max_output_tokens {
body["max_output_tokens"] = max_output_tokens.into();
Expand Down
4 changes: 4 additions & 0 deletions src/client/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
messages,
temperature,
top_p,
stream,
} = data;

Expand Down Expand Up @@ -185,6 +186,9 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
if let Some(temperature) = temperature {
body["options"]["temperature"] = temperature.into();
}
if let Some(top_p) = top_p {
body["options"]["top_p"] = top_p.into();
}

Ok(body)
}
Expand Down
12 changes: 8 additions & 4 deletions src/client/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value {
let SendData {
messages,
temperature,
top_p,
stream,
} = data;

Expand All @@ -139,13 +140,16 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value {
});

if let Some(max_tokens) = model.max_output_tokens {
body["max_tokens"] = json!(max_tokens);
body["max_tokens"] = max_tokens.into();
} else if model.name == "gpt-4-vision-preview" {
// The default max_tokens of gpt-4-vision-preview is only 16, we need to make it larger
body["max_tokens"] = json!(4096);
body["max_tokens"] = 4096.into();
}
if let Some(v) = temperature {
body["temperature"] = v.into();
if let Some(temperature) = temperature {
body["temperature"] = temperature.into();
}
if let Some(top_p) = top_p {
body["top_p"] = top_p.into();
}
if stream {
body["stream"] = true.into();
Expand Down
53 changes: 24 additions & 29 deletions src/client/qianwen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ async fn send_message_streaming(
is_vl: bool,
) -> Result<()> {
let mut es = builder.eventsource()?;
let mut offset = 0;

while let Some(event) = es.next().await {
match event {
Expand All @@ -139,12 +138,10 @@ async fn send_message_streaming(
let data: Value = serde_json::from_str(&message.data)?;
catch_error(&data)?;
if is_vl {
let text =
data["output"]["choices"][0]["message"]["content"][0]["text"].as_str();
if let Some(text) = text {
let text = &text[offset..];
if let Some(text) =
data["output"]["choices"][0]["message"]["content"][0]["text"].as_str()
{
handler.text(text)?;
offset += text.len();
}
} else if let Some(text) = data["output"]["text"].as_str() {
handler.text(text)?;
Expand All @@ -169,11 +166,12 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
let SendData {
messages,
temperature,
top_p,
stream,
} = data;

let mut has_upload = false;
let (input, parameters) = if is_vl {
let input = if is_vl {
let messages: Vec<Value> = messages
.into_iter()
.map(|message| {
Expand All @@ -199,40 +197,37 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
})
.collect();

let input = json!({
json!({
"messages": messages,
});

let mut parameters = json!({});
if let Some(v) = temperature {
parameters["temperature"] = v.into();
}
(input, parameters)
})
} else {
let input = json!({
json!({
"messages": messages,
});
})
};

let mut parameters = json!({});
if stream {
parameters["incremental_output"] = true.into();
}
let mut parameters = json!({});
if stream {
parameters["incremental_output"] = true.into();
}

if let Some(max_tokens) = model.max_output_tokens {
parameters["max_tokens"] = max_tokens.into();
}
if let Some(max_tokens) = model.max_output_tokens {
parameters["max_tokens"] = max_tokens.into();
}

if let Some(v) = temperature {
parameters["temperature"] = v.into();
}
(input, parameters)
};
if let Some(temperature) = temperature {
parameters["temperature"] = temperature.into();
}
if let Some(top_p) = top_p {
parameters["top_p"] = top_p.into();
}

let body = json!({
"model": &model.name,
"input": input,
"parameters": parameters
});

Ok((body, has_upload))
}

Expand Down
7 changes: 6 additions & 1 deletion src/client/vertexai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ pub(crate) fn build_body(
let SendData {
mut messages,
temperature,
..
top_p,
stream: _,
} = data;

patch_system_message(&mut messages);
Expand Down Expand Up @@ -223,6 +224,10 @@ pub(crate) fn build_body(
body["generationConfig"]["temperature"] = temperature.into();
}

if let Some(top_p) = top_p {
body["generationConfig"]["topP"] = top_p.into();
}

Ok(body)
}

Expand Down
32 changes: 32 additions & 0 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub struct Config {
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_id: Option<String>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub dry_run: bool,
pub save: bool,
pub save_session: Option<bool>,
Expand Down Expand Up @@ -89,6 +90,7 @@ impl Default for Config {
Self {
model_id: None,
temperature: None,
top_p: None,
save: true,
save_session: None,
highlight: true,
Expand Down Expand Up @@ -297,6 +299,7 @@ impl Config {
if let Some(session) = self.session.as_mut() {
session.guard_empty()?;
session.set_temperature(role.temperature);
session.set_top_p(role.top_p);
}
self.role = Some(role);
Ok(())
Expand Down Expand Up @@ -335,6 +338,16 @@ impl Config {
}
}

pub fn set_top_p(&mut self, value: Option<f64>) {
if let Some(session) = self.session.as_mut() {
session.set_top_p(value);
} else if let Some(role) = self.role.as_mut() {
role.set_top_p(value);
} else {
self.top_p = value;
}
}

pub fn set_save_session(&mut self, value: Option<bool>) {
if let Some(session) = self.session.as_mut() {
session.set_save_session(value);
Expand Down Expand Up @@ -411,6 +424,7 @@ impl Config {
let items = vec![
("model", self.model.id()),
("temperature", format_option(&self.temperature)),
("top_p", format_option(&self.top_p)),
("dry_run", self.dry_run.to_string()),
("save", self.save.to_string()),
("save_session", format_option(&self.save_session)),
Expand Down Expand Up @@ -478,6 +492,7 @@ impl Config {
".session" => self.list_sessions(),
".set" => vec![
"temperature ",
"top_p ",
"compress_threshold",
"save ",
"save_session ",
Expand Down Expand Up @@ -529,6 +544,10 @@ impl Config {
let value = parse_value(value)?;
self.set_temperature(value);
}
"top_p" => {
let value = parse_value(value)?;
self.set_top_p(value);
}
"compress_threshold" => {
let value = parse_value(value)?;
self.set_compress_threshold(value);
Expand Down Expand Up @@ -756,10 +775,18 @@ impl Config {
} else {
self.temperature
};
let top_p = if let Some(session) = input.session(&self.session) {
session.top_p()
} else if let Some(role) = input.role() {
role.top_p
} else {
self.top_p
};
self.model.max_input_tokens_limit(&messages)?;
Ok(SendData {
messages,
temperature,
top_p,
stream,
})
}
Expand Down Expand Up @@ -791,6 +818,11 @@ impl Config {
output.insert("temperature", temperature.to_string());
}
}
if let Some(top_p) = self.top_p {
if top_p != 0.0 {
output.insert("top_p", top_p.to_string());
}
}
if self.dry_run {
output.insert("dry_run", "true".to_string());
}
Expand Down
Loading