From 426773d552126fb0a81f1510e4804f28a8a78235 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 3 Sep 2024 16:35:56 -0700 Subject: [PATCH 1/4] [tokenizer] Consolidate how runner decide which tokenizer to use Summary: Make sure `tiktoken.cpp` gives correct error when file is invalid. Make sure `runner.cpp` is able to fallback to BPETokenizer when an invalid tokenizer.model file is given. Test Plan: Rely on unittest Reviewers: Subscribers: Tasks: Tags: --- examples/models/llama2/runner/runner.cpp | 10 ++- extension/llm/tokenizer/base64.h | 76 ++++++++++++------- .../test_tiktoken_invalid_base64.model | 1 + .../test_tiktoken_invalid_rank.model | 1 + .../resources/test_tiktoken_no_space.model | 1 + .../llm/tokenizer/test/test_tiktoken.cpp | 37 +++++++++ extension/llm/tokenizer/tiktoken.cpp | 44 +++++++---- 7 files changed, 124 insertions(+), 46 deletions(-) create mode 100644 extension/llm/tokenizer/test/resources/test_tiktoken_invalid_base64.model create mode 100644 extension/llm/tokenizer/test/resources/test_tiktoken_invalid_rank.model create mode 100644 extension/llm/tokenizer/test/resources/test_tiktoken_no_space.model diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index 8b9e6865516..ebfc59ae76a 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -69,17 +69,19 @@ Error Runner::load() { return Error::Ok; } ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); - // load tokenizer + // load tokenizer. Assuming tiktoken is the default tokenizer tokenizer_ = nullptr; - tokenizer_ = std::make_unique(); + tokenizer_ = get_tiktoken_for_llama(); Error err = tokenizer_->load(tokenizer_path_); + // Rely on tiktoken to throw error if the artifact is incompatible. Then we + // fallback to BPE tokenizer. if (err == Error::InvalidArgument) { ET_LOG( Info, - "Failed to load %s as a BPETokenizer artifact, trying Tiktoken", + "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer", tokenizer_path_.c_str()); tokenizer_.reset(); - tokenizer_ = get_tiktoken_for_llama(); + tokenizer_ = std::make_unique(); tokenizer_->load(tokenizer_path_); } diff --git a/extension/llm/tokenizer/base64.h b/extension/llm/tokenizer/base64.h index 7337ecead4e..83ef9e0696b 100644 --- a/extension/llm/tokenizer/base64.h +++ b/extension/llm/tokenizer/base64.h @@ -24,6 +24,8 @@ #pragma once +#include +#include #include #include #include @@ -32,10 +34,13 @@ namespace executorch { namespace extension { namespace llm { +using Error = executorch::runtime::Error; +template +using Result = executorch::runtime::Result; namespace base64 { -std::string decode(const std::string_view& input); +Result decode(const std::string_view& input); namespace detail { @@ -59,96 +64,111 @@ constexpr uint32_t DECODE_TABLE[] = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}; -inline void validate(uint32_t v) { - ET_CHECK_MSG(v != 255, "invalid char"); +inline Error validate(uint32_t v) { + ET_CHECK_OR_RETURN_ERROR(v != 255, InvalidArgument, "invalid char"); + return Error::Ok; } -inline void decode(const std::string_view& input, std::string& output) { - ET_CHECK_MSG( - input.size() == 4, "input length must be 4, got %zu", input.size()); +inline Error decode(const std::string_view& input, std::string& output) { + ET_CHECK_OR_RETURN_ERROR( + input.size() == 4, + InvalidArgument, + "input length must be 4, got %zu", + input.size()); uint32_t val = 0; uint8_t c = input[0]; auto v = DECODE_TABLE[c]; - validate(v); + ET_CHECK_OK_OR_RETURN_ERROR(validate(v)); val = v; c = input[1]; v = DECODE_TABLE[c]; - validate(v); + ET_CHECK_OK_OR_RETURN_ERROR(validate(v)); val = (val << 6) | v; c = input[2]; v = DECODE_TABLE[c]; - validate(v); + ET_CHECK_OK_OR_RETURN_ERROR(validate(v)); val = (val << 6) | v; c = input[3]; v = DECODE_TABLE[c]; - validate(v); + ET_CHECK_OK_OR_RETURN_ERROR(validate(v)); val = (val << 6) | v; output.push_back(static_cast((val >> 16) & 0xFF)); output.push_back(static_cast((val >> 8) & 0xFF)); output.push_back(static_cast(val & 0xFF)); + return Error::Ok; } -inline void decode_1_padding( +inline Error decode_1_padding( const std::string_view& input, std::string& output) { - ET_CHECK_MSG( - input.size() == 3, "input length must be 3, got %zu", input.size()); + ET_CHECK_OR_RETURN_ERROR( + input.size() == 3, + InvalidArgument, + "input length must be 3, got %zu", + input.size()); uint32_t val = 0; uint8_t c = input[0]; auto v = DECODE_TABLE[c]; - validate(v); + ET_CHECK_OK_OR_RETURN_ERROR(validate(v)); val = v; c = input[1]; v = DECODE_TABLE[c]; - validate(v); + ET_CHECK_OK_OR_RETURN_ERROR(validate(v)); val = (val << 6) | v; c = input[2]; v = DECODE_TABLE[c]; - validate(v); + ET_CHECK_OK_OR_RETURN_ERROR(validate(v)); val = (val << 6) | v; output.push_back(static_cast((val >> 10) & 0xFF)); output.push_back(static_cast((val >> 2) & 0xFF)); + return Error::Ok; } -inline void decode_2_padding( +inline Error decode_2_padding( const std::string_view& input, std::string& output) { - assert(input.size() == 2); + ET_CHECK_OR_RETURN_ERROR( + input.size() == 2, + InvalidArgument, + "input length must be 2, got %zu", + input.size()); uint32_t val = 0; uint8_t c = input[0]; auto v = DECODE_TABLE[c]; - validate(v); + ET_CHECK_OK_OR_RETURN_ERROR(validate(v)); val = v; c = input[1]; v = DECODE_TABLE[c]; - validate(v); + ET_CHECK_OK_OR_RETURN_ERROR(validate(v)); val = (val << 6) | v; output.push_back(static_cast((val >> 4) & 0xFF)); + return Error::Ok; } } // namespace detail -inline std::string decode(const std::string_view& input) { - ET_CHECK_MSG(!input.empty(), "empty input"); +inline Result decode(const std::string_view& input) { + ET_CHECK_OR_RETURN_ERROR(!input.empty(), InvalidArgument, "empty input"); // Faster than `input.size() % 4`. - ET_CHECK_MSG( + ET_CHECK_OR_RETURN_ERROR( (input.size() & 3) == 0 && input.size() >= 4, + InvalidArgument, "input length must be larger than 4 and is multiple of 4, got %zu", input.size()); @@ -156,21 +176,23 @@ inline std::string decode(const std::string_view& input) { output.reserve(input.size() / 4 * 3); auto idx = 0U; for (; idx < input.size() - 4; idx += 4) { - detail::decode(input.substr(idx, 4), output); + ET_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output)); } // Last 4 bytes. Might contain paddings. if (input[idx + 3] == '=') { if (input[idx + 2] == '=') { // Tow paddings. - detail::decode_2_padding(input.substr(idx, 2), output); + ET_CHECK_OK_OR_RETURN_ERROR( + detail::decode_2_padding(input.substr(idx, 2), output)); } else { // One padding. - detail::decode_1_padding(input.substr(idx, 3), output); + ET_CHECK_OK_OR_RETURN_ERROR( + detail::decode_1_padding(input.substr(idx, 3), output)); } } else { // No padding. - detail::decode(input.substr(idx, 4), output); + ET_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output)); } return output; diff --git a/extension/llm/tokenizer/test/resources/test_tiktoken_invalid_base64.model b/extension/llm/tokenizer/test/resources/test_tiktoken_invalid_base64.model new file mode 100644 index 00000000000..0d29969a91c --- /dev/null +++ b/extension/llm/tokenizer/test/resources/test_tiktoken_invalid_base64.model @@ -0,0 +1 @@ +testtest 0 diff --git a/extension/llm/tokenizer/test/resources/test_tiktoken_invalid_rank.model b/extension/llm/tokenizer/test/resources/test_tiktoken_invalid_rank.model new file mode 100644 index 00000000000..07d43b1e439 --- /dev/null +++ b/extension/llm/tokenizer/test/resources/test_tiktoken_invalid_rank.model @@ -0,0 +1 @@ +ICAgICAgIA== 18446744073709551616 diff --git a/extension/llm/tokenizer/test/resources/test_tiktoken_no_space.model b/extension/llm/tokenizer/test/resources/test_tiktoken_no_space.model new file mode 100644 index 00000000000..c025dddd3ba --- /dev/null +++ b/extension/llm/tokenizer/test/resources/test_tiktoken_no_space.model @@ -0,0 +1 @@ +ICAgICAgIA==10 diff --git a/extension/llm/tokenizer/test/test_tiktoken.cpp b/extension/llm/tokenizer/test/test_tiktoken.cpp index a81b20bcf88..9a8359adb0a 100644 --- a/extension/llm/tokenizer/test/test_tiktoken.cpp +++ b/extension/llm/tokenizer/test/test_tiktoken.cpp @@ -8,7 +8,9 @@ #include #include +#include #include +#include #include using namespace ::testing; @@ -140,3 +142,38 @@ TEST_F(TiktokenExtensionTest, ConstructionWithInvalidEOSIndex) { ""); #endif } + +TEST_F(TiktokenExtensionTest, LoadWithInvalidPath) { + auto invalidModelPath = + std::getenv("RESOURCES_PATH") + std::string("/nonexistent.model"); + + Error res = tokenizer_->load(invalidModelPath.c_str()); + EXPECT_EQ(res, Error::InvalidArgument); +} + +TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidRank) { + auto invalidModelPath = std::getenv("RESOURCES_PATH") + + std::string("/test_tiktoken_invalid_rank.model"); + + Error res = tokenizer_->load(invalidModelPath.c_str()); + + EXPECT_EQ(res, Error::InvalidArgument); +} + +TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidBase64) { + auto invalidModelPath = std::getenv("RESOURCES_PATH") + + std::string("/test_tiktoken_invalid_base64.model"); + + Error res = tokenizer_->load(invalidModelPath.c_str()); + + EXPECT_EQ(res, Error::InvalidArgument); +} + +TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithBPEFile) { + auto invalidModelPath = + std::getenv("RESOURCES_PATH") + std::string("/test_bpe_tokenizer.bin"); + + Error res = tokenizer_->load(invalidModelPath.c_str()); + + EXPECT_EQ(res, Error::InvalidArgument); +} diff --git a/extension/llm/tokenizer/tiktoken.cpp b/extension/llm/tokenizer/tiktoken.cpp index 7b15d25f0da..768aa18541b 100644 --- a/extension/llm/tokenizer/tiktoken.cpp +++ b/extension/llm/tokenizer/tiktoken.cpp @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -65,33 +66,43 @@ static Re2UPtr _build_special_token_regex(const Encoder& special_encoder) { return _create_regex(special_pattern); } -static std::pair _parse(const std::string& line) { +static Result> _parse( + const std::string& line) { + // Tiktoken format + // https://github.com/openai/tiktoken/blob/main/tiktoken/load.py#L140 auto pos = line.find(" "); - ET_CHECK_MSG( - pos != std::string::npos, "invalid encoder line: %s", line.c_str()); + ET_CHECK_OR_RETURN_ERROR( + pos != std::string::npos, + InvalidArgument, + "invalid tiktoken line: %s", + line.c_str()); - auto token = base64::decode({line.data(), pos}); + auto token = ET_UNWRAP(base64::decode({line.data(), pos})); uint64_t rank = 0; try { rank = std::stoul(line.substr(pos + 1)); } catch (const std::exception&) { - ET_CHECK_MSG(false, "invalid encoder rank: %s", line.c_str()); + ET_CHECK_OR_RETURN_ERROR( + false, InvalidArgument, "invalid encoder rank: %s", line.c_str()); } - return {std::move(token), rank}; + return std::pair{std::move(token), rank}; } -static Encoder _load_encoder(const std::string& path) { +static Result _load_encoder(const std::string& path) { std::ifstream file(path); - ET_CHECK_MSG(file, "failed to open encoder file: %s", path.c_str()); + ET_CHECK_OR_RETURN_ERROR( + file, InvalidArgument, "failed to open encoder file: %s", path.c_str()); Encoder encoder; std::string line; while (std::getline(file, line)) { - auto [token, rank] = _parse(line); + auto [token, rank] = ET_UNWRAP(_parse(line)); - ET_CHECK_MSG( + ET_CHECK_OR_RETURN_ERROR( encoder.emplace(std::move(token), rank).second, + InvalidArgument, "duplicate item: %s", line.c_str()); } @@ -99,13 +110,16 @@ static Encoder _load_encoder(const std::string& path) { return encoder; } -static Decoder _build_decoder(const Encoder& encoder) { +static Result _build_decoder(const Encoder& encoder) { Decoder decoder; for (const auto& [k, v] : encoder) { decoder.emplace(v, k); } - ET_CHECK_MSG(encoder.size() == decoder.size(), "duplicate items in encoder"); + ET_CHECK_OR_RETURN_ERROR( + encoder.size() == decoder.size(), + InvalidArgument, + "duplicate items in encoder"); return decoder; } @@ -356,11 +370,11 @@ Tiktoken::Tiktoken( } Error Tiktoken::load(const std::string& path) { - _encoder = _load_encoder(path); + _encoder = ET_UNWRAP(_load_encoder(path)); _special_token_encoder = _build_special_token_encoder(_encoder.size()); - _decoder = _build_decoder(_encoder); - _special_token_decoder = _build_decoder(_special_token_encoder); + _decoder = ET_UNWRAP(_build_decoder(_encoder)); + _special_token_decoder = ET_UNWRAP(_build_decoder(_special_token_encoder)); _regex = _create_regex(_pattern); // Warmup re2 as it is slow on the first run, void the return value as it's From b854cb087bf7ea880fa52c7ebf4985e21002b371 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 3 Sep 2024 16:40:06 -0700 Subject: [PATCH 2/4] Add test for tokenizer without space Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- extension/llm/tokenizer/test/test_tiktoken.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/extension/llm/tokenizer/test/test_tiktoken.cpp b/extension/llm/tokenizer/test/test_tiktoken.cpp index 9a8359adb0a..ce2a781aa1c 100644 --- a/extension/llm/tokenizer/test/test_tiktoken.cpp +++ b/extension/llm/tokenizer/test/test_tiktoken.cpp @@ -169,6 +169,15 @@ TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidBase64) { EXPECT_EQ(res, Error::InvalidArgument); } +TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithNoSpace) { + auto invalidModelPath = std::getenv("RESOURCES_PATH") + + std::string("/test_tiktoken_no_space.model"); + + Error res = tokenizer_->load(invalidModelPath.c_str()); + + EXPECT_EQ(res, Error::InvalidArgument); +} + TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithBPEFile) { auto invalidModelPath = std::getenv("RESOURCES_PATH") + std::string("/test_bpe_tokenizer.bin"); From da65188ee5284474ff5cb4a86d7ab1d4071bf800 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 4 Sep 2024 08:09:27 -0700 Subject: [PATCH 3/4] Fix Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- extension/llm/tokenizer/tiktoken.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/llm/tokenizer/tiktoken.cpp b/extension/llm/tokenizer/tiktoken.cpp index 768aa18541b..f8ccf74fd6b 100644 --- a/extension/llm/tokenizer/tiktoken.cpp +++ b/extension/llm/tokenizer/tiktoken.cpp @@ -407,7 +407,7 @@ Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) const { for (auto i = 0; i < eos; ++i) { res.push_back(eos_tok_); } - return Result(res); + return Result>(std::move(res)); } Result Tiktoken::decode(uint64_t prev, uint64_t cur) const { From fefc9588b57408c265127ddf313f01cfc2dabf33 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 4 Sep 2024 08:40:04 -0700 Subject: [PATCH 4/4] Fix test --- .../tokenizer/test/resources/test_tiktoken_invalid_base64.model | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/llm/tokenizer/test/resources/test_tiktoken_invalid_base64.model b/extension/llm/tokenizer/test/resources/test_tiktoken_invalid_base64.model index 0d29969a91c..2d9c39f19d6 100644 --- a/extension/llm/tokenizer/test/resources/test_tiktoken_invalid_base64.model +++ b/extension/llm/tokenizer/test/resources/test_tiktoken_invalid_base64.model @@ -1 +1 @@ -testtest 0 +tet 0