Skip to content

Commit

Permalink
Make str(msg) in Python print raw UTF-8 strings. Only invalid UTF-8…
Browse files Browse the repository at this point in the history
… is escaped.

PiperOrigin-RevId: 597917280
  • Loading branch information
haberman authored and Copybara-Service committed Jan 12, 2024
1 parent 3a007b5 commit f2a91b3
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 46 deletions.
24 changes: 24 additions & 0 deletions python/google/protobuf/internal/message_test.py
Expand Up @@ -2606,6 +2606,30 @@ def testTypeNamesCanBeImported(self):
self.assertImportFromName(pb.repeated_nested_message, 'Composite')


# We can only test this case under proto2, because proto3 will reject invalid
# UTF-8 in the parser, so there should be no way of creating a string field
# that contains invalid UTF-8.
#
# We also can't test it in pure-Python, which validates all string fields for
# UTF-8 even when the spec says it shouldn't.
@unittest.skipIf(api_implementation.Type() == 'python',
'Python can\'t create invalid UTF-8 strings')
@testing_refleaks.TestCase
class InvalidUtf8Test(unittest.TestCase):

def testInvalidUtf8Printing(self):
one_bytes = unittest_pb2.OneBytes()
one_bytes.data = b'ABC\xff123'
one_string = unittest_pb2.OneString()
one_string.ParseFromString(one_bytes.SerializeToString())
self.assertIn('data: "ABC\\377123"', str(one_string))

def testValidUtf8Printing(self):
self.assertIn('data: "€"', str(unittest_pb2.OneString(data='€'))) # 2 byte
self.assertIn('data: "£"', str(unittest_pb2.OneString(data='£'))) # 3 byte
self.assertIn('data: "🙂"', str(unittest_pb2.OneString(data='🙂'))) # 4 byte


@testing_refleaks.TestCase
class PackedFieldTest(unittest.TestCase):

Expand Down
23 changes: 18 additions & 5 deletions python/google/protobuf/pyext/message.cc
Expand Up @@ -30,10 +30,16 @@
#endif
#include "google/protobuf/stubs/common.h"
#include "google/protobuf/descriptor.pb.h"
#include "absl/strings/escaping.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/strtod.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/unknown_field_set.h"
#include "google/protobuf/util/message_differencer.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include "google/protobuf/pyext/extension_dict.h"
Expand All @@ -46,11 +52,6 @@
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include "google/protobuf/pyext/unknown_field_set.h"
#include "google/protobuf/pyext/unknown_fields.h"
#include "google/protobuf/util/message_differencer.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/strtod.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"

// clang-format off
#include "google/protobuf/port_def.inc"
Expand Down Expand Up @@ -1683,6 +1684,18 @@ class PythonFieldValuePrinter : public TextFormat::FastFieldValuePrinter {

generator->PrintString(PyString_AsString(py_str.get()));
}
void PrintString(const std::string& val,
TextFormat::BaseTextGenerator* generator) const override {
TextFormat::Printer::HardenedPrintString(val, generator);
}
void PrintBytes(const std::string& val,
TextFormat::BaseTextGenerator* generator) const override {
generator->PrintLiteral("\"");
if (!val.empty()) {
generator->PrintString(absl::CEscape(val));
}
generator->PrintLiteral("\"");
}
};

