From 5113d6b0591fe2c80cace33654e9088d1330277c Mon Sep 17 00:00:00 2001 From: tompng Date: Mon, 26 Feb 2024 22:05:30 +0900 Subject: [PATCH] [ruby/prism] Faster pm_integer_parse pm_integer_string using karatsuba algorithm https://github.com/ruby/prism/commit/ae4fb6b988 --- prism/util/pm_integer.c | 406 ++++++++++++++++++++++++++++------------ 1 file changed, 288 insertions(+), 118 deletions(-) diff --git a/prism/util/pm_integer.c b/prism/util/pm_integer.c index c03b930ad3f3d7..5bcb508c1ce0d7 100644 --- a/prism/util/pm_integer.c +++ b/prism/util/pm_integer.c @@ -1,117 +1,139 @@ #include "prism/util/pm_integer.h" /** - * Create a new node for an integer in the linked list. + * Bigint with arbitary base. In practice, base is 1<<32 or 10**9. + * When base is 10**9, it acts as bigdecimal. */ -static pm_integer_word_t * -pm_integer_node_create(pm_integer_t *integer, uint32_t value) { - integer->length++; - - pm_integer_word_t *node = xmalloc(sizeof(pm_integer_word_t)); - if (node == NULL) return NULL; - - *node = (pm_integer_word_t) { .next = NULL, .value = value }; - return node; -} +typedef struct { + size_t length; + uint32_t *values; +} bigint_t; /** - * Copy one integer onto another. + * Adds two bigint_t with the given base. */ -static void -pm_integer_copy(pm_integer_t *dest, const pm_integer_t *src) { - dest->negative = src->negative; - dest->length = 0; - - dest->head.value = src->head.value; - dest->head.next = NULL; - - pm_integer_word_t *dest_current = &dest->head; - const pm_integer_word_t *src_current = src->head.next; - - while (src_current != NULL) { - dest_current->next = pm_integer_node_create(dest, src_current->value); - if (dest_current->next == NULL) return; - - dest_current = dest_current->next; - src_current = src_current->next; +static bigint_t +big_add(bigint_t left, bigint_t right, uint64_t base) { + size_t length = (left.length < right.length ? right.length : left.length); + uint32_t *values = (uint32_t*) malloc(sizeof(uint32_t) * (length + 1)); + uint64_t carry = 0; + for (size_t i = 0; i < length; i++) { + uint64_t sum = carry + (i < left.length ? left.values[i] : 0) + (i < right.length ? right.values[i] : 0); + values[i] = (uint32_t) (sum % base); + carry = sum / base; } - - dest_current->next = NULL; + if (carry > 0) { + values[length] = (uint32_t) carry; + length++; + } + return (bigint_t) { length, values }; } /** - * Add a 32-bit integer to an integer. + * Calculates `a - b - c` with the given base. + * Result is assumed to be positive value. Internal use for karatsuba_multiply. */ -static void -pm_integer_add(pm_integer_t *integer, uint32_t addend) { - uint32_t carry = addend; - pm_integer_word_t *current = &integer->head; - - while (carry > 0) { - uint64_t result = (uint64_t) current->value + carry; - carry = (uint32_t) (result >> 32); - current->value = (uint32_t) result; - - if (carry > 0) { - if (current->next == NULL) { - current->next = pm_integer_node_create(integer, carry); - break; - } - - current = current->next; +static bigint_t +big_sub2(bigint_t a, bigint_t b, bigint_t c, uint64_t base) { + size_t length = a.length; + uint32_t *values = (uint32_t*) malloc(sizeof(uint32_t) * length); + int64_t carry = 0; + for (size_t i = 0; i < length; i++) { + int64_t sub = carry + a.values[i] - (i < b.length ? b.values[i] : 0) - (i < c.length ? c.values[i] : 0); + if (sub >= 0) { + values[i] = (uint32_t) sub; + carry = 0; + } else { + sub += 2 * (int64_t) base; + values[i] = (uint32_t) ((uint64_t) sub % base); + carry = sub / (int64_t) base - 2; } } + while (length > 1 && values[length - 1] == 0) length--; + return (bigint_t) { length, values }; } /** - * Multiple an integer by a 32-bit integer. In practice, the multiplier is the - * base of the integer, so this is 2, 8, 10, or 16. + * Multiply two bigint_t with the given base using karatsuba algorithm. */ -static void -pm_integer_multiply(pm_integer_t *integer, uint32_t multiplier) { - uint32_t carry = 0; - - for (pm_integer_word_t *current = &integer->head; current != NULL; current = current->next) { - uint64_t result = (uint64_t) current->value * multiplier + carry; - carry = (uint32_t) (result >> 32); - current->value = (uint32_t) result; - - if (carry > 0 && current->next == NULL) { - current->next = pm_integer_node_create(integer, carry); - break; +static bigint_t +karatsuba_multiply(bigint_t left, bigint_t right, uint64_t base) { + if (left.length > right.length) { + bigint_t temp = left; + left = right; + right = temp; + } + if (left.length <= 10) { + size_t length = left.length + right.length; + uint32_t *values = (uint32_t*) calloc(length, sizeof(uint32_t)); + for (size_t i = 0; i < left.length; i++) { + uint32_t carry = 0; + for (size_t j = 0; j < right.length; j++) { + uint64_t product = (uint64_t) left.values[i] * right.values[j] + values[i + j] + carry; + values[i + j] = (uint32_t) (product % base); + carry = (uint32_t) (product / base); + } + values[i + right.length] = carry; } + while (length > 1 && values[length - 1] == 0) length--; + return (bigint_t) { length, values }; } -} - -/** - * Divide an individual word by a 32-bit integer. This will recursively divide - * any subsequent nodes in the linked list. - */ -static uint32_t -pm_integer_divide_word(pm_integer_t *integer, pm_integer_word_t *word, uint32_t dividend) { - uint32_t remainder = 0; - if (word->next != NULL) { - remainder = pm_integer_divide_word(integer, word->next, dividend); - - if (integer->length > 0 && word->next->value == 0) { - xfree(word->next); - word->next = NULL; - integer->length--; + if (left.length * 2 <= right.length) { + uint32_t *values = (uint32_t*) calloc(left.length + right.length, sizeof(uint32_t)); + for (size_t start_offset = 0; start_offset < right.length; start_offset += left.length) { + size_t end_offset = start_offset + left.length; + if (end_offset > right.length) end_offset = right.length; + bigint_t sliced_right = { end_offset - start_offset, right.values + start_offset }; + bigint_t v = karatsuba_multiply(left, sliced_right, base); + uint32_t carry = 0; + for (size_t i = 0; i < v.length; i++) { + uint64_t sum = (uint64_t) values[start_offset + i] + v.values[i] + carry; + values[start_offset + i] = (uint32_t) (sum % base); + carry = (uint32_t) (sum / base); + } + free(v.values); + values[start_offset + v.length] += carry; } + return (bigint_t) { left.length + right.length, values }; } - - uint64_t value = ((uint64_t) remainder << 32) | word->value; - word->value = (uint32_t) (value / dividend); - return (uint32_t) (value % dividend); -} - -/** - * Divide an integer by a 32-bit integer. In practice, this is only 10 so that - * we can format it as a string. It returns the remainder of the division. - */ -static uint32_t -pm_integer_divide(pm_integer_t *integer, uint32_t dividend) { - return pm_integer_divide_word(integer, &integer->head, dividend); + size_t half = left.length / 2; + bigint_t x0 = { half, left.values }; + bigint_t x1 = { left.length - half, left.values + half }; + bigint_t y0 = { half, right.values }; + bigint_t y1 = { right.length - half, right.values + half }; + bigint_t z0 = karatsuba_multiply(x0, y0, base); + bigint_t z2 = karatsuba_multiply(x1, y1, base); + + // For simplicity to avoid considering negative values, + // use `z1 = (x0 + x1) * (y0 + y1) - z0 - z2` instead of original karatsuba algorithm. + bigint_t x01 = big_add(x0, x1, base); + bigint_t y01 = big_add(y0, y1, base); + bigint_t xy = karatsuba_multiply(x01, y01, base); + bigint_t z1 = big_sub2(xy, z0, z2, base); + + size_t length = left.length + right.length; + uint32_t *values = (uint32_t*) calloc(length, sizeof(uint32_t)); + memcpy(values, z0.values, sizeof(uint32_t) * z0.length); + memcpy(values + 2 * half, z2.values, sizeof(uint32_t) * z2.length); + uint32_t carry = 0; + for(size_t i = 0; i < z1.length; i++) { + uint64_t sum = (uint64_t) carry + values[i + half] + z1.values[i]; + values[i + half] = (uint32_t) (sum % base); + carry = (uint32_t) (sum / base); + } + for(size_t i = half + z1.length; carry > 0; i++) { + uint64_t sum = (uint64_t) carry + values[i]; + values[i] = (uint32_t) (sum % base); + carry = (uint32_t) (sum / base); + } + while (length > 1 && values[length - 1] == 0) length--; + free(z0.values); + free(z1.values); + free(z2.values); + free(x01.values); + free(y01.values); + free(xy.values); + return (bigint_t) { length, values }; } /** @@ -140,6 +162,140 @@ pm_integer_parse_digit(const uint8_t character) { } } +/** + * Create a bigint_t from uint64_t with the given base. + */ +static bigint_t +uint64_to_bigint(uint64_t value, uint64_t base) { + uint64_t v = value; + size_t len = 0; + while (value > 0) { len++; value /= base; } + if (len == 0) len = 1; + uint32_t *values = (uint32_t*) malloc(sizeof(uint32_t) * len); + for (size_t i = 0; i < len; i++) { + values[i] = (uint32_t) (v % base); + v /= base; + } + return (bigint_t) { len, values }; +} + +/** + * Convert base of bigint. + * In practice, it converts 10**9 to 1<<32 or 1<<32 to 10**9. + */ +static bigint_t +karatsuba_convert_base(bigint_t source, uint64_t base_from, uint64_t base_to) { + size_t bigints_length = (source.length + 1) / 2; + bigint_t *bigints = (bigint_t*) malloc(sizeof(bigint_t) * bigints_length); + for (size_t i = 0; i < source.length; i += 2) { + uint64_t v = source.values[i] + base_from * (i + 1 < source.length ? source.values[i + 1] : 0); + bigints[i / 2] = uint64_to_bigint(v, base_to); + } + bigint_t base = uint64_to_bigint(base_from, base_to); + while (bigints_length > 1) { + size_t new_length = (bigints_length + 1) / 2; + bigint_t new_base = karatsuba_multiply(base, base, base_to); + free(base.values); + base = new_base; + bigint_t *new_bigints = (bigint_t*) malloc(sizeof(bigint_t) * new_length); + for (size_t i = 0; i < bigints_length; i += 2) { + if (i + 1 == bigints_length) { + new_bigints[i / 2] = bigints[i]; + } else { + bigint_t multiplied = karatsuba_multiply(base, bigints[i + 1], base_to); + new_bigints[i / 2] = big_add(bigints[i], multiplied, base_to); + free(bigints[i].values); + free(bigints[i + 1].values); + free(multiplied.values); + } + } + free(bigints); + bigints = new_bigints; + bigints_length = new_length; + } + free(base.values); + bigint_t result = bigints[0]; + free(bigints); + return result; +} + +/** + * Convert digits to bigint_t with the given power-of-two base. + */ +static bigint_t +big_parse_powof2(uint32_t base, const uint8_t *digits, size_t digits_length) { + size_t bit = 1; + while (base > (uint32_t) (1 << bit)) bit++; + size_t length = (digits_length * bit + 31) / 32; + uint32_t *values = (uint32_t*) calloc(length, sizeof(uint32_t)); + for (size_t i = 0; i < digits_length; i++) { + size_t bit_position = bit * (digits_length - i - 1); + uint32_t value = digits[i]; + size_t index = bit_position / 32; + size_t shift = bit_position % 32; + values[index] |= value << shift; + if (32 - shift < bit) values[index + 1] |= value >> (32 - shift); + } + while (length > 1 && values[length - 1] == 0) length--; + return (bigint_t) { length, values }; +} + +/** + * Convert decimal digits to bigint. + */ +static bigint_t +big_parse_decimal(const uint8_t *digits, size_t digits_length) { + // Construct a bigdecimal from the digits. + const size_t batch = 9; + const uint64_t batch_base = 1000000000; + size_t values_length = (digits_length + batch - 1) / batch; + bigint_t bigint = { values_length, (uint32_t*) calloc(values_length, sizeof(uint32_t)) }; + uint32_t v = 0; + for (size_t i = 0; i < digits_length; i++) { + v = v * 10 + digits[i]; + size_t reverse_index = digits_length - i - 1; + if (reverse_index % batch == 0) { + bigint.values[reverse_index / batch] = v; + v = 0; + } + } + // Convert bigint base from 10**9 to 1<<32. + bigint_t converted = karatsuba_convert_base(bigint, batch_base, ((uint64_t) 1 << 32)); + free(bigint.values); + return converted; +} + +/** + * Parse a large integer from a string that does not fit into uint32_t. + */ +static void +pm_integer_parse_big(pm_integer_t *integer, uint32_t multiplier, const uint8_t *start, const uint8_t *end) { + // Allocate an array to store digits. + uint8_t *digits = malloc(sizeof(uint8_t) * (size_t) (end - start)); + size_t digits_length = 0; + for (; start < end; start++) { + if (*start == '_') continue; + digits[digits_length++] = (uint8_t) pm_integer_parse_digit(*start); + } + // Construct bigint_t from the digits. + bigint_t bigint = + multiplier == 10 ? big_parse_decimal(digits, digits_length) : big_parse_powof2(multiplier, digits, digits_length); + + // Pack bigint_t to pm_integer_t. + integer->length = bigint.length - 1; + integer->head.value = bigint.values[0]; + pm_integer_word_t *current = &integer->head; + for (size_t i = 1; i < bigint.length; i++) { + current->next = malloc(sizeof(pm_integer_word_t)); + current = current->next; + current->value = bigint.values[i]; + current->next = NULL; + } + + free(bigint.values); + free(digits); +} + /** * Parse an integer from a string. This assumes that the format of the integer * has already been validated, as internal validation checks are not performed @@ -189,15 +345,19 @@ pm_integer_parse(pm_integer_t *integer, pm_integer_base_t base, const uint8_t *s // invalid integer. If this is the case, we'll just return 0. if (start >= end) return; - // Add the first digit to the integer. - pm_integer_add(integer, pm_integer_parse_digit(*start++)); - - // Add the subsequent digits to the integer. - for (; start < end; start++) { - if (*start == '_') continue; - pm_integer_multiply(integer, multiplier); - pm_integer_add(integer, pm_integer_parse_digit(*start)); + const uint8_t *ptr = start; + uint64_t value = pm_integer_parse_digit(*ptr++); + for (; ptr < end; ptr++) { + if (*ptr == '_') continue; + value = value * multiplier + pm_integer_parse_digit(*ptr); + if (value > UINT32_MAX) { + // If the integer is too large to fit into a single node, then we'll + // parse it as a big integer. + pm_integer_parse_big(integer, multiplier, start, end); + return; + } } + integer->head.value = (uint32_t) value; } /** @@ -254,29 +414,39 @@ pm_integer_string(pm_buffer_t *buffer, const pm_integer_t *integer) { return; } default: { - // First, allocate a buffer that we'll copy the decimal digits into. - size_t length = (integer->length + 1) * 10; - char *digits = xcalloc(length, sizeof(char)); + // Pack pm_integer_t to bigint_t. + size_t length = integer->length + 1; + uint32_t *values = calloc(length, sizeof(uint32_t)); + const pm_integer_word_t *current = &(integer->head); + for (size_t i = 0; i < length; i++) { + values[i] = current->value; + current = current->next; + } + bigint_t bigint = { length, values }; + // Convert bigint base from 1<<32 to 10**9. + bigint_t converted = karatsuba_convert_base(bigint, (uint64_t) 1 << 32, 1000000000); + free(values); + + // Allocate a buffer that we'll copy the decimal digits into. + size_t char_length = converted.length * 9; + char *digits = calloc(char_length, sizeof(char)); if (digits == NULL) return; - // Next, create a new integer that we'll use to store the result of - // the division and modulo operations. - pm_integer_t copy; - pm_integer_copy(©, integer); - - // Then, iterate through the integer, dividing by 10 and storing the - // result in the buffer. - char *ending = digits + length - 1; - char *current = ending; - - while (copy.length > 0 || copy.head.value > 0) { - uint32_t remainder = pm_integer_divide(©, 10); - *current-- = (char) ('0' + remainder); + // Pack bigdecimal to digits. + for (size_t i = 0; i < converted.length; i++) { + uint32_t v = converted.values[i]; + for (size_t j = 0; j < 9; j++) { + digits[char_length - 9 * i - j - 1] = (char) ('0' + v % 10); + v /= 10; + } } + size_t start_offset = 0; + while (start_offset < char_length - 1 && digits[start_offset] == '0') start_offset++; // Finally, append the string to the buffer and free the digits. - pm_buffer_append_string(buffer, current + 1, (size_t) (ending - current)); - xfree(digits); + pm_buffer_append_string(buffer, digits + start_offset, char_length - start_offset); + free(digits); + free(converted.values); return; } }