Skip to content

Commit

Permalink
Optionally ignore utf-8 decoding error when converting std::string to…
Browse files Browse the repository at this point in the history
… python str. (#97282)

Summary:
X-link: pytorch/pytorch#97282

Pull Request resolved: pytorch#2126

When language models use c++ tokenizer, outputs are a c++ strings that are not necessarily valid utf-8 encodings. Default pybind11 casting uses strict utf-8 decoding. We relax the decoding using 'ignore' argument.

Reviewed By: Nayef211

Differential Revision: D43970697

fbshipit-source-id: 4988147e6905d1fb8096bf6d172ab8d8952b49b0
  • Loading branch information
shuminghu authored and facebook-github-bot committed Mar 22, 2023
1 parent 145479c commit c560fd9
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 0 deletions.
20 changes: 20 additions & 0 deletions test/torchtext_unittest/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,16 @@ def _gpt2_bpe_decoder_with_special_tokens(self, tokenizer):
for idx, ids in enumerate(sample_ids):
self.assertEqual(tokenizer.decode(ids), expected_texts[idx])

def _gpt_bpe_decoder_partial_utf8(self, tokenizer):
sample_ids = [
['47728', '245', '114'],
['47728', '245', '114', '47728'], # containing partial utf-8 encoding
]
expected_texts = ["𝗶", "𝗶"]

for idx, ids in enumerate(sample_ids):
self.assertEqual(tokenizer.decode(ids), expected_texts[idx])

@nested_params([True, False], [True, False])
def test_gpt2_bpe_tokenizer(self, test_scripting, return_tokens):
"""test tokenization on single sentence input as well as batch on sentences"""
Expand All @@ -704,6 +714,16 @@ def test_gpt2_bpe_decoder(self):
self._gpt2_bpe_decoder(self._load_tokenizer(test_scripting=False, return_tokens=False))
self._gpt2_bpe_decoder_with_special_tokens(self._load_tokenizer(test_scripting=False, return_tokens=False))

torch.ops.torchtext.set_utf8_decoding_ignore(True)
self._gpt_bpe_decoder_partial_utf8(self._load_tokenizer(test_scripting=False, return_tokens=False))
self._gpt_bpe_decoder_partial_utf8(self._load_tokenizer(test_scripting=True, return_tokens=False))

torch.ops.torchtext.set_utf8_decoding_ignore(False)
with self.assertRaises(UnicodeDecodeError):
self._gpt_bpe_decoder_partial_utf8(self._load_tokenizer(test_scripting=True, return_tokens=False))



@nested_params([True, False])
def test_gpt2_bpe_tokenizer_with_added_vocab(self, return_tokens):
self._gpt2_bpe_tokenizer_with_added_vocab(
Expand Down
4 changes: 4 additions & 0 deletions torchtext/csrc/register_pybindings.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/csrc/jit/python/module_python.h> // @manual
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>
#include <torch/csrc/utils/pybind.h> // @manual
#include <torch/script.h>
#include <torchtext/csrc/bert_tokenizer.h> // @manual
Expand Down Expand Up @@ -287,6 +288,9 @@ PYBIND11_MODULE(_torchtext, m) {
m.def(
"_build_vocab_from_text_file_using_python_tokenizer",
&_build_vocab_from_text_file_using_python_tokenizer);
m.def(
"torchtext::set_utf8_decoding_ignore",
&torch::jit::setUTF8DecodingIgnore);
}

} // namespace torchtext
4 changes: 4 additions & 0 deletions torchtext/csrc/register_torchbindings.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>
#include <torch/script.h>
#include <torchtext/csrc/bert_tokenizer.h> // @manual
#include <torchtext/csrc/clip_tokenizer.h> // @manual
Expand Down Expand Up @@ -210,6 +211,9 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
m.def("torchtext::load_sp_model", &load_sp_model);
m.def("torchtext::load_sp_model_string", &load_sp_model_string);
m.def("torchtext::gpt2_bpe_pre_tokenizer", &gpt2_bpe_pre_tokenizer);
m.def(
"torchtext::set_utf8_decoding_ignore",
&torch::jit::setUTF8DecodingIgnore);
}

} // namespace torchtext
1 change: 1 addition & 0 deletions torchtext/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def __prepare_scriptable__(self):
return tokenizer_copy
return self

@torch.jit.export
def decode(self, tokens: List[str]) -> str:
"""Return a decoded string given a list of string token ids.
Expand Down

0 comments on commit c560fd9

Please sign in to comment.