Skip to content

Commit

Permalink
Optionally ignore utf-8 decoding error for scripted C++ tokenizers. (… (
Browse files Browse the repository at this point in the history
#2134)

* Optionally ignore utf-8 decoding error for scripted C++ tokenizers. (#2128)

Summary:
Pull Request resolved: #2128

Binding and test to make sure we can use 'ignore' option for utf-8 decoding added to pytorch in D43970697( pytorch/pytorch#97282).

Reviewed By: Nayef211

Differential Revision: D44315169

fbshipit-source-id: d42fcacafd429cf586c631faf826abc172b173d3

* Linter fixes

---------

Co-authored-by: Shuming Hu <smhu@meta.com>
  • Loading branch information
Nayef211 and shuminghu committed Mar 29, 2023
1 parent f151b4c commit 726f5df
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 0 deletions.
18 changes: 18 additions & 0 deletions test/torchtext_unittest/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,16 @@ def _gpt2_bpe_decoder_with_special_tokens(self, tokenizer):
for idx, ids in enumerate(sample_ids):
self.assertEqual(tokenizer.decode(ids), expected_texts[idx])

def _gpt_bpe_decoder_partial_utf8(self, tokenizer):
sample_ids = [
["47728", "245", "114"],
["47728", "245", "114", "47728"], # containing partial utf-8 encoding
]
expected_texts = ["𝗶", "𝗶"]

for idx, ids in enumerate(sample_ids):
self.assertEqual(tokenizer.decode(ids), expected_texts[idx])

@nested_params([True, False], [True, False])
def test_gpt2_bpe_tokenizer(self, test_scripting, return_tokens):
"""test tokenization on single sentence input as well as batch on sentences"""
Expand All @@ -704,6 +714,14 @@ def test_gpt2_bpe_decoder(self):
self._gpt2_bpe_decoder(self._load_tokenizer(test_scripting=False, return_tokens=False))
self._gpt2_bpe_decoder_with_special_tokens(self._load_tokenizer(test_scripting=False, return_tokens=False))

torch.ops.torchtext.set_utf8_decoding_ignore(True)
self._gpt_bpe_decoder_partial_utf8(self._load_tokenizer(test_scripting=False, return_tokens=False))
self._gpt_bpe_decoder_partial_utf8(self._load_tokenizer(test_scripting=True, return_tokens=False))

torch.ops.torchtext.set_utf8_decoding_ignore(False)
with self.assertRaises(UnicodeDecodeError):
self._gpt_bpe_decoder_partial_utf8(self._load_tokenizer(test_scripting=True, return_tokens=False))

@nested_params([True, False])
def test_gpt2_bpe_tokenizer_with_added_vocab(self, return_tokens):
self._gpt2_bpe_tokenizer_with_added_vocab(
Expand Down
4 changes: 4 additions & 0 deletions torchtext/csrc/register_pybindings.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/csrc/jit/python/module_python.h> // @manual
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>
#include <torch/csrc/utils/pybind.h> // @manual
#include <torch/script.h>
#include <torchtext/csrc/bert_tokenizer.h> // @manual
Expand Down Expand Up @@ -287,6 +288,9 @@ PYBIND11_MODULE(_torchtext, m) {
m.def(
"_build_vocab_from_text_file_using_python_tokenizer",
&_build_vocab_from_text_file_using_python_tokenizer);
m.def(
"torchtext::set_utf8_decoding_ignore",
&torch::jit::setUTF8DecodingIgnore);
}

} // namespace torchtext
4 changes: 4 additions & 0 deletions torchtext/csrc/register_torchbindings.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>
#include <torch/script.h>
#include <torchtext/csrc/bert_tokenizer.h> // @manual
#include <torchtext/csrc/clip_tokenizer.h> // @manual
Expand Down Expand Up @@ -210,6 +211,9 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
m.def("torchtext::load_sp_model", &load_sp_model);
m.def("torchtext::load_sp_model_string", &load_sp_model_string);
m.def("torchtext::gpt2_bpe_pre_tokenizer", &gpt2_bpe_pre_tokenizer);
m.def(
"torchtext::set_utf8_decoding_ignore",
&torch::jit::setUTF8DecodingIgnore);
}

} // namespace torchtext
1 change: 1 addition & 0 deletions torchtext/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def __prepare_scriptable__(self):
return tokenizer_copy
return self

@torch.jit.export
def decode(self, tokens: List[str]) -> str:
"""Return a decoded string given a list of string token ids.
Expand Down

0 comments on commit 726f5df

Please sign in to comment.