Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,19 @@ Error Runner::load() {
return Error::Ok;
}
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
// load tokenizer
// load tokenizer. Assuming tiktoken is the default tokenizer
tokenizer_ = nullptr;
tokenizer_ = std::make_unique<BPETokenizer>();
tokenizer_ = get_tiktoken_for_llama();
Error err = tokenizer_->load(tokenizer_path_);
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
// fallback to BPE tokenizer.
if (err == Error::InvalidArgument) {
ET_LOG(
Info,
"Failed to load %s as a BPETokenizer artifact, trying Tiktoken",
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
tokenizer_path_.c_str());
tokenizer_.reset();
tokenizer_ = get_tiktoken_for_llama();
tokenizer_ = std::make_unique<BPETokenizer>();
tokenizer_->load(tokenizer_path_);
}

Expand Down
76 changes: 49 additions & 27 deletions extension/llm/tokenizer/base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#pragma once

#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/platform/assert.h>
#include <cassert>
#include <string>
Expand All @@ -32,10 +34,13 @@
namespace executorch {
namespace extension {
namespace llm {
using Error = executorch::runtime::Error;
template <typename T>
using Result = executorch::runtime::Result<T>;

namespace base64 {

std::string decode(const std::string_view& input);
Result<std::string> decode(const std::string_view& input);

namespace detail {

Expand All @@ -59,118 +64,135 @@ constexpr uint32_t DECODE_TABLE[] = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255};

inline void validate(uint32_t v) {
ET_CHECK_MSG(v != 255, "invalid char");
inline Error validate(uint32_t v) {
ET_CHECK_OR_RETURN_ERROR(v != 255, InvalidArgument, "invalid char");
return Error::Ok;
}

inline void decode(const std::string_view& input, std::string& output) {
ET_CHECK_MSG(
input.size() == 4, "input length must be 4, got %zu", input.size());
inline Error decode(const std::string_view& input, std::string& output) {
ET_CHECK_OR_RETURN_ERROR(
input.size() == 4,
InvalidArgument,
"input length must be 4, got %zu",
input.size());

uint32_t val = 0;

uint8_t c = input[0];
auto v = DECODE_TABLE[c];
validate(v);
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
val = v;

c = input[1];
v = DECODE_TABLE[c];
validate(v);
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
val = (val << 6) | v;

c = input[2];
v = DECODE_TABLE[c];
validate(v);
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
val = (val << 6) | v;

c = input[3];
v = DECODE_TABLE[c];
validate(v);
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
val = (val << 6) | v;

output.push_back(static_cast<char>((val >> 16) & 0xFF));
output.push_back(static_cast<char>((val >> 8) & 0xFF));
output.push_back(static_cast<char>(val & 0xFF));
return Error::Ok;
}

inline void decode_1_padding(
inline Error decode_1_padding(
const std::string_view& input,
std::string& output) {
ET_CHECK_MSG(
input.size() == 3, "input length must be 3, got %zu", input.size());
ET_CHECK_OR_RETURN_ERROR(
input.size() == 3,
InvalidArgument,
"input length must be 3, got %zu",
input.size());

uint32_t val = 0;

uint8_t c = input[0];
auto v = DECODE_TABLE[c];
validate(v);
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
val = v;

c = input[1];
v = DECODE_TABLE[c];
validate(v);
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
val = (val << 6) | v;

c = input[2];
v = DECODE_TABLE[c];
validate(v);
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
val = (val << 6) | v;

output.push_back(static_cast<char>((val >> 10) & 0xFF));
output.push_back(static_cast<char>((val >> 2) & 0xFF));
return Error::Ok;
}

inline void decode_2_padding(
inline Error decode_2_padding(
const std::string_view& input,
std::string& output) {
assert(input.size() == 2);
ET_CHECK_OR_RETURN_ERROR(
input.size() == 2,
InvalidArgument,
"input length must be 2, got %zu",
input.size());

uint32_t val = 0;

uint8_t c = input[0];
auto v = DECODE_TABLE[c];
validate(v);
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
val = v;

c = input[1];
v = DECODE_TABLE[c];
validate(v);
ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
val = (val << 6) | v;

output.push_back(static_cast<char>((val >> 4) & 0xFF));
return Error::Ok;
}

} // namespace detail

inline std::string decode(const std::string_view& input) {
ET_CHECK_MSG(!input.empty(), "empty input");
inline Result<std::string> decode(const std::string_view& input) {
ET_CHECK_OR_RETURN_ERROR(!input.empty(), InvalidArgument, "empty input");

// Faster than `input.size() % 4`.
ET_CHECK_MSG(
ET_CHECK_OR_RETURN_ERROR(
(input.size() & 3) == 0 && input.size() >= 4,
InvalidArgument,
"input length must be larger than 4 and is multiple of 4, got %zu",
input.size());

std::string output;
output.reserve(input.size() / 4 * 3);
auto idx = 0U;
for (; idx < input.size() - 4; idx += 4) {
detail::decode(input.substr(idx, 4), output);
ET_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output));
}

// Last 4 bytes. Might contain paddings.
if (input[idx + 3] == '=') {
if (input[idx + 2] == '=') {
// Tow paddings.
detail::decode_2_padding(input.substr(idx, 2), output);
ET_CHECK_OK_OR_RETURN_ERROR(
detail::decode_2_padding(input.substr(idx, 2), output));
} else {
// One padding.
detail::decode_1_padding(input.substr(idx, 3), output);
ET_CHECK_OK_OR_RETURN_ERROR(
detail::decode_1_padding(input.substr(idx, 3), output));
}
} else {
// No padding.
detail::decode(input.substr(idx, 4), output);
ET_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output));
}

return output;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tet 0
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ICAgICAgIA== 18446744073709551616
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ICAgICAgIA==10
46 changes: 46 additions & 0 deletions extension/llm/tokenizer/test/test_tiktoken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

#include <executorch/extension/llm/tokenizer/tiktoken.h>
#include <executorch/runtime/platform/runtime.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <sstream>
#include <vector>

using namespace ::testing;
Expand Down Expand Up @@ -140,3 +142,47 @@ TEST_F(TiktokenExtensionTest, ConstructionWithInvalidEOSIndex) {
"");
#endif
}

TEST_F(TiktokenExtensionTest, LoadWithInvalidPath) {
auto invalidModelPath =
std::getenv("RESOURCES_PATH") + std::string("/nonexistent.model");

Error res = tokenizer_->load(invalidModelPath.c_str());
EXPECT_EQ(res, Error::InvalidArgument);
}

TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidRank) {
auto invalidModelPath = std::getenv("RESOURCES_PATH") +
std::string("/test_tiktoken_invalid_rank.model");

Error res = tokenizer_->load(invalidModelPath.c_str());

EXPECT_EQ(res, Error::InvalidArgument);
}

TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidBase64) {
auto invalidModelPath = std::getenv("RESOURCES_PATH") +
std::string("/test_tiktoken_invalid_base64.model");

Error res = tokenizer_->load(invalidModelPath.c_str());

EXPECT_EQ(res, Error::InvalidArgument);
}

TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithNoSpace) {
auto invalidModelPath = std::getenv("RESOURCES_PATH") +
std::string("/test_tiktoken_no_space.model");

Error res = tokenizer_->load(invalidModelPath.c_str());

EXPECT_EQ(res, Error::InvalidArgument);
}

TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithBPEFile) {
auto invalidModelPath =
std::getenv("RESOURCES_PATH") + std::string("/test_bpe_tokenizer.bin");

Error res = tokenizer_->load(invalidModelPath.c_str());

EXPECT_EQ(res, Error::InvalidArgument);
}
46 changes: 30 additions & 16 deletions extension/llm/tokenizer/tiktoken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <executorch/extension/llm/tokenizer/base64.h>
#include <executorch/extension/llm/tokenizer/tiktoken.h>
#include <executorch/runtime/core/result.h>
#include <fstream>
#include <limits>

Expand Down Expand Up @@ -65,47 +66,60 @@ static Re2UPtr _build_special_token_regex(const Encoder& special_encoder) {
return _create_regex(special_pattern);
}

static std::pair<std::string, uint64_t> _parse(const std::string& line) {
static Result<std::pair<std::string, uint64_t>> _parse(
const std::string& line) {
// Tiktoken format
// https://github.com/openai/tiktoken/blob/main/tiktoken/load.py#L140 <base64
// encoded token str> <rank>
auto pos = line.find(" ");
ET_CHECK_MSG(
pos != std::string::npos, "invalid encoder line: %s", line.c_str());
ET_CHECK_OR_RETURN_ERROR(
pos != std::string::npos,
InvalidArgument,
"invalid tiktoken line: %s",
line.c_str());

auto token = base64::decode({line.data(), pos});
auto token = ET_UNWRAP(base64::decode({line.data(), pos}));
uint64_t rank = 0;
try {
rank = std::stoul(line.substr(pos + 1));
} catch (const std::exception&) {
ET_CHECK_MSG(false, "invalid encoder rank: %s", line.c_str());
ET_CHECK_OR_RETURN_ERROR(
false, InvalidArgument, "invalid encoder rank: %s", line.c_str());
}

return {std::move(token), rank};
return std::pair{std::move(token), rank};
}

static Encoder _load_encoder(const std::string& path) {
static Result<Encoder> _load_encoder(const std::string& path) {
std::ifstream file(path);
ET_CHECK_MSG(file, "failed to open encoder file: %s", path.c_str());
ET_CHECK_OR_RETURN_ERROR(
file, InvalidArgument, "failed to open encoder file: %s", path.c_str());

Encoder encoder;
std::string line;
while (std::getline(file, line)) {
auto [token, rank] = _parse(line);
auto [token, rank] = ET_UNWRAP(_parse(line));

ET_CHECK_MSG(
ET_CHECK_OR_RETURN_ERROR(
encoder.emplace(std::move(token), rank).second,
InvalidArgument,
"duplicate item: %s",
line.c_str());
}

return encoder;
}

static Decoder _build_decoder(const Encoder& encoder) {
static Result<Decoder> _build_decoder(const Encoder& encoder) {
Decoder decoder;
for (const auto& [k, v] : encoder) {
decoder.emplace(v, k);
}

ET_CHECK_MSG(encoder.size() == decoder.size(), "duplicate items in encoder");
ET_CHECK_OR_RETURN_ERROR(
encoder.size() == decoder.size(),
InvalidArgument,
"duplicate items in encoder");

return decoder;
}
Expand Down Expand Up @@ -356,11 +370,11 @@ Tiktoken::Tiktoken(
}

Error Tiktoken::load(const std::string& path) {
_encoder = _load_encoder(path);
_encoder = ET_UNWRAP(_load_encoder(path));
_special_token_encoder = _build_special_token_encoder(_encoder.size());

_decoder = _build_decoder(_encoder);
_special_token_decoder = _build_decoder(_special_token_encoder);
_decoder = ET_UNWRAP(_build_decoder(_encoder));
_special_token_decoder = ET_UNWRAP(_build_decoder(_special_token_encoder));

_regex = _create_regex(_pattern);
// Warmup re2 as it is slow on the first run, void the return value as it's
Expand Down Expand Up @@ -393,7 +407,7 @@ Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) const {
for (auto i = 0; i < eos; ++i) {
res.push_back(eos_tok_);
}
return Result(res);
return Result<std::vector<uint64_t>>(std::move(res));
}

Result<std::string> Tiktoken::decode(uint64_t prev, uint64_t cur) const {
Expand Down
Loading