Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally ignore utf-8 decoding error when converting std::string to python str. #97282

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ core_sources_full_mobile_no_backend_interface_xplat = [
"torch/csrc/jit/passes/quantization/fusion_passes.cpp",
"torch/csrc/jit/passes/quantization/register_packed_params.cpp",
"torch/csrc/jit/python/update_graph_executor_opt.cpp",
"torch/csrc/jit/python/utf8_decoding_ignore.cpp",
"torch/csrc/jit/runtime/argument_spec.cpp",
"torch/csrc/jit/runtime/autodiff.cpp",
"torch/csrc/jit/runtime/graph_executor.cpp",
Expand Down
48 changes: 26 additions & 22 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
#include <torch/csrc/jit/python/python_tracer.h>
#include <torch/csrc/jit/python/python_tree_views.h>
#include <torch/csrc/jit/python/script_init.h>
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/decomposition_registry.h>
Expand Down Expand Up @@ -993,7 +994,7 @@ void initJITBindings(PyObject* module) {
#ifdef TORCH_ENABLE_LLVM
return true;
#else
return false;
return false;
#endif
})
.def(
Expand Down Expand Up @@ -1124,27 +1125,30 @@ void initJITBindings(PyObject* module) {
return retval;
})
.def("_jit_pass_batch_mm", BatchMM)
.def("_jit_decay_packed_param_input_types", [](Graph& g) {
for (Value* i : g.inputs()) {
if (i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase") ||
i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase") ||
i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.LinearPackedParamsBase")) {
// Dummy CompleteTensorType to appease ONNX validator.
i->setType(TensorType::create(
at::kQInt8,
c10::kCPU,
std::vector<int64_t>{1},
std::vector<int64_t>{1},
c10::nullopt));
}
}
});
.def(
"_jit_decay_packed_param_input_types",
[](Graph& g) {
for (Value* i : g.inputs()) {
if (i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase") ||
i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase") ||
i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.LinearPackedParamsBase")) {
// Dummy CompleteTensorType to appease ONNX validator.
i->setType(TensorType::create(
at::kQInt8,
c10::kCPU,
std::vector<int64_t>{1},
std::vector<int64_t>{1},
c10::nullopt));
}
}
})
.def("_jit_set_utf8_decoding_ignore", &setUTF8DecodingIgnore);

// NB: This isn't actually used for regular PyTorch symbolic tracing;
// XLA is what needs this
Expand Down
9 changes: 8 additions & 1 deletion torch/csrc/jit/python/pybind_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/csrc/jit/python/python_dict.h>
#include <torch/csrc/jit/python/python_ivalue.h>
#include <torch/csrc/jit/python/python_list.h>
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>

#include <ATen/ScalarOps.h>

Expand Down Expand Up @@ -555,7 +556,13 @@ py::object toPyObject(IValue ivalue) {
} else if (ivalue.isBool()) {
return py::cast(std::move(ivalue).toBool());
} else if (ivalue.isString()) {
return py::cast(std::move(ivalue).toStringRef());
if (getUTF8DecodingIgnore()) {
std::string s = std::move(ivalue).toStringRef();
PyObject* pyObj = PyUnicode_DecodeUTF8(s.data(), s.length(), "ignore");
return py::reinterpret_steal<py::object>(pyObj);
} else {
return py::cast(std::move(ivalue).toStringRef());
}
} else if (ivalue.isList()) {
auto list = std::move(ivalue).toList();
py::list t{list.size()};
Expand Down
16 changes: 16 additions & 0 deletions torch/csrc/jit/python/utf8_decoding_ignore.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <torch/csrc/jit/python/utf8_decoding_ignore.h>

namespace torch::jit {

namespace {
thread_local bool kIgnore = false;
}

void setUTF8DecodingIgnore(bool o) {
kIgnore = o;
}
bool getUTF8DecodingIgnore() {
return kIgnore;
}

} // namespace torch::jit
8 changes: 8 additions & 0 deletions torch/csrc/jit/python/utf8_decoding_ignore.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#pragma once
#include <torch/csrc/Export.h>
namespace torch {
namespace jit {
TORCH_API void setUTF8DecodingIgnore(bool o);
TORCH_API bool getUTF8DecodingIgnore();
} // namespace jit
} // namespace torch