diff --git a/mypyc/lib-rt/librt_internal.c b/mypyc/lib-rt/librt_internal.c index 6cae63cfadcb..eaf451eff22b 100644 --- a/mypyc/lib-rt/librt_internal.c +++ b/mypyc/lib-rt/librt_internal.c @@ -319,6 +319,11 @@ read_str_internal(PyObject *data) { CPyTagged tagged_size = _read_short_int(data, first); if (tagged_size == CPY_INT_TAG) return NULL; + if ((Py_ssize_t)tagged_size < 0) { + // Fail fast for invalid/tampered data. + PyErr_SetString(PyExc_ValueError, "invalid str size"); + return NULL; + } Py_ssize_t size = tagged_size >> 1; // Read string content. char *buf = ((BufferObject *)data)->buf; @@ -437,6 +442,11 @@ read_bytes_internal(PyObject *data) { CPyTagged tagged_size = _read_short_int(data, first); if (tagged_size == CPY_INT_TAG) return NULL; + if ((Py_ssize_t)tagged_size < 0) { + // Fail fast for invalid/tampered data. + PyErr_SetString(PyExc_ValueError, "invalid bytes size"); + return NULL; + } Py_ssize_t size = tagged_size >> 1; // Read bytes content. char *buf = ((BufferObject *)data)->buf; @@ -601,6 +611,10 @@ read_int_internal(PyObject *data) { Py_ssize_t size_and_sign = _read_short_int(data, first); if (size_and_sign == CPY_INT_TAG) return CPY_INT_TAG; + if ((Py_ssize_t)size_and_sign < 0) { + PyErr_SetString(PyExc_ValueError, "invalid int data"); + return CPY_INT_TAG; + } bool sign = (size_and_sign >> 1) & 1; Py_ssize_t size = size_and_sign >> 2; diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index b02d10446800..0805da184e1a 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -5359,3 +5359,71 @@ def test_deletable_attr() -> None: assert i.del_counter == 1 test_deletable_attr() + +[case testBufferCorruptedData_librt_internal] +from librt.internal import ( + Buffer, read_bool, read_str, read_float, read_int, read_tag, read_bytes +) +from random import randbytes + +def check(data: bytes) -> None: + b = Buffer(data) + try: + while True: + read_bool(b) + except ValueError: + pass + b = Buffer(data) + read_tag(b) # Always succeeds + try: + while True: + read_int(b) + except ValueError: + pass + b = Buffer(data) + try: + while True: + read_str(b) + except ValueError: + pass + b = Buffer(data) + try: + while True: + read_bytes(b) + except ValueError: + pass + b = Buffer(data) + try: + while True: + read_float(b) + except ValueError: + pass + +import time + +def test_read_corrupted_data() -> None: + # Test various deterministic byte sequences (1 to 4 bytes). + t0 = time.time() + for a in range(256): + check(bytes([a])) + for a in range(256): + for b in range(256): + check(bytes([a, b])) + for a in range(32): + for b in range(48): + for c in range(48): + check(bytes([a, b, c])) + for a in range(32): + for b in (0, 5, 17, 34): + for c in (0, 5, 17, 34): + for d in (0, 5, 17, 34): + check(bytes([a, b, c, d])) + # Also test some random data. + for i in range(20000): + data = randbytes(16) + try: + check(data) + except BaseException as e: + print("RANDOMIZED TEST FAILURE -- please open an issue with the following context:") + print(">>>", e, data) + raise