Skip to content

Commit

Permalink
[PRISM] Fix EnsureNode, pass depth to get locals
Browse files Browse the repository at this point in the history
This commit fixes a bug with locals in ensure nodes by setting
the local tables correctly. It also changes accessing locals to
look at local tables in parent scopes, and account for this
correctly on depths of get or setlocals.
  • Loading branch information
jemmaissroff committed Nov 29, 2023
1 parent 6ebcf25 commit 5384194
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 39 deletions.
84 changes: 45 additions & 39 deletions prism_compile.c
Expand Up @@ -685,9 +685,11 @@ pm_interpolated_node_compile(pm_node_list_t *parts, rb_iseq_t *iseq, NODE dummy_
}
}

// This recurses through scopes and finds the local index at any scope leve
// This recurses through scopes and finds the local index at any scope level
// It also takes a pointer to depth, and increments depth appropriately
// according to the depth of the local
static int
pm_lookup_local_index_any_scope(rb_iseq_t *iseq, pm_scope_node_t *scope_node, pm_constant_id_t constant_id)
pm_lookup_local_index_any_scope(rb_iseq_t *iseq, pm_scope_node_t *scope_node, pm_constant_id_t constant_id, int *depth)
{
if (!scope_node) {
// We have recursed up all scope nodes
Expand All @@ -699,7 +701,8 @@ pm_lookup_local_index_any_scope(rb_iseq_t *iseq, pm_scope_node_t *scope_node, pm

if (!st_lookup(scope_node->index_lookup_table, constant_id, &local_index)) {
// Local does not exist at this level, continue recursing up
return pm_lookup_local_index_any_scope(iseq, scope_node->previous, constant_id);
(*depth)++;
return pm_lookup_local_index_any_scope(iseq, scope_node->previous, constant_id, depth);
}

return (int)scope_node->index_lookup_table->num_entries - (int)local_index;
Expand All @@ -720,14 +723,14 @@ pm_lookup_local_index(rb_iseq_t *iseq, pm_scope_node_t *scope_node, pm_constant_
}

static int
pm_lookup_local_index_with_depth(rb_iseq_t *iseq, pm_scope_node_t *scope_node, pm_constant_id_t constant_id, uint32_t depth)
pm_lookup_local_index_with_depth(rb_iseq_t *iseq, pm_scope_node_t *scope_node, pm_constant_id_t constant_id, int *depth)
{
for(uint32_t i = 0; i < depth; i++) {
for(int i = 0; i < *depth; i++) {
scope_node = scope_node->previous;
iseq = (rb_iseq_t *)ISEQ_BODY(iseq)->parent_iseq;
}

return pm_lookup_local_index(iseq, scope_node, constant_id);
return pm_lookup_local_index_any_scope(iseq, scope_node, constant_id, depth);
}

// This returns the CRuby ID which maps to the pm_constant_id_t
Expand Down Expand Up @@ -1782,31 +1785,30 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
ADD_LABEL(ret, lend);

if (begin_node->ensure_clause) {
DECL_ANCHOR(ensr);

iseq_set_exception_local_table(iseq);
pm_scope_node_t next_scope_node;
pm_scope_node_init((pm_node_t *)begin_node->ensure_clause, &next_scope_node, scope_node, parser);

child_iseq = NEW_CHILD_ISEQ(next_scope_node,
rb_str_new2("ensure in"),
ISEQ_TYPE_ENSURE, lineno);
pm_statements_node_t *statements = begin_node->ensure_clause->statements;
if (statements) {
PM_COMPILE((pm_node_t *)statements);
PM_POP_UNLESS_POPPED;
}

LABEL *lcont = NEW_LABEL(lineno);
struct ensure_range er;
struct iseq_compile_data_ensure_node_stack enl;
struct ensure_range *erange;

INIT_ANCHOR(ensr);
PM_COMPILE_INTO_ANCHOR(ensr, (pm_node_t *)&next_scope_node);

er.begin = lstart;
er.end = lend;
er.next = 0;
push_ensure_entry(iseq, &enl, &er, (void *)&next_scope_node);
push_ensure_entry(iseq, &enl, &er, (void *)&begin_node->ensure_clause);

pm_scope_node_t next_scope_node;
pm_scope_node_init((pm_node_t *)begin_node->ensure_clause, &next_scope_node, scope_node, parser);

child_iseq = NEW_CHILD_ISEQ(next_scope_node,
rb_str_new2("ensure in"),
ISEQ_TYPE_ENSURE, lineno);
ISEQ_COMPILE_DATA(iseq)->current_block = child_iseq;
ADD_SEQ(ret, ensr);

ADD_LABEL(ret, lcont);
erange = ISEQ_COMPILE_DATA(iseq)->ensure_node_stack->erange;
if (lstart->link.next != &lend->link) {
Expand All @@ -1815,7 +1817,6 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
erange = erange->next;
}
}
PM_POP_UNLESS_POPPED;
}
return;
}
Expand Down Expand Up @@ -3017,9 +3018,9 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
LABEL *end_label = NEW_LABEL(lineno);

pm_constant_id_t constant_id = local_variable_and_write_node->name;
int depth = local_variable_and_write_node->depth;
int *depth = (int *)&local_variable_and_write_node->depth;
int local_index = pm_lookup_local_index_with_depth(iseq, scope_node, constant_id, depth);
ADD_GETLOCAL(ret, &dummy_line_node, local_index, depth);
ADD_GETLOCAL(ret, &dummy_line_node, local_index, *depth);

PM_DUP_UNLESS_POPPED;

Expand All @@ -3031,7 +3032,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,

PM_DUP_UNLESS_POPPED;

ADD_SETLOCAL(ret, &dummy_line_node, local_index, depth);
ADD_SETLOCAL(ret, &dummy_line_node, local_index, *depth);
ADD_LABEL(ret, end_label);

return;
Expand All @@ -3041,9 +3042,9 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,

pm_constant_id_t constant_id = local_variable_operator_write_node->name;

int depth = local_variable_operator_write_node->depth;
int *depth = (int *)&local_variable_operator_write_node->depth;
int local_index = pm_lookup_local_index_with_depth(iseq, scope_node, constant_id, depth);
ADD_GETLOCAL(ret, &dummy_line_node, local_index, depth);
ADD_GETLOCAL(ret, &dummy_line_node, local_index, *depth);

PM_COMPILE_NOT_POPPED(local_variable_operator_write_node->value);
ID method_id = pm_constant_id_lookup(scope_node, local_variable_operator_write_node->operator);
Expand All @@ -3053,7 +3054,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,

PM_DUP_UNLESS_POPPED;

ADD_SETLOCAL(ret, &dummy_line_node, local_index, depth);
ADD_SETLOCAL(ret, &dummy_line_node, local_index, *depth);

return;
}
Expand All @@ -3067,9 +3068,9 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
ADD_INSNL(ret, &dummy_line_node, branchunless, set_label);

pm_constant_id_t constant_id = local_variable_or_write_node->name;
int depth = local_variable_or_write_node->depth;
int *depth = (int *)&local_variable_or_write_node->depth;
int local_index = pm_lookup_local_index_with_depth(iseq, scope_node, constant_id, depth);
ADD_GETLOCAL(ret, &dummy_line_node, local_index, depth);
ADD_GETLOCAL(ret, &dummy_line_node, local_index, *depth);

PM_DUP_UNLESS_POPPED;

Expand All @@ -3082,7 +3083,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,

PM_DUP_UNLESS_POPPED;

ADD_SETLOCAL(ret, &dummy_line_node, local_index, depth);
ADD_SETLOCAL(ret, &dummy_line_node, local_index, *depth);
ADD_LABEL(ret, end_label);

return;
Expand All @@ -3091,18 +3092,20 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
pm_local_variable_read_node_t *local_read_node = (pm_local_variable_read_node_t *) node;

if (!popped) {
int index = pm_lookup_local_index_with_depth(iseq, scope_node, local_read_node->name, local_read_node->depth);
ADD_GETLOCAL(ret, &dummy_line_node, index, local_read_node->depth);
int *depth = (int *)&local_read_node->depth;
int index = pm_lookup_local_index_with_depth(iseq, scope_node, local_read_node->name, depth);
ADD_GETLOCAL(ret, &dummy_line_node, index, *depth);
}
return;
}
case PM_LOCAL_VARIABLE_TARGET_NODE: {
pm_local_variable_target_node_t *local_write_node = (pm_local_variable_target_node_t *) node;

pm_constant_id_t constant_id = local_write_node->name;
int index = pm_lookup_local_index_any_scope(iseq, scope_node, constant_id);
int depth = 0;
int index = pm_lookup_local_index_any_scope(iseq, scope_node, constant_id, &depth);

ADD_SETLOCAL(ret, &dummy_line_node, index, local_write_node->depth);
ADD_SETLOCAL(ret, &dummy_line_node, index, depth);
return;
}
case PM_LOCAL_VARIABLE_WRITE_NODE: {
Expand All @@ -3112,9 +3115,11 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
PM_DUP_UNLESS_POPPED;

pm_constant_id_t constant_id = local_write_node->name;
int index = pm_lookup_local_index_any_scope(iseq, scope_node, constant_id);

ADD_SETLOCAL(ret, &dummy_line_node, index, local_write_node->depth);
int depth = 0;
int index = pm_lookup_local_index_any_scope(iseq, scope_node, constant_id, &depth);

ADD_SETLOCAL(ret, &dummy_line_node, index, depth);
return;
}
case PM_MATCH_LAST_LINE_NODE: {
Expand Down Expand Up @@ -3828,12 +3833,13 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
}
case ISEQ_TYPE_ENSURE: {
iseq_set_exception_local_table(iseq);
PM_COMPILE((pm_node_t *)scope_node->body);
PM_POP;

if (scope_node->body) {
PM_COMPILE_POPPED((pm_node_t *)scope_node->body);
}

ADD_GETLOCAL(ret, &dummy_line_node, 1, 0);
ADD_INSN1(ret, &dummy_line_node, throw, INT2FIX(0));

return;
}
default:
Expand Down
32 changes: 32 additions & 0 deletions test/ruby/test_compile_prism.rb
Expand Up @@ -700,6 +700,38 @@ def test_BreakNode
def test_EnsureNode
assert_prism_eval("begin; 1; ensure; 2; end")
assert_prism_eval("begin; 1; begin; 3; ensure; 4; end; ensure; 2; end")
assert_prism_eval(<<-CODE)
begin
a = 2
ensure
end
CODE
assert_prism_eval(<<-CODE)
begin
a = 2
ensure
a = 3
end
a
CODE
assert_prism_eval(<<-CODE)
a = 1
begin
a = 2
ensure
a = 3
end
a
CODE
assert_prism_eval(<<-CODE)
a = 1
begin
b = 2
ensure
c = 3
end
a + b + c
CODE
end

def test_NextNode
Expand Down

0 comments on commit 5384194

Please sign in to comment.