diff --git a/mypyc/lib-rt/librt_internal.c b/mypyc/lib-rt/librt_internal.c index 8599017c31a8..4f6e138c96f9 100644 --- a/mypyc/lib-rt/librt_internal.c +++ b/mypyc/lib-rt/librt_internal.c @@ -8,15 +8,23 @@ #include "librt_internal.h" #define START_SIZE 512 -#define MAX_SHORT_INT_TAGGED (255 << 1) -#define MAX_SHORT_LEN 127 -#define LONG_STR_TAG 1 +// See comment in read_int_internal() on motivation for these values. +#define MIN_ONE_BYTE_INT -10 +#define MAX_ONE_BYTE_INT 117 // 2 ** 7 - 1 - 10 +#define MIN_TWO_BYTES_INT -100 +#define MAX_TWO_BYTES_INT 16283 // 2 ** (8 + 6) - 1 - 100 +#define MIN_FOUR_BYTES_INT -10000 +#define MAX_FOUR_BYTES_INT 536860911 // 2 ** (3 * 8 + 5) - 1 - 10000 -#define MIN_SHORT_INT -10 -#define MAX_SHORT_INT 117 -#define MEDIUM_INT_TAG 1 -#define LONG_INT_TAG 3 +#define TWO_BYTES_INT_BIT 1 +#define FOUR_BYTES_INT_BIT 2 +#define LONG_INT_BIT 4 + +#define FOUR_BYTES_INT_TRAILER 3 +// We add one reserved bit here so that we can potentially support +// 8 bytes format in the future. +#define LONG_INT_TRAILER 15 #define CPY_BOOL_ERROR 2 #define CPY_NONE_ERROR 2 @@ -35,13 +43,22 @@ #define _WRITE(data, type, v) *(type *)(((BufferObject *)data)->buf + ((BufferObject *)data)->pos) = v; \ ((BufferObject *)data)->pos += sizeof(type); +#if PY_BIG_ENDIAN +uint16_t reverse_16(uint16_t number) { + return (number << 8) | (number >> 8); +} + +uint32_t reverse_32(uint32_t number) { + return ((number & 0xFF) << 24) | ((number & 0xFF00) << 8) | ((number & 0xFF0000) >> 8) | (number >> 24); +} +#endif + typedef struct { PyObject_HEAD Py_ssize_t pos; Py_ssize_t end; Py_ssize_t size; char *buf; - PyObject *source; } BufferObject; static PyTypeObject BufferType; @@ -259,26 +276,50 @@ write_bool(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwname } /* -str format: size followed by UTF-8 bytes - short strings (len <= 127): single byte for size as `(uint8_t)size << 1` - long strings: \x01 followed by size as Py_ssize_t +str format: size as int (see below) followed by UTF-8 bytes */ +static inline CPyTagged +_read_short_int(PyObject *data, uint8_t first) { + uint8_t second; + uint16_t two_more; + if ((first & TWO_BYTES_INT_BIT) == 0) { + // Note we use tagged ints since this function can return an error. + return ((Py_ssize_t)(first >> 1) + MIN_ONE_BYTE_INT) << 1; + } + if ((first & FOUR_BYTES_INT_BIT) == 0) { + _CHECK_READ(data, 1, CPY_INT_TAG) + second = _READ(data, uint8_t) + return ((((Py_ssize_t)second) << 6) + (Py_ssize_t)(first >> 2) + MIN_TWO_BYTES_INT) << 1; + } + // The caller is responsible to verify this is called only for short ints. + _CHECK_READ(data, 3, CPY_INT_TAG) + // TODO: check if compilers emit optimal code for these two reads, and tweak if needed. + second = _READ(data, uint8_t) + two_more = _READ(data, uint16_t) +#if PY_BIG_ENDIAN + two_more = reverse_16(two_more); +#endif + Py_ssize_t higher = (((Py_ssize_t)two_more) << 13) + (((Py_ssize_t)second) << 5); + return (higher + (Py_ssize_t)(first >> 3) + MIN_FOUR_BYTES_INT) << 1; +} + static PyObject* read_str_internal(PyObject *data) { _CHECK_BUFFER(data, NULL) // Read string length. - Py_ssize_t size; _CHECK_READ(data, 1, NULL) uint8_t first = _READ(data, uint8_t) - if (likely(first != LONG_STR_TAG)) { - // Common case: short string (len <= 127). - size = (Py_ssize_t)(first >> 1); - } else { - _CHECK_READ(data, sizeof(CPyTagged), NULL) - size = _READ(data, Py_ssize_t) + if (unlikely(first == LONG_INT_TRAILER)) { + // Fail fast for invalid/tampered data. + PyErr_SetString(PyExc_ValueError, "invalid str size"); + return NULL; } + CPyTagged tagged_size = _read_short_int(data, first); + if (tagged_size == CPY_INT_TAG) + return NULL; + Py_ssize_t size = tagged_size >> 1; // Read string content. char *buf = ((BufferObject *)data)->buf; _CHECK_READ(data, size, NULL) @@ -302,6 +343,35 @@ read_str(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames) return read_str_internal(data); } +// The caller *must* check that real_value is within allowed range (29 bits). +static inline char +_write_short_int(PyObject *data, Py_ssize_t real_value) { + if (real_value >= MIN_ONE_BYTE_INT && real_value <= MAX_ONE_BYTE_INT) { + _CHECK_SIZE(data, 1) + _WRITE(data, uint8_t, (uint8_t)(real_value - MIN_ONE_BYTE_INT) << 1) + ((BufferObject *)data)->end += 1; + } else if (real_value >= MIN_TWO_BYTES_INT && real_value <= MAX_TWO_BYTES_INT) { + _CHECK_SIZE(data, 2) +#if PY_BIG_ENDIAN + uint16_t to_write = ((uint16_t)(real_value - MIN_TWO_BYTES_INT) << 2) | TWO_BYTES_INT_BIT; + _WRITE(data, uint16_t, reverse_16(to_write)) +#else + _WRITE(data, uint16_t, ((uint16_t)(real_value - MIN_TWO_BYTES_INT) << 2) | TWO_BYTES_INT_BIT) +#endif + ((BufferObject *)data)->end += 2; + } else { + _CHECK_SIZE(data, 4) +#if PY_BIG_ENDIAN + uint32_t to_write = ((uint32_t)(real_value - MIN_FOUR_BYTES_INT) << 3) | FOUR_BYTES_INT_TRAILER; + _WRITE(data, uint32_t, reverse_32(to_write)) +#else + _WRITE(data, uint32_t, ((uint32_t)(real_value - MIN_FOUR_BYTES_INT) << 3) | FOUR_BYTES_INT_TRAILER) +#endif + ((BufferObject *)data)->end += 4; + } + return CPY_NONE; +} + static char write_str_internal(PyObject *data, PyObject *value) { _CHECK_BUFFER(data, CPY_NONE_ERROR) @@ -311,24 +381,20 @@ write_str_internal(PyObject *data, PyObject *value) { if (unlikely(chunk == NULL)) return CPY_NONE_ERROR; - Py_ssize_t need; // Write string length. - if (likely(size <= MAX_SHORT_LEN)) { - // Common case: short string (len <= 127) store as single byte. - need = size + 1; - _CHECK_SIZE(data, need) - _WRITE(data, uint8_t, (uint8_t)size << 1) + if (likely(size >= MIN_FOUR_BYTES_INT && size <= MAX_FOUR_BYTES_INT)) { + if (_write_short_int(data, size) == CPY_NONE_ERROR) + return CPY_NONE_ERROR; } else { - need = size + sizeof(Py_ssize_t) + 1; - _CHECK_SIZE(data, need) - _WRITE(data, uint8_t, LONG_STR_TAG) - _WRITE(data, Py_ssize_t, size) + PyErr_SetString(PyExc_ValueError, "str too long to serialize"); + return CPY_NONE_ERROR; } // Write string content. + _CHECK_SIZE(data, size) char *buf = ((BufferObject *)data)->buf; memcpy(buf + ((BufferObject *)data)->pos, chunk, size); ((BufferObject *)data)->pos += size; - ((BufferObject *)data)->end += need; + ((BufferObject *)data)->end += size; return CPY_NONE; } @@ -353,9 +419,7 @@ write_str(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames } /* -bytes format: size followed by bytes - short bytes (len <= 127): single byte for size as `(uint8_t)size << 1` - long bytes: \x01 followed by size as Py_ssize_t +bytes format: size as int (see below) followed by bytes */ static PyObject* @@ -363,16 +427,17 @@ read_bytes_internal(PyObject *data) { _CHECK_BUFFER(data, NULL) // Read length. - Py_ssize_t size; _CHECK_READ(data, 1, NULL) uint8_t first = _READ(data, uint8_t) - if (likely(first != LONG_STR_TAG)) { - // Common case: short bytes (len <= 127). - size = (Py_ssize_t)(first >> 1); - } else { - _CHECK_READ(data, sizeof(CPyTagged), NULL) - size = _READ(data, Py_ssize_t) + if (unlikely(first == LONG_INT_TRAILER)) { + // Fail fast for invalid/tampered data. + PyErr_SetString(PyExc_ValueError, "invalid bytes size"); + return NULL; } + CPyTagged tagged_size = _read_short_int(data, first); + if (tagged_size == CPY_INT_TAG) + return NULL; + Py_ssize_t size = tagged_size >> 1; // Read bytes content. char *buf = ((BufferObject *)data)->buf; _CHECK_READ(data, size, NULL) @@ -405,24 +470,20 @@ write_bytes_internal(PyObject *data, PyObject *value) { return CPY_NONE_ERROR; Py_ssize_t size = PyBytes_GET_SIZE(value); - Py_ssize_t need; // Write length. - if (likely(size <= MAX_SHORT_LEN)) { - // Common case: short bytes (len <= 127) store as single byte. - need = size + 1; - _CHECK_SIZE(data, need) - _WRITE(data, uint8_t, (uint8_t)size << 1) + if (likely(size >= MIN_FOUR_BYTES_INT && size <= MAX_FOUR_BYTES_INT)) { + if (_write_short_int(data, size) == CPY_NONE_ERROR) + return CPY_NONE_ERROR; } else { - need = size + sizeof(Py_ssize_t) + 1; - _CHECK_SIZE(data, need) - _WRITE(data, uint8_t, LONG_STR_TAG) - _WRITE(data, Py_ssize_t, size) + PyErr_SetString(PyExc_ValueError, "bytes too long to serialize"); + return CPY_NONE_ERROR; } // Write bytes content. + _CHECK_SIZE(data, size) char *buf = ((BufferObject *)data)->buf; memcpy(buf + ((BufferObject *)data)->pos, chunk, size); ((BufferObject *)data)->pos += size; - ((BufferObject *)data)->end += need; + ((BufferObject *)data)->end += size; return CPY_NONE; } @@ -455,7 +516,7 @@ static double read_float_internal(PyObject *data) { _CHECK_BUFFER(data, CPY_FLOAT_ERROR) _CHECK_READ(data, sizeof(double), CPY_FLOAT_ERROR) - double res = _READ(data, double); + double res = _READ(data, double) return res; } @@ -505,9 +566,13 @@ write_float(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnam /* int format: - most common values (-10 <= value <= 117): single byte as `(uint8_t)(value + 10) << 1` - medium values (fit in CPyTagged): \x01 followed by CPyTagged value - long values (very rare): \x03 followed by decimal string (see str format) + one byte: last bit 0, 7 bits used + two bytes: last two bits 01, 14 bits used + four bytes: last three bits 011, 29 bits used + everything else: 00001111 followed by serialized string representation + +Note: for fixed size formats we skew ranges towards more positive values, +since negative integers are much more rare. */ static CPyTagged @@ -516,22 +581,17 @@ read_int_internal(PyObject *data) { _CHECK_READ(data, 1, CPY_INT_TAG) uint8_t first = _READ(data, uint8_t) - if ((first & MEDIUM_INT_TAG) == 0) { - // Most common case: int that is small in absolute value. - return ((Py_ssize_t)(first >> 1) + MIN_SHORT_INT) << 1; - } - if (first == MEDIUM_INT_TAG) { - _CHECK_READ(data, sizeof(CPyTagged), CPY_INT_TAG) - CPyTagged ret = _READ(data, CPyTagged) - return ret; + if (likely(first != LONG_INT_TRAILER)) { + return _read_short_int(data, first); } - // People who have literal ints not fitting in size_t should be punished :-) PyObject *str_ret = read_str_internal(data); if (unlikely(str_ret == NULL)) return CPY_INT_TAG; PyObject* ret_long = PyLong_FromUnicodeObject(str_ret, 10); Py_DECREF(str_ret); - return ((CPyTagged)ret_long) | CPY_INT_TAG; + if (ret_long == NULL) + return CPY_INT_TAG; + return CPyTagged_StealFromObject(ret_long); } static PyObject* @@ -549,36 +609,38 @@ read_int(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames) return CPyTagged_StealAsObject(retval); } +static inline char +_write_long_int(PyObject *data, CPyTagged value) { + // TODO(jukka): write a more compact/optimal format for arbitrary length ints. + _CHECK_SIZE(data, 1) + _WRITE(data, uint8_t, LONG_INT_TRAILER) + ((BufferObject *)data)->end += 1; + PyObject* int_value = CPyTagged_AsObject(value); + if (unlikely(int_value == NULL)) + return CPY_NONE_ERROR; + PyObject *str_value = PyObject_Str(int_value); + Py_DECREF(int_value); + if (unlikely(str_value == NULL)) + return CPY_NONE_ERROR; + char res = write_str_internal(data, str_value); + Py_DECREF(str_value); + return res; +} + static char write_int_internal(PyObject *data, CPyTagged value) { _CHECK_BUFFER(data, CPY_NONE_ERROR) if (likely((value & CPY_INT_TAG) == 0)) { Py_ssize_t real_value = CPyTagged_ShortAsSsize_t(value); - if (real_value >= MIN_SHORT_INT && real_value <= MAX_SHORT_INT) { - // Most common case: int that is small in absolute value. - _CHECK_SIZE(data, 1) - _WRITE(data, uint8_t, (uint8_t)(real_value - MIN_SHORT_INT) << 1) - ((BufferObject *)data)->end += 1; + if (likely(real_value >= MIN_FOUR_BYTES_INT && real_value <= MAX_FOUR_BYTES_INT)) { + return _write_short_int(data, real_value); } else { - _CHECK_SIZE(data, sizeof(CPyTagged) + 1) - _WRITE(data, uint8_t, MEDIUM_INT_TAG) - _WRITE(data, CPyTagged, value) - ((BufferObject *)data)->end += sizeof(CPyTagged) + 1; + return _write_long_int(data, value); } } else { - _CHECK_SIZE(data, 1) - _WRITE(data, uint8_t, LONG_INT_TAG) - ((BufferObject *)data)->end += 1; - PyObject *str_value = PyObject_Str(CPyTagged_LongAsObject(value)); - if (unlikely(str_value == NULL)) - return CPY_NONE_ERROR; - char res = write_str_internal(data, str_value); - Py_DECREF(str_value); - if (unlikely(res == CPY_NONE_ERROR)) - return CPY_NONE_ERROR; + return _write_long_int(data, value); } - return CPY_NONE; } static PyObject* diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index e08f2fd7007d..04bbed78b318 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -2738,8 +2738,8 @@ def test_buffer_roundtrip() -> None: write_bytes(b, b"bar") write_bytes(b, b"bar" * 100) write_bytes(b, b"") - write_bytes(b, b"a" * 127) - write_bytes(b, b"a" * 128) + write_bytes(b, b"a" * 117) + write_bytes(b, b"a" * 118) write_float(b, 0.1) write_float(b, -113.0) write_int(b, 0) @@ -2752,8 +2752,9 @@ def test_buffer_roundtrip() -> None: write_int(b, 255) write_int(b, -1) write_int(b, -255) - write_int(b, 1234512344) - write_int(b, 1234512345) + write_int(b, 536860911) + write_int(b, 536860912) + write_int(b, 1234567891) b = Buffer(b.getvalue()) assert read_str(b) == "foo" @@ -2763,8 +2764,8 @@ def test_buffer_roundtrip() -> None: assert read_bytes(b) == b"bar" assert read_bytes(b) == b"bar" * 100 assert read_bytes(b) == b"" - assert read_bytes(b) == b"a" * 127 - assert read_bytes(b) == b"a" * 128 + assert read_bytes(b) == b"a" * 117 + assert read_bytes(b) == b"a" * 118 assert read_float(b) == 0.1 assert read_float(b) == -113.0 assert read_int(b) == 0 @@ -2777,8 +2778,9 @@ def test_buffer_roundtrip() -> None: assert read_int(b) == 255 assert read_int(b) == -1 assert read_int(b) == -255 - assert read_int(b) == 1234512344 - assert read_int(b) == 1234512345 + assert read_int(b) == 536860911 + assert read_int(b) == 536860912 + assert read_int(b) == 1234567891 def test_buffer_int_size() -> None: for i in (-10, -9, 0, 116, 117): @@ -2787,21 +2789,44 @@ def test_buffer_int_size() -> None: assert len(b.getvalue()) == 1 b = Buffer(b.getvalue()) assert read_int(b) == i - for i in (-12345, -12344, -11, 118, 12344, 12345): + for i in (-100, -11, 118, 12344, 16283): b = Buffer() write_int(b, i) - assert len(b.getvalue()) <= 9 # sizeof(size_t) + 1 + assert len(b.getvalue()) == 2 b = Buffer(b.getvalue()) assert read_int(b) == i + for i in (-10000, 16284, 123456789): + b = Buffer() + write_int(b, i) + assert len(b.getvalue()) == 4 + b = Buffer(b.getvalue()) + assert read_int(b) == i + +def test_buffer_int_powers() -> None: + # 0, 1, 2 are tested above + for p in range(2, 9): + b = Buffer() + write_int(b, 1 << p) + write_int(b, -1 << p) + b = Buffer(b.getvalue()) + assert read_int(b) == 1 << p + assert read_int(b) == -1 << p def test_buffer_str_size() -> None: - for s in ("", "a", "a" * 127): + for s in ("", "a", "a" * 117): b = Buffer() write_str(b, s) assert len(b.getvalue()) == len(s) + 1 b = Buffer(b.getvalue()) assert read_str(b) == s + for s in ("a" * 118, "a" * 16283): + b = Buffer() + write_str(b, s) + assert len(b.getvalue()) == len(s) + 2 + b = Buffer(b.getvalue()) + assert read_str(b) == s + [file driver.py] from native import * @@ -2809,6 +2834,7 @@ test_buffer_basic() test_buffer_roundtrip() test_buffer_int_size() test_buffer_str_size() +test_buffer_int_powers() def test_buffer_basic_interpreted() -> None: b = Buffer(b"foo") @@ -2823,8 +2849,8 @@ def test_buffer_roundtrip_interpreted() -> None: write_bytes(b, b"bar") write_bytes(b, b"bar" * 100) write_bytes(b, b"") - write_bytes(b, b"a" * 127) - write_bytes(b, b"a" * 128) + write_bytes(b, b"a" * 117) + write_bytes(b, b"a" * 118) write_float(b, 0.1) write_int(b, 0) write_int(b, 1) @@ -2836,8 +2862,9 @@ def test_buffer_roundtrip_interpreted() -> None: write_int(b, 255) write_int(b, -1) write_int(b, -255) - write_int(b, 1234512344) - write_int(b, 1234512345) + write_int(b, 536860911) + write_int(b, 536860912) + write_int(b, 1234567891) b = Buffer(b.getvalue()) assert read_str(b) == "foo" @@ -2847,8 +2874,8 @@ def test_buffer_roundtrip_interpreted() -> None: assert read_bytes(b) == b"bar" assert read_bytes(b) == b"bar" * 100 assert read_bytes(b) == b"" - assert read_bytes(b) == b"a" * 127 - assert read_bytes(b) == b"a" * 128 + assert read_bytes(b) == b"a" * 117 + assert read_bytes(b) == b"a" * 118 assert read_float(b) == 0.1 assert read_int(b) == 0 assert read_int(b) == 1 @@ -2860,8 +2887,9 @@ def test_buffer_roundtrip_interpreted() -> None: assert read_int(b) == 255 assert read_int(b) == -1 assert read_int(b) == -255 - assert read_int(b) == 1234512344 - assert read_int(b) == 1234512345 + assert read_int(b) == 536860911 + assert read_int(b) == 536860912 + assert read_int(b) == 1234567891 def test_buffer_int_size_interpreted() -> None: for i in (-10, -9, 0, 116, 117): @@ -2870,25 +2898,49 @@ def test_buffer_int_size_interpreted() -> None: assert len(b.getvalue()) == 1 b = Buffer(b.getvalue()) assert read_int(b) == i - for i in (-12345, -12344, -11, 118, 12344, 12345): + for i in (-100, -11, 118, 12344, 16283): b = Buffer() write_int(b, i) - assert len(b.getvalue()) <= 9 # sizeof(size_t) + 1 + assert len(b.getvalue()) == 2 b = Buffer(b.getvalue()) assert read_int(b) == i + for i in (-10000, 16284, 123456789): + b = Buffer() + write_int(b, i) + assert len(b.getvalue()) == 4 + b = Buffer(b.getvalue()) + assert read_int(b) == i + +def test_buffer_int_powers_interpreted() -> None: + # 0, 1, 2 are tested above + for p in range(2, 9): + b = Buffer() + write_int(b, 1 << p) + write_int(b, -1 << p) + b = Buffer(b.getvalue()) + assert read_int(b) == 1 << p + assert read_int(b) == -1 << p def test_buffer_str_size_interpreted() -> None: - for s in ("", "a", "a" * 127): + for s in ("", "a", "a" * 117): b = Buffer() write_str(b, s) assert len(b.getvalue()) == len(s) + 1 b = Buffer(b.getvalue()) assert read_str(b) == s + for s in ("a" * 118, "a" * 16283): + b = Buffer() + write_str(b, s) + assert len(b.getvalue()) == len(s) + 2 + b = Buffer(b.getvalue()) + assert read_str(b) == s + test_buffer_basic_interpreted() test_buffer_roundtrip_interpreted() test_buffer_int_size_interpreted() test_buffer_str_size_interpreted() +test_buffer_int_powers_interpreted() [case testBufferEmpty_librt_internal] from librt.internal import Buffer, write_int, read_int