Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 94 additions & 14 deletions mypyc/lib-rt/librt_internal.c
Original file line number Diff line number Diff line change
Expand Up @@ -584,14 +584,35 @@ read_int_internal(PyObject *data) {
if (likely(first != LONG_INT_TRAILER)) {
return _read_short_int(data, first);
}
PyObject *str_ret = read_str_internal(data);
if (unlikely(str_ret == NULL))

// Long integer encoding -- byte length and sign, followed by a byte array.

// Read byte length and sign.
_CHECK_READ(data, 1, CPY_INT_TAG)
first = _READ(data, uint8_t)
Py_ssize_t size_and_sign = _read_short_int(data, first);
if (size_and_sign == CPY_INT_TAG)
return CPY_INT_TAG;
PyObject* ret_long = PyLong_FromUnicodeObject(str_ret, 10);
Py_DECREF(str_ret);
if (ret_long == NULL)
bool sign = (size_and_sign >> 1) & 1;
Py_ssize_t size = size_and_sign >> 2;

// Construct an int object from the byte array.
_CHECK_READ(data, size, CPY_INT_TAG)
char *buf = ((BufferObject *)data)->buf;
PyObject *num = _PyLong_FromByteArray(
(unsigned char *)(buf + ((BufferObject *)data)->pos), size, 1, 0);
if (num == NULL)
return CPY_INT_TAG;
return CPyTagged_StealFromObject(ret_long);
((BufferObject *)data)->pos += size;
if (sign) {
PyObject *old = num;
num = PyNumber_Negative(old);
Py_DECREF(old);
if (num == NULL) {
return CPY_INT_TAG;
}
}
return CPyTagged_StealFromObject(num);
}

static PyObject*
Expand All @@ -609,22 +630,81 @@ read_int(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames)
return CPyTagged_StealAsObject(retval);
}


static inline int hex_to_int(char c) {
if (c >= '0' && c <= '9')
return c - '0';
else if (c >= 'a' && c <= 'f')
return c - 'a' + 10;
else
return c - 'A' + 10; // Assume valid hex digit
}

static inline char
_write_long_int(PyObject *data, CPyTagged value) {
// TODO(jukka): write a more compact/optimal format for arbitrary length ints.
_CHECK_SIZE(data, 1)
_WRITE(data, uint8_t, LONG_INT_TRAILER)
((BufferObject *)data)->end += 1;

PyObject *hex_str = NULL;
PyObject* int_value = CPyTagged_AsObject(value);
if (unlikely(int_value == NULL))
return CPY_NONE_ERROR;
PyObject *str_value = PyObject_Str(int_value);
goto error;

hex_str = PyNumber_ToBase(int_value, 16);
if (hex_str == NULL)
goto error;
Py_DECREF(int_value);
if (unlikely(str_value == NULL))
return CPY_NONE_ERROR;
char res = write_str_internal(data, str_value);
Py_DECREF(str_value);
return res;
int_value = NULL;

const char *str = PyUnicode_AsUTF8(hex_str);
if (str == NULL)
goto error;
Py_ssize_t len = strlen(str);
bool neg;
if (str[0] == '-') {
str++;
len--;
neg = true;
} else {
neg = false;
}
// Skip the 0x hex prefix.
str += 2;
len -= 2;

// Write bytes encoded length and sign.
Py_ssize_t size = (len + 1) / 2;
Py_ssize_t encoded_size = (size << 1) | neg;
if (encoded_size <= MAX_FOUR_BYTES_INT) {
if (_write_short_int(data, encoded_size) == CPY_NONE_ERROR)
goto error;
} else {
PyErr_SetString(PyExc_ValueError, "int too long to serialize");
goto error;
}

// Write absolute integer value as byte array in a variable-length little endian format.
int i;
for (i = len; i > 1; i -= 2) {
if (write_tag_internal(
data, hex_to_int(str[i - 1]) | (hex_to_int(str[i - 2]) << 4)) == CPY_NONE_ERROR)
goto error;
}
// The final byte may correspond to only one hex digit.
if (i == 1) {
if (write_tag_internal(data, hex_to_int(str[i - 1])) == CPY_NONE_ERROR)
goto error;
}

Py_DECREF(hex_str);
return CPY_NONE;

error:

Py_XDECREF(int_value);
Py_XDECREF(hex_str);
return CPY_NONE_ERROR;
}

static char
Expand Down
27 changes: 26 additions & 1 deletion mypyc/test-data/run-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -2804,13 +2804,36 @@ def test_buffer_int_size() -> None:

def test_buffer_int_powers() -> None:
# 0, 1, 2 are tested above
for p in range(2, 9):
for p in range(2, 200):
b = Buffer()
write_int(b, 1 << p)
write_int(b, (1 << p) - 1)
write_int(b, -1 << p)
write_int(b, (-1 << p) + 1)
b = Buffer(b.getvalue())
assert read_int(b) == 1 << p
assert read_int(b) == (1 << p) - 1
assert read_int(b) == -1 << p
assert read_int(b) == (-1 << p) + 1

def test_positive_long_int_serialized_bytes() -> None:
b = Buffer()
n = 0x123456789ab
write_int(b, n)
x = b.getvalue()
# Two prefix bytes, followed by little endian encoded integer in variable-length format
assert x == b"\x0f\x2c\xab\x89\x67\x45\x23\x01"
b = Buffer(x)
assert read_int(b) == n

def test_negative_long_int_serialized_bytes() -> None:
b = Buffer()
n = -0x123456789abcde
write_int(b, n)
x = b.getvalue()
assert x == b"\x0f\x32\xde\xbc\x9a\x78\x56\x34\x12"
b = Buffer(x)
assert read_int(b) == n

def test_buffer_str_size() -> None:
for s in ("", "a", "a" * 117):
Expand All @@ -2835,6 +2858,8 @@ test_buffer_roundtrip()
test_buffer_int_size()
test_buffer_str_size()
test_buffer_int_powers()
test_positive_long_int_serialized_bytes()
test_negative_long_int_serialized_bytes()

def test_buffer_basic_interpreted() -> None:
b = Buffer(b"foo")
Expand Down
Loading