Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PRISM] opt_case_dispatch #9887

Merged
merged 2 commits into from
Feb 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
271 changes: 222 additions & 49 deletions prism_compile.c
Original file line number Diff line number Diff line change
Expand Up @@ -4025,6 +4025,54 @@ pm_compile_constant_path(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *co
}
}

/**
* When we're compiling a case node, it's possible that we can speed it up using
* a dispatch hash, which will allow us to jump directly to the correct when
* clause body based on a hash lookup of the value. This can only happen when
* the conditions are literals that can be compiled into a hash key.
*
* This function accepts a dispatch hash and the condition of a when clause. It
* is responsible for compiling the condition into a hash key and then adding it
* to the dispatch hash.
*
* If the value can be successfully compiled into the hash, then this function
* returns the dispatch hash with the new key added. If the value cannot be
* compiled into the hash, then this function returns Qundef. In the case of
* Qundef, this function is signaling that the caller should abandon the
* optimization entirely.
*/
static VALUE
pm_compile_case_node_dispatch(VALUE dispatch, const pm_node_t *node, LABEL *label, const pm_scope_node_t *scope_node)
{
VALUE key = Qundef;

switch (PM_NODE_TYPE(node)) {
case PM_FALSE_NODE:
case PM_FLOAT_NODE:
case PM_INTEGER_NODE:
case PM_NIL_NODE:
case PM_SOURCE_FILE_NODE:
case PM_SOURCE_LINE_NODE:
case PM_SYMBOL_NODE:
case PM_TRUE_NODE:
key = pm_static_literal_value(node, scope_node, scope_node->parser);
break;
case PM_STRING_NODE: {
const pm_string_node_t *cast = (const pm_string_node_t *) node;
key = rb_fstring(parse_string_encoded(node, &cast->unescaped, scope_node->parser));
break;
}
default:
return Qundef;
}

if (NIL_P(rb_hash_lookup(dispatch, key))) {
rb_hash_aset(dispatch, key, ((VALUE) label) | 1);
}

return dispatch;
}

