Skip to content

Commit

Permalink
Optionally ignore utf-8 decoding error when converting std::string to…
Browse files Browse the repository at this point in the history
… python str. (#97282)

Summary: When language models use c++ tokenizer, outputs are a c++ strings that are not necessarily valid utf-8 encodings. Default pybind11 casting uses strict utf-8 decoding. We relax the decoding using 'ignore' argument.

Test Plan: https://www.internalfb.com/intern/testinfra/testrun/6473924609918070

Reviewed By: Nayef211

Differential Revision: D43970697

Pull Request resolved: #97282
Approved by: https://github.com/davidberard98
  • Loading branch information
shuminghu authored and pytorchmergebot committed Mar 23, 2023
1 parent a524123 commit b45880c
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 23 deletions.
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ core_sources_full_mobile_no_backend_interface_xplat = [
"torch/csrc/jit/passes/quantization/fusion_passes.cpp",
"torch/csrc/jit/passes/quantization/register_packed_params.cpp",
"torch/csrc/jit/python/update_graph_executor_opt.cpp",
"torch/csrc/jit/python/utf8_decoding_ignore.cpp",
"torch/csrc/jit/runtime/argument_spec.cpp",
"torch/csrc/jit/runtime/autodiff.cpp",
"torch/csrc/jit/runtime/graph_executor.cpp",
Expand Down
48 changes: 26 additions & 22 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
#include <torch/csrc/jit/python/python_tracer.h>
#include <torch/csrc/jit/python/python_tree_views.h>
#include <torch/csrc/jit/python/script_init.h>
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/decomposition_registry.h>
Expand Down Expand Up @@ -993,7 +994,7 @@ void initJITBindings(PyObject* module) {
#ifdef TORCH_ENABLE_LLVM
return true;
#else
return false;
return false;
#endif
})
.def(
Expand Down Expand Up @@ -1124,27 +1125,30 @@ void initJITBindings(PyObject* module) {
return retval;
})
.def("_jit_pass_batch_mm", BatchMM)
.def("_jit_decay_packed_param_input_types", [](Graph& g) {
for (Value* i : g.inputs()) {
if (i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase") ||
i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase") ||
i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.LinearPackedParamsBase")) {
// Dummy CompleteTensorType to appease ONNX validator.
i->setType(TensorType::create(
at::kQInt8,
c10::kCPU,
std::vector<int64_t>{1},
std::vector<int64_t>{1},
c10::nullopt));
}
}
});
.def(
"_jit_decay_packed_param_input_types",
[](Graph& g) {
for (Value* i : g.inputs()) {
if (i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase") ||
i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase") ||
i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.LinearPackedParamsBase")) {
// Dummy CompleteTensorType to appease ONNX validator.
i->setType(TensorType::create(
at::kQInt8,
c10::kCPU,
std::vector<int64_t>{1},
std::vector<int64_t>{1},
c10::nullopt));
}
}
})
.def("_jit_set_utf8_decoding_ignore", &setUTF8DecodingIgnore);

// NB: This isn't actually used for regular PyTorch symbolic tracing;
// XLA is what needs this
Expand Down
9 changes: 8 additions & 1 deletion torch/csrc/jit/python/pybind_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/csrc/jit/python/python_dict.h>
#include <torch/csrc/jit/python/python_ivalue.h>
#include <torch/csrc/jit/python/python_list.h>
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>

#include <ATen/ScalarOps.h>

Expand Down Expand Up @@ -555,7 +556,13 @@ py::object toPyObject(IValue ivalue) {
} else if (ivalue.isBool()) {
return py::cast(std::move(ivalue).toBool());
} else if (ivalue.isString()) {
return py::cast(std::move(ivalue).toStringRef());
if (getUTF8DecodingIgnore()) {
std::string s = std::move(ivalue).toStringRef();
PyObject* pyObj = PyUnicode_DecodeUTF8(s.data(), s.length(), "ignore");
return py::reinterpret_steal<py::object>(pyObj);
} else {
return py::cast(std::move(ivalue).toStringRef());
}
} else if (ivalue.isList()) {
auto list = std::move(ivalue).toList();
py::list t{list.size()};
Expand Down
16 changes: 16 additions & 0 deletions torch/csrc/jit/python/utf8_decoding_ignore.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>

namespace torch::jit {

namespace {
thread_local bool kIgnore = false;
}

void setUTF8DecodingIgnore(bool o) {
kIgnore = o;
}
bool getUTF8DecodingIgnore() {
return kIgnore;
}

} // namespace torch::jit
8 changes: 8 additions & 0 deletions torch/csrc/jit/python/utf8_decoding_ignore.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#pragma once
#include <torch/csrc/Export.h>
namespace torch {
namespace jit {
TORCH_API void setUTF8DecodingIgnore(bool o);
TORCH_API bool getUTF8DecodingIgnore();
} // namespace jit
} // namespace torch

0 comments on commit b45880c

Please sign in to comment.