static PyObject* ToStr(CMessage* self) {
Expand Down
14 changes: 7 additions & 7 deletions src/google/protobuf/text_format.cc
Expand Up @@ -1706,14 +1706,17 @@ size_t SkipPassthroughBytes(absl::string_view val) {
return val.size();
}

void HardenedPrintString(absl::string_view src,
TextFormat::BaseTextGenerator* generator) {
} // namespace

void TextFormat::Printer::HardenedPrintString(
absl::string_view src, TextFormat::BaseTextGenerator* generator) {
// Print as UTF-8, while guarding against any invalid UTF-8 in the string
// field.
//
// If in the future we have a guaranteed invariant that invalid UTF-8 will
// never be present, we could avoid the UTF-8 check here.

generator->PrintLiteral("\"");
while (!src.empty()) {
size_t n = SkipPassthroughBytes(src);
if (n != 0) {
Expand All @@ -1727,20 +1730,17 @@ void HardenedPrintString(absl::string_view src,
generator->PrintString(absl::CEscape(src.substr(0, 1)));
src.remove_prefix(1);
}
generator->PrintLiteral("\"");
}

} // namespace

// ===========================================================================
// An internal field value printer that escape UTF8 strings.
class TextFormat::Printer::FastFieldValuePrinterUtf8Escaping
: public TextFormat::Printer::DebugStringFieldValuePrinter {
public:
void PrintString(const std::string& val,
TextFormat::BaseTextGenerator* generator) const override {
generator->PrintLiteral("\"");
HardenedPrintString(val, generator);
generator->PrintLiteral("\"");
TextFormat::Printer::HardenedPrintString(val, generator);
}
void PrintBytes(const std::string& val,
TextFormat::BaseTextGenerator* generator) const override {
Expand Down
10 changes: 10 additions & 0 deletions src/google/protobuf/text_format.h
Expand Up @@ -78,6 +78,12 @@ namespace io {
class ErrorCollector; // tokenizer.h
}

namespace python {
namespace cmessage {
class PythonFieldValuePrinter;
}
} // namespace python

namespace internal {
// Enum used to set printing options for StringifyMessage.
PROTOBUF_EXPORT enum class Option;
Expand Down Expand Up @@ -526,6 +532,10 @@ class PROTOBUF_EXPORT TextFormat {
: it->second.get();
}

friend class google::protobuf::python::cmessage::PythonFieldValuePrinter;
static void HardenedPrintString(absl::string_view src,
TextFormat::BaseTextGenerator* generator);

int initial_indent_level_;
bool single_line_mode_;
bool use_field_number_;
Expand Down
2 changes: 1 addition & 1 deletion upb/text/BUILD
Expand Up @@ -18,12 +18,12 @@ cc_library(
copts = UPB_DEFAULT_COPTS,
visibility = ["//visibility:public"],
deps = [
"//third_party/utf8_range",
"//upb:base",
"//upb:eps_copy_input_stream",
"//upb:message",
"//upb:port",
"//upb:reflection",
"//upb:wire",
"//upb:wire_reader",
"//upb/lex",
"//upb/message:internal",
Expand Down
146 changes: 113 additions & 33 deletions upb/text/encode.c
Expand Up @@ -29,6 +29,7 @@
#include "upb/wire/eps_copy_input_stream.h"
#include "upb/wire/reader.h"
#include "upb/wire/types.h"
#include "utf8_range.h"

// Must be last.
#include "upb/port/def.inc"
Expand Down Expand Up @@ -108,42 +109,121 @@ static void txtenc_enum(int32_t val, const upb_FieldDef* f, txtenc* e) {
}
}

static void txtenc_string(txtenc* e, upb_StringView str, bool bytes) {
const char* ptr = str.data;
const char* end = ptr + str.size;
txtenc_putstr(e, "\"");
static void txtenc_escaped(txtenc* e, unsigned char ch) {
switch (ch) {
case '\n':
txtenc_putstr(e, "\\n");
break;
case '\r':
txtenc_putstr(e, "\\r");
break;
case '\t':
txtenc_putstr(e, "\\t");
break;
case '\"':
txtenc_putstr(e, "\\\"");
break;
case '\'':
txtenc_putstr(e, "\\'");
break;
case '\\':
txtenc_putstr(e, "\\\\");
break;
default:
txtenc_printf(e, "\\%03o", ch);
break;
}
}

// Returns true if `ch` needs to be escaped in TextFormat, independent of any
// UTF-8 validity issues.
static bool upb_DefinitelyNeedsEscape(unsigned char ch) {
if (ch < 32) return true;
switch (ch) {
case '\"':
case '\'':
case '\\':
case 127:
return true;
}
return false;
}

static bool upb_AsciiIsPrint(unsigned char ch) { return ch >= 32 && ch < 127; }

// Returns true if this is a high byte that requires UTF-8 validation. If the
// UTF-8 validation fails, we must escape the byte.
static bool upb_NeedsUtf8Validation(unsigned char ch) { return ch > 127; }

// Returns the number of bytes in the prefix of `val` that do not need escaping.
// This is like utf8_range::SpanStructurallyValid(), except that it also
// terminates at any ASCII char that needs to be escaped in TextFormat (any char
// that has `DefinitelyNeedsEscape(ch) == true`).
//
// If we could get a variant of utf8_range::SpanStructurallyValid() that could
// terminate on any of these chars, that might be more efficient, but it would
// be much more complicated to modify that heavily SIMD code.
static size_t SkipPassthroughBytes(const char* ptr, size_t size) {
for (size_t i = 0; i < size; i++) {
unsigned char uc = ptr[i];
if (upb_DefinitelyNeedsEscape(uc)) return i;
if (upb_NeedsUtf8Validation(uc)) {
// Find the end of this region of consecutive high bytes, so that we only
// give high bytes to the UTF-8 checker. This avoids needing to perform
// a second scan of the ASCII characters looking for characters that
// need escaping.
//
// We assume that high bytes are less frequent than plain, printable ASCII
// bytes, so we accept the double-scan of high bytes.
size_t end = i + 1;
for (; end < size; end++) {
if (!upb_NeedsUtf8Validation(ptr[end])) break;
}
size_t n = end - i;
size_t ok = utf8_range_ValidPrefix(ptr + i, n);
if (ok != n) return i + ok;
i += ok - 1;
}
}
return size;
}

static void upb_HardenedPrintString(txtenc* e, const char* ptr, size_t len) {
// Print as UTF-8, while guarding against any invalid UTF-8 in the string
// field.
//
// If in the future we have a guaranteed invariant that invalid UTF-8 will
// never be present, we could avoid the UTF-8 check here.
txtenc_putstr(e, "\"");
const char* end = ptr + len;
while (ptr < end) {
switch (*ptr) {
case '\n':
txtenc_putstr(e, "\\n");
break;
case '\r':
txtenc_putstr(e, "\\r");
break;
case '\t':
txtenc_putstr(e, "\\t");
break;
case '\"':
txtenc_putstr(e, "\\\"");
break;
case '\'':
txtenc_putstr(e, "\\'");
break;
case '\\':
txtenc_putstr(e, "\\\\");
break;
default:
if ((bytes || (uint8_t)*ptr < 0x80) && !isprint(*ptr)) {
txtenc_printf(e, "\\%03o", (int)(uint8_t)*ptr);
} else {
txtenc_putbytes(e, ptr, 1);
}
break;
size_t n = SkipPassthroughBytes(ptr, end - ptr);
if (n != 0) {
txtenc_putbytes(e, ptr, n);
ptr += n;
if (ptr == end) break;
}

// If repeated calls to CEscape() and PrintString() are expensive, we could
// consider batching them, at the cost of some complexity.
txtenc_escaped(e, *ptr);
ptr++;
}
txtenc_putstr(e, "\"");
}

static void txtenc_bytes(txtenc* e, upb_StringView data) {
const char* ptr = data.data;
const char* end = ptr + data.size;
txtenc_putstr(e, "\"");
for (; ptr < end; ptr++) {
unsigned char uc = *ptr;
if (upb_AsciiIsPrint(uc)) {
txtenc_putbytes(e, ptr, 1);
} else {
txtenc_escaped(e, uc);
}
}
txtenc_putstr(e, "\"");
}

Expand Down Expand Up @@ -206,10 +286,10 @@ static void txtenc_field(txtenc* e, upb_MessageValue val,
txtenc_printf(e, "%" PRIu64, val.uint64_val);
break;
case kUpb_CType_String:
txtenc_string(e, val.str_val, false);
upb_HardenedPrintString(e, val.str_val.data, val.str_val.size);
break;
case kUpb_CType_Bytes:
txtenc_string(e, val.str_val, true);
txtenc_bytes(e, val.str_val);
break;
case kUpb_CType_Enum:
txtenc_enum(val.int32_val, f, e);
Expand Down Expand Up @@ -378,7 +458,7 @@ static const char* txtenc_unknown(txtenc* e, const char* ptr,
const char* str = ptr;
ptr = upb_EpsCopyInputStream_ReadString(stream, &str, size, NULL);
UPB_ASSERT(ptr);
txtenc_string(e, (upb_StringView){.data = str, .size = size}, true);
txtenc_bytes(e, (upb_StringView){.data = str, .size = size});
}
break;
}
Expand Down

0 comments on commit f2a91b3

Please sign in to comment.