Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions extension/llm/runner/stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ struct ET_EXPERIMENTAL Stats {
long aggregate_sampling_timer_start_timestamp = 0;
};

static constexpr auto kTopp = 0.9f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this bc-breaking because this header doesn't include sampler.h?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constant belong to sampler.h not stats.h. Anyone using this constant should already imported sampler.h so it's not breaking anything. They don't have to include stats.h right now if they just need to use this constant.


inline std::string stats_to_json_string(const Stats& stats) {
std::stringstream ss;
ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << ","
Expand Down Expand Up @@ -168,7 +166,6 @@ namespace executorch {
namespace llm {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::extension::llm::kTopp;
using ::executorch::extension::llm::print_report;
using ::executorch::extension::llm::Stats;
} // namespace llm
Expand Down
7 changes: 7 additions & 0 deletions extension/llm/sampler/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include <executorch/extension/llm/sampler/sampler.h>
#include <algorithm>
#include <ctime>

namespace executorch {
namespace extension {
Expand Down Expand Up @@ -129,6 +130,12 @@ Sampler::Sampler(
topp_(topp),
rng_state_(rng_seed) {}

Sampler::Sampler(int vocab_size, float temperature)
: vocab_size_(vocab_size),
inv_temperature_(static_cast<bool>(temperature) ? 1.0f / temperature : 0),
topp_(kTopp),
rng_state_(std::time(nullptr)) {}

template <typename T>
static void softmax(T* x, int size) {
// find max value (for numerical stability)
Expand Down
10 changes: 10 additions & 0 deletions extension/llm/sampler/sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ namespace extension {
namespace llm {
// A simple llama2 sampler.

inline constexpr auto kTopp = 0.9f;

template <typename T>
struct ET_EXPERIMENTAL ProbIndex {
T prob;
Expand All @@ -40,6 +42,8 @@ class ET_EXPERIMENTAL Sampler {
float topp,
unsigned long long rng_seed);

Sampler(int32_t vocab_size, float temperature);

template <typename T>
int32_t sample(T* logits);

Expand Down Expand Up @@ -71,3 +75,9 @@ using ::executorch::extension::llm::ProbIndex;
using ::executorch::extension::llm::Sampler;
} // namespace executor
} // namespace torch

namespace executorch::llm {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch::extension::llm` namespaces.
using ::executorch::extension::llm::kTopp;
} // namespace executorch::llm
Loading