diff --git a/prism/extension.c b/prism/extension.c index cb7d3a5e6a3096..4af28bdb9672b4 100644 --- a/prism/extension.c +++ b/prism/extension.c @@ -949,6 +949,34 @@ named_captures(VALUE self, VALUE source) { return names; } +/** + * call-seq: + * Debug::number_parse(source) -> Integer + * + * Parses the given source string and returns the number it represents. + */ +static VALUE +number_parse(VALUE self, VALUE source) { + const uint8_t *start = (const uint8_t *) RSTRING_PTR(source); + size_t length = RSTRING_LEN(source); + + pm_number_t number = { 0 }; + pm_number_parse(&number, PM_NUMBER_BASE_UNKNOWN, start, start + length); + + VALUE result = UINT2NUM(number.head.value); + size_t shift = 0; + + for (pm_number_node_t *node = number.head.next; node != NULL; node = node->next) { + VALUE receiver = rb_funcall(UINT2NUM(node->value), rb_intern("<<"), 1, ULONG2NUM(++shift * 32)); + result = rb_funcall(receiver, rb_intern("|"), 1, result); + } + + if (number.negative) result = rb_funcall(result, rb_intern("-@"), 0); + pm_number_free(&number); + + return result; +} + /** * call-seq: * Debug::memsize(source) -> { length: xx, memsize: xx, node_count: xx } @@ -1148,6 +1176,7 @@ Init_prism(void) { // internal tasks. We expose these to make them easier to test. VALUE rb_cPrismDebug = rb_define_module_under(rb_cPrism, "Debug"); rb_define_singleton_method(rb_cPrismDebug, "named_captures", named_captures, 1); + rb_define_singleton_method(rb_cPrismDebug, "number_parse", number_parse, 1); rb_define_singleton_method(rb_cPrismDebug, "memsize", memsize, 1); rb_define_singleton_method(rb_cPrismDebug, "profile_file", profile_file, 1); rb_define_singleton_method(rb_cPrismDebug, "inspect_node", inspect_node, 1); diff --git a/prism/prism.h b/prism/prism.h index ffc722e90c48a6..88eb128769225d 100644 --- a/prism/prism.h +++ b/prism/prism.h @@ -10,6 +10,7 @@ #include "prism/util/pm_buffer.h" #include "prism/util/pm_char.h" #include "prism/util/pm_memchr.h" +#include "prism/util/pm_number.h" #include "prism/util/pm_strncasecmp.h" #include "prism/util/pm_strpbrk.h" #include "prism/ast.h" diff --git a/prism/util/pm_number.c b/prism/util/pm_number.c new file mode 100644 index 00000000000000..d2733330a64cc6 --- /dev/null +++ b/prism/util/pm_number.c @@ -0,0 +1,164 @@ +#include "prism/util/pm_number.h" + +/** + * Create a new node for a number in the linked list. + */ +static pm_number_node_t * +pm_number_node_create(pm_number_t *number, uint32_t value) { + number->length++; + pm_number_node_t *node = malloc(sizeof(pm_number_node_t)); + *node = (pm_number_node_t) { .next = NULL, .value = value }; + return node; +} + +/** + * Add a 32-bit integer to a number. + */ +static void +pm_number_add(pm_number_t *number, uint32_t addend) { + uint32_t carry = addend; + pm_number_node_t *current = &number->head; + + while (carry > 0) { + uint64_t result = (uint64_t) current->value + carry; + carry = (uint32_t) (result >> 32); + current->value = (uint32_t) result; + + if (carry > 0) { + if (current->next == NULL) { + current->next = pm_number_node_create(number, carry); + break; + } + + current = current->next; + } + } +} + +/** + * Multiple a number by a 32-bit integer. In practice, the multiplier is the + * base of the number, so this is 2, 8, 10, or 16. + */ +static void +pm_number_multiply(pm_number_t *number, uint32_t multiplier) { + uint32_t carry = 0; + + for (pm_number_node_t *current = &number->head; current != NULL; current = current->next) { + uint64_t result = (uint64_t) current->value * multiplier + carry; + carry = (uint32_t) (result >> 32); + current->value = (uint32_t) result; + + if (carry > 0 && current->next == NULL) { + current->next = pm_number_node_create(number, carry); + break; + } + } +} + +/** + * Return the value of a digit in a number. + */ +static uint32_t +pm_number_parse_digit(const uint8_t character) { + switch (character) { + case '0': return 0; + case '1': return 1; + case '2': return 2; + case '3': return 3; + case '4': return 4; + case '5': return 5; + case '6': return 6; + case '7': return 7; + case '8': return 8; + case '9': return 9; + case 'a': case 'A': return 10; + case 'b': case 'B': return 11; + case 'c': case 'C': return 12; + case 'd': case 'D': return 13; + case 'e': case 'E': return 14; + case 'f': case 'F': return 15; + default: assert(false && "unreachable"); + } +} + +/** + * Parse a number from a string. This assumes that the format of the number has + * already been validated, as internal validation checks are not performed here. + */ +PRISM_EXPORTED_FUNCTION void +pm_number_parse(pm_number_t *number, pm_number_base_t base, const uint8_t *start, const uint8_t *end) { + switch (*start) { + case '-': + number->negative = true; + /* fallthrough */ + case '+': + start++; + break; + default: + break; + } + + uint32_t multiplier; + switch (base) { + case PM_NUMBER_BASE_BINARY: + start += 2; // 0b + multiplier = 2; + break; + case PM_NUMBER_BASE_OCTAL: + start++; // 0 + if (*start == 'o' || *start == 'O') start++; // o + multiplier = 8; + break; + case PM_NUMBER_BASE_DECIMAL: + if (*start == '0' && (end - start) > 1) start += 2; // 0d + multiplier = 10; + break; + case PM_NUMBER_BASE_HEXADECIMAL: + start += 2; // 0x + multiplier = 16; + break; + case PM_NUMBER_BASE_UNKNOWN: + if (*start == '0' && (end - start) > 1) { + switch (start[1]) { + case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': start++; multiplier = 8; break; + case 'b': case 'B': start += 2; multiplier = 2; break; + case 'o': case 'O': start += 2; multiplier = 8; break; + case 'd': case 'D': start += 2; multiplier = 10; break; + case 'x': case 'X': start += 2; multiplier = 16; break; + default: assert(false && "unreachable"); + } + } else { + multiplier = 10; + } + break; + } + + for (pm_number_add(number, pm_number_parse_digit(*start++)); start < end; start++) { + if (*start == '_') continue; + pm_number_multiply(number, multiplier); + pm_number_add(number, pm_number_parse_digit(*start)); + } +} + +/** + * Recursively destroy the linked list of a number. + */ +static void +pm_number_node_destroy(pm_number_node_t *number) { + if (number->next != NULL) { + pm_number_node_destroy(number->next); + } + + free(number); +} + +/** + * Free the internal memory of a number. This memory will only be allocated if + * the number exceeds the size of a single node in the linked list. + */ +PRISM_EXPORTED_FUNCTION void +pm_number_free(pm_number_t *number) { + if (number->head.next) { + pm_number_node_destroy(number->head.next); + } +} diff --git a/prism/util/pm_number.h b/prism/util/pm_number.h new file mode 100644 index 00000000000000..c2507cf1df42ca --- /dev/null +++ b/prism/util/pm_number.h @@ -0,0 +1,95 @@ +/** + * @file pm_number.h + * + * This module provides functions for working with arbitrary-sized numbers. + */ +#ifndef PRISM_NUMBER_H +#define PRISM_NUMBER_H + +#include "prism/defines.h" + +#include +#include +#include +#include + +/** + * A node in the linked list of a pm_number_t. + */ +typedef struct pm_number_node { + /** A pointer to the next node in the list. */ + struct pm_number_node *next; + + /** The value of the node. */ + uint32_t value; +} pm_number_node_t; + +/** + * This structure represents an arbitrary-sized number. It is implemented as a + * linked list of 32-bit integers, with the least significant digit at the head + * of the list. + */ +typedef struct { + /** + * The head of the linked list, embedded directly so that allocations do not + * need to be performed for small numbers. + */ + pm_number_node_t head; + + /** The number of nodes in the linked list that have been allocated. */ + size_t length; + + /** + * Whether or not the number is negative. It is stored this way so that a + * zeroed pm_number_t is always positive zero. + */ + bool negative; +} pm_number_t; + +/** + * An enum controlling the base of a number. It is expected that the base is + * already known before parsing the number, even though it could be derived from + * the string itself. + */ +typedef enum { + /** The binary base, indicated by a 0b or 0B prefix. */ + PM_NUMBER_BASE_BINARY, + + /** The octal base, indicated by a 0, 0o, or 0O prefix. */ + PM_NUMBER_BASE_OCTAL, + + /** The decimal base, indicated by a 0d, 0D, or empty prefix. */ + PM_NUMBER_BASE_DECIMAL, + + /** The hexidecimal base, indicated by a 0x or 0X prefix. */ + PM_NUMBER_BASE_HEXADECIMAL, + + /** + * An unknown base, in which case pm_number_parse will derive it based on + * the content of the string. This is less efficient and does more + * comparisons, so if callers know the base ahead of time, they should use + * that instead. + */ + PM_NUMBER_BASE_UNKNOWN +} pm_number_base_t; + +/** + * Parse a number from a string. This assumes that the format of the number has + * already been validated, as internal validation checks are not performed here. + * + * @param number The number to parse into. + * @param base The base of the number. + * @param start The start of the string. + * @param end The end of the string. + */ +PRISM_EXPORTED_FUNCTION void pm_number_parse(pm_number_t *number, pm_number_base_t base, const uint8_t *start, const uint8_t *end); + +/** + * Free the internal memory of a number. This memory will only be allocated if + * the number exceeds the size of a single node in the linked list. + * + * @param number The number to free. + */ +PRISM_EXPORTED_FUNCTION void pm_number_free(pm_number_t *number); + +#endif diff --git a/test/prism/number_parse_test.rb b/test/prism/number_parse_test.rb new file mode 100644 index 00000000000000..401aead243fb20 --- /dev/null +++ b/test/prism/number_parse_test.rb @@ -0,0 +1,36 @@ +# frozen_string_literal: true + +require_relative "test_helper" + +return if Prism::BACKEND == :FFI + +module Prism + class NumberParseTest < TestCase + def test_number_parse + assert_number_parse(1) + assert_number_parse(50) + assert_number_parse(100) + assert_number_parse(100, "1_0_0") + + assert_number_parse(10, "0b1010") + assert_number_parse(10, "0B1010") + assert_number_parse(10, "0o12") + assert_number_parse(10, "0O12") + assert_number_parse(10, "012") + assert_number_parse(10, "0d10") + assert_number_parse(10, "0D10") + assert_number_parse(10, "0xA") + assert_number_parse(10, "0XA") + + assert_number_parse(2**32) + assert_number_parse(2**64 + 2**32) + assert_number_parse(2**128 + 2**64 + 2**32) + end + + private + + def assert_number_parse(expected, source = expected.to_s) + assert_equal expected, Debug.number_parse(source) + end + end +end