Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Call*Node#*name use the constant pool #1533

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
14 changes: 7 additions & 7 deletions config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -641,9 +641,9 @@ nodes:
type: flags
kind: CallNodeFlags
- name: read_name
type: string
type: constant
- name: write_name
type: string
type: constant
- name: operator_loc
type: location
- name: value
Expand Down Expand Up @@ -674,7 +674,7 @@ nodes:
type: flags
kind: CallNodeFlags
- name: name
type: string
type: constant
comment: |
Represents a method call, in all of the various forms that can take.

Expand Down Expand Up @@ -714,9 +714,9 @@ nodes:
type: flags
kind: CallNodeFlags
- name: read_name
type: string
type: constant
- name: write_name
type: string
type: constant
- name: operator
type: constant
- name: operator_loc
Expand Down Expand Up @@ -747,9 +747,9 @@ nodes:
type: flags
kind: CallNodeFlags
- name: read_name
type: string
type: constant
- name: write_name
type: string
type: constant
- name: operator_loc
type: location
- name: value
Expand Down
5 changes: 5 additions & 0 deletions include/prism/util/pm_constant_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ typedef struct {
// Initialize a new constant pool with a given capacity.
bool pm_constant_pool_init(pm_constant_pool_t *pool, uint32_t capacity);

static inline pm_constant_t* pm_constant_pool_id_to_constant(pm_constant_pool_t *pool, pm_constant_id_t constant_id) {
assert(constant_id > 0 && constant_id <= pool->size);
return &pool->constants[constant_id - 1];
}

// Insert a constant into a constant pool that is a slice of a source string.
// Returns the id of the constant, or 0 if any potential calls to resize fail.
pm_constant_id_t pm_constant_pool_insert_shared(pm_constant_pool_t *pool, const uint8_t *start, size_t length);
Expand Down
3 changes: 2 additions & 1 deletion include/prism/util/pm_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

// This struct represents a string value.
typedef struct {
enum { PM_STRING_SHARED, PM_STRING_OWNED, PM_STRING_CONSTANT, PM_STRING_MAPPED } type;
const uint8_t *source;
size_t length;
// This field is not the first one, because otherwise things like .pm_string_t_field = 123/pm_constant_id_t does not warn or error
enum { PM_STRING_SHARED, PM_STRING_OWNED, PM_STRING_CONSTANT, PM_STRING_MAPPED } type;
} pm_string_t;

#define PM_EMPTY_STRING ((pm_string_t) { .type = PM_STRING_CONSTANT, .source = NULL, .length = 0 })
Expand Down
86 changes: 51 additions & 35 deletions src/prism.c
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,22 @@ pm_parser_constant_id_owned(pm_parser_t *parser, const uint8_t *start, size_t le
return pm_constant_pool_insert_owned(&parser->constant_pool, start, length);
}

// Retrieve the constant pool id for the given static literal C string.
static inline pm_constant_id_t
pm_parser_constant_id_static(pm_parser_t *parser, const char *start, size_t length) {
uint8_t *owned_copy;
if (length > 0) {
owned_copy = malloc(length);
memcpy(owned_copy, start, length);
} else {
owned_copy = malloc(1);
owned_copy[0] = '\0';
}
return pm_constant_pool_insert_owned(&parser->constant_pool, owned_copy, length);
// Does not work because the static literal cannot be serialized as an offset of source
// return pm_constant_pool_insert_shared(&parser->constant_pool, start, length);
}
Comment on lines +442 to +454
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think once we always serialize all constants as embedded then we can simplify this and actually use the C literal string pointer without copying, because then we only need start+offset, and don't care if that's within the source or not when serializing.


// Retrieve the constant pool id for the given token.
static inline pm_constant_id_t
pm_parser_constant_id_token(pm_parser_t *parser, const pm_token_t *token) {
Expand Down Expand Up @@ -1343,7 +1359,8 @@ pm_call_node_create(pm_parser_t *parser) {
.opening_loc = PM_OPTIONAL_LOCATION_NOT_PROVIDED_VALUE,
.arguments = NULL,
.closing_loc = PM_OPTIONAL_LOCATION_NOT_PROVIDED_VALUE,
.block = NULL
.block = NULL,
.name = 0
};

return node;
Expand Down Expand Up @@ -1371,7 +1388,7 @@ pm_call_node_aref_create(pm_parser_t *parser, pm_node_t *receiver, pm_arguments_
node->closing_loc = arguments->closing_loc;
node->block = arguments->block;

pm_string_constant_init(&node->name, "[]", 2);
node->name = pm_parser_constant_id_static(parser, "[]", 2);
return node;
}

Expand All @@ -1390,7 +1407,7 @@ pm_call_node_binary_create(pm_parser_t *parser, pm_node_t *receiver, pm_token_t
pm_arguments_node_arguments_append(arguments, argument);
node->arguments = arguments;

pm_string_shared_init(&node->name, operator->start, operator->end);
node->name = pm_parser_constant_id_token(parser, operator);
return node;
}

Expand Down Expand Up @@ -1422,7 +1439,7 @@ pm_call_node_call_create(pm_parser_t *parser, pm_node_t *receiver, pm_token_t *o
node->base.flags |= PM_CALL_NODE_FLAGS_SAFE_NAVIGATION;
}

pm_string_shared_init(&node->name, message->start, message->end);
node->name = pm_parser_constant_id_token(parser, message);
return node;
}

Expand All @@ -1449,7 +1466,7 @@ pm_call_node_fcall_create(pm_parser_t *parser, pm_token_t *message, pm_arguments
node->closing_loc = arguments->closing_loc;
node->block = arguments->block;

pm_string_shared_init(&node->name, message->start, message->end);
node->name = pm_parser_constant_id_token(parser, message);
return node;
}

Expand All @@ -1471,7 +1488,7 @@ pm_call_node_not_create(pm_parser_t *parser, pm_node_t *receiver, pm_token_t *me
node->arguments = arguments->arguments;
node->closing_loc = arguments->closing_loc;

pm_string_constant_init(&node->name, "!", 1);
node->name = pm_parser_constant_id_static(parser, "!", 1);
return node;
}

Expand All @@ -1498,7 +1515,7 @@ pm_call_node_shorthand_create(pm_parser_t *parser, pm_node_t *receiver, pm_token
node->base.flags |= PM_CALL_NODE_FLAGS_SAFE_NAVIGATION;
}

pm_string_constant_init(&node->name, "call", 4);
node->name = pm_parser_constant_id_static(parser, "call", 4);
return node;
}

