Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch prefill tokens uses max input tokens as default #320

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,10 @@ struct Args {
/// Limits the number of tokens for the prefill operation.
/// Since this operation take the most memory and is compute bound, it is interesting
/// to limit the number of requests that can be sent.
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
/// The default value will be set based on the max_input_length, since it cannot be less than
/// that value.
#[clap(long, env)]
max_batch_prefill_tokens: Option<u32>,

/// **IMPORTANT** This is one critical control to allow maximum usage
/// of the available hardware.
Expand Down Expand Up @@ -1181,6 +1183,13 @@ fn main() -> Result<(), LauncherError> {

tracing::info!("{:?}", args);

// Set default values dervided from other args

// If the value of max_batch_prefill_tokens is not specified, default to max_input_length
if args.max_batch_prefill_tokens.is_none() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we follow the same convention used by max_batch_total_tokens, which is also optional?

args.max_batch_prefill_tokens = Option<u32>(args.max_input_length as u32)
}

// Validate args
if args.max_input_length >= args.max_total_tokens {
return Err(LauncherError::ArgumentValidation(
Expand Down Expand Up @@ -1212,7 +1221,7 @@ fn main() -> Result<(), LauncherError> {
}

if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
if args.max_batch_prefill_tokens.unwrap() > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, max_batch_total_tokens
Expand Down
Loading