/*
* Compiles a prism node into instruction sequences
*
Expand Down Expand Up @@ -4463,77 +4511,202 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
return;
}
case PM_CASE_NODE: {
pm_case_node_t *case_node = (pm_case_node_t *)node;
bool has_predicate = case_node->predicate;
if (has_predicate) {
PM_COMPILE_NOT_POPPED(case_node->predicate);
}
LABEL *end_label = NEW_LABEL(lineno);
// case foo; when bar; end
// ^^^^^^^^^^^^^^^^^^^^^^^
const pm_case_node_t *cast = (const pm_case_node_t *) node;
const pm_node_list_t *conditions = &cast->conditions;

// This is the anchor that we will compile the conditions of the various
// `when` nodes into. If a match is found, they will need to jump into
// the body_seq anchor to the correct spot.
DECL_ANCHOR(cond_seq);
INIT_ANCHOR(cond_seq);

pm_node_list_t conditions = case_node->conditions;
// This is the anchor that we will compile the bodies of the various
// `when` nodes into. We'll make sure that the clauses that are compiled
// jump into the correct spots within this anchor.
DECL_ANCHOR(body_seq);
INIT_ANCHOR(body_seq);

LABEL **conditions_labels = (LABEL **)ALLOCA_N(VALUE, conditions.size + 1);
LABEL *label;
// This is the label where all of the when clauses will jump to if they
// have matched and are done executing their bodies.
LABEL *end_label = NEW_LABEL(lineno);

for (size_t i = 0; i < conditions.size; i++) {
label = NEW_LABEL(lineno);
conditions_labels[i] = label;
pm_when_node_t *when_node = (pm_when_node_t *)conditions.nodes[i];

for (size_t i = 0; i < when_node->conditions.size; i++) {
pm_node_t *condition_node = when_node->conditions.nodes[i];

if (PM_NODE_TYPE_P(condition_node, PM_SPLAT_NODE)) {
int checkmatch_type = has_predicate ? VM_CHECKMATCH_TYPE_CASE : VM_CHECKMATCH_TYPE_WHEN;
ADD_INSN (ret, &dummy_line_node, dup);
PM_COMPILE_NOT_POPPED(condition_node);
ADD_INSN1(ret, &dummy_line_node, checkmatch,
INT2FIX(checkmatch_type | VM_CHECKMATCH_ARRAY));
// If we have a predicate on this case statement, then it's going to
// compare all of the various when clauses to the predicate. If we
// don't, then it's basically an if-elsif-else chain.
if (cast->predicate == NULL) {
// Loop through each clauses in the case node and compile each of
// the conditions within them into cond_seq. If they match, they
// should jump into their respective bodies in body_seq.
for (size_t clause_index = 0; clause_index < conditions->size; clause_index++) {
const pm_when_node_t *clause = (const pm_when_node_t *) conditions->nodes[clause_index];
const pm_node_list_t *conditions = &clause->conditions;

int clause_lineno = (int) pm_newline_list_line_column(&scope_node->parser->newline_list, clause->base.location.start).line;
LABEL *label = NEW_LABEL(clause_lineno);

ADD_LABEL(body_seq, label);
if (clause->statements != NULL) {
pm_compile_node(iseq, (const pm_node_t *) clause->statements, body_seq, popped, scope_node);
}
else {
PM_COMPILE_NOT_POPPED(condition_node);
if (has_predicate) {
ADD_INSN1(ret, &dummy_line_node, topn, INT2FIX(1));
ADD_SEND_WITH_FLAG(ret, &dummy_line_node, idEqq, INT2NUM(1), INT2FIX(VM_CALL_FCALL | VM_CALL_ARGS_SIMPLE));
}
else if (!popped) {
ADD_INSN(body_seq, &dummy_line_node, putnil);
}

ADD_INSNL(ret, &dummy_line_node, branchif, label);
ADD_INSNL(body_seq, &dummy_line_node, jump, end_label);

// Compile each of the conditions for the when clause into the
// cond_seq. Each one should have a unique condition and should
// jump to the subsequent one if it doesn't match.
for (size_t condition_index = 0; condition_index < conditions->size; condition_index++) {
const pm_node_t *condition = conditions->nodes[condition_index];

if (PM_NODE_TYPE_P(condition, PM_SPLAT_NODE)) {
ADD_INSN(cond_seq, &dummy_line_node, putnil);
pm_compile_node(iseq, condition, cond_seq, false, scope_node);
ADD_INSN1(cond_seq, &dummy_line_node, checkmatch, INT2FIX(VM_CHECKMATCH_TYPE_WHEN | VM_CHECKMATCH_ARRAY));
ADD_INSNL(cond_seq, &dummy_line_node, branchif, label);
}
else {
int condition_lineno = (int) pm_newline_list_line_column(&scope_node->parser->newline_list, condition->location.start).line;
LABEL *next_label = NEW_LABEL(condition_lineno);

pm_compile_branch_condition(iseq, cond_seq, condition, label, next_label, false, scope_node);
ADD_LABEL(cond_seq, next_label);
}
}
}
}

if (has_predicate) {
PM_POP;
}
// Compile the consequent else clause if there is one.
if (cast->consequent) {
pm_compile_node(iseq, (const pm_node_t *) cast->consequent, cond_seq, popped, scope_node);
}
else if (!popped) {
ADD_INSN(cond_seq, &dummy_line_node, putnil);
}

if (case_node->consequent) {
PM_COMPILE((pm_node_t *)case_node->consequent);
// Finally, jump to the end label if none of the other conditions
// have matched.
ADD_INSNL(cond_seq, &dummy_line_node, jump, end_label);
ADD_SEQ(ret, cond_seq);
}
else {
PM_PUTNIL_UNLESS_POPPED;
}
// This is the label where everything will fall into if none of the
// conditions matched.
LABEL *else_label = NEW_LABEL(lineno);

// It's possible for us to speed up the case node by using a
// dispatch hash. This is a hash that maps the conditions of the
// various when clauses to the labels of their bodies. If we can
// compile the conditions into a hash key, then we can use a hash
// lookup to jump directly to the correct when clause body.
VALUE dispatch = Qundef;
if (ISEQ_COMPILE_DATA(iseq)->option->specialized_instruction) {
dispatch = rb_hash_new();
RHASH_TBL_RAW(dispatch)->type = &cdhash_type;
}

// We're going to loop through each of the conditions in the case
// node and compile each of their contents into both the cond_seq
// and the body_seq. Each condition will use its own label to jump
// from its conditions into its body.
//
// Note that none of the code in the loop below should be adding
// anything to ret, as we're going to be laying out the entire case
// node instructions later.
for (size_t clause_index = 0; clause_index < conditions->size; clause_index++) {
const pm_when_node_t *clause = (const pm_when_node_t *) conditions->nodes[clause_index];
const pm_node_list_t *conditions = &clause->conditions;

LABEL *label = NEW_LABEL(lineno);

// Compile each of the conditions for the when clause into the
// cond_seq. Each one should have a unique comparison that then
// jumps into the body if it matches.
for (size_t condition_index = 0; condition_index < conditions->size; condition_index++) {
const pm_node_t *condition = conditions->nodes[condition_index];

// If we haven't already abandoned the optimization, then
// we're going to try to compile the condition into the
// dispatch hash.
if (dispatch != Qundef) {
dispatch = pm_compile_case_node_dispatch(dispatch, condition, label, scope_node);
}

ADD_INSNL(ret, &dummy_line_node, jump, end_label);
if (PM_NODE_TYPE_P(condition, PM_SPLAT_NODE)) {
ADD_INSN(cond_seq, &dummy_line_node, dup);
pm_compile_node(iseq, condition, cond_seq, false, scope_node);
ADD_INSN1(cond_seq, &dummy_line_node, checkmatch, INT2FIX(VM_CHECKMATCH_TYPE_CASE | VM_CHECKMATCH_ARRAY));
}
else {
if (PM_NODE_TYPE_P(condition, PM_STRING_NODE)) {
const pm_string_node_t *string = (const pm_string_node_t *) condition;
VALUE value = rb_fstring(parse_string_encoded((const pm_node_t *) string, &string->unescaped, parser));
ADD_INSN1(cond_seq, &dummy_line_node, putobject, value);
}
else {
pm_compile_node(iseq, condition, cond_seq, false, scope_node);
}

for (size_t i = 0; i < conditions.size; i++) {
label = conditions_labels[i];
ADD_LABEL(ret, label);
if (has_predicate) {
PM_POP;
ADD_INSN1(cond_seq, &dummy_line_node, topn, INT2FIX(1));
ADD_SEND_WITH_FLAG(cond_seq, &dummy_line_node, idEqq, INT2NUM(1), INT2FIX(VM_CALL_FCALL | VM_CALL_ARGS_SIMPLE));
}

ADD_INSNL(cond_seq, &dummy_line_node, branchif, label);
}

// Now, add the label to the body and compile the body of the
// when clause. This involves popping the predicate, compiling
// the statements to be executed, and then compiling a jump to
// the end of the case node.
ADD_LABEL(body_seq, label);
ADD_INSN(body_seq, &dummy_line_node, pop);

if (clause->statements != NULL) {
pm_compile_node(iseq, (const pm_node_t *) clause->statements, body_seq, popped, scope_node);
}
else if (!popped) {
ADD_INSN(body_seq, &dummy_line_node, putnil);
}

ADD_INSNL(body_seq, &dummy_line_node, jump, end_label);
}

pm_while_node_t *condition_node = (pm_while_node_t *)conditions.nodes[i];
if (condition_node->statements) {
PM_COMPILE((pm_node_t *)condition_node->statements);
// Now that we have compiled the conditions and the bodies of the
// various when clauses, we can compile the predicate, lay out the
// conditions, compile the fallback consequent if there is one, and
// finally put in the bodies of the when clauses.
PM_COMPILE_NOT_POPPED(cast->predicate);

// If we have a dispatch hash, then we'll use it here to create the
// optimization.
if (dispatch != Qundef) {
PM_DUP;
ADD_INSN2(ret, &dummy_line_node, opt_case_dispatch, dispatch, else_label);
LABEL_REF(else_label);
}
else {
PM_PUTNIL_UNLESS_POPPED;

ADD_SEQ(ret, cond_seq);

// Compile either the explicit else clause or an implicit else
// clause.
ADD_LABEL(ret, else_label);
PM_POP;

if (cast->consequent != NULL) {
PM_COMPILE((const pm_node_t *) cast->consequent);
}
else if (!popped) {
PM_PUTNIL;
}

ADD_INSNL(ret, &dummy_line_node, jump, end_label);
}

ADD_SEQ(ret, body_seq);
ADD_LABEL(ret, end_label);

return;
}
case PM_CASE_MATCH_NODE: {
Expand Down