Expand All @@ -1513,7 +1530,7 @@ pm_call_node_unary_create(pm_parser_t *parser, pm_token_t *operator, pm_node_t *
node->receiver = receiver;
node->message_loc = PM_OPTIONAL_LOCATION_TOKEN_VALUE(operator);

pm_string_constant_init(&node->name, name, strlen(name));
node->name = pm_parser_constant_id_static(parser, name, strlen(name));
return node;
}

Expand All @@ -1526,7 +1543,7 @@ pm_call_node_variable_call_create(pm_parser_t *parser, pm_token_t *message) {
node->base.location = PM_LOCATION_TOKEN_VALUE(message);
node->message_loc = PM_OPTIONAL_LOCATION_TOKEN_VALUE(message);

pm_string_shared_init(&node->name, message->start, message->end);
node->name = pm_parser_constant_id_token(parser, message);
return node;
}

Expand All @@ -1539,17 +1556,18 @@ pm_call_node_variable_call_p(pm_call_node_t *node) {

// Initialize the read name by reading the write name and chopping off the '='.
static void
pm_call_write_read_name_init(pm_string_t *read_name, pm_string_t *write_name) {
if (write_name->length >= 1) {
size_t length = write_name->length - 1;
pm_call_write_read_name_init(pm_parser_t *parser, pm_constant_id_t *read_name, pm_constant_id_t *write_name) {
pm_constant_t *write_constant = pm_constant_pool_id_to_constant(&parser->constant_pool, *write_name);
if (write_constant->length >= 1) {
size_t length = write_constant->length - 1;

void *memory = malloc(length);
memcpy(memory, write_name->source, length);
memcpy(memory, write_constant->start, length);

pm_string_owned_init(read_name, (uint8_t *) memory, length);
*read_name = pm_constant_pool_insert_owned(&parser->constant_pool, (uint8_t *) memory, length);
} else {
// We can get here if the message was missing because of a syntax error.
pm_string_constant_init(read_name, "", 0);
*read_name = pm_parser_constant_id_static(parser, "", 0);
}
}

Expand All @@ -1575,13 +1593,13 @@ pm_call_and_write_node_create(pm_parser_t *parser, pm_call_node_t *target, const
.opening_loc = target->opening_loc,
.arguments = target->arguments,
.closing_loc = target->closing_loc,
.read_name = PM_EMPTY_STRING,
.read_name = 0,
.write_name = target->name,
.operator_loc = PM_LOCATION_TOKEN_VALUE(operator),
.value = value
};

pm_call_write_read_name_init(&node->read_name, &node->write_name);
pm_call_write_read_name_init(parser, &node->read_name, &node->write_name);

// Here we're going to free the target, since it is no longer necessary.
// However, we don't want to call `pm_node_destroy` because we want to keep
Expand Down Expand Up @@ -1612,14 +1630,14 @@ pm_call_operator_write_node_create(pm_parser_t *parser, pm_call_node_t *target,
.opening_loc = target->opening_loc,
.arguments = target->arguments,
.closing_loc = target->closing_loc,
.read_name = PM_EMPTY_STRING,
.read_name = 0,
.write_name = target->name,
.operator = pm_parser_constant_id_location(parser, operator->start, operator->end - 1),
.operator_loc = PM_LOCATION_TOKEN_VALUE(operator),
.value = value
};

pm_call_write_read_name_init(&node->read_name, &node->write_name);
pm_call_write_read_name_init(parser, &node->read_name, &node->write_name);

// Here we're going to free the target, since it is no longer necessary.
// However, we don't want to call `pm_node_destroy` because we want to keep
Expand Down Expand Up @@ -1651,13 +1669,13 @@ pm_call_or_write_node_create(pm_parser_t *parser, pm_call_node_t *target, const
.opening_loc = target->opening_loc,
.arguments = target->arguments,
.closing_loc = target->closing_loc,
.read_name = PM_EMPTY_STRING,
.read_name = 0,
.write_name = target->name,
.operator_loc = PM_LOCATION_TOKEN_VALUE(operator),
.value = value
};

pm_call_write_read_name_init(&node->read_name, &node->write_name);
pm_call_write_read_name_init(parser, &node->read_name, &node->write_name);

// Here we're going to free the target, since it is no longer necessary.
// However, we don't want to call `pm_node_destroy` because we want to keep
Expand Down Expand Up @@ -8388,23 +8406,23 @@ parse_starred_expression(pm_parser_t *parser, pm_binding_power_t binding_power,
}

// Convert the name of a method into the corresponding write method name. For
// exmaple, foo would be turned into foo=.
// example, foo would be turned into foo=.
static void
parse_write_name(pm_string_t *string) {
parse_write_name(pm_parser_t *parser, pm_constant_id_t *name_field) {
// The method name needs to change. If we previously had
// foo, we now need foo=. In this case we'll allocate a new
// owned string, copy the previous method name in, and
// append an =.
size_t length = pm_string_length(string);
pm_constant_t *constant = pm_constant_pool_id_to_constant(&parser->constant_pool, *name_field);
size_t length = constant->length;
uint8_t *name = calloc(length + 1, sizeof(uint8_t));
if (name == NULL) return;

memcpy(name, pm_string_source(string), length);
memcpy(name, constant->start, length);
name[length] = '=';

// Now switch the name to the new string.
pm_string_free(string);
pm_string_owned_init(string, name, length + 1);
*name_field = pm_constant_pool_insert_owned(&parser->constant_pool, name, length + 1);
}

// Convert the given node into a valid target node.
Expand Down Expand Up @@ -8502,7 +8520,7 @@ parse_target(pm_parser_t *parser, pm_node_t *target) {
}

if (*call->message_loc.start == '_' || parser->encoding.alnum_char(call->message_loc.start, call->message_loc.end - call->message_loc.start)) {
parse_write_name(&call->name);
parse_write_name(parser, &call->name);
return (pm_node_t *) call;
}
}
Expand All @@ -8517,9 +8535,8 @@ parse_target(pm_parser_t *parser, pm_node_t *target) {
(call->message_loc.end[-1] == ']') &&
(call->block == NULL)
) {
// Free the previous name and replace it with "[]=".
pm_string_free(&call->name);
pm_string_constant_init(&call->name, "[]=", 3);
// Replace the name with "[]=".
call->name = pm_parser_constant_id_static(parser, "[]=", 3);
return target;
}
}
Expand Down Expand Up @@ -8664,7 +8681,7 @@ parse_write(pm_parser_t *parser, pm_node_t *target, pm_token_t *operator, pm_nod
pm_arguments_node_arguments_append(arguments, value);
call->base.location.end = arguments->base.location.end;

parse_write_name(&call->name);
parse_write_name(parser, &call->name);
return (pm_node_t *) call;
}
}
Expand All @@ -8685,9 +8702,8 @@ parse_write(pm_parser_t *parser, pm_node_t *target, pm_token_t *operator, pm_nod
pm_arguments_node_arguments_append(call->arguments, value);
target->location.end = value->location.end;

// Free the previous name and replace it with "[]=".
pm_string_free(&call->name);
pm_string_constant_init(&call->name, "[]=", 3);
// Replace the name with "[]=".
call->name = pm_parser_constant_id_static(parser, "[]=", 3);
return target;
}

Expand Down
5 changes: 3 additions & 2 deletions templates/lib/prism/serialize.rb.erb
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,14 @@ module Prism
end

def load_string
case io.getbyte
type = io.getbyte
case type
when 1
input.byteslice(load_varint, load_varint).force_encoding(encoding)
when 2
load_embedded_string
else
raise
raise "Unknown serialized string type: #{type}"
end
end

Expand Down
20 changes: 10 additions & 10 deletions test/prism/encoding_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,28 @@ class EncodingTest < TestCase
CP1252
].each do |encoding|
define_method "test_encoding_#{encoding}" do
result = Prism.parse("# encoding: #{encoding}\nident")
actual = result.value.statements.body.first.name.encoding
result = Prism.parse("# encoding: #{encoding}\n'string'")
actual = result.value.statements.body.first.unescaped.encoding
assert_equal Encoding.find(encoding), actual
end
end

