From a9a90c8a7c041681191710263fd2bcd496e3a1fc Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Mon, 27 Jun 2022 09:48:12 -0700 Subject: [PATCH] Fix performance issues in gpt2_bpe_tokenizer (#401) Summary: Pull Request resolved: https://github.com/pytorch/torcharrow/pull/401 complex structures in c++ should be passed as const ref instead of value to avoid data copy. A bunch of functions was passing by value gpt2_bpe_tokenizer Differential Revision: D37423480 fbshipit-source-id: 3daa2790a54b1214f3f632b201373e719a3b60a5 --- .../functions/text/gpt2_bpe_tokenizer.cpp | 36 ++++++++++--------- .../velox/functions/text/gpt2_bpe_tokenizer.h | 24 ++++++------- 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/csrc/velox/functions/text/gpt2_bpe_tokenizer.cpp b/csrc/velox/functions/text/gpt2_bpe_tokenizer.cpp index 838a080c0..d01da693c 100644 --- a/csrc/velox/functions/text/gpt2_bpe_tokenizer.cpp +++ b/csrc/velox/functions/text/gpt2_bpe_tokenizer.cpp @@ -34,7 +34,8 @@ bool is_whitespace(const std::string& input) { } template -c10::Dict _map_to_c10_dict(std::unordered_map m) { +c10::Dict _map_to_c10_dict( + const std::unordered_map& m) { c10::Dict d; for (const auto& item : m) d.insert(item.first, item.second); @@ -42,14 +43,15 @@ c10::Dict _map_to_c10_dict(std::unordered_map m) { } template -std::unordered_map _c10_dict_to_map(c10::Dict d) { +std::unordered_map _c10_dict_to_map( + const c10::Dict& d) { std::unordered_map m; for (const auto& item : d) m[item.key()] = item.value(); return m; } -std::vector gpt2_bpe_pre_tokenizer(std::string input) { +std::vector gpt2_bpe_pre_tokenizer(const std::string& input) { // Python implementation: // https://github.com/pytorch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe_utils.py#L69 // Original regex contains a negative lookahead pattern, which is not @@ -102,16 +104,16 @@ std::vector gpt2_bpe_pre_tokenizer(std::string input) { } std::pair split_tokens( - std::string s, - std::string delimiter) { + const std::string& s, + const std::string& delimiter) { auto pos = s.find(delimiter); TORCH_CHECK(pos != std::string::npos, "Expected `s`to contain `delimiter`"); return std::make_pair(s.substr(0, pos), s.substr(pos + delimiter.length())); } int list_str_index( - std::vector list, - std::string element, + const std::vector& list, + const std::string& element, int start) { // Equivalent to: list.index(element, start) for (std::size_t i = start; i < list.size(); ++i) { @@ -130,7 +132,7 @@ std::string concatenate_strings(const std::vector& list) { } std::vector get_pairs( - std::vector token_list, + const std::vector& token_list, const std::string& separator) { // For example: ["he", "l", "l", "o"] // ==> ["he\u0001l", "l\u0001l", "l\u0001o"] @@ -175,7 +177,7 @@ GPT2BPEEncoder::GPT2BPEEncoder( _map_to_c10_dict(byte_encoder), caching_enabled) {} -std::vector GPT2BPEEncoder::ByteEncode_(std::string token) { +std::vector GPT2BPEEncoder::ByteEncode_(const std::string& token) { // Equivalent to: (self.byte_encoder[b] for b in token.encode('utf-8') std::vector encoded; for (auto& ch : token) { @@ -184,14 +186,15 @@ std::vector GPT2BPEEncoder::ByteEncode_(std::string token) { return encoded; } -int64_t GPT2BPEEncoder::GetBPEMergeRank_(std::string pair) { +int64_t GPT2BPEEncoder::GetBPEMergeRank_(const std::string& pair) { if (bpe_merge_ranks_.contains(pair)) { return bpe_merge_ranks_.at(pair); } return inf_; } -std::string GPT2BPEEncoder::FindBestPair_(std::vector pairs) { +std::string GPT2BPEEncoder::FindBestPair_( + const std::vector& pairs) { // Equivalent to: // min(pairs, key = lambda pair: self.bpe_merge_ranks.get(pair, // float('inf'))) @@ -277,7 +280,8 @@ std::vector GPT2BPEEncoder::BPE_( return tok_list; } -std::vector GPT2BPEEncoder::PreTokenize_(std::string input) { +std::vector GPT2BPEEncoder::PreTokenize_( + const std::string& input) { return gpt2_bpe_pre_tokenizer(input); } @@ -327,8 +331,8 @@ GPT2BPEEncoderStatesTorchbind _serialize_gpt2_bpe_encoder_torchbind( } c10::intrusive_ptr _deserialize_gpt2_bpe_encoder_pybind( - GPT2BPEEncoderStatesPybind states) { - auto state_size = std::tuple_size::value; + const GPT2BPEEncoderStatesPybind& states) { + auto state_size = std::tuple_size::value; TORCH_CHECK( state_size == 5, "Expected deserialized GPT2BPEEncoder to have 5 states but found " + @@ -342,8 +346,8 @@ c10::intrusive_ptr _deserialize_gpt2_bpe_encoder_pybind( } c10::intrusive_ptr _deserialize_gpt2_bpe_encoder_torchbind( - GPT2BPEEncoderStatesTorchbind states) { - auto state_size = std::tuple_size::value; + const GPT2BPEEncoderStatesTorchbind& states) { + auto state_size = std::tuple_size::value; TORCH_CHECK( state_size == 5, "Expected deserialized GPT2BPEEncoder to have 5 states but found " + diff --git a/csrc/velox/functions/text/gpt2_bpe_tokenizer.h b/csrc/velox/functions/text/gpt2_bpe_tokenizer.h index 0192190ab..8747469d4 100644 --- a/csrc/velox/functions/text/gpt2_bpe_tokenizer.h +++ b/csrc/velox/functions/text/gpt2_bpe_tokenizer.h @@ -42,42 +42,42 @@ typedef std::tuple< // Applies regex based pre-tokenization step for GPT-2 BPE tokenizer // and returns a list of tokens. -std::vector gpt2_bpe_pre_tokenizer(std::string input); +std::vector gpt2_bpe_pre_tokenizer(const std::string& input); // Concatenate a vector of strings to a single string std::string concatenate_strings(const std::vector& list); // Return set of token pairs in a word, separated by the `separator`. std::vector get_pairs( - std::vector token_list, + const std::vector& token_list, const std::string& separator); // Split a string into 2 parts separated by a `separator`. std::pair split_tokens( - std::string s, - std::string delimiter); + const std::string& s, + const std::string& delimiter); // Find index of `element` in a list of strings. int list_str_index( - std::vector list, - std::string element, + const std::vector& list, + const std::string& element, int start); struct GPT2BPEEncoder : torch::CustomClassHolder { private: const int64_t inf_; // Encode byte into an unicode character. - std::vector ByteEncode_(std::string token); - int64_t GetBPEMergeRank_(std::string pair); + std::vector ByteEncode_(const std::string& token); + int64_t GetBPEMergeRank_(const std::string& pair); protected: c10::Dict> cache_; - virtual std::vector PreTokenize_(std::string input); + virtual std::vector PreTokenize_(const std::string& input); // Return a list of bpe tokens. virtual std::vector BPE_( const std::vector& token_list); // Return the token pair(e.g bpe merge) with lowest rank. - std::string FindBestPair_(std::vector pairs); + std::string FindBestPair_(const std::vector& pairs); public: const c10::Dict bpe_encoder_; @@ -122,9 +122,9 @@ GPT2BPEEncoderStatesPybind _serialize_gpt2_bpe_encoder_pybind( GPT2BPEEncoderStatesTorchbind _serialize_gpt2_bpe_encoder_torchbind( const c10::intrusive_ptr& self); c10::intrusive_ptr _deserialize_gpt2_bpe_encoder_pybind( - GPT2BPEEncoderStatesPybind states); + const GPT2BPEEncoderStatesPybind& states); c10::intrusive_ptr _deserialize_gpt2_bpe_encoder_torchbind( - GPT2BPEEncoderStatesTorchbind states); + const GPT2BPEEncoderStatesTorchbind& states); } // namespace facebook::torcharrow::functions #endif // GPT2_BPE_TOKENIZER_H_