Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions csrc/velox/functions/text/gpt2_bpe_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,24 @@ bool is_whitespace(const std::string& input) {
}

template <class Key_, class Value_>
c10::Dict<Key_, Value_> _map_to_c10_dict(std::unordered_map<Key_, Value_> m) {
c10::Dict<Key_, Value_> _map_to_c10_dict(
const std::unordered_map<Key_, Value_>& m) {
c10::Dict<Key_, Value_> d;
for (const auto& item : m)
d.insert(item.first, item.second);
return d;
}

template <class Key_, class Value_>
std::unordered_map<Key_, Value_> _c10_dict_to_map(c10::Dict<Key_, Value_> d) {
std::unordered_map<Key_, Value_> _c10_dict_to_map(
const c10::Dict<Key_, Value_>& d) {
std::unordered_map<Key_, Value_> m;
for (const auto& item : d)
m[item.key()] = item.value();
return m;
}

std::vector<std::string> gpt2_bpe_pre_tokenizer(std::string input) {
std::vector<std::string> gpt2_bpe_pre_tokenizer(const std::string& input) {
// Python implementation:
// https://github.com/pytorch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe_utils.py#L69
// Original regex contains a negative lookahead pattern, which is not
Expand Down Expand Up @@ -102,16 +104,16 @@ std::vector<std::string> gpt2_bpe_pre_tokenizer(std::string input) {
}

std::pair<std::string, std::string> split_tokens(
std::string s,
std::string delimiter) {
const std::string& s,
const std::string& delimiter) {
auto pos = s.find(delimiter);
TORCH_CHECK(pos != std::string::npos, "Expected `s`to contain `delimiter`");
return std::make_pair(s.substr(0, pos), s.substr(pos + delimiter.length()));
}

int list_str_index(
std::vector<std::string> list,
std::string element,
const std::vector<std::string>& list,
const std::string& element,
int start) {
// Equivalent to: list.index(element, start)
for (std::size_t i = start; i < list.size(); ++i) {
Expand All @@ -130,7 +132,7 @@ std::string concatenate_strings(const std::vector<std::string>& list) {
}

std::vector<std::string> get_pairs(
std::vector<std::string> token_list,
const std::vector<std::string>& token_list,
const std::string& separator) {
// For example: ["he", "l", "l", "o"]
// ==> ["he\u0001l", "l\u0001l", "l\u0001o"]
Expand Down Expand Up @@ -175,7 +177,7 @@ GPT2BPEEncoder::GPT2BPEEncoder(
_map_to_c10_dict<int64_t, std::string>(byte_encoder),
caching_enabled) {}

std::vector<std::string> GPT2BPEEncoder::ByteEncode_(std::string token) {
std::vector<std::string> GPT2BPEEncoder::ByteEncode_(const std::string& token) {
// Equivalent to: (self.byte_encoder[b] for b in token.encode('utf-8')
std::vector<std::string> encoded;
for (auto& ch : token) {
Expand All @@ -184,14 +186,15 @@ std::vector<std::string> GPT2BPEEncoder::ByteEncode_(std::string token) {
return encoded;
}

int64_t GPT2BPEEncoder::GetBPEMergeRank_(std::string pair) {
int64_t GPT2BPEEncoder::GetBPEMergeRank_(const std::string& pair) {
if (bpe_merge_ranks_.contains(pair)) {
return bpe_merge_ranks_.at(pair);
}
return inf_;
}

std::string GPT2BPEEncoder::FindBestPair_(std::vector<std::string> pairs) {
std::string GPT2BPEEncoder::FindBestPair_(
const std::vector<std::string>& pairs) {
// Equivalent to:
// min(pairs, key = lambda pair: self.bpe_merge_ranks.get(pair,
// float('inf')))
Expand Down Expand Up @@ -277,7 +280,8 @@ std::vector<std::string> GPT2BPEEncoder::BPE_(
return tok_list;
}

std::vector<std::string> GPT2BPEEncoder::PreTokenize_(std::string input) {
std::vector<std::string> GPT2BPEEncoder::PreTokenize_(
const std::string& input) {
return gpt2_bpe_pre_tokenizer(input);
}

Expand Down Expand Up @@ -327,8 +331,8 @@ GPT2BPEEncoderStatesTorchbind _serialize_gpt2_bpe_encoder_torchbind(
}

c10::intrusive_ptr<GPT2BPEEncoder> _deserialize_gpt2_bpe_encoder_pybind(
GPT2BPEEncoderStatesPybind states) {
auto state_size = std::tuple_size<decltype(states)>::value;
const GPT2BPEEncoderStatesPybind& states) {
auto state_size = std::tuple_size<GPT2BPEEncoderStatesPybind>::value;
TORCH_CHECK(
state_size == 5,
"Expected deserialized GPT2BPEEncoder to have 5 states but found " +
Expand All @@ -342,8 +346,8 @@ c10::intrusive_ptr<GPT2BPEEncoder> _deserialize_gpt2_bpe_encoder_pybind(
}

c10::intrusive_ptr<GPT2BPEEncoder> _deserialize_gpt2_bpe_encoder_torchbind(
GPT2BPEEncoderStatesTorchbind states) {
auto state_size = std::tuple_size<decltype(states)>::value;
const GPT2BPEEncoderStatesTorchbind& states) {
auto state_size = std::tuple_size<GPT2BPEEncoderStatesTorchbind>::value;
TORCH_CHECK(
state_size == 5,
"Expected deserialized GPT2BPEEncoder to have 5 states but found " +
Expand Down
24 changes: 12 additions & 12 deletions csrc/velox/functions/text/gpt2_bpe_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,42 +42,42 @@ typedef std::tuple<

// Applies regex based pre-tokenization step for GPT-2 BPE tokenizer
// and returns a list of tokens.
std::vector<std::string> gpt2_bpe_pre_tokenizer(std::string input);
std::vector<std::string> gpt2_bpe_pre_tokenizer(const std::string& input);

// Concatenate a vector of strings to a single string
std::string concatenate_strings(const std::vector<std::string>& list);

// Return set of token pairs in a word, separated by the `separator`.
std::vector<std::string> get_pairs(
std::vector<std::string> token_list,
const std::vector<std::string>& token_list,
const std::string& separator);

// Split a string into 2 parts separated by a `separator`.
std::pair<std::string, std::string> split_tokens(
std::string s,
std::string delimiter);
const std::string& s,
const std::string& delimiter);

// Find index of `element` in a list of strings.
int list_str_index(
std::vector<std::string> list,
std::string element,
const std::vector<std::string>& list,
const std::string& element,
int start);

struct GPT2BPEEncoder : torch::CustomClassHolder {
private:
const int64_t inf_;
// Encode byte into an unicode character.
std::vector<std::string> ByteEncode_(std::string token);
int64_t GetBPEMergeRank_(std::string pair);
std::vector<std::string> ByteEncode_(const std::string& token);
int64_t GetBPEMergeRank_(const std::string& pair);

protected:
c10::Dict<std::string, std::vector<std::string>> cache_;
virtual std::vector<std::string> PreTokenize_(std::string input);
virtual std::vector<std::string> PreTokenize_(const std::string& input);
// Return a list of bpe tokens.
virtual std::vector<std::string> BPE_(
const std::vector<std::string>& token_list);
// Return the token pair(e.g bpe merge) with lowest rank.
std::string FindBestPair_(std::vector<std::string> pairs);
std::string FindBestPair_(const std::vector<std::string>& pairs);

public:
const c10::Dict<std::string, int64_t> bpe_encoder_;
Expand Down Expand Up @@ -122,9 +122,9 @@ GPT2BPEEncoderStatesPybind _serialize_gpt2_bpe_encoder_pybind(
GPT2BPEEncoderStatesTorchbind _serialize_gpt2_bpe_encoder_torchbind(
const c10::intrusive_ptr<GPT2BPEEncoder>& self);
c10::intrusive_ptr<GPT2BPEEncoder> _deserialize_gpt2_bpe_encoder_pybind(
GPT2BPEEncoderStatesPybind states);
const GPT2BPEEncoderStatesPybind& states);
c10::intrusive_ptr<GPT2BPEEncoder> _deserialize_gpt2_bpe_encoder_torchbind(
GPT2BPEEncoderStatesTorchbind states);
const GPT2BPEEncoderStatesTorchbind& states);
} // namespace facebook::torcharrow::functions

#endif // GPT2_BPE_TOKENIZER_H_