def test_coding
result = Prism.parse("# coding: utf-8\nident")
actual = result.value.statements.body.first.name.encoding
result = Prism.parse("# coding: utf-8\n'string'")
actual = result.value.statements.body.first.unescaped.encoding
assert_equal Encoding.find("utf-8"), actual
end

def test_coding_with_whitespace
result = Prism.parse("# coding \t \r \v : \t \v \r ascii-8bit \nident")
actual = result.value.statements.body.first.name.encoding
result = Prism.parse("# coding \t \r \v : \t \v \r ascii-8bit \n'string'")
actual = result.value.statements.body.first.unescaped.encoding
assert_equal Encoding.find("ascii-8bit"), actual
end


def test_emacs_style
result = Prism.parse("# -*- coding: utf-8 -*-\nident")
actual = result.value.statements.body.first.name.encoding
result = Prism.parse("# -*- coding: utf-8 -*-\n'string'")
actual = result.value.statements.body.first.unescaped.encoding
assert_equal Encoding.find("utf-8"), actual
end

Expand All @@ -86,8 +86,8 @@ def test_utf_8_variations
utf-8-mac
utf-8-*
].each do |encoding|
result = Prism.parse("# coding: #{encoding}\nident")
actual = result.value.statements.body.first.name.encoding
result = Prism.parse("# coding: #{encoding}\n'string'")
actual = result.value.statements.body.first.unescaped.encoding
assert_equal Encoding.find("utf-8"), actual
end
end
Expand Down