diff --git a/mypyc/lib-rt/librt_internal.c b/mypyc/lib-rt/librt_internal.c index 4f6e138c96f9..acbacba61ec9 100644 --- a/mypyc/lib-rt/librt_internal.c +++ b/mypyc/lib-rt/librt_internal.c @@ -584,14 +584,35 @@ read_int_internal(PyObject *data) { if (likely(first != LONG_INT_TRAILER)) { return _read_short_int(data, first); } - PyObject *str_ret = read_str_internal(data); - if (unlikely(str_ret == NULL)) + + // Long integer encoding -- byte length and sign, followed by a byte array. + + // Read byte length and sign. + _CHECK_READ(data, 1, CPY_INT_TAG) + first = _READ(data, uint8_t) + Py_ssize_t size_and_sign = _read_short_int(data, first); + if (size_and_sign == CPY_INT_TAG) return CPY_INT_TAG; - PyObject* ret_long = PyLong_FromUnicodeObject(str_ret, 10); - Py_DECREF(str_ret); - if (ret_long == NULL) + bool sign = (size_and_sign >> 1) & 1; + Py_ssize_t size = size_and_sign >> 2; + + // Construct an int object from the byte array. + _CHECK_READ(data, size, CPY_INT_TAG) + char *buf = ((BufferObject *)data)->buf; + PyObject *num = _PyLong_FromByteArray( + (unsigned char *)(buf + ((BufferObject *)data)->pos), size, 1, 0); + if (num == NULL) return CPY_INT_TAG; - return CPyTagged_StealFromObject(ret_long); + ((BufferObject *)data)->pos += size; + if (sign) { + PyObject *old = num; + num = PyNumber_Negative(old); + Py_DECREF(old); + if (num == NULL) { + return CPY_INT_TAG; + } + } + return CPyTagged_StealFromObject(num); } static PyObject* @@ -609,22 +630,81 @@ read_int(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames) return CPyTagged_StealAsObject(retval); } + +static inline int hex_to_int(char c) { + if (c >= '0' && c <= '9') + return c - '0'; + else if (c >= 'a' && c <= 'f') + return c - 'a' + 10; + else + return c - 'A' + 10; // Assume valid hex digit +} + static inline char _write_long_int(PyObject *data, CPyTagged value) { - // TODO(jukka): write a more compact/optimal format for arbitrary length ints. _CHECK_SIZE(data, 1) _WRITE(data, uint8_t, LONG_INT_TRAILER) ((BufferObject *)data)->end += 1; + + PyObject *hex_str = NULL; PyObject* int_value = CPyTagged_AsObject(value); if (unlikely(int_value == NULL)) - return CPY_NONE_ERROR; - PyObject *str_value = PyObject_Str(int_value); + goto error; + + hex_str = PyNumber_ToBase(int_value, 16); + if (hex_str == NULL) + goto error; Py_DECREF(int_value); - if (unlikely(str_value == NULL)) - return CPY_NONE_ERROR; - char res = write_str_internal(data, str_value); - Py_DECREF(str_value); - return res; + int_value = NULL; + + const char *str = PyUnicode_AsUTF8(hex_str); + if (str == NULL) + goto error; + Py_ssize_t len = strlen(str); + bool neg; + if (str[0] == '-') { + str++; + len--; + neg = true; + } else { + neg = false; + } + // Skip the 0x hex prefix. + str += 2; + len -= 2; + + // Write bytes encoded length and sign. + Py_ssize_t size = (len + 1) / 2; + Py_ssize_t encoded_size = (size << 1) | neg; + if (encoded_size <= MAX_FOUR_BYTES_INT) { + if (_write_short_int(data, encoded_size) == CPY_NONE_ERROR) + goto error; + } else { + PyErr_SetString(PyExc_ValueError, "int too long to serialize"); + goto error; + } + + // Write absolute integer value as byte array in a variable-length little endian format. + int i; + for (i = len; i > 1; i -= 2) { + if (write_tag_internal( + data, hex_to_int(str[i - 1]) | (hex_to_int(str[i - 2]) << 4)) == CPY_NONE_ERROR) + goto error; + } + // The final byte may correspond to only one hex digit. + if (i == 1) { + if (write_tag_internal(data, hex_to_int(str[i - 1])) == CPY_NONE_ERROR) + goto error; + } + + Py_DECREF(hex_str); + return CPY_NONE; + + error: + + Py_XDECREF(int_value); + Py_XDECREF(hex_str); + return CPY_NONE_ERROR; } static char diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index 04bbed78b318..09655cf35d6b 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -2804,13 +2804,36 @@ def test_buffer_int_size() -> None: def test_buffer_int_powers() -> None: # 0, 1, 2 are tested above - for p in range(2, 9): + for p in range(2, 200): b = Buffer() write_int(b, 1 << p) + write_int(b, (1 << p) - 1) write_int(b, -1 << p) + write_int(b, (-1 << p) + 1) b = Buffer(b.getvalue()) assert read_int(b) == 1 << p + assert read_int(b) == (1 << p) - 1 assert read_int(b) == -1 << p + assert read_int(b) == (-1 << p) + 1 + +def test_positive_long_int_serialized_bytes() -> None: + b = Buffer() + n = 0x123456789ab + write_int(b, n) + x = b.getvalue() + # Two prefix bytes, followed by little endian encoded integer in variable-length format + assert x == b"\x0f\x2c\xab\x89\x67\x45\x23\x01" + b = Buffer(x) + assert read_int(b) == n + +def test_negative_long_int_serialized_bytes() -> None: + b = Buffer() + n = -0x123456789abcde + write_int(b, n) + x = b.getvalue() + assert x == b"\x0f\x32\xde\xbc\x9a\x78\x56\x34\x12" + b = Buffer(x) + assert read_int(b) == n def test_buffer_str_size() -> None: for s in ("", "a", "a" * 117): @@ -2835,6 +2858,8 @@ test_buffer_roundtrip() test_buffer_int_size() test_buffer_str_size() test_buffer_int_powers() +test_positive_long_int_serialized_bytes() +test_negative_long_int_serialized_bytes() def test_buffer_basic_interpreted() -> None: b = Buffer(b"foo")