Skip to content

Commit 232e77a

Browse files
committed
Reject class/module defs in method params/rescue/ensure/else
Fix #1936
1 parent 7fb4b67 commit 232e77a

File tree

6 files changed

+174
-71
lines changed

6 files changed

+174
-71
lines changed

include/prism/parser.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,9 @@ typedef enum {
297297
/** an ensure statement */
298298
PM_CONTEXT_ENSURE,
299299

300+
/** an ensure statement within a method definition */
301+
PM_CONTEXT_ENSURE_DEF,
302+
300303
/** a for loop */
301304
PM_CONTEXT_FOR,
302305

@@ -333,9 +336,15 @@ typedef enum {
333336
/** a rescue else statement */
334337
PM_CONTEXT_RESCUE_ELSE,
335338

339+
/** a rescue else statement within a method definition */
340+
PM_CONTEXT_RESCUE_ELSE_DEF,
341+
336342
/** a rescue statement */
337343
PM_CONTEXT_RESCUE,
338344

345+
/** a rescue statement within a method definition */
346+
PM_CONTEXT_RESCUE_DEF,
347+
339348
/** a singleton class definition */
340349
PM_CONTEXT_SCLASS,
341350

src/diagnostic.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ static const char* const diagnostic_messages[PM_DIAGNOSTIC_ID_LEN] = {
9090
[PM_ERR_CASE_MATCH_MISSING_PREDICATE] = "expected a predicate for a case matching statement",
9191
[PM_ERR_CASE_MISSING_CONDITIONS] = "expected a `when` or `in` clause after `case`",
9292
[PM_ERR_CASE_TERM] = "expected an `end` to close the `case` statement",
93-
[PM_ERR_CLASS_IN_METHOD] = "unexpected class definition in a method body",
93+
[PM_ERR_CLASS_IN_METHOD] = "unexpected class definition in a method definition",
9494
[PM_ERR_CLASS_NAME] = "expected a constant name after `class`",
9595
[PM_ERR_CLASS_SUPERCLASS] = "expected a superclass after `<`",
9696
[PM_ERR_CLASS_TERM] = "expected an `end` to close the `class` statement",
@@ -185,7 +185,7 @@ static const char* const diagnostic_messages[PM_DIAGNOSTIC_ID_LEN] = {
185185
[PM_ERR_LIST_W_UPPER_ELEMENT] = "expected a string in a `%W` list",
186186
[PM_ERR_LIST_W_UPPER_TERM] = "expected a closing delimiter for the `%W` list",
187187
[PM_ERR_MALLOC_FAILED] = "failed to allocate memory",
188-
[PM_ERR_MODULE_IN_METHOD] = "unexpected module definition in a method body",
188+
[PM_ERR_MODULE_IN_METHOD] = "unexpected module definition in a method definition",
189189
[PM_ERR_MODULE_NAME] = "expected a constant name after `module`",
190190
[PM_ERR_MODULE_TERM] = "expected an `end` to close the `module` statement",
191191
[PM_ERR_MULTI_ASSIGN_MULTI_SPLATS] = "multiple splats in multiple assignment",

src/prism.c

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6603,6 +6603,7 @@ context_terminator(pm_context_t context, pm_token_t *token) {
66036603
case PM_CONTEXT_ELSE:
66046604
case PM_CONTEXT_FOR:
66056605
case PM_CONTEXT_ENSURE:
6606+
case PM_CONTEXT_ENSURE_DEF:
66066607
return token->type == PM_TOKEN_KEYWORD_END;
66076608
case PM_CONTEXT_FOR_INDEX:
66086609
return token->type == PM_TOKEN_KEYWORD_IN;
@@ -6623,8 +6624,10 @@ context_terminator(pm_context_t context, pm_token_t *token) {
66236624
return token->type == PM_TOKEN_PARENTHESIS_RIGHT;
66246625
case PM_CONTEXT_BEGIN:
66256626
case PM_CONTEXT_RESCUE:
6627+
case PM_CONTEXT_RESCUE_DEF:
66266628
return token->type == PM_TOKEN_KEYWORD_ENSURE || token->type == PM_TOKEN_KEYWORD_RESCUE || token->type == PM_TOKEN_KEYWORD_ELSE || token->type == PM_TOKEN_KEYWORD_END;
66276629
case PM_CONTEXT_RESCUE_ELSE:
6630+
case PM_CONTEXT_RESCUE_ELSE_DEF:
66286631
return token->type == PM_TOKEN_KEYWORD_ENSURE || token->type == PM_TOKEN_KEYWORD_END;
66296632
case PM_CONTEXT_LAMBDA_BRACES:
66306633
return token->type == PM_TOKEN_BRACE_RIGHT;
@@ -6690,6 +6693,10 @@ context_def_p(pm_parser_t *parser) {
66906693
while (context_node != NULL) {
66916694
switch (context_node->context) {
66926695
case PM_CONTEXT_DEF:
6696+
case PM_CONTEXT_DEF_PARAMS:
6697+
case PM_CONTEXT_ENSURE_DEF:
6698+
case PM_CONTEXT_RESCUE_DEF:
6699+
case PM_CONTEXT_RESCUE_ELSE_DEF:
66936700
return true;
66946701
case PM_CONTEXT_CLASS:
66956702
case PM_CONTEXT_MODULE:
@@ -11837,7 +11844,7 @@ parse_parameters(
1183711844
* nodes pointing to each other from the top.
1183811845
*/
1183911846
static inline void
11840-
parse_rescues(pm_parser_t *parser, pm_begin_node_t *parent_node) {
11847+
parse_rescues(pm_parser_t *parser, pm_begin_node_t *parent_node, bool def_p) {
1184111848
pm_rescue_node_t *current = NULL;
1184211849

1184311850
while (accept1(parser, PM_TOKEN_KEYWORD_RESCUE)) {
@@ -11900,7 +11907,7 @@ parse_rescues(pm_parser_t *parser, pm_begin_node_t *parent_node) {
1190011907

1190111908
if (!match3(parser, PM_TOKEN_KEYWORD_ELSE, PM_TOKEN_KEYWORD_ENSURE, PM_TOKEN_KEYWORD_END)) {
1190211909
pm_accepts_block_stack_push(parser, true);
11903-
pm_statements_node_t *statements = parse_statements(parser, PM_CONTEXT_RESCUE);
11910+
pm_statements_node_t *statements = parse_statements(parser, def_p ? PM_CONTEXT_RESCUE_DEF : PM_CONTEXT_RESCUE);
1190411911
if (statements) {
1190511912
pm_rescue_node_statements_set(rescue, statements);
1190611913
}
@@ -11936,7 +11943,7 @@ parse_rescues(pm_parser_t *parser, pm_begin_node_t *parent_node) {
1193611943
pm_statements_node_t *else_statements = NULL;
1193711944
if (!match2(parser, PM_TOKEN_KEYWORD_END, PM_TOKEN_KEYWORD_ENSURE)) {
1193811945
pm_accepts_block_stack_push(parser, true);
11939-
else_statements = parse_statements(parser, PM_CONTEXT_RESCUE_ELSE);
11946+
else_statements = parse_statements(parser, def_p ? PM_CONTEXT_RESCUE_ELSE_DEF : PM_CONTEXT_RESCUE_ELSE);
1194011947
pm_accepts_block_stack_pop(parser);
1194111948
accept2(parser, PM_TOKEN_NEWLINE, PM_TOKEN_SEMICOLON);
1194211949
}
@@ -11952,7 +11959,7 @@ parse_rescues(pm_parser_t *parser, pm_begin_node_t *parent_node) {
1195211959
pm_statements_node_t *ensure_statements = NULL;
1195311960
if (!match1(parser, PM_TOKEN_KEYWORD_END)) {
1195411961
pm_accepts_block_stack_push(parser, true);
11955-
ensure_statements = parse_statements(parser, PM_CONTEXT_ENSURE);
11962+
ensure_statements = parse_statements(parser, def_p ? PM_CONTEXT_ENSURE_DEF : PM_CONTEXT_ENSURE);
1195611963
pm_accepts_block_stack_pop(parser);
1195711964
accept2(parser, PM_TOKEN_NEWLINE, PM_TOKEN_SEMICOLON);
1195811965
}
@@ -11970,10 +11977,10 @@ parse_rescues(pm_parser_t *parser, pm_begin_node_t *parent_node) {
1197011977
}
1197111978

1197211979
static inline pm_begin_node_t *
11973-
parse_rescues_as_begin(pm_parser_t *parser, pm_statements_node_t *statements) {
11980+
parse_rescues_as_begin(pm_parser_t *parser, pm_statements_node_t *statements, bool def_p) {
1197411981
pm_token_t no_begin_token = not_provided(parser);
1197511982
pm_begin_node_t *begin_node = pm_begin_node_create(parser, &no_begin_token, statements);
11976-
parse_rescues(parser, begin_node);
11983+
parse_rescues(parser, begin_node, def_p);
1197711984

1197811985
// All nodes within a begin node are optional, so we look
1197911986
// for the earliest possible node that we can use to set
@@ -12078,7 +12085,7 @@ parse_block(pm_parser_t *parser) {
1207812085

1207912086
if (match2(parser, PM_TOKEN_KEYWORD_RESCUE, PM_TOKEN_KEYWORD_ENSURE)) {
1208012087
assert(statements == NULL || PM_NODE_TYPE_P(statements, PM_STATEMENTS_NODE));
12081-
statements = (pm_node_t *) parse_rescues_as_begin(parser, (pm_statements_node_t *) statements);
12088+
statements = (pm_node_t *) parse_rescues_as_begin(parser, (pm_statements_node_t *) statements, false);
1208212089
}
1208312090
}
1208412091

@@ -14547,7 +14554,7 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power) {
1454714554
}
1454814555

1454914556
pm_begin_node_t *begin_node = pm_begin_node_create(parser, &begin_keyword, begin_statements);
14550-
parse_rescues(parser, begin_node);
14557+
parse_rescues(parser, begin_node, false);
1455114558

1455214559
expect1(parser, PM_TOKEN_KEYWORD_END, PM_ERR_BEGIN_TERM);
1455314560
begin_node->base.location.end = parser->previous.end;
@@ -14665,7 +14672,7 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power) {
1466514672

1466614673
if (match2(parser, PM_TOKEN_KEYWORD_RESCUE, PM_TOKEN_KEYWORD_ENSURE)) {
1466714674
assert(statements == NULL || PM_NODE_TYPE_P(statements, PM_STATEMENTS_NODE));
14668-
statements = (pm_node_t *) parse_rescues_as_begin(parser, (pm_statements_node_t *) statements);
14675+
statements = (pm_node_t *) parse_rescues_as_begin(parser, (pm_statements_node_t *) statements, false);
1466914676
}
1467014677

1467114678
expect1(parser, PM_TOKEN_KEYWORD_END, PM_ERR_CLASS_TERM);
@@ -14717,7 +14724,7 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power) {
1471714724

1471814725
if (match2(parser, PM_TOKEN_KEYWORD_RESCUE, PM_TOKEN_KEYWORD_ENSURE)) {
1471914726
assert(statements == NULL || PM_NODE_TYPE_P(statements, PM_STATEMENTS_NODE));
14720-
statements = (pm_node_t *) parse_rescues_as_begin(parser, (pm_statements_node_t *) statements);
14727+
statements = (pm_node_t *) parse_rescues_as_begin(parser, (pm_statements_node_t *) statements, false);
1472114728
}
1472214729

1472314730
expect1(parser, PM_TOKEN_KEYWORD_END, PM_ERR_CLASS_TERM);
@@ -14744,6 +14751,8 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power) {
1474414751
pm_token_t operator = not_provided(parser);
1474514752
pm_token_t name = (pm_token_t) { .type = PM_TOKEN_MISSING, .start = def_keyword.end, .end = def_keyword.end };
1474614753

14754+
// This context is necessary for lexing `...` in a bare params correctly.
14755+
// It must be pushed before lexing the first param, so it is here.
1474714756
context_push(parser, PM_CONTEXT_DEF_PARAMS);
1474814757
parser_lex(parser);
1474914758
pm_constant_id_t old_param_name = parser->current_param_name;
@@ -14844,7 +14853,12 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power) {
1484414853
break;
1484514854
}
1484614855
case PM_TOKEN_PARENTHESIS_LEFT: {
14856+
// The current context is `PM_CONTEXT_DEF_PARAMS`, however the inner expression
14857+
// of this parenthesis should not be processed under this context.
14858+
// Thus, the context is popped here.
14859+
context_pop(parser);
1484714860
parser_lex(parser);
14861+
1484814862
pm_token_t lparen = parser->previous;
1484914863
pm_node_t *expression = parse_value_expression(parser, PM_BINDING_POWER_STATEMENT, PM_ERR_DEF_RECEIVER);
1485014864

@@ -14859,6 +14873,9 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power) {
1485914873

1486014874
pm_parser_scope_push(parser, true);
1486114875
parser->current_param_name = 0;
14876+
14877+
// To push `PM_CONTEXT_DEF_PARAMS` again is for the same reason as described the above.
14878+
context_push(parser, PM_CONTEXT_DEF_PARAMS);
1486214879
name = parse_method_definition_name(parser);
1486314880
break;
1486414881
}
@@ -14967,7 +14984,7 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power) {
1496714984

1496814985
if (match2(parser, PM_TOKEN_KEYWORD_RESCUE, PM_TOKEN_KEYWORD_ENSURE)) {
1496914986
assert(statements == NULL || PM_NODE_TYPE_P(statements, PM_STATEMENTS_NODE));
14970-
statements = (pm_node_t *) parse_rescues_as_begin(parser, (pm_statements_node_t *) statements);
14987+
statements = (pm_node_t *) parse_rescues_as_begin(parser, (pm_statements_node_t *) statements, true);
1497114988
}
1497214989

1497314990
pm_accepts_block_stack_pop(parser);
@@ -15222,7 +15239,7 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power) {
1522215239

1522315240
if (match2(parser, PM_TOKEN_KEYWORD_RESCUE, PM_TOKEN_KEYWORD_ENSURE)) {
1522415241
assert(statements == NULL || PM_NODE_TYPE_P(statements, PM_STATEMENTS_NODE));
15225-
statements = (pm_node_t *) parse_rescues_as_begin(parser, (pm_statements_node_t *) statements);
15242+
statements = (pm_node_t *) parse_rescues_as_begin(parser, (pm_statements_node_t *) statements, false);
1522615243
}
1522715244

1522815245
pm_constant_id_list_t locals = parser->current_scope->locals;
@@ -15893,7 +15910,7 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power) {
1589315910

1589415911
if (match2(parser, PM_TOKEN_KEYWORD_RESCUE, PM_TOKEN_KEYWORD_ENSURE)) {
1589515912
assert(body == NULL || PM_NODE_TYPE_P(body, PM_STATEMENTS_NODE));
15896-
body = (pm_node_t *) parse_rescues_as_begin(parser, (pm_statements_node_t *) body);
15913+
body = (pm_node_t *) parse_rescues_as_begin(parser, (pm_statements_node_t *) body, false);
1589715914
}
1589815915

1589915916
expect1(parser, PM_TOKEN_KEYWORD_END, PM_ERR_LAMBDA_TERM_END);

test/prism/errors_test.rb

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def test_module_definition_in_method_body
428428
)
429429

430430
assert_errors expected, "def foo;module A;end;end", [
431-
["unexpected module definition in a method body", 8..14]
431+
["unexpected module definition in a method definition", 8..14]
432432
]
433433
end
434434

@@ -467,7 +467,7 @@ def test_module_definition_in_method_body_within_block
467467
Location()
468468
)
469469

470-
assert_errors expected, <<~RUBY, [["unexpected module definition in a method body", 21..27]]
470+
assert_errors expected, <<~RUBY, [["unexpected module definition in a method definition", 21..27]]
471471
def foo
472472
bar do
473473
module Foo;end
@@ -476,6 +476,20 @@ module Foo;end
476476
RUBY
477477
end
478478

479+
def test_module_definition_in_method_defs
480+
source = <<~RUBY
481+
def foo(bar = module A;end);end
482+
def foo;rescue;module A;end;end
483+
def foo;ensure;module A;end;end
484+
RUBY
485+
message = "unexpected module definition in a method definition"
486+
assert_errors expression(source), source, [
487+
[message, 14..20],
488+
[message, 47..53],
489+
[message, 79..85],
490+
]
491+
end
492+
479493
def test_class_definition_in_method_body
480494
expected = DefNode(
481495
:foo,
@@ -504,7 +518,21 @@ def test_class_definition_in_method_body
504518
)
505519

506520
assert_errors expected, "def foo;class A;end;end", [
507-
["unexpected class definition in a method body", 8..13]
521+
["unexpected class definition in a method definition", 8..13]
522+
]
523+
end
524+
525+
def test_class_definition_in_method_defs
526+
source = <<~RUBY
527+
def foo(bar = class A;end);end
528+
def foo;rescue;class A;end;end
529+
def foo;ensure;class A;end;end
530+
RUBY
531+
message = "unexpected class definition in a method definition"
532+
assert_errors expression(source), source, [
533+
[message, 14..19],
534+
[message, 46..51],
535+
[message, 77..82],
508536
]
509537
end
510538

test/prism/fixtures/methods.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,5 @@ def foo(...)
182182
end
183183

184184
def foo(bar = (def baz(bar) = bar; 1)) = 2
185+
186+
def (class Foo; end).foo(bar = 1) = 2

0 commit comments

Comments
 (0)