diff --git a/extension/llm/runner/stats.h b/extension/llm/runner/stats.h index 79c5781e337..19766329ed3 100644 --- a/extension/llm/runner/stats.h +++ b/extension/llm/runner/stats.h @@ -82,8 +82,6 @@ struct ET_EXPERIMENTAL Stats { long aggregate_sampling_timer_start_timestamp = 0; }; -static constexpr auto kTopp = 0.9f; - inline std::string stats_to_json_string(const Stats& stats) { std::stringstream ss; ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << "," @@ -168,7 +166,6 @@ namespace executorch { namespace llm { // TODO(T197294990): Remove these deprecated aliases once all users have moved // to the new `::executorch` namespaces. -using ::executorch::extension::llm::kTopp; using ::executorch::extension::llm::print_report; using ::executorch::extension::llm::Stats; } // namespace llm diff --git a/extension/llm/sampler/sampler.cpp b/extension/llm/sampler/sampler.cpp index 63e4b79d568..18e82418841 100644 --- a/extension/llm/sampler/sampler.cpp +++ b/extension/llm/sampler/sampler.cpp @@ -34,6 +34,7 @@ #include #include +#include namespace executorch { namespace extension { @@ -129,6 +130,12 @@ Sampler::Sampler( topp_(topp), rng_state_(rng_seed) {} +Sampler::Sampler(int vocab_size, float temperature) + : vocab_size_(vocab_size), + inv_temperature_(static_cast(temperature) ? 1.0f / temperature : 0), + topp_(kTopp), + rng_state_(std::time(nullptr)) {} + template static void softmax(T* x, int size) { // find max value (for numerical stability) diff --git a/extension/llm/sampler/sampler.h b/extension/llm/sampler/sampler.h index 759eb6c88a7..1525f38692a 100644 --- a/extension/llm/sampler/sampler.h +++ b/extension/llm/sampler/sampler.h @@ -26,6 +26,8 @@ namespace extension { namespace llm { // A simple llama2 sampler. +inline constexpr auto kTopp = 0.9f; + template struct ET_EXPERIMENTAL ProbIndex { T prob; @@ -40,6 +42,8 @@ class ET_EXPERIMENTAL Sampler { float topp, unsigned long long rng_seed); + Sampler(int32_t vocab_size, float temperature); + template int32_t sample(T* logits); @@ -71,3 +75,9 @@ using ::executorch::extension::llm::ProbIndex; using ::executorch::extension::llm::Sampler; } // namespace executor } // namespace torch + +namespace executorch::llm { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch::extension::llm` namespaces. +using ::executorch::extension::llm::kTopp; +} // namespace executorch::llm