diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index c48436a80..c86e8ab62 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -942,7 +942,7 @@ jobs: path: | ./et-build ./torchchat/utils/scripts - key: et-build-${{runner.os}}-${{runner.arch}}-${{env.et-git-hash}}-${{ hashFiles('**/install_et.sh') }} + key: et-build-${{runner.os}}-${{runner.arch}}-${{env.et-git-hash}}-${{ hashFiles('**/install_et.sh', '**/build_native.sh') }} - if: ${{ steps.install-et.outputs.cache-hit != 'true' }} continue-on-error: true run: | @@ -1053,7 +1053,7 @@ jobs: # Pull submodules (re2, abseil) for Tiktoken git submodule sync - git submodule update --init + git submodule update --init --recursive ./runner/build_android.sh echo "Tests complete." diff --git a/.gitmodules b/.gitmodules index 7681823df..76bc1b9fd 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,3 @@ -[submodule "tokenizer/third-party/abseil-cpp"] - path = tokenizer/third-party/abseil-cpp - url = https://github.com/abseil/abseil-cpp.git -[submodule "tokenizer/third-party/re2"] - path = tokenizer/third-party/re2 - url = https://github.com/google/re2.git -[submodule "tokenizer/third-party/sentencepiece"] - path = tokenizer/third-party/sentencepiece - url = https://github.com/google/sentencepiece.git +[submodule "runner/third-party/tokenizers"] + path = runner/third-party/tokenizers + url = https://github.com/pytorch-labs/tokenizers diff --git a/CMakeLists.txt b/CMakeLists.txt index 61fd4d5a6..e004dbfcb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,18 +7,21 @@ ELSE() ENDIF() project(Torchchat) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes") # include tokenizer -add_subdirectory(tokenizer) +add_subdirectory(runner/third-party/tokenizers) # include et_run executable include(runner/et.cmake) if(TARGET et_run) - target_link_libraries(et_run PUBLIC tokenizer microkernels-prod) + target_link_libraries(et_run PUBLIC tokenizers microkernels-prod) + target_include_directories(et_run PUBLIC runner/third-party/tokenizers/include) endif() # include aoti_run executable include(runner/aoti.cmake) if(TARGET aoti_run) - target_link_libraries(aoti_run tokenizer) + target_link_libraries(aoti_run tokenizers) + target_include_directories(aoti_run PUBLIC runner/third-party/tokenizers/include) endif() diff --git a/runner/run.cpp b/runner/run.cpp index abfbb4584..f2b8e8e6b 100644 --- a/runner/run.cpp +++ b/runner/run.cpp @@ -7,20 +7,21 @@ LICENSE file in the root directory of this source tree. */ /* Inference for Llama-2 Transformer model in pure C++ */ +#include "sentencepiece.h" +#include "tiktoken.h" +#include +#include +#include +#include #include +#include #include #include #include #include #include -#include -#include -#include -#include -#include -#include #include - +#include #ifdef DEBUG #include #include @@ -47,13 +48,25 @@ torch::Device aoti_device(torch::kCPU); #endif using exec_aten::ScalarType; -using torch::executor::EValue; -using executorch::extension::TensorPtr; using executorch::extension::make_tensor_ptr; +using executorch::extension::TensorPtr; +using torch::executor::EValue; using torch::executor::Module; using torch::executor::Result; #endif +using tokenizers::SPTokenizer; +using tokenizers::Tiktoken; +using tokenizers::Tokenizer; + +#define UNWRAP(x) \ + ({ \ + if (!(x).ok()) { \ + fprintf(stderr, "Got error code % " PRIu32, x.error()); \ + exit(EXIT_FAILURE); \ + } \ + std::move(x.get()); \ + }) // ---------------------------------------------------------------------------- // Transformer model @@ -65,56 +78,56 @@ enum ModelType { ModelType get_model_type(int model_int) { switch (model_int) { - case 2: - return LLAMA2_MODEL; - break; - case 3: - return LLAMA3_MODEL; - break; - default: - return UNKNOWN_MODEL; + case 2: + return LLAMA2_MODEL; + break; + case 3: + return LLAMA3_MODEL; + break; + default: + return UNKNOWN_MODEL; } } typedef struct { int vocab_size; // vocabulary size, usually 256 (byte-level) - int seq_len; // max sequence length + int seq_len; // max sequence length } Config; typedef struct { - float* logits; // output logits - int64_t* toks; // tokens seen so far; no kv-cache :( + float *logits; // output logits + int64_t *toks; // tokens seen so far; no kv-cache :( } RunState; typedef struct { - Config config; // the hyperparameters of the architecture (the blueprint) + Config config; // the hyperparameters of the architecture (the blueprint) RunState state; // buffers for the "wave" of activations in the forward pass #ifdef __AOTI_MODEL__ - torch::inductor::AOTIModelPackageLoader* runner; + torch::inductor::AOTIModelPackageLoader *runner; #else // __ET_MODEL__ - Module* runner; + Module *runner; #endif } Transformer; -void malloc_run_state(RunState* s, Config* p) { +void malloc_run_state(RunState *s, Config *p) { // we calloc instead of malloc to keep valgrind happy - s->logits = (float*)calloc(p->vocab_size, sizeof(float)); - s->toks = (int64_t*)calloc(p->seq_len, sizeof(int64_t)); + s->logits = (float *)calloc(p->vocab_size, sizeof(float)); + s->toks = (int64_t *)calloc(p->seq_len, sizeof(int64_t)); if (!s->logits || !s->toks) { fprintf(stderr, "malloc failed!\n"); exit(EXIT_FAILURE); } } -void free_run_state(RunState* s) { +void free_run_state(RunState *s) { free(s->logits); free(s->toks); } -void read_checkpoint(char* checkpoint, Config* config) { - FILE* file = fopen(checkpoint, "rb"); +void read_checkpoint(char *checkpoint, Config *config) { + FILE *file = fopen(checkpoint, "rb"); if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); @@ -128,11 +141,8 @@ void read_checkpoint(char* checkpoint, Config* config) { config->vocab_size = abs(config->vocab_size); } -void build_transformer( - Transformer* t, - char* model_path, - int vocab_size, - int seq_len) { +void build_transformer(Transformer *t, char *model_path, int vocab_size, + int seq_len) { // read in the Config and the Weights from the model // read_checkpoint(model_path, &t->config); // allocate the RunState buffers @@ -142,7 +152,9 @@ void build_transformer( #ifdef __AOTI_MODEL__ t->runner = new torch::inductor::AOTIModelPackageLoader(model_path); - aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" ? torch::Device(torch::kCPU) : torch::Device(torch::kCUDA); + aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" + ? torch::Device(torch::kCPU) + : torch::Device(torch::kCUDA); #else //__ET_MODEL__ t->runner = new Module( /* path to PTE model */ model_path, @@ -150,7 +162,7 @@ void build_transformer( #endif } -void free_transformer(Transformer* t) { +void free_transformer(Transformer *t) { // free the RunState buffers free_run_state(&t->state); delete t->runner; @@ -159,7 +171,7 @@ void free_transformer(Transformer* t) { // ---------------------------------------------------------------------------- // neural net blocks; the dynamics of the Transformer -void softmax(float* x, int size) { +void softmax(float *x, int size) { // find max value (for numerical stability) float max_val = x[0]; for (int i = 1; i < size; i++) { @@ -179,9 +191,9 @@ void softmax(float* x, int size) { } } -float* forward(Transformer* transformer, int token, int pos) { - Config* p = &transformer->config; - RunState* s = &transformer->state; +float *forward(Transformer *transformer, int token, int pos) { + Config *p = &transformer->config; + RunState *s = &transformer->state; s->toks[pos] = token; long token_buffer[1] = {token}; long pos_buffer[1] = {pos}; @@ -194,8 +206,8 @@ float* forward(Transformer* transformer, int token, int pos) { torch::Tensor token_tensor = torch::from_blob(token_buffer, {1, 1}, torch::kLong); torch::Tensor pos_tensor = torch::from_blob(pos_buffer, {1}, torch::kLong); - std::vector inputs{ - token_tensor.to(aoti_device), pos_tensor.to(aoti_device)}; + std::vector inputs{token_tensor.to(aoti_device), + pos_tensor.to(aoti_device)}; torch::Tensor result = transformer->runner->run(inputs)[0] .to(torch::dtype(torch::kFloat32)) @@ -204,7 +216,8 @@ float* forward(Transformer* transformer, int token, int pos) { memcpy(s->logits, logits, p->vocab_size * sizeof(float)); #else // __ET_MODEL__ TensorPtr pos_managed = make_tensor_ptr({1}, pos_buffer, ScalarType::Long); - TensorPtr tokens_managed = make_tensor_ptr({1, 1}, token_buffer, ScalarType::Long); + TensorPtr tokens_managed = + make_tensor_ptr({1, 1}, token_buffer, ScalarType::Long); std::vector inputs; auto tmp1 = EValue(tokens_managed); auto tmp2 = EValue(pos_managed); @@ -221,17 +234,12 @@ float* forward(Transformer* transformer, int token, int pos) { // HACK: the rest of this runner assumes that logits must be float, // so we simply convert them rather than plumbing // templating/switch-on-type through the rest of this file. - const auto& result_tensor = result[0].toTensor(); + const auto &result_tensor = result[0].toTensor(); ET_SWITCH_REALHBBF16_TYPES( - result_tensor.scalar_type(), - unused, - "forward", - CTYPE, - [&]() { - const CTYPE* logits = result_tensor.const_data_ptr(); - std::transform(logits, logits + p->vocab_size, s->logits, [](auto x) { - return static_cast(x); - }); + result_tensor.scalar_type(), unused, "forward", CTYPE, [&]() { + const CTYPE *logits = result_tensor.const_data_ptr(); + std::transform(logits, logits + p->vocab_size, s->logits, + [](auto x) { return static_cast(x); }); }); #endif @@ -249,13 +257,13 @@ typedef struct { typedef struct { int vocab_size; - ProbIndex* probindex; // buffer used in top-p sampling + ProbIndex *probindex; // buffer used in top-p sampling float temperature; float topp; unsigned long long rng_state; } Sampler; -int sample_argmax(float* probabilities, int n) { +int sample_argmax(float *probabilities, int n) { // return the index that has the highest probability int max_i = 0; float max_p = probabilities[0]; @@ -268,7 +276,7 @@ int sample_argmax(float* probabilities, int n) { return max_i; } -int sample_mult(float* probabilities, int n, float coin) { +int sample_mult(float *probabilities, int n, float coin) { // sample index from probabilities (they must sum to 1!) // coin is a random number in [0, 1), usually from random_f32() float cdf = 0.0f; @@ -281,9 +289,9 @@ int sample_mult(float* probabilities, int n, float coin) { return n - 1; // in case of rounding errors } -int compare(const void* a, const void* b) { - ProbIndex* a_ = (ProbIndex*)a; - ProbIndex* b_ = (ProbIndex*)b; +int compare(const void *a, const void *b) { + ProbIndex *a_ = (ProbIndex *)a; + ProbIndex *b_ = (ProbIndex *)b; if (a_->prob > b_->prob) return -1; if (a_->prob < b_->prob) @@ -291,12 +299,8 @@ int compare(const void* a, const void* b) { return 0; } -int sample_topp( - float* probabilities, - int n, - float topp, - ProbIndex* probindex, - float coin) { +int sample_topp(float *probabilities, int n, float topp, ProbIndex *probindex, + float coin) { // top-p sampling (or "nucleus sampling") samples from the smallest set of // tokens that exceed probability topp. This way we never sample tokens that // have very low probabilities and are less likely to go "off the rails". @@ -339,37 +343,31 @@ int sample_topp( return probindex[last_idx].index; // in case of rounding errors } -void build_sampler( - Sampler* sampler, - int vocab_size, - float temperature, - float topp, - unsigned long long rng_seed) { +void build_sampler(Sampler *sampler, int vocab_size, float temperature, + float topp, unsigned long long rng_seed) { sampler->vocab_size = vocab_size; sampler->temperature = temperature; sampler->topp = topp; sampler->rng_state = rng_seed; // buffer only used with nucleus sampling; may not need but it's ~small sampler->probindex = - (ProbIndex*)malloc(sampler->vocab_size * sizeof(ProbIndex)); + (ProbIndex *)malloc(sampler->vocab_size * sizeof(ProbIndex)); } -void free_sampler(Sampler* sampler) { - free(sampler->probindex); -} +void free_sampler(Sampler *sampler) { free(sampler->probindex); } -unsigned int random_u32(unsigned long long* state) { +unsigned int random_u32(unsigned long long *state) { // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A *state ^= *state >> 12; *state ^= *state << 25; *state ^= *state >> 27; return (*state * 0x2545F4914F6CDD1Dull) >> 32; } -float random_f32(unsigned long long* state) { // random float32 in [0,1) +float random_f32(unsigned long long *state) { // random float32 in [0,1) return (random_u32(state) >> 8) / 16777216.0f; } -int sample(Sampler* sampler, float* logits) { +int sample(Sampler *sampler, float *logits) { // sample the token given the logits and some hyperparameters int next; if (sampler->temperature == 0.0f) { @@ -390,39 +388,37 @@ int sample(Sampler* sampler, float* logits) { next = sample_mult(logits, sampler->vocab_size, coin); } else { // top-p (nucleus) sampling, clamping the least likely tokens to zero - next = sample_topp( - logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin); + next = sample_topp(logits, sampler->vocab_size, sampler->topp, + sampler->probindex, coin); } } return next; } -Tokenizer* build_tokenizer(const char* tokenizer_path, ModelType model_type) { - Tokenizer* tokenizer = NULL; +Tokenizer *build_tokenizer(const char *tokenizer_path, ModelType model_type) { + Tokenizer *tokenizer = NULL; switch (model_type) { - case LLAMA2_MODEL: - tokenizer = new SPTokenizer(); - tokenizer->load(tokenizer_path); - break; - case LLAMA3_MODEL: - tokenizer = new Tiktoken(); - tokenizer->load(tokenizer_path); - break; - default: - fprintf(stderr, "No tokenizer defined for model type %d.\n", model_type); - exit(EXIT_FAILURE); + case LLAMA2_MODEL: + tokenizer = new SPTokenizer(); + tokenizer->load(tokenizer_path); + break; + case LLAMA3_MODEL: + tokenizer = new Tiktoken(); + tokenizer->load(tokenizer_path); + break; + default: + fprintf(stderr, "No tokenizer defined for model type %d.\n", model_type); + exit(EXIT_FAILURE); } return tokenizer; } -void free_tokenizer(Tokenizer* tokenizer) { - delete tokenizer; -} +void free_tokenizer(Tokenizer *tokenizer) { delete tokenizer; } // ---------------------------------------------------------------------------- // utilities: time -void safe_printf(const char* piece) { +void safe_printf(const char *piece) { // piece might be a raw byte token, and we only want to print printable chars // or whitespace because some of the other bytes can be various control codes, // backspace, etc. @@ -454,21 +450,18 @@ long time_in_ms() { // Prints decoded tokens generated from the transformer. // The first token is not printed and is assumed to be a BOS or other similar // token -unsigned generate_from_prompt_tokens( - Transformer* transformer, - Tokenizer* tokenizer, - Sampler* sampler, - const std::vector& prompt_tokens, - unsigned pos, - const std::vector& stop_tokens, - int stop_pos, - bool print_prompt, - bool print_tok_per_sec) { +unsigned generate_from_prompt_tokens(Transformer *transformer, + Tokenizer *tokenizer, Sampler *sampler, + const std::vector &prompt_tokens, + unsigned pos, + const std::vector &stop_tokens, + int stop_pos, bool print_prompt, + bool print_tok_per_sec) { if (prompt_tokens.size() == 0) { return pos; } - uint64_t next; // will store the next token in the sequence + uint64_t next; // will store the next token in the sequence uint64_t token; // stores the current token to feed into the transformer bool done_with_prompt; // whether we are done processing prompt @@ -486,7 +479,7 @@ unsigned generate_from_prompt_tokens( if (pos_in_prompt < prompt_tokens.size()) { // Token comes from prompt token = prompt_tokens[pos_in_prompt++]; - float* logits = forward(transformer, token, pos); + float *logits = forward(transformer, token, pos); // Next token is either from prompt or if on last // prompt token, next is sampled @@ -498,29 +491,27 @@ unsigned generate_from_prompt_tokens( } else { // Token comes from next sampled from previous round. token = next; - float* logits = forward(transformer, token, pos); + float *logits = forward(transformer, token, pos); next = sample(sampler, logits); } done_with_prompt = (pos_in_prompt >= prompt_tokens.size()); // we terminate on finding the stop_token if we are done processing the // prompt (stop_tokens in the prompt do not terminate the loop) - if (done_with_prompt && - (std::find(stop_tokens.begin(), stop_tokens.end(), token) != - stop_tokens.end())) { + if (done_with_prompt && (std::find(stop_tokens.begin(), stop_tokens.end(), + token) != stop_tokens.end())) { found_stop_token = true; } // We print next in each iteration of the loop, not token if (!found_stop_token && (print_prompt || done_with_prompt)) { // The stop_token is printed as newline - bool next_is_stop = - std::find(stop_tokens.begin(), stop_tokens.end(), next) != - stop_tokens.end(); + bool next_is_stop = std::find(stop_tokens.begin(), stop_tokens.end(), + next) != stop_tokens.end(); if (next_is_stop) { printf("\n"); } else { - std::string piece = tokenizer->decode(token, next); + std::string piece = UNWRAP(tokenizer->decode(token, next)); safe_printf(piece.c_str()); // same as printf("%s", piece), but skips // "unsafe" bytes fflush(stdout); @@ -538,23 +529,16 @@ unsigned generate_from_prompt_tokens( // iteration) if (print_tok_per_sec && pos > 1) { long end = time_in_ms(); - fprintf( - stderr, - "\n\nachieved tok/s: %f\n", - (pos - 1) / (double)(end - start) * 1000); + fprintf(stderr, "\n\nachieved tok/s: %f\n", + (pos - 1) / (double)(end - start) * 1000); } return pos; } -void generate( - Transformer* transformer, - Tokenizer* tokenizer, - Sampler* sampler, - const char* prompt, - int steps, - ModelType model_type) { - const char* default_prompt = "Once upon a time"; +void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, + const char *prompt, int steps, ModelType model_type) { + const char *default_prompt = "Once upon a time"; if (prompt == NULL) { prompt = default_prompt; } @@ -566,33 +550,30 @@ void generate( std::vector prompt_tokens; std::vector stop_tokens; switch (model_type) { - case LLAMA2_MODEL: - prompt_tokens = tokenizer->encode(prompt, 1, 0); - stop_tokens.push_back(tokenizer->eos_tok()); - break; - case LLAMA3_MODEL: - prompt_tokens = tokenizer->encode(prompt, 1, 0); - stop_tokens.push_back(tokenizer->encode("<|end_of_text|>", 0, 0)[0]); - stop_tokens.push_back(tokenizer->encode("<|eot_id|>", 0, 0)[0]); - break; - default: - fprintf(stderr, "Generate does not support model type %d.\n", model_type); - exit(EXIT_FAILURE); - } - - generate_from_prompt_tokens( - transformer, - tokenizer, - sampler, - prompt_tokens, - /*pos=*/0, - /*stop_tokens=*/stop_tokens, - /*stop_pos=*/steps - 1, - /*print_prompt=*/true, - /*print_tok_per_sec=*/true); + case LLAMA2_MODEL: + prompt_tokens = UNWRAP(tokenizer->encode(prompt, 1, 0)); + stop_tokens.push_back(tokenizer->eos_tok()); + break; + case LLAMA3_MODEL: + prompt_tokens = UNWRAP(tokenizer->encode(prompt, 1, 0)); + stop_tokens.push_back( + UNWRAP(tokenizer->encode("<|end_of_text|>", 0, 0))[0]); + stop_tokens.push_back(UNWRAP(tokenizer->encode("<|eot_id|>", 0, 0))[0]); + break; + default: + fprintf(stderr, "Generate does not support model type %d.\n", model_type); + exit(EXIT_FAILURE); + } + + generate_from_prompt_tokens(transformer, tokenizer, sampler, prompt_tokens, + /*pos=*/0, + /*stop_tokens=*/stop_tokens, + /*stop_pos=*/steps - 1, + /*print_prompt=*/true, + /*print_tok_per_sec=*/true); } -void read_stdin(const char* guide, char* buffer, size_t bufsize) { +void read_stdin(const char *guide, char *buffer, size_t bufsize) { // read a line from stdin, up to but not including \n printf("%s", guide); if (fgets(buffer, bufsize, stdin) != NULL) { @@ -609,11 +590,10 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) { // python reference and that seemed ok, but this was not thoroughly tested and // is not safely implemented, it's more a proof of concept atm. -std::vector get_initial_prompt_tokens( - const char* cli_system_prompt, - const char* cli_user_prompt, - Tokenizer* tokenizer, - ModelType model_type) { +std::vector get_initial_prompt_tokens(const char *cli_system_prompt, + const char *cli_user_prompt, + Tokenizer *tokenizer, + ModelType model_type) { char system_prompt[512]; char user_prompt[512]; char rendered_prompt[512 * 2 + 200]; // the prompt template is ~170 @@ -622,10 +602,8 @@ std::vector get_initial_prompt_tokens( if (cli_system_prompt != NULL) { strcpy(system_prompt, cli_system_prompt); } else { - read_stdin( - "Enter system prompt (optional): ", - system_prompt, - sizeof(system_prompt)); + read_stdin("Enter system prompt (optional): ", system_prompt, + sizeof(system_prompt)); } if (cli_user_prompt != NULL) { @@ -637,48 +615,40 @@ std::vector get_initial_prompt_tokens( std::vector tokens; switch (model_type) { - case LLAMA2_MODEL: - if (system_prompt[0] != '\0') { - snprintf( - rendered_prompt, - sizeof(rendered_prompt) - 1, - "[INST] <>\n%s\n<>\n\n%s [/INST]", - system_prompt, - user_prompt); - } else { - snprintf( - rendered_prompt, - sizeof(rendered_prompt) - 1, - "[INST] %s [/INST]", - user_prompt); - } + case LLAMA2_MODEL: + if (system_prompt[0] != '\0') { + snprintf(rendered_prompt, sizeof(rendered_prompt) - 1, + "[INST] <>\n%s\n<>\n\n%s [/INST]", system_prompt, + user_prompt); + } else { + snprintf(rendered_prompt, sizeof(rendered_prompt) - 1, + "[INST] %s [/INST]", user_prompt); + } - // We need to add BOS token here and not in template because llama2 - // tokenizer does not pattern match special tokens - tokens = tokenizer->encode(rendered_prompt, 1, 0); - break; - - case LLAMA3_MODEL: - if (system_prompt[0] != '\0') { - snprintf( - rendered_prompt, - sizeof(rendered_prompt) - 1, - "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", - system_prompt, - user_prompt); - } else { - snprintf( - rendered_prompt, - sizeof(rendered_prompt) - 1, - "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", - user_prompt); - } - tokens = tokenizer->encode(rendered_prompt, 0, 0); - break; + // We need to add BOS token here and not in template because llama2 + // tokenizer does not pattern match special tokens + tokens = UNWRAP(tokenizer->encode(rendered_prompt, 1, 0)); + break; + + case LLAMA3_MODEL: + if (system_prompt[0] != '\0') { + snprintf(rendered_prompt, sizeof(rendered_prompt) - 1, + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>" + "\n\n%s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<" + "|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + system_prompt, user_prompt); + } else { + snprintf(rendered_prompt, sizeof(rendered_prompt) - 1, + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%" + "s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + user_prompt); + } + tokens = UNWRAP(tokenizer->encode(rendered_prompt, 0, 0)); + break; - default: - fprintf(stderr, "Chat does not support model type %d.\n", model_type); - exit(EXIT_FAILURE); + default: + fprintf(stderr, "Chat does not support model type %d.\n", model_type); + exit(EXIT_FAILURE); } #ifdef DEBUG @@ -695,9 +665,8 @@ std::vector get_initial_prompt_tokens( return tokens; } -std::vector get_next_user_prompt_tokens( - Tokenizer* tokenizer, - ModelType model_type) { +std::vector get_next_user_prompt_tokens(Tokenizer *tokenizer, + ModelType model_type) { char user_prompt[512]; char rendered_prompt[512 + 150]; // the prompt template is ~100 characters. We // use 150 to be safe. @@ -706,30 +675,26 @@ std::vector get_next_user_prompt_tokens( std::vector tokens; switch (model_type) { - case LLAMA2_MODEL: - snprintf( - rendered_prompt, - sizeof(rendered_prompt) - 1, - "[INST] %s [/INST]", - user_prompt); - - // We need to add BOS token here and not in template because llama2 - // tokenizer does not pattern match special tokens - tokens = tokenizer->encode(rendered_prompt, /*bos*/ 1, /*eos*/ 0); - break; - - case LLAMA3_MODEL: - snprintf( - rendered_prompt, - sizeof(rendered_prompt) - 1, - "<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", - user_prompt); - tokens = tokenizer->encode(rendered_prompt, 0, 0); - break; - - default: - fprintf(stderr, "Chat does not support model type %d.\n", model_type); - exit(EXIT_FAILURE); + case LLAMA2_MODEL: + snprintf(rendered_prompt, sizeof(rendered_prompt) - 1, "[INST] %s [/INST]", + user_prompt); + + // We need to add BOS token here and not in template because llama2 + // tokenizer does not pattern match special tokens + tokens = UNWRAP(tokenizer->encode(rendered_prompt, /*bos*/ 1, /*eos*/ 0)); + break; + + case LLAMA3_MODEL: + snprintf(rendered_prompt, sizeof(rendered_prompt) - 1, + "<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_" + "header_id|>assistant<|end_header_id|>\n\n", + user_prompt); + tokens = UNWRAP(tokenizer->encode(rendered_prompt, 0, 0)); + break; + + default: + fprintf(stderr, "Chat does not support model type %d.\n", model_type); + exit(EXIT_FAILURE); } #ifdef DEBUG @@ -746,14 +711,9 @@ std::vector get_next_user_prompt_tokens( return tokens; } -void chat( - Transformer* transformer, - Tokenizer* tokenizer, - Sampler* sampler, - const char* cli_user_prompt, - const char* cli_system_prompt, - unsigned steps, - ModelType model_type) { +void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, + const char *cli_user_prompt, const char *cli_system_prompt, + unsigned steps, ModelType model_type) { if (steps == 0) { return; } @@ -761,16 +721,16 @@ void chat( uint64_t eot_token; std::vector prompt_tokens; switch (model_type) { - case LLAMA2_MODEL: - // llama2 uses EOS as EOT token - eot_token = tokenizer->eos_tok(); - break; - case LLAMA3_MODEL: - eot_token = tokenizer->encode("<|eot_id|>", 0, 0)[0]; - break; - default: - fprintf(stderr, "Chat does not support model type %d.\n", model_type); - exit(EXIT_FAILURE); + case LLAMA2_MODEL: + // llama2 uses EOS as EOT token + eot_token = tokenizer->eos_tok(); + break; + case LLAMA3_MODEL: + eot_token = UNWRAP(tokenizer->encode("<|eot_id|>", 0, 0))[0]; + break; + default: + fprintf(stderr, "Chat does not support model type %d.\n", model_type); + exit(EXIT_FAILURE); } std::vector stop_tokens{eot_token}; @@ -784,11 +744,7 @@ void chat( } printf("Assistant: "); pos = generate_from_prompt_tokens( - transformer, - tokenizer, - sampler, - prompt_tokens, - pos, + transformer, tokenizer, sampler, prompt_tokens, pos, /*stop_tokens=*/stop_tokens, /*stop_pos=*/steps - 1, // We could pass in -1 here if we do not want // the model to stop mid-reply @@ -803,46 +759,43 @@ void chat( void error_usage() { fprintf(stderr, "Usage: run [options]\n"); - fprintf( - stderr, "Example: run model.{so,pte} -n 256 -i \"Once upon a time\"\n"); + fprintf(stderr, + "Example: run model.{so,pte} -n 256 -i \"Once upon a time\"\n"); fprintf(stderr, "Options:\n"); fprintf(stderr, " -t temperature in [0,inf], default 1.0\n"); - fprintf( - stderr, - " -p p value in top-p (nucleus) sampling in [0,1], default 0.9\n"); + fprintf(stderr, " -p p value in top-p (nucleus) sampling in [0,1], " + "default 0.9\n"); fprintf(stderr, " -s random seed, default time(NULL)\n"); - fprintf( - stderr, - " -n number of steps to run for, default 256. 0 = max_seq_len\n"); + fprintf(stderr, " -n number of steps to run for, default 256. 0 = " + "max_seq_len\n"); fprintf(stderr, " -i input prompt\n"); fprintf(stderr, " -z path to tokenizer\n"); fprintf(stderr, " -m mode: generate|chat, default: generate\n"); fprintf(stderr, " -y (optional) system prompt in chat mode\n"); - fprintf( - stderr, - " -v (optional) vocab size, default is model-specific.\n"); - fprintf( - stderr, " -l (optional) llama version (2 or 3), default 2.\n"); + fprintf(stderr, + " -v (optional) vocab size, default is model-specific.\n"); + fprintf(stderr, + " -l (optional) llama version (2 or 3), default 2.\n"); fprintf( stderr, " -d (optional) device(CUDA or CPU) model was exported for\n"); exit(EXIT_FAILURE); } -int main(int argc, char* argv[]) { +int main(int argc, char *argv[]) { // default parameters - char* model_path = NULL; - char* tokenizer_path = NULL; + char *model_path = NULL; + char *tokenizer_path = NULL; float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, // but slower - int steps = 128; // number of steps to run for - const char* prompt = NULL; // prompt string + int steps = 128; // number of steps to run for + const char *prompt = NULL; // prompt string unsigned long long rng_seed = 0; // seed rng with time by default - const char* mode = "generate"; // generate|chat - char* system_prompt = + const char *mode = "generate"; // generate|chat + char *system_prompt = NULL; // the (optional) system prompt to use in chat mode int vocab_size = -1; @@ -916,10 +869,8 @@ int main(int argc, char* argv[]) { ModelType model_type = get_model_type(llama_ver); if (model_type == UNKNOWN_MODEL) { - fprintf( - stderr, - "Unknown model type passed by -l argument. Received l=%d.", - llama_ver); + fprintf(stderr, "Unknown model type passed by -l argument. Received l=%d.", + llama_ver); error_usage(); } @@ -943,7 +894,7 @@ int main(int argc, char* argv[]) { if (steps < 0) steps = 0; - Tokenizer* tokenizer = build_tokenizer(tokenizer_path, model_type); + Tokenizer *tokenizer = build_tokenizer(tokenizer_path, model_type); // If no tokenizer path provided, get default for model_type if (vocab_size == -1) { @@ -959,14 +910,8 @@ int main(int argc, char* argv[]) { if (strcmp(mode, "generate") == 0) { generate(&transformer, tokenizer, &sampler, prompt, steps, model_type); } else if (strcmp(mode, "chat") == 0) { - chat( - &transformer, - tokenizer, - &sampler, - prompt, - system_prompt, - steps, - model_type); + chat(&transformer, tokenizer, &sampler, prompt, system_prompt, steps, + model_type); } else { fprintf(stderr, "unknown mode: %s\n", mode); error_usage(); diff --git a/runner/third-party/tokenizers b/runner/third-party/tokenizers new file mode 160000 index 000000000..19e463d66 --- /dev/null +++ b/runner/third-party/tokenizers @@ -0,0 +1 @@ +Subproject commit 19e463d665110e1d23145df1ad72bb8db111618b diff --git a/tokenizer/CMakeLists.txt b/tokenizer/CMakeLists.txt deleted file mode 100644 index 39c20885d..000000000 --- a/tokenizer/CMakeLists.txt +++ /dev/null @@ -1,29 +0,0 @@ -cmake_minimum_required(VERSION 3.24) -set(CMAKE_CXX_STANDARD 17) -IF(DEFINED ENV{TORCHCHAT_ROOT}) - set(TORCHCHAT_ROOT $ENV{TORCHCHAT_ROOT}) -ELSE() - set(TORCHCHAT_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) -ENDIF() - -# build tokenizer library -add_library( - tokenizer - tokenizer.h - sentencepiece.cpp - tiktoken.cpp) - -target_include_directories(tokenizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} third-party/sentencepiece/src) - -# add RE2 as subdirectory -set(ABSL_ENABLE_INSTALL ON) -set(ABSL_PROPAGATE_CXX_STD ON) -set(_pic_flag -${CMAKE_POSITION_INDEPENDENT_CODE}) -set(CMAKE_POSITION_INDEPENDENT_CODE ON) -add_subdirectory(third-party/abseil-cpp) -add_subdirectory(third-party/re2) -add_subdirectory(third-party/sentencepiece) -set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag}) - -target_link_libraries(tokenizer PUBLIC re2::re2 sentencepiece-static) diff --git a/tokenizer/__init__.py b/tokenizer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tokenizer/base.py b/tokenizer/base.py deleted file mode 100644 index 75998b32b..000000000 --- a/tokenizer/base.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -""" -Abstract base class for all tokenizer classes in python matching c++ interface. -""" - -# Standard -from abc import ABC, abstractmethod -from typing import List - - -class TokenizerBase(ABC): - __doc__ = __doc__ - - @abstractmethod - def encode(self, s: str, *, bos: bool = False, eos: bool = False) -> List[int]: - """Encode the given string and optionally include bos/eos tokens""" - - @abstractmethod - def decode(self, ids: List[int]) -> str: - """Decode the given token ids into a string""" - - @abstractmethod - def bos_id(self) -> int: - """The id of the begin-of-string token""" - - @abstractmethod - def eos_id(self) -> int: - """The id of the end-of-string token""" diff --git a/tokenizer/base64.h b/tokenizer/base64.h deleted file mode 100644 index 12b8703a8..000000000 --- a/tokenizer/base64.h +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -// @lint-ignore-every LICENSELINT -/************************************************************************** - Copyright (c) 2023 sewenew - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - *************************************************************************/ - -#pragma once - -#include -#include -#include -#include - -namespace base64 { - -std::string decode(const std::string_view& input); - -namespace detail { - -constexpr uint32_t DECODE_TABLE[] = { - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, - 255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, - 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, - 25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33, - 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, - 49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255}; - -inline void validate(uint32_t v) { - if (v == 255) { - fprintf(stderr, "invalid char"); - exit(EXIT_FAILURE); - } -} - -inline void decode(const std::string_view& input, std::string& output) { - if (input.size() != 4) { - fprintf(stderr, "input length must be 4, got %zu", input.size()); - exit(EXIT_FAILURE); - } - - uint32_t val = 0; - - uint8_t c = input[0]; - auto v = DECODE_TABLE[c]; - validate(v); - val = v; - - c = input[1]; - v = DECODE_TABLE[c]; - validate(v); - val = (val << 6) | v; - - c = input[2]; - v = DECODE_TABLE[c]; - validate(v); - val = (val << 6) | v; - - c = input[3]; - v = DECODE_TABLE[c]; - validate(v); - val = (val << 6) | v; - - output.push_back(static_cast((val >> 16) & 0xFF)); - output.push_back(static_cast((val >> 8) & 0xFF)); - output.push_back(static_cast(val & 0xFF)); -} - -inline void decode_1_padding( - const std::string_view& input, - std::string& output) { - if (input.size() != 3) { - fprintf(stderr, "input length must be 3, got %zu", input.size()); - exit(EXIT_FAILURE); - } - - uint32_t val = 0; - - uint8_t c = input[0]; - auto v = DECODE_TABLE[c]; - validate(v); - val = v; - - c = input[1]; - v = DECODE_TABLE[c]; - validate(v); - val = (val << 6) | v; - - c = input[2]; - v = DECODE_TABLE[c]; - validate(v); - val = (val << 6) | v; - - output.push_back(static_cast((val >> 10) & 0xFF)); - output.push_back(static_cast((val >> 2) & 0xFF)); -} - -inline void decode_2_padding( - const std::string_view& input, - std::string& output) { - assert(input.size() == 2); - - uint32_t val = 0; - - uint8_t c = input[0]; - auto v = DECODE_TABLE[c]; - validate(v); - val = v; - - c = input[1]; - v = DECODE_TABLE[c]; - validate(v); - val = (val << 6) | v; - - output.push_back(static_cast((val >> 4) & 0xFF)); -} - -} // namespace detail - -inline std::string decode(const std::string_view& input) { - if (input.empty()) { - fprintf(stderr, "empty input"); - exit(EXIT_FAILURE); - } - - // Faster than `input.size() % 4`. - if ((input.size() & 3) != 0 || input.size() < 4) { - fprintf( - stderr, - "input length must be larger than 4 and is multiple of 4, got %zu", - input.size()); - exit(EXIT_FAILURE); - } - - std::string output; - output.reserve(input.size() / 4 * 3); - auto idx = 0U; - for (; idx < input.size() - 4; idx += 4) { - detail::decode(input.substr(idx, 4), output); - } - - // Last 4 bytes. Might contain paddings. - if (input[idx + 3] == '=') { - if (input[idx + 2] == '=') { - // Tow paddings. - detail::decode_2_padding(input.substr(idx, 2), output); - } else { - // One padding. - detail::decode_1_padding(input.substr(idx, 3), output); - } - } else { - // No padding. - detail::decode(input.substr(idx, 4), output); - } - - return output; -} -} // namespace base64 diff --git a/tokenizer/hf_tokenizer.py b/tokenizer/hf_tokenizer.py deleted file mode 100644 index 7ad5807d1..000000000 --- a/tokenizer/hf_tokenizer.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Standard -from typing import List, Optional -import json -import os - -# Third Party -from tokenizers import Tokenizer - -# Local -from .base import TokenizerBase - - -class HFTokenizer(TokenizerBase): - """ - Wrapper around the Huggingface `tokenizers` library for API compatibility - """ - - def __init__(self, file_path: str): - # If the path is a directory, look for "tokenizer.json" which is - # standard for transformers checkpoints and also look for the - # "tokenizer_config.json" file to parse eos/bos tokens - if os.path.isdir(file_path): - tokenizer_path = os.path.join(file_path, "tokenizer.json") - tokenizer_config_path = os.path.join(file_path, "tokenizer_config.json") - else: - tokenizer_path = file_path - tokenizer_config_path = os.path.join(os.path.dirname(file_path), "tokenizer_config.json") - if not os.path.isfile(tokenizer_path): - tokenizer_config_path = None - - # Load the tokenizer itself - self._tokenizer = Tokenizer.from_file(tokenizer_path) - - # If available, parse bos/eos tokens from the tokenizer config - self._bos_id, self._eos_id = None, None - if tokenizer_config_path is not None: - with open(tokenizer_config_path, "r") as handle: - tok_config = json.load(handle) - bos_token = tok_config.get("bos_token") - eos_token = tok_config.get("eos_token") - if bos_token is not None: - self._bos_id = self._tokenizer.token_to_id(bos_token) - if eos_token is not None: - self._eos_id = self._tokenizer.token_to_id(eos_token) - - # If no eos/bos tokens found, go looking for them! - if None in [self._bos_id, self._eos_id]: - tok_content = json.loads(self._tokenizer.to_str()) - if self._bos_id is None: - self._bos_id = self._look_for_special_token(tok_content, ["begin", "text"]) - if self._eos_id is None: - self._eos_id = self._look_for_special_token(tok_content, ["end", "text"]) - - assert None not in [self._bos_id, self._eos_id], "Unable to find an BOS/EOS tokens" - - @staticmethod - def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optional[int]: - candidate_toks = added_tokens - for search_str in search_strs: - candidate_toks = [ - tok for tok in candidate_toks - if tok["special"] and search_str in tok["content"] - ] - if len(candidate_toks) == 1: - return candidate_toks[0]["id"] - - def encode( - self, - s: str, - *, - bos: bool = False, - eos: bool = False, - ) -> List[int]: - res = self._tokenizer.encode(s, add_special_tokens=bos).ids - if eos and (not res or res[-1] != self._eos_token): - res.append(self._eos_token) - return res - - def decode(self, ids: List[int]) -> str: - return self._tokenizer.decode(ids) - - def bos_id(self) -> int: - return self._bos_id - - def eos_id(self) -> int: - return self._eos_id diff --git a/tokenizer/sentencepiece.cpp b/tokenizer/sentencepiece.cpp deleted file mode 100644 index 0cdfc7e30..000000000 --- a/tokenizer/sentencepiece.cpp +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -// sentencepiece tokenizer - -#include -#include -#include -#include -#include "absl/strings/str_replace.h" - -const char kSpaceSymbol[] = "\xe2\x96\x81"; - -SPTokenizer::SPTokenizer() - : Tokenizer(), - _processor(std::make_unique()) {} - -/** - * @brief Load the tokenizer from a file. The tokenizer file contains the - * vocabulary and scores. The format is: the first integer is the maximum - * token length, followed by a list of (word_len, word) pairs. Here we - * are reading all the vocabulary into memory and keep it sorted for fast - * lookup. - * - * @param tokenizer_path The path to the tokenizer file. - * @return void - */ -void SPTokenizer::load(const std::string& tokenizer_path) { - if (initialized_) { - fprintf(stderr, "Tokenizer already initialized.\n"); - return; - } - // read in the file - const auto status = _processor->Load(tokenizer_path); - if (!status.ok()) { - fprintf(stderr, "couldn't load %s\n. If this tokenizer artifact is for llama3, please pass `-l 3`.", tokenizer_path.c_str()); - exit(EXIT_FAILURE); - } - // load vocab_size, bos_tok, eos_tok - vocab_size_ = _processor->GetPieceSize(); - bos_tok_ = _processor->bos_id(); - eos_tok_ = _processor->eos_id(); - initialized_ = true; -} - -SPTokenizer::~SPTokenizer() {} - -/** - * @brief Decode a token into string. - * - * @param prev_token The previous token. - * @param token The current token. - * @return std::string A pointer to the string representation of the - * token. - */ -std::string SPTokenizer::decode(uint64_t prev_token, uint64_t token) { - if (!initialized_) { - fprintf(stderr, "Tokenizer not initialized\n"); - exit(EXIT_FAILURE); - } - // get rid of the control ids and - if (_processor->IsControl(token)) { - // NB: returning empty string doesn't work for some reason. It causes - // free(): invalid pointer error. - return " "; - } - - std::string result = - absl::StrReplaceAll(_processor->IdToPiece(token), {{kSpaceSymbol, " "}}); - - // following BOS token, sentencepiece decoder strips any leading - // whitespace - if (prev_token == bos_tok_ && result[0] == ' ') { - result = result.substr(1); - } - - // handle <0x0A> - result = absl::StrReplaceAll(result, {{"<0x0A>", "\n"}}); - - return result; -} - -/** - * @brief Encode a string into a sequence of tokens. - * - * @param text The string to be encoded. - * @param bos The number of BOS to prepend to the token list. - * @param eos The number of EOS to append to the token list. - * @return std::vector - */ -std::vector -SPTokenizer::encode(const std::string& text, int8_t bos, int8_t eos) { - if (!initialized_) { - fprintf(stderr, "Tokenizer not initialized\n"); - exit(EXIT_FAILURE); - } - // workaround a weird issue that text doesn't have correct size() - std::string input(text.c_str()); - // should we reserve memory? - std::vector res; - auto status = _processor->Encode(input, &res); - if (!status.ok()) { - fprintf(stderr, "couldn't encode %s\n", text.c_str()); - exit(EXIT_FAILURE); - } - - std::vector tokens; - for (auto i = 0; i < bos; ++i) { - tokens.push_back(bos_tok_); - } - - for (auto i = 0; i < res.size(); ++i) { - tokens.push_back(res[i]); - } - - for (auto i = 0; i < eos; ++i) { - tokens.push_back(eos_tok_); - } - return tokens; -} diff --git a/tokenizer/third-party/abseil-cpp b/tokenizer/third-party/abseil-cpp deleted file mode 160000 index 854193071..000000000 --- a/tokenizer/third-party/abseil-cpp +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 854193071498f330b71083d7e06a7cd18e02a4cc diff --git a/tokenizer/third-party/re2 b/tokenizer/third-party/re2 deleted file mode 160000 index ac82d4f62..000000000 --- a/tokenizer/third-party/re2 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit ac82d4f628a2045d89964ae11c48403d3b091af1 diff --git a/tokenizer/third-party/sentencepiece b/tokenizer/third-party/sentencepiece deleted file mode 160000 index 7dcb54145..000000000 --- a/tokenizer/third-party/sentencepiece +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7dcb541451b1862d73f473b3804ccf8f2a9e10f6 diff --git a/tokenizer/tiktoken.cpp b/tokenizer/tiktoken.cpp deleted file mode 100644 index 2f31f057a..000000000 --- a/tokenizer/tiktoken.cpp +++ /dev/null @@ -1,390 +0,0 @@ -// @lint-ignore-every LICENSELINT -/************************************************************************** - Copyright (c) 2023 sewenew - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - *************************************************************************/ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// ------------------------------Util start------------------------------------ - -static uint64_t _max_size() { - return std::numeric_limits::max(); -} - -static Re2UPtr _create_regex(const std::string& pattern) { - assert(!pattern.empty()); - - return std::make_unique("(" + pattern + ")"); -} - -static Re2UPtr _build_special_token_regex(const Encoder& special_encoder) { - std::string special_pattern; - for (const auto& ele : special_encoder) { - if (!special_pattern.empty()) { - special_pattern += "|"; - } - special_pattern += re2::RE2::QuoteMeta(ele.first); - } - - if (special_pattern.empty()) { - return nullptr; - } - - return _create_regex(special_pattern); -} - -static std::pair _parse(const std::string& line) { - auto pos = line.find(" "); - if (pos == std::string::npos) { - throw std::invalid_argument("invalid encoder line: " + line); - } - - auto token = base64::decode({line.data(), pos}); - uint64_t rank = 0; - try { - rank = std::stoul(line.substr(pos + 1)); - } catch (const std::exception&) { - throw std::invalid_argument("invalid encoder rank: " + line); - } - - return {std::move(token), rank}; -} - -static Encoder _load_encoder(const std::string& path) { - std::ifstream file(path); - if (!file) { - fprintf(stderr, "failed to open encoder file: %s\n", path.c_str()); - exit(EXIT_FAILURE); - } - - Encoder encoder; - std::string line; - while (std::getline(file, line)) { - auto [token, rank] = _parse(line); - - if (!encoder.emplace(std::move(token), rank).second) { - fprintf(stderr, "duplicate item: %s\n", line.c_str()); - } - } - return encoder; -} - -static Decoder _build_decoder(const Encoder& encoder) { - Decoder decoder; - for (const auto& [k, v] : encoder) { - decoder.emplace(v, k); - } - - if (encoder.size() != decoder.size()) { - fprintf(stderr, "duplicate items in encoder"); - exit(EXIT_FAILURE); - } - - return decoder; -} - -static std::vector _byte_pair_merge( - const std::string& piece, - const std::unordered_map& ranks, - std::function func) { - // This is a vector of (start, rank). - // The rank is of the byte pair starting at position start. - // The rank of the last item in the vector is not a valid value. - std::vector> parts; - parts.reserve(piece.size() + 1); - for (auto idx = 0U; idx < piece.size() + 1; ++idx) { - parts.emplace_back(idx, _max_size()); - } - - auto get_rank = [&piece, &ranks]( - const std::vector>& parts, - uint64_t start_idx, - uint64_t skip) -> std::optional { - if (start_idx + skip + 2 < parts.size()) { - auto s = parts[start_idx].first; - auto e = parts[start_idx + skip + 2].first; - auto key = piece.substr(s, e - s); - auto iter = ranks.find(key); - if (iter != ranks.end()) { - return iter->second; - } - } - return std::nullopt; - }; - - // We look up the ranks once in the beginning and iteratively update - // them during each merge, which reduces the number of rank lookups. - for (auto i = 0U; i < parts.size() - 2; ++i) { - auto rank = get_rank(parts, i, 0); - if (rank) { - // usize::MAX is a sentinel value and cannot be a valid rank - if (*rank == _max_size()) { - fprintf(stderr, "at %" PRIu32 " rank is too large\n", i); - } - parts[i].second = *rank; - } - } - - // If you have n parts and m merges, this does O(mn) work. - // We could do something with a heap and do O(m log n) work. - // It is important to consider that n is often small (<100), and as such - // the cache-locality benefits outweigh the algorithmic complexity downsides - // of the `parts` vector data structure above. - - // Note that we hash bytes, not token pairs. As long as we train BPE the way - // we currently do, this is equivalent. An easy way to break this would be - // to decouple merge priority from token index or to prevent specific token - // merges. - while (true) { - if (parts.size() == 1) { - break; - } - - // usize::MAX is a sentinel rank value allowing us to - // take the min more quickly - auto min_rank = std::make_pair(_max_size(), 0); - for (auto i = 0U; i < parts.size() - 1; ++i) { - auto rank = parts[i].second; - if (rank < min_rank.first) { - min_rank.first = rank; - min_rank.second = i; - } - } - - if (min_rank.first != _max_size()) { - auto i = min_rank.second; - - // NOTE: We are about to remove parts[i + 1]. We do not do it - // yet because there are cache-locality benefits to updating - // parts[i] and parts[i-1] before removing, which could thrash - // the cache. Thus, we update the rank calculation by skipping over - // parts[i + 1], by invoking `get_rank!` with `skip = 1`. - auto rank = get_rank(parts, i, 1); - if (rank) { - parts[i].second = *rank; - } else { - parts[i].second = _max_size(); - } - if (i > 0) { - rank = get_rank(parts, i - 1, 1); - if (rank) { - parts[i - 1].second = *rank; - } else { - parts[i - 1].second = _max_size(); - } - } - - parts.erase(parts.begin() + (i + 1)); - } else { - break; - } - } - std::vector out; - out.reserve(parts.size() - 1); - for (auto i = 0U; i < parts.size() - 1; ++i) { - auto s = parts[i].first; - auto e = parts[i + 1].first; - out.push_back(func(s, e)); - } - return out; -} - -static std::vector _byte_pair_encode( - const std::string& piece, - const Encoder& encoder) { - if (piece.size() == 1) { - auto iter = encoder.find(piece); - if (iter != encoder.end()) { - return std::vector({iter->second}); - } else { - // TODO: is it possible? - return {}; - } - } - - return _byte_pair_merge( - piece, encoder, [&piece, &encoder](uint64_t start, uint64_t stop) { - std::string key = piece.substr(start, stop - start); - auto iter = encoder.find(key); - if (iter != encoder.end()) { - return iter->second; - } else { - // TODO: what if key does not exist? Should we return `unknown`? - // assert(false); // ?? - return uint64_t(0); - } - }); -} -// ------------------------------Util end------------------------------------ -// -------------------------private method start------------------------------- - -template -std::pair, re2::StringPiece> -Tiktoken::_split_with_allowed_special_token( - re2::StringPiece& input, - const T& allowed_special) { - if (!_special_token_regex) { - return std::make_pair(std::nullopt, input); - } - - auto start = input.begin(); - std::string special; - while (true) { - if (!re2::RE2::FindAndConsume(&input, *_special_token_regex, &special)) { - // No special token. - break; - } - - if (allowed_special.count(special) == 1) { - // Found an allowed special token, split the text with it. - return std::make_pair( - special, - re2::StringPiece(start, input.begin() - start - special.size())); - } // else try to find the next special token - } - - return std::make_pair(std::nullopt, input); -} - -void Tiktoken::_encode( - re2::StringPiece& input, - std::vector& ret, - uint64_t& last_piece_token_len) { - std::string piece; - assert(_regex); - while (re2::RE2::FindAndConsume(&input, *_regex, &piece)) { - auto iter = _encoder.find(piece); - if (iter != _encoder.end()) { - last_piece_token_len = 1; - ret.push_back(iter->second); - continue; - } - auto tokens = _byte_pair_encode(piece, _encoder); - last_piece_token_len = tokens.size(); - ret.insert(ret.end(), tokens.begin(), tokens.end()); - } -} - -template -std::pair, uint64_t> Tiktoken::_encode_with_special_token( - const std::string& text, - const T& allowed_special) { - std::vector tokens; - uint64_t last_piece_token_len = 0; - re2::StringPiece input(text); - while (true) { - auto [special, sub_input] = - _split_with_allowed_special_token(input, allowed_special); - - _encode(sub_input, tokens, last_piece_token_len); - - if (special) { - uint64_t token = 0; - try { - token = _special_token_encoder.at(*special); - } catch (const std::out_of_range&) { - // Should never go here, since special pattern includes all special - // chars. - fprintf(stderr, "unknown special token: %s\n", special->c_str()); - exit(EXIT_FAILURE); - } - - tokens.push_back(token); - last_piece_token_len = 0; - } else { - break; - } - } - - // last_piece_token_len is how many tokens came from the last regex split. - // This is used for determining unstable tokens, since you can't merge - // across (stable) regex splits - return std::make_pair(tokens, last_piece_token_len); -} - -// -------------------------private method end------------------------------- -// -------------------------public method start------------------------------- - -Tiktoken::Tiktoken() : Tokenizer() {} - -void Tiktoken::load(const std::string& path) { - _encoder = _load_encoder(path); - _special_token_encoder = _get_special_tokens(_encoder.size()); - - _decoder = _build_decoder(_encoder); - _special_token_decoder = _build_decoder(_special_token_encoder); - - _regex = _create_regex(_pattern); - _special_token_regex = _build_special_token_regex(_special_token_encoder); - - // initialize vocab_size, bos_tok, eos_tok - vocab_size_ = _encoder.size() + _special_token_encoder.size(); - bos_tok_ = _encoder.size(); // hardcoded (see _get_special_tokens) - eos_tok_ = _encoder.size() + 1; // hardcoded (see _get_special_tokens) - initialized_ = true; -} - -std::vector -Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) { - if (!initialized_) { - exit(EXIT_FAILURE); - } - auto res = _encode_with_special_token(text, _special_token_encoder).first; - for (auto i = 0; i < bos; ++i) { - res.insert(res.begin(), bos_tok_); - } - for (auto i = 0; i < eos; ++i) { - res.push_back(eos_tok_); - } - return res; -} - -std::string Tiktoken::decode(uint64_t prev, uint64_t cur) { - (void)prev; - if (!initialized_) { - exit(EXIT_FAILURE); - } - std::string ret; - - std::string token_bytes; - auto iter = _decoder.find(cur); - if (iter != _decoder.end()) { - token_bytes = iter->second; - } else { - iter = _special_token_decoder.find(cur); - if (iter != _special_token_decoder.end()) { - token_bytes = iter->second; - } else { - fprintf(stderr, "unknown token: %" PRIu64 "\n", cur); - exit(EXIT_FAILURE); - } - } - ret += token_bytes; - - return ret; -} -// -------------------------public method end------------------------------- diff --git a/tokenizer/tiktoken.py b/tokenizer/tiktoken.py deleted file mode 100644 index 30eb98624..000000000 --- a/tokenizer/tiktoken.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -from logging import getLogger -from pathlib import Path -from typing import ( - AbstractSet, - cast, - Collection, - Dict, - Iterator, - List, - Literal, - Sequence, - TypedDict, - Union, -) - -import tiktoken -from tiktoken.load import load_tiktoken_bpe - -from .base import TokenizerBase - - -logger = getLogger(__name__) - - -Role = Literal["system", "user", "assistant"] - - -class Message(TypedDict): - role: Role - content: str - - -Dialog = Sequence[Message] - - -class Tokenizer(TokenizerBase): - """ - tokenizing and encoding/decoding text using the Tiktoken tokenizer. - """ - - special_tokens: Dict[str, int] - - num_reserved_special_tokens = 256 - - pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 - - def __init__(self, model_path: str): - """ - Initializes the Tokenizer with a Tiktoken model. - - Args: - model_path (str): The path to the Tiktoken model file. - """ - # reload tokenizer - assert os.path.isfile(model_path), model_path - - mergeable_ranks = load_tiktoken_bpe(model_path) - num_base_tokens = len(mergeable_ranks) - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|reserved_special_token_2|>", - "<|reserved_special_token_3|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|reserved_special_token_4|>", - "<|eot_id|>", # end of turn - ] + [ - f"<|reserved_special_token_{i}|>" - for i in range(5, self.num_reserved_special_tokens - 5) - ] - self.special_tokens = { - token: num_base_tokens + i for i, token in enumerate(special_tokens) - } - self.model = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=self.pat_str, - mergeable_ranks=mergeable_ranks, - special_tokens=self.special_tokens, - ) - logger.debug(f"Reloaded Tiktoken model from {model_path}") - - # BOS / EOS token IDs - self.n_words: int = self.model.n_vocab - self._bos_id: int = self.special_tokens["<|begin_of_text|>"] - self._eos_id: int = self.special_tokens["<|end_of_text|>"] - self.pad_id: int = -1 - self.stop_tokens = { - self.special_tokens["<|end_of_text|>"], - self.special_tokens["<|eot_id|>"], - } - logger.debug( - f"#words: {self.n_words} - BOS ID: {self._bos_id} - EOS ID: {self._eos_id}" - ) - - def encode( - self, - s: str, - *, - bos: bool = False, - eos: bool = False, - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa B006 - disallowed_special: Union[Literal["all"], Collection[str]] = (), - ) -> List[int]: - """ - Encodes a string into a list of token IDs. - - Args: - s (str): The input string to be encoded. - bos (bool): Whether to prepend the beginning-of-sequence token. - eos (bool): Whether to append the end-of-sequence token. - allowed_special ("all"|set[str]): allowed special tokens in string - disallowed_special ("all"|set[str]): special tokens that raise an error when in string - - Returns: - list[int]: A list of token IDs. - - By default, setting disallowed_special=() encodes a string by ignoring - special tokens. Specifically: - - Setting `disallowed_special` to () will cause all text corresponding - to special tokens to be encoded as natural text (instead of raising - an error). - - Setting `allowed_special` to "all" will treat all text corresponding - to special tokens to be encoded as special tokens. - """ - assert type(s) is str - - # The tiktoken tokenizer can handle <=400k chars without - # pyo3_runtime.PanicException (may go beyond 400k) - TIKTOKEN_MAX_ENCODE_CHARS = 400_000 - - # https://github.com/openai/tiktoken/issues/195 - # Here we iterate over subsequences and split if we exceed the limit - # of max consecutive non-whitespace or whitespace characters. - MAX_NO_WHITESPACES_CHARS = 25_000 - - substrs = ( - substr - for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) - for substr in self._split_whitespaces_or_nonwhitespaces( - s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS - ) - ) - t: List[int] = [] - for substr in substrs: - t.extend( - self.model.encode( - substr, - allowed_special=allowed_special, - disallowed_special=disallowed_special, - ) - ) - if bos: - t.insert(0, self._bos_id) - if eos: - t.append(self._eos_id) - return t - - def bos_id(self) -> int: - return self._bos_id - - def eos_id(self) -> int: - return self._eos_id - - def decode(self, t: Sequence[int]) -> str: - """ - Decodes a list of token IDs into a string. - - Args: - t (List[int]): The list of token IDs to be decoded. - - Returns: - str: The decoded string. - """ - # typecast is safe here, Tiktoken doesn't do anything list-related with the sequence. - return self.model.decode(cast(List[int], t)) - - @staticmethod - def _split_whitespaces_or_nonwhitespaces( - s: str, max_consecutive_slice_len: int - ) -> Iterator[str]: - """ - Split the string `s` so that each substring contains no more than `max_consecutive_slice_len` - consecutive whitespaces or consecutive non-whitespaces - """ - current_slice_len = 0 - current_slice_is_space = s[0].isspace() if len(s) > 0 else False - slice_start = 0 - - for i in range(len(s)): - is_now_space = s[i].isspace() - - if current_slice_is_space ^ is_now_space: - current_slice_len = 1 - current_slice_is_space = is_now_space - else: - current_slice_len += 1 - if current_slice_len > max_consecutive_slice_len: - yield s[slice_start:i] - slice_start = i - current_slice_len = 1 - yield s[slice_start:] - - -class ChatFormat: - def __init__(self, tokenizer: Tokenizer): - self.tokenizer = tokenizer - - def encode_header(self, message: Message) -> List[int]: - tokens = [] - tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) - tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False)) - tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"]) - tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) - return tokens - - def encode_message(self, message: Message) -> List[int]: - tokens = self.encode_header(message) - tokens.extend( - self.tokenizer.encode(message["content"].strip(), bos=False, eos=False) - ) - tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) - return tokens - - def encode_dialog_prompt(self, dialog: Dialog) -> List[int]: - tokens = [] - tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"]) - for message in dialog: - tokens.extend(self.encode_message(message)) - # Add the start of an assistant message for the model to complete - tokens.extend(self.encode_header({"role": "assistant", "content": ""})) - return tokens diff --git a/tokenizer/tokenizer.h b/tokenizer/tokenizer.h deleted file mode 100644 index 9e1977b71..000000000 --- a/tokenizer/tokenizer.h +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -// A simple Tokenizer interface. -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "sentencepiece_processor.h" - -class Tokenizer { - public: - explicit Tokenizer() {} - virtual ~Tokenizer() {} - - virtual void load(const std::string& tokenizer_path) = 0; - - virtual std::vector - encode(const std::string& input, int8_t bos, int8_t eos) = 0; - - virtual std::string decode(uint64_t prev_token, uint64_t token) = 0; - - // getters - int32_t vocab_size() const { - return vocab_size_; - } - - uint64_t bos_tok() const { - return bos_tok_; - } - - uint64_t eos_tok() const { - return eos_tok_; - } - - protected: - bool initialized_ = false; - int32_t vocab_size_; - uint64_t bos_tok_, eos_tok_; -}; - -// ----------------------- SPTokenizer ----------------------- -// Used by sentencepiece. Adapted from llama2.c. -struct TokenIndex { - const char* str; - int32_t id; -}; - -class SPTokenizer : public Tokenizer { - public: - explicit SPTokenizer(); - ~SPTokenizer() override; - - void load(const std::string& tokenizer_path) override; - - std::vector encode(const std::string& input, int8_t bos, int8_t eos) - override; - - std::string decode(uint64_t prev_token, uint64_t token) override; - - private: - std::unique_ptr _processor; -}; - -// ----------------------- Tiktoken ----------------------- -// Used by OpenAI, adapted from https://github.com/sewenew/tokenizer - -using Encoder = std::unordered_map; -using Decoder = std::unordered_map; -using Re2UPtr = std::unique_ptr; - -class Tiktoken : public Tokenizer { - public: - explicit Tiktoken(); - ~Tiktoken(){}; - - void load(const std::string& tokenizer_path); - - std::vector - encode(const std::string& input, int8_t bos, int8_t eos); - - std::string decode(uint64_t prev_token, uint64_t token); - - private: - static inline const Encoder _get_special_tokens(ssize_t num_base_tokens) { - Encoder special_tokens; - special_tokens.emplace("<|begin_of_text|>", num_base_tokens++); - special_tokens.emplace("<|end_of_text|>", num_base_tokens++); - special_tokens.emplace("<|reserved_special_token_0|>", num_base_tokens++); - special_tokens.emplace("<|reserved_special_token_1|>", num_base_tokens++); - special_tokens.emplace("<|reserved_special_token_2|>", num_base_tokens++); - special_tokens.emplace("<|reserved_special_token_3|>", num_base_tokens++); - special_tokens.emplace("<|start_header_id|>", num_base_tokens++); - special_tokens.emplace("<|end_header_id|>", num_base_tokens++); - special_tokens.emplace("<|reserved_special_token_4|>", num_base_tokens++); - special_tokens.emplace("<|eot_id|>", num_base_tokens++); - for (auto i = 5; i < 251; ++i) { - special_tokens.emplace( - "<|reserved_special_token_" + std::to_string(i) + "|>", - num_base_tokens++); - } - return special_tokens; - } - - template - std::pair, re2::StringPiece> - _split_with_allowed_special_token( - re2::StringPiece& input, - const T& allowed_special); - - void _encode( - re2::StringPiece& input, - std::vector& ret, - uint64_t& last_piece_token_len); - - template - std::pair, uint64_t> _encode_with_special_token( - const std::string& text, - const T& allowed_special); - - // Removed negative lookahead \s+(?!\S) since it's not supported by RE2. - const std::string _pattern = - R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"; - Encoder _encoder; - Encoder _special_token_encoder; - Decoder _decoder; - Decoder _special_token_decoder; - - Re2UPtr _regex; - Re2UPtr _special_token_regex; -}; diff --git a/torchchat/utils/scripts/build_native.sh b/torchchat/utils/scripts/build_native.sh index 3c2c1c846..909fd2b97 100755 --- a/torchchat/utils/scripts/build_native.sh +++ b/torchchat/utils/scripts/build_native.sh @@ -64,7 +64,7 @@ fi pushd ${TORCHCHAT_ROOT} -git submodule update --init +git submodule update --init --recursive git submodule sync if [[ "$TARGET" == "et" ]]; then if [ ! -d "${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install" ]; then