Skip to content

Commit

Permalink
[ruby/prism] Replace old circular parameter definition detection
Browse files Browse the repository at this point in the history
  • Loading branch information
kddnewton authored and matzbot committed Apr 5, 2024
1 parent bf3a911 commit dcec1e0
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 70 deletions.
3 changes: 0 additions & 3 deletions prism/parser.h
Expand Up @@ -790,9 +790,6 @@ struct pm_parser {
*/
const pm_encoding_t *explicit_encoding;

/** The current parameter name id on parsing its default value. */
pm_constant_id_t current_param_name;

/**
* When parsing block exits (e.g., break, next, redo), we need to validate
* that they are in correct contexts. For the most part we can do this by
Expand Down
92 changes: 29 additions & 63 deletions prism/prism.c
Expand Up @@ -872,6 +872,17 @@ pm_locals_unread(pm_locals_t *locals, pm_constant_id_t name) {
local->reads--;
}

/**
* Returns the current number of reads for a local variable.
*/
static uint32_t
pm_locals_reads(pm_locals_t *locals, pm_constant_id_t name) {
uint32_t index = pm_locals_find(locals, name);
assert(index != UINT32_MAX);

return locals->locals[index].reads;
}

/**
* Write out the locals into the given list of constant ids in the correct
* order. This is used to set the list of locals on the nodes in the tree once
Expand Down Expand Up @@ -5097,10 +5108,6 @@ pm_local_variable_or_write_node_create(pm_parser_t *parser, pm_node_t *target, c
*/
static pm_local_variable_read_node_t *
pm_local_variable_read_node_create_constant_id(pm_parser_t *parser, const pm_token_t *name, pm_constant_id_t name_id, uint32_t depth, bool missing) {
if (parser->current_param_name == name_id) {
PM_PARSER_ERR_TOKEN_FORMAT_CONTENT(parser, *name, PM_ERR_PARAMETER_CIRCULAR);
}

if (!missing) {
pm_scope_t *scope = parser->current_scope;
for (uint32_t index = 0; index < depth; index++) scope = scope->previous;
Expand Down Expand Up @@ -7275,33 +7282,6 @@ pm_parser_scope_shareable_constant_set(pm_parser_t *parser, pm_shareable_constan
} while (!scope->closed && (scope = scope->previous) != NULL);
}

/**
* Save the current param name as the return value and set it to the given
* constant id.
*/
static inline pm_constant_id_t
pm_parser_current_param_name_set(pm_parser_t *parser, pm_constant_id_t current_param_name) {
pm_constant_id_t saved_param_name = parser->current_param_name;
parser->current_param_name = current_param_name;
return saved_param_name;
}

/**
* Save the current param name as the return value and clear it.
*/
static inline pm_constant_id_t
pm_parser_current_param_name_unset(pm_parser_t *parser) {
return pm_parser_current_param_name_set(parser, PM_CONSTANT_ID_UNSET);
}

/**
* Restore the current param name from the given value.
*/
static inline void
pm_parser_current_param_name_restore(pm_parser_t *parser, pm_constant_id_t saved_param_name) {
parser->current_param_name = saved_param_name;
}

/**
* Check if any of the currently visible scopes contain a local variable
* described by the given constant id.
Expand Down Expand Up @@ -13548,16 +13528,24 @@ parse_parameters(
pm_token_t operator = parser->previous;
context_push(parser, PM_CONTEXT_DEFAULT_PARAMS);

pm_constant_id_t saved_param_name = pm_parser_current_param_name_set(parser, pm_parser_constant_id_token(parser, &name));
pm_node_t *value = parse_value_expression(parser, binding_power, false, PM_ERR_PARAMETER_NO_DEFAULT);
pm_constant_id_t name_id = pm_parser_constant_id_token(parser, &name);
uint32_t reads = pm_locals_reads(&parser->current_scope->locals, name_id);

pm_node_t *value = parse_value_expression(parser, binding_power, false, PM_ERR_PARAMETER_NO_DEFAULT);
pm_optional_parameter_node_t *param = pm_optional_parameter_node_create(parser, &name, &operator, value);

if (repeated) {
pm_node_flag_set_repeated_parameter((pm_node_t *)param);
}
pm_parameters_node_optionals_append(params, param);

pm_parser_current_param_name_restore(parser, saved_param_name);
// If the value of the parameter increased the number of
// reads of that parameter, then we need to warn that we
// have a circular definition.
if (pm_locals_reads(&parser->current_scope->locals, name_id) != reads) {
PM_PARSER_ERR_TOKEN_FORMAT_CONTENT(parser, name, PM_ERR_PARAMETER_CIRCULAR);
}

context_pop(parser);

// If parsing the value of the parameter resulted in error recovery,
Expand Down Expand Up @@ -13626,12 +13614,15 @@ parse_parameters(
if (token_begins_expression_p(parser->current.type)) {
context_push(parser, PM_CONTEXT_DEFAULT_PARAMS);

pm_constant_id_t saved_param_name = pm_parser_current_param_name_set(parser, pm_parser_constant_id_token(parser, &local));
pm_constant_id_t name_id = pm_parser_constant_id_token(parser, &local);
uint32_t reads = pm_locals_reads(&parser->current_scope->locals, name_id);
pm_node_t *value = parse_value_expression(parser, binding_power, false, PM_ERR_PARAMETER_NO_DEFAULT_KW);

pm_parser_current_param_name_restore(parser, saved_param_name);
context_pop(parser);
if (pm_locals_reads(&parser->current_scope->locals, name_id) != reads) {
PM_PARSER_ERR_TOKEN_FORMAT_CONTENT(parser, local, PM_ERR_PARAMETER_CIRCULAR);
}

context_pop(parser);
param = (pm_node_t *) pm_optional_keyword_parameter_node_create(parser, &name, value);
}
else {
Expand Down Expand Up @@ -14067,7 +14058,6 @@ parse_block(pm_parser_t *parser) {
pm_token_t opening = parser->previous;
accept1(parser, PM_TOKEN_NEWLINE);

pm_constant_id_t saved_param_name = pm_parser_current_param_name_unset(parser);
pm_accepts_block_stack_push(parser, true);
pm_parser_scope_push(parser, false);

Expand Down Expand Up @@ -14124,7 +14114,6 @@ parse_block(pm_parser_t *parser) {

pm_parser_scope_pop(parser);
pm_accepts_block_stack_pop(parser);
pm_parser_current_param_name_restore(parser, saved_param_name);

return pm_block_node_create(parser, &locals, &opening, parameters, statements, &parser->previous);
}
Expand Down Expand Up @@ -17426,7 +17415,6 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b
pm_token_t operator = parser->previous;
pm_node_t *expression = parse_value_expression(parser, PM_BINDING_POWER_NOT, true, PM_ERR_EXPECT_EXPRESSION_AFTER_LESS_LESS);

pm_constant_id_t saved_param_name = pm_parser_current_param_name_unset(parser);
pm_parser_scope_push(parser, true);
accept2(parser, PM_TOKEN_NEWLINE, PM_TOKEN_SEMICOLON);

Expand All @@ -17449,7 +17437,6 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b

pm_parser_scope_pop(parser);
pm_do_loop_stack_pop(parser);
pm_parser_current_param_name_restore(parser, saved_param_name);

return (pm_node_t *) pm_singleton_class_node_create(parser, &locals, &class_keyword, &operator, expression, statements, &parser->previous);
}
Expand Down Expand Up @@ -17479,7 +17466,6 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b
superclass = NULL;
}

pm_constant_id_t saved_param_name = pm_parser_current_param_name_unset(parser);
pm_parser_scope_push(parser, true);

if (inheritance_operator.type != PM_TOKEN_NOT_PROVIDED) {
Expand Down Expand Up @@ -17511,7 +17497,6 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b

pm_parser_scope_pop(parser);
pm_do_loop_stack_pop(parser);
pm_parser_current_param_name_restore(parser, saved_param_name);

if (!PM_NODE_TYPE_P(constant_path, PM_CONSTANT_PATH_NODE) && !(PM_NODE_TYPE_P(constant_path, PM_CONSTANT_READ_NODE))) {
pm_parser_err_node(parser, constant_path, PM_ERR_CLASS_NAME);
Expand All @@ -17533,13 +17518,10 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b
// correctly. It must be pushed before lexing the first param, so it
// is here.
context_push(parser, PM_CONTEXT_DEF_PARAMS);
pm_constant_id_t saved_param_name;

parser_lex(parser);

switch (parser->current.type) {
case PM_CASE_OPERATOR:
saved_param_name = pm_parser_current_param_name_unset(parser);
pm_parser_scope_push(parser, true);
lex_state_set(parser, PM_LEX_STATE_ENDFN);
parser_lex(parser);
Expand All @@ -17553,15 +17535,13 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b
receiver = parse_variable_call(parser);
receiver = pm_node_check_it(parser, receiver);

saved_param_name = pm_parser_current_param_name_unset(parser);
pm_parser_scope_push(parser, true);
lex_state_set(parser, PM_LEX_STATE_FNAME);
parser_lex(parser);

operator = parser->previous;
name = parse_method_definition_name(parser);
} else {
saved_param_name = pm_parser_current_param_name_unset(parser);
pm_refute_numbered_parameter(parser, parser->previous.start, parser->previous.end);
pm_parser_scope_push(parser, true);

Expand All @@ -17581,7 +17561,6 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b
case PM_TOKEN_KEYWORD___FILE__:
case PM_TOKEN_KEYWORD___LINE__:
case PM_TOKEN_KEYWORD___ENCODING__: {
saved_param_name = pm_parser_current_param_name_unset(parser);
pm_parser_scope_push(parser, true);
parser_lex(parser);

Expand Down Expand Up @@ -17656,18 +17635,14 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b
operator = parser->previous;
receiver = (pm_node_t *) pm_parentheses_node_create(parser, &lparen, expression, &rparen);

saved_param_name = pm_parser_current_param_name_unset(parser);
pm_parser_scope_push(parser, true);

// To push `PM_CONTEXT_DEF_PARAMS` again is for the same reason as described the above.
pm_parser_scope_push(parser, true);
context_push(parser, PM_CONTEXT_DEF_PARAMS);
name = parse_method_definition_name(parser);
break;
}
default:
saved_param_name = pm_parser_current_param_name_unset(parser);
pm_parser_scope_push(parser, true);

name = parse_method_definition_name(parser);
break;
}
Expand Down Expand Up @@ -17784,9 +17759,7 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b

pm_constant_id_list_t locals;
pm_locals_order(parser, &parser->current_scope->locals, &locals, true);

pm_parser_scope_pop(parser);
pm_parser_current_param_name_restore(parser, saved_param_name);

/**
* If the final character is @. As is the case when defining
Expand Down Expand Up @@ -18031,9 +18004,7 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b
pm_parser_err_token(parser, &name, PM_ERR_MODULE_NAME);
}

pm_constant_id_t saved_param_name = pm_parser_current_param_name_unset(parser);
pm_parser_scope_push(parser, true);

accept2(parser, PM_TOKEN_SEMICOLON, PM_TOKEN_NEWLINE);
pm_node_t *statements = NULL;

Expand All @@ -18052,8 +18023,6 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b
pm_locals_order(parser, &parser->current_scope->locals, &locals, true);

pm_parser_scope_pop(parser);
pm_parser_current_param_name_restore(parser, saved_param_name);

expect1(parser, PM_TOKEN_KEYWORD_END, PM_ERR_MODULE_TERM);

if (context_def_p(parser)) {
Expand Down Expand Up @@ -18747,7 +18716,6 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b
parser_lex(parser);

pm_token_t operator = parser->previous;
pm_constant_id_t saved_param_name = pm_parser_current_param_name_unset(parser);
pm_parser_scope_push(parser, false);

pm_block_parameters_node_t *block_parameters;
Expand Down Expand Up @@ -18823,7 +18791,6 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b

pm_parser_scope_pop(parser);
pm_accepts_block_stack_pop(parser);
pm_parser_current_param_name_restore(parser, saved_param_name);

return (pm_node_t *) pm_lambda_node_create(parser, &locals, &operator, &opening, &parser->previous, parameters, body);
}
Expand Down Expand Up @@ -20199,7 +20166,6 @@ pm_parser_init(pm_parser_t *parser, const uint8_t *source, size_t size, const pm
.encoding_changed = false,
.pattern_matching_newlines = false,
.in_keyword_arg = false,
.current_param_name = 0,
.current_block_exits = NULL,
.semantic_token_seen = false,
.frozen_string_literal = PM_OPTIONS_FROZEN_STRING_LITERAL_UNSET,
Expand Down
10 changes: 6 additions & 4 deletions test/prism/errors_test.rb
Expand Up @@ -1949,11 +1949,13 @@ def foo(bar: bar) = 42
RUBY

assert_errors expression(source), source, [
["circular argument reference - bar", 14..17],
["circular argument reference - bar", 37..40],
["circular argument reference - foo", 61..64],
["circular argument reference - foo", 81..84],
["circular argument reference - bar", 8..11],
["circular argument reference - bar", 32..35],
["circular argument reference - foo", 55..58],
["circular argument reference - foo", 76..79]
]

refute_error_messages("def foo(bar: bar = 1); end")
end

def test_command_calls
Expand Down

0 comments on commit dcec1e0

Please sign in to comment.