diff --git a/prism/extension.c b/prism/extension.c index 1a471000b49606..292e67891f386b 100644 --- a/prism/extension.c +++ b/prism/extension.c @@ -1047,30 +1047,16 @@ integer_parse(VALUE self, VALUE source) { pm_integer_t integer = { 0 }; pm_integer_parse(&integer, PM_INTEGER_BASE_UNKNOWN, start, start + length); - VALUE number; - - if (integer.values == NULL) { - number = UINT2NUM(integer.value); - } else { - number = UINT2NUM(0); - for (size_t i = 0; i < integer.length; i++) { - VALUE receiver = rb_funcall(UINT2NUM(integer.values[i]), rb_intern("<<"), 1, ULONG2NUM(i * 32)); - number = rb_funcall(receiver, rb_intern("|"), 1, number); - } - } - - if (integer.negative) number = rb_funcall(number, rb_intern("-@"), 0); - pm_buffer_t buffer = { 0 }; pm_integer_string(&buffer, &integer); VALUE string = rb_str_new(pm_buffer_value(&buffer), pm_buffer_length(&buffer)); pm_buffer_free(&buffer); - pm_integer_free(&integer); VALUE result = rb_ary_new_capa(2); - rb_ary_push(result, number); + rb_ary_push(result, pm_integer_new(&integer)); rb_ary_push(result, string); + pm_integer_free(&integer); return result; } diff --git a/prism/extension.h b/prism/extension.h index 6e5a3450122a93..13a9aabde3e5c3 100644 --- a/prism/extension.h +++ b/prism/extension.h @@ -10,6 +10,7 @@ VALUE pm_source_new(const pm_parser_t *parser, rb_encoding *encoding); VALUE pm_token_new(const pm_parser_t *parser, const pm_token_t *token, rb_encoding *encoding, VALUE source); VALUE pm_ast_new(const pm_parser_t *parser, const pm_node_t *node, rb_encoding *encoding, VALUE source); +VALUE pm_integer_new(const pm_integer_t *integer); void Init_prism_api_node(void); void Init_prism_pack(void); diff --git a/prism/templates/ext/prism/api_node.c.erb b/prism/templates/ext/prism/api_node.c.erb index a195fa6325d7b4..0e8aaae3226093 100644 --- a/prism/templates/ext/prism/api_node.c.erb +++ b/prism/templates/ext/prism/api_node.c.erb @@ -37,23 +37,26 @@ pm_string_new(const pm_string_t *string, rb_encoding *encoding) { return rb_enc_str_new((const char *) pm_string_source(string), pm_string_length(string), encoding); } -static VALUE +VALUE pm_integer_new(const pm_integer_t *integer) { VALUE result; - if (integer->values) { - VALUE str = rb_str_new(NULL, integer->length * 8); - unsigned char *buf = (unsigned char *)RSTRING_PTR(str); + if (integer->values == NULL) { + result = UINT2NUM(integer->value); + } else { + VALUE string = rb_str_new(NULL, integer->length * 8); + unsigned char *bytes = (unsigned char *) RSTRING_PTR(string); + size_t offset = integer->length * 8; - for (size_t i = 0; i < integer->length; i++) { - uint32_t value = integer->values[i]; - for (int i = 0; i < 8; i++) { - int n = (value >> (4 * i)) & 0xf; - buf[--offset] = n < 10 ? n + '0' : n - 10 + 'a'; + for (size_t value_index = 0; value_index < integer->length; value_index++) { + uint32_t value = integer->values[value_index]; + + for (int index = 0; index < 8; index++) { + int byte = (value >> (4 * index)) & 0xf; + bytes[--offset] = byte < 10 ? byte + '0' : byte - 10 + 'a'; } } - result = rb_funcall(str, rb_intern("to_i"), 1, UINT2NUM(16)); - } else { - result = UINT2NUM(integer->value); + + result = rb_funcall(string, rb_intern("to_i"), 1, UINT2NUM(16)); } if (integer->negative) { diff --git a/prism/util/pm_integer.c b/prism/util/pm_integer.c index 1f198af101b858..160a78920c0086 100644 --- a/prism/util/pm_integer.c +++ b/prism/util/pm_integer.c @@ -29,8 +29,9 @@ big_add(pm_integer_t *destination, pm_integer_t *left, pm_integer_t *right, uint size_t length = left_length < right_length ? right_length : left_length; uint32_t *values = (uint32_t *) xmalloc(sizeof(uint32_t) * (length + 1)); - uint64_t carry = 0; + if (values == NULL) return; + uint64_t carry = 0; for (size_t index = 0; index < length; index++) { uint64_t sum = carry + (index < left_length ? left_values[index] : 0) + (index < right_length ? right_values[index] : 0); values[index] = (uint32_t) (sum % base); @@ -54,36 +55,15 @@ static void big_sub2(pm_integer_t *destination, pm_integer_t *a, pm_integer_t *b, pm_integer_t *c, uint64_t base) { size_t a_length; uint32_t *a_values; - - if (a->values == NULL) { - a_length = 1; - a_values = &a->value; - } else { - a_length = a->length; - a_values = a->values; - } + INTEGER_EXTRACT(a, a_length, a_values) size_t b_length; uint32_t *b_values; - - if (b->values == NULL) { - b_length = 1; - b_values = &b->value; - } else { - b_length = b->length; - b_values = b->values; - } + INTEGER_EXTRACT(b, b_length, b_values) size_t c_length; uint32_t *c_values; - - if (c->values == NULL) { - c_length = 1; - c_values = &c->value; - } else { - c_length = c->length; - c_values = c->values; - } + INTEGER_EXTRACT(c, c_length, c_values) uint32_t *values = (uint32_t*) xmalloc(sizeof(uint32_t) * a_length); int64_t carry = 0; @@ -137,6 +117,7 @@ karatsuba_multiply(pm_integer_t *destination, pm_integer_t *left, pm_integer_t * if (left_length <= 10) { size_t length = left_length + right_length; uint32_t *values = (uint32_t *) xcalloc(length, sizeof(uint32_t)); + if (values == NULL) return; for (size_t left_index = 0; left_index < left_length; left_index++) { uint32_t carry = 0; @@ -293,6 +274,8 @@ pm_integer_from_uint64(pm_integer_t *integer, uint64_t value, uint64_t base) { } uint32_t *values = (uint32_t *) xmalloc(sizeof(uint32_t) * length); + if (values == NULL) return; + for (size_t value_index = 0; value_index < length; value_index++) { values[value_index] = (uint32_t) (value % base); value /= base; @@ -340,6 +323,7 @@ pm_integer_convert_base(pm_integer_t *destination, const pm_integer_t *source, u size_t bigints_length = (source_length + 1) / 2; pm_integer_t *bigints = (pm_integer_t *) xcalloc(bigints_length, sizeof(pm_integer_t)); + if (bigints == NULL) return; for (size_t index = 0; index < source_length; index += 2) { uint64_t value = source_values[index] + base_from * (index + 1 < source_length ? source_values[index + 1] : 0); @@ -516,7 +500,7 @@ pm_integer_parse(pm_integer_t *integer, pm_integer_base_t base, const uint8_t *s const uint8_t *cursor = start; uint64_t value = (uint64_t) pm_integer_parse_digit(*cursor++); - + for (; cursor < end; cursor++) { if (*cursor == '_') continue; value = value * multiplier + (uint64_t) pm_integer_parse_digit(*cursor); @@ -528,7 +512,7 @@ pm_integer_parse(pm_integer_t *integer, pm_integer_base_t base, const uint8_t *s return; } } - + integer->value = (uint32_t) value; }