From b672cf3ebd8521c846b8fa9efa5bf6cbe87ba8e3 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Thu, 27 Mar 2025 14:58:02 -0700 Subject: [PATCH] Add a convenient constructor (#9707) Summary: As titled. Use default values for `topp` and `rng`. Reviewed By: iseeyuan Differential Revision: D71956172 --- extension/llm/runner/stats.h | 3 --- extension/llm/sampler/sampler.cpp | 7 +++++++ extension/llm/sampler/sampler.h | 10 ++++++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/extension/llm/runner/stats.h b/extension/llm/runner/stats.h index 79c5781e337..19766329ed3 100644 --- a/extension/llm/runner/stats.h +++ b/extension/llm/runner/stats.h @@ -82,8 +82,6 @@ struct ET_EXPERIMENTAL Stats { long aggregate_sampling_timer_start_timestamp = 0; }; -static constexpr auto kTopp = 0.9f; - inline std::string stats_to_json_string(const Stats& stats) { std::stringstream ss; ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << "," @@ -168,7 +166,6 @@ namespace executorch { namespace llm { // TODO(T197294990): Remove these deprecated aliases once all users have moved // to the new `::executorch` namespaces. -using ::executorch::extension::llm::kTopp; using ::executorch::extension::llm::print_report; using ::executorch::extension::llm::Stats; } // namespace llm diff --git a/extension/llm/sampler/sampler.cpp b/extension/llm/sampler/sampler.cpp index 63e4b79d568..18e82418841 100644 --- a/extension/llm/sampler/sampler.cpp +++ b/extension/llm/sampler/sampler.cpp @@ -34,6 +34,7 @@ #include #include +#include namespace executorch { namespace extension { @@ -129,6 +130,12 @@ Sampler::Sampler( topp_(topp), rng_state_(rng_seed) {} +Sampler::Sampler(int vocab_size, float temperature) + : vocab_size_(vocab_size), + inv_temperature_(static_cast(temperature) ? 1.0f / temperature : 0), + topp_(kTopp), + rng_state_(std::time(nullptr)) {} + template static void softmax(T* x, int size) { // find max value (for numerical stability) diff --git a/extension/llm/sampler/sampler.h b/extension/llm/sampler/sampler.h index 759eb6c88a7..1525f38692a 100644 --- a/extension/llm/sampler/sampler.h +++ b/extension/llm/sampler/sampler.h @@ -26,6 +26,8 @@ namespace extension { namespace llm { // A simple llama2 sampler. +inline constexpr auto kTopp = 0.9f; + template struct ET_EXPERIMENTAL ProbIndex { T prob; @@ -40,6 +42,8 @@ class ET_EXPERIMENTAL Sampler { float topp, unsigned long long rng_seed); + Sampler(int32_t vocab_size, float temperature); + template int32_t sample(T* logits); @@ -71,3 +75,9 @@ using ::executorch::extension::llm::ProbIndex; using ::executorch::extension::llm::Sampler; } // namespace executor } // namespace torch + +namespace executorch::llm { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch::extension::llm` namespaces. +using ::executorch::extension::llm::kTopp; +} // namespace executorch::llm