diff --git a/prism_compile.c b/prism_compile.c index 27c77c0410876a..e9c8c7f88f8aaa 100644 --- a/prism_compile.c +++ b/prism_compile.c @@ -214,7 +214,7 @@ parse_symbol(const uint8_t *start, const uint8_t *end, pm_parser_t *parser) } static inline ID -parse_string_symbol(pm_string_t *string, pm_parser_t *parser) +parse_string_symbol(const pm_string_t *string, pm_parser_t *parser) { const uint8_t *start = pm_string_source(string); return parse_symbol(start, start + pm_string_length(string), parser); @@ -409,6 +409,53 @@ pm_static_literal_value(const pm_node_t *node, pm_scope_node_t *scope_node, pm_p } } +/** + * Currently, the ADD_INSN family of macros expects a NODE as the second + * parameter. It uses this node to determine the line number and the node ID for + * the instruction. + * + * Because prism does not use the NODE struct (or have node IDs for that matter) + * we need to generate a dummy node to pass to these macros. We also need to use + * the line number from the node to generate labels. + * + * We use this struct to store the dummy node and the line number together so + * that we can use it while we're compiling code. + * + * In the future, we'll need to eventually remove this dependency and figure out + * a more permanent solution. For the line numbers, this shouldn't be too much + * of a problem, we can redefine the ADD_INSN family of macros. For the node ID, + * we can probably replace it directly with the column information since we have + * that at the time that we're generating instructions. In theory this could + * make node ID unnecessary. + */ +typedef struct { + NODE node; + int lineno; +} pm_line_node_t; + +/** + * The function generates a dummy node and stores the line number after it looks + * it up for the given scope and node. (The scope in this case is just used + * because it holds a reference to the parser, which holds a reference to the + * newline list that we need to look up the line numbers.) + */ +static void +pm_line_node(pm_line_node_t *line_node, const pm_scope_node_t *scope_node, const pm_node_t *node) +{ + // First, clear out the pointer. + memset(line_node, 0, sizeof(pm_line_node_t)); + + // Next, retrieve the line and column information from prism. + pm_line_column_t line_column = pm_newline_list_line_column(&scope_node->parser->newline_list, node->location.start); + + // Next, use the line number for the dummy node. + int lineno = (int) line_column.line; + + nd_set_line(&line_node->node, lineno); + nd_set_node_id(&line_node->node, lineno); + line_node->lineno = lineno; +} + static void pm_compile_branch_condition(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const pm_node_t *cond, LABEL *then_label, LABEL *else_label, const uint8_t *src, bool popped, pm_scope_node_t *scope_node); @@ -1195,60 +1242,706 @@ pm_compile_multi_write_lhs(rb_iseq_t *iseq, NODE dummy_line_node, const pm_node_ return pushed; } +// When we compile a pattern matching expression, we use the stack as a scratch +// space to store lots of different values (consider it like we have a pattern +// matching function and we need space for a bunch of different local +// variables). The "base index" refers to the index on the stack where we +// started compiling the pattern matching expression. These offsets from that +// base index indicate the location of the various locals we need. +#define PM_PATTERN_BASE_INDEX_OFFSET_DECONSTRUCTED_CACHE 0 +#define PM_PATTERN_BASE_INDEX_OFFSET_ERROR_STRING 1 +#define PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_P 2 +#define PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_MATCHEE 3 +#define PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_KEY 4 + +// A forward declaration because this is the recursive function that handles +// compiling a pattern. It can be reentered by nesting patterns, as in the case +// of arrays or hashes. +static int pm_compile_pattern(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, const uint8_t *src, LABEL *matched_label, LABEL *unmatched_label, bool in_single_pattern, bool in_alternation_pattern, bool use_deconstructed_cache, unsigned int base_index); + +/** + * This function generates the code to set up the error string and error_p + * locals depending on whether or not the pattern matched. + */ +static int +pm_compile_pattern_generic_error(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, VALUE message, unsigned int base_index) +{ + pm_line_node_t line; + pm_line_node(&line, scope_node, node); + + LABEL *match_succeeded_label = NEW_LABEL(line.lineno); + + ADD_INSN(ret, &line.node, dup); + ADD_INSNL(ret, &line.node, branchif, match_succeeded_label); + + ADD_INSN1(ret, &line.node, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE)); + ADD_INSN1(ret, &line.node, putobject, message); + ADD_INSN1(ret, &line.node, topn, INT2FIX(3)); + ADD_SEND(ret, &line.node, id_core_sprintf, INT2FIX(2)); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_ERROR_STRING + 1)); + + ADD_INSN1(ret, &line.node, putobject, Qfalse); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_P + 2)); + + ADD_INSN(ret, &line.node, pop); + ADD_INSN(ret, &line.node, pop); + ADD_LABEL(ret, match_succeeded_label); + + return COMPILE_OK; +} + +/** + * This function generates the code to set up the error string and error_p + * locals depending on whether or not the pattern matched when the value needs + * to match a specific deconstructed length. + */ +static int +pm_compile_pattern_length_error(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, VALUE message, VALUE length, unsigned int base_index) +{ + pm_line_node_t line; + pm_line_node(&line, scope_node, node); + + LABEL *match_succeeded_label = NEW_LABEL(line.lineno); + + ADD_INSN(ret, &line.node, dup); + ADD_INSNL(ret, &line.node, branchif, match_succeeded_label); + + ADD_INSN1(ret, &line.node, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE)); + ADD_INSN1(ret, &line.node, putobject, message); + ADD_INSN1(ret, &line.node, topn, INT2FIX(3)); + ADD_INSN(ret, &line.node, dup); + ADD_SEND(ret, &line.node, idLength, INT2FIX(0)); + ADD_INSN1(ret, &line.node, putobject, length); + ADD_SEND(ret, &line.node, id_core_sprintf, INT2FIX(4)); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_ERROR_STRING + 1)); + + ADD_INSN1(ret, &line.node, putobject, Qfalse); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_P + 2)); + + ADD_INSN(ret, &line.node, pop); + ADD_INSN(ret, &line.node, pop); + ADD_LABEL(ret, match_succeeded_label); + + return COMPILE_OK; +} + +/** + * This function generates the code to set up the error string and error_p + * locals depending on whether or not the pattern matched when the value needs + * to pass a specific #=== method call. + */ +static int +pm_compile_pattern_eqq_error(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, unsigned int base_index) +{ + pm_line_node_t line; + pm_line_node(&line, scope_node, node); + + LABEL *match_succeeded_label = NEW_LABEL(line.lineno); + + ADD_INSN(ret, &line.node, dup); + ADD_INSNL(ret, &line.node, branchif, match_succeeded_label); + + ADD_INSN1(ret, &line.node, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE)); + ADD_INSN1(ret, &line.node, putobject, rb_fstring_lit("%p === %p does not return true")); + ADD_INSN1(ret, &line.node, topn, INT2FIX(3)); + ADD_INSN1(ret, &line.node, topn, INT2FIX(5)); + ADD_SEND(ret, &line.node, id_core_sprintf, INT2FIX(3)); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_ERROR_STRING + 1)); + ADD_INSN1(ret, &line.node, putobject, Qfalse); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_P + 2)); + ADD_INSN(ret, &line.node, pop); + ADD_INSN(ret, &line.node, pop); + + ADD_LABEL(ret, match_succeeded_label); + ADD_INSN1(ret, &line.node, setn, INT2FIX(2)); + ADD_INSN(ret, &line.node, pop); + ADD_INSN(ret, &line.node, pop); + + return COMPILE_OK; +} + +/** + * This is a variation on compiling a pattern matching expression that is used + * to have the pattern matching instructions fall through to immediately after + * the pattern if it passes. Otherwise it jumps to the given unmatched_label + * label. + */ +static int +pm_compile_pattern_match(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, const uint8_t *src, LABEL *unmatched_label, bool in_single_pattern, bool in_alternation_pattern, bool use_deconstructed_cache, unsigned int base_index) +{ + LABEL *matched_label = NEW_LABEL(nd_line(node)); + CHECK(pm_compile_pattern(iseq, scope_node, node, ret, src, matched_label, unmatched_label, in_single_pattern, in_alternation_pattern, use_deconstructed_cache, base_index)); + ADD_LABEL(ret, matched_label); + return COMPILE_OK; +} + +/** + * This function compiles in the code necessary to call #deconstruct on the + * value to match against. It raises appropriate errors if the method does not + * exist or if it returns the wrong type. + */ +static int +pm_compile_pattern_deconstruct(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, const uint8_t *src, LABEL *deconstruct_label, LABEL *match_failed_label, LABEL *deconstructed_label, LABEL *type_error_label, bool in_single_pattern, bool use_deconstructed_cache, unsigned int base_index) +{ + pm_line_node_t line; + pm_line_node(&line, scope_node, node); + + if (use_deconstructed_cache) { + ADD_INSN1(ret, &line.node, topn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_DECONSTRUCTED_CACHE)); + ADD_INSNL(ret, &line.node, branchnil, deconstruct_label); + + ADD_INSN1(ret, &line.node, topn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_DECONSTRUCTED_CACHE)); + ADD_INSNL(ret, &line.node, branchunless, match_failed_label); + + ADD_INSN(ret, &line.node, pop); + ADD_INSN1(ret, &line.node, topn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_DECONSTRUCTED_CACHE - 1)); + ADD_INSNL(ret, &line.node, jump, deconstructed_label); + } else { + ADD_INSNL(ret, &line.node, jump, deconstruct_label); + } + + ADD_LABEL(ret, deconstruct_label); + ADD_INSN(ret, &line.node, dup); + ADD_INSN1(ret, &line.node, putobject, ID2SYM(rb_intern("deconstruct"))); + ADD_SEND(ret, &line.node, idRespond_to, INT2FIX(1)); + + if (use_deconstructed_cache) { + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_DECONSTRUCTED_CACHE + 1)); + } + + if (in_single_pattern) { + CHECK(pm_compile_pattern_generic_error(iseq, scope_node, node, ret, rb_fstring_lit("%p does not respond to #deconstruct"), base_index + 1)); + } + + ADD_INSNL(ret, &line.node, branchunless, match_failed_label); + ADD_SEND(ret, &line.node, rb_intern("deconstruct"), INT2FIX(0)); + + if (use_deconstructed_cache) { + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_DECONSTRUCTED_CACHE)); + } + + ADD_INSN(ret, &line.node, dup); + ADD_INSN1(ret, &line.node, checktype, INT2FIX(T_ARRAY)); + ADD_INSNL(ret, &line.node, branchunless, type_error_label); + ADD_LABEL(ret, deconstructed_label); + + return COMPILE_OK; +} + +/** + * This function compiles in the code necessary to match against the optional + * constant path that is attached to an array, find, or hash pattern. + */ +static int +pm_compile_pattern_constant(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, const uint8_t *src, LABEL *match_failed_label, bool in_single_pattern, unsigned int base_index) +{ + pm_line_node_t line; + pm_line_node(&line, scope_node, node); + + ADD_INSN(ret, &line.node, dup); + PM_COMPILE_NOT_POPPED(node); + + if (in_single_pattern) { + ADD_INSN1(ret, &line.node, dupn, INT2FIX(2)); + } + ADD_INSN1(ret, &line.node, checkmatch, INT2FIX(VM_CHECKMATCH_TYPE_CASE)); + if (in_single_pattern) { + CHECK(pm_compile_pattern_eqq_error(iseq, scope_node, node, ret, base_index + 3)); + } + ADD_INSNL(ret, &line.node, branchunless, match_failed_label); + return COMPILE_OK; +} + +/** + * When matching fails, an appropriate error must be raised. This function is + * responsible for compiling in those error raising instructions. + */ +static void +pm_compile_pattern_error_handler(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, const uint8_t *src, LABEL *done_label, bool popped) +{ + pm_line_node_t line; + pm_line_node(&line, scope_node, node); + + LABEL *key_error_label = NEW_LABEL(line.lineno); + LABEL *cleanup_label = NEW_LABEL(line.lineno); + + struct rb_callinfo_kwarg *kw_arg = rb_xmalloc_mul_add(2, sizeof(VALUE), sizeof(struct rb_callinfo_kwarg)); + kw_arg->references = 0; + kw_arg->keyword_len = 2; + kw_arg->keywords[0] = ID2SYM(rb_intern("matchee")); + kw_arg->keywords[1] = ID2SYM(rb_intern("key")); + + ADD_INSN1(ret, &line.node, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE)); + ADD_INSN1(ret, &line.node, topn, INT2FIX(PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_P + 2)); + ADD_INSNL(ret, &line.node, branchif, key_error_label); + + ADD_INSN1(ret, &line.node, putobject, rb_eNoMatchingPatternError); + ADD_INSN1(ret, &line.node, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE)); + ADD_INSN1(ret, &line.node, putobject, rb_fstring_lit("%p: %s")); + ADD_INSN1(ret, &line.node, topn, INT2FIX(4)); + ADD_INSN1(ret, &line.node, topn, INT2FIX(PM_PATTERN_BASE_INDEX_OFFSET_ERROR_STRING + 6)); + ADD_SEND(ret, &line.node, id_core_sprintf, INT2FIX(3)); + ADD_SEND(ret, &line.node, id_core_raise, INT2FIX(2)); + ADD_INSNL(ret, &line.node, jump, cleanup_label); + + ADD_LABEL(ret, key_error_label); + ADD_INSN1(ret, &line.node, putobject, rb_eNoMatchingPatternKeyError); + ADD_INSN1(ret, &line.node, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE)); + ADD_INSN1(ret, &line.node, putobject, rb_fstring_lit("%p: %s")); + ADD_INSN1(ret, &line.node, topn, INT2FIX(4)); + ADD_INSN1(ret, &line.node, topn, INT2FIX(PM_PATTERN_BASE_INDEX_OFFSET_ERROR_STRING + 6)); + ADD_SEND(ret, &line.node, id_core_sprintf, INT2FIX(3)); + ADD_INSN1(ret, &line.node, topn, INT2FIX(PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_MATCHEE + 4)); + ADD_INSN1(ret, &line.node, topn, INT2FIX(PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_KEY + 5)); + ADD_SEND_R(ret, &line.node, rb_intern("new"), INT2FIX(1), NULL, INT2FIX(VM_CALL_KWARG), kw_arg); + ADD_SEND(ret, &line.node, id_core_raise, INT2FIX(1)); + ADD_LABEL(ret, cleanup_label); + + ADD_INSN1(ret, &line.node, adjuststack, INT2FIX(7)); + if (!popped) ADD_INSN(ret, &line.node, putnil); + ADD_INSNL(ret, &line.node, jump, done_label); + ADD_INSN1(ret, &line.node, dupn, INT2FIX(5)); + if (popped) ADD_INSN(ret, &line.node, putnil); +} + /** * Compile a pattern matching expression. */ static int -pm_compile_pattern(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, const uint8_t *src, pm_scope_node_t *scope_node, LABEL *matched_label, LABEL *unmatched_label, bool in_alternation_pattern) +pm_compile_pattern(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, const uint8_t *src, LABEL *matched_label, LABEL *unmatched_label, bool in_single_pattern, bool in_alternation_pattern, bool use_deconstructed_cache, unsigned int base_index) { - int lineno = (int) pm_newline_list_line_column(&scope_node->parser->newline_list, node->location.start).line; - NODE dummy_line_node = generate_dummy_line_node(lineno, lineno); + pm_line_node_t line; + pm_line_node(&line, scope_node, node); switch (PM_NODE_TYPE(node)) { - case PM_ARRAY_PATTERN_NODE: - rb_bug("Array pattern matching not yet supported."); - break; - case PM_FIND_PATTERN_NODE: - rb_bug("Find pattern matching not yet supported."); - break; - case PM_HASH_PATTERN_NODE: - rb_bug("Hash pattern matching not yet supported."); + case PM_ARRAY_PATTERN_NODE: { + // Array patterns in pattern matching are triggered by using commas in + // a pattern or wrapping it in braces. They are represented by a + // ArrayPatternNode. This looks like: + // + // foo => [1, 2, 3] + // + // It can optionally have a splat in the middle of it, which can + // optionally have a name attached. + const pm_array_pattern_node_t *cast = (const pm_array_pattern_node_t *) node; + + const size_t requireds_size = cast->requireds.size; + const size_t posts_size = cast->posts.size; + const size_t minimum_size = requireds_size + posts_size; + + bool use_rest_size = ( + cast->rest != NULL && + PM_NODE_TYPE_P(cast->rest, PM_SPLAT_NODE) && + ((((const pm_splat_node_t *) cast->rest)->expression != NULL) || posts_size > 0) + ); + + LABEL *match_failed_label = NEW_LABEL(line.lineno); + LABEL *type_error_label = NEW_LABEL(line.lineno); + LABEL *deconstruct_label = NEW_LABEL(line.lineno); + LABEL *deconstructed_label = NEW_LABEL(line.lineno); + + if (use_rest_size) { + ADD_INSN1(ret, &line.node, putobject, INT2FIX(0)); + ADD_INSN(ret, &line.node, swap); + base_index++; + } + + if (cast->constant != NULL) { + CHECK(pm_compile_pattern_constant(iseq, scope_node, cast->constant, ret, src, match_failed_label, in_single_pattern, base_index)); + } + + CHECK(pm_compile_pattern_deconstruct(iseq, scope_node, node, ret, src, deconstruct_label, match_failed_label, deconstructed_label, type_error_label, in_single_pattern, use_deconstructed_cache, base_index)); + + ADD_INSN(ret, &line.node, dup); + ADD_SEND(ret, &line.node, idLength, INT2FIX(0)); + ADD_INSN1(ret, &line.node, putobject, INT2FIX(minimum_size)); + ADD_SEND(ret, &line.node, cast->rest == NULL ? idEq : idGE, INT2FIX(1)); + if (in_single_pattern) { + VALUE message = cast->rest == NULL ? rb_fstring_lit("%p length mismatch (given %p, expected %p)") : rb_fstring_lit("%p length mismatch (given %p, expected %p+)"); + CHECK(pm_compile_pattern_length_error(iseq, scope_node, node, ret, message, INT2FIX(minimum_size), base_index + 1)); + } + ADD_INSNL(ret, &line.node, branchunless, match_failed_label); + + for (size_t index = 0; index < requireds_size; index++) { + const pm_node_t *required = cast->requireds.nodes[index]; + ADD_INSN(ret, &line.node, dup); + ADD_INSN1(ret, &line.node, putobject, INT2FIX(index)); + ADD_SEND(ret, &line.node, idAREF, INT2FIX(1)); + CHECK(pm_compile_pattern_match(iseq, scope_node, required, ret, src, match_failed_label, in_single_pattern, in_alternation_pattern, false, base_index + 1)); + } + + if (cast->rest != NULL) { + if (((const pm_splat_node_t *) cast->rest)->expression != NULL) { + ADD_INSN(ret, &line.node, dup); + ADD_INSN1(ret, &line.node, putobject, INT2FIX(requireds_size)); + ADD_INSN1(ret, &line.node, topn, INT2FIX(1)); + ADD_SEND(ret, &line.node, idLength, INT2FIX(0)); + ADD_INSN1(ret, &line.node, putobject, INT2FIX(minimum_size)); + ADD_SEND(ret, &line.node, idMINUS, INT2FIX(1)); + ADD_INSN1(ret, &line.node, setn, INT2FIX(4)); + ADD_SEND(ret, &line.node, idAREF, INT2FIX(2)); + CHECK(pm_compile_pattern_match(iseq, scope_node, ((const pm_splat_node_t *) cast->rest)->expression, ret, src, match_failed_label, in_single_pattern, in_alternation_pattern, false, base_index + 1)); + } else if (posts_size > 0) { + ADD_INSN(ret, &line.node, dup); + ADD_SEND(ret, &line.node, idLength, INT2FIX(0)); + ADD_INSN1(ret, &line.node, putobject, INT2FIX(minimum_size)); + ADD_SEND(ret, &line.node, idMINUS, INT2FIX(1)); + ADD_INSN1(ret, &line.node, setn, INT2FIX(2)); + ADD_INSN(ret, &line.node, pop); + } + } + + for (size_t index = 0; index < posts_size; index++) { + const pm_node_t *post = cast->posts.nodes[index]; + ADD_INSN(ret, &line.node, dup); + + ADD_INSN1(ret, &line.node, putobject, INT2FIX(requireds_size + index)); + ADD_INSN1(ret, &line.node, topn, INT2FIX(3)); + ADD_SEND(ret, &line.node, idPLUS, INT2FIX(1)); + ADD_SEND(ret, &line.node, idAREF, INT2FIX(1)); + CHECK(pm_compile_pattern_match(iseq, scope_node, post, ret, src, match_failed_label, in_single_pattern, in_alternation_pattern, false, base_index + 1)); + } + + ADD_INSN(ret, &line.node, pop); + if (use_rest_size) { + ADD_INSN(ret, &line.node, pop); + } + + ADD_INSNL(ret, &line.node, jump, matched_label); + ADD_INSN(ret, &line.node, putnil); + if (use_rest_size) { + ADD_INSN(ret, &line.node, putnil); + } + + ADD_LABEL(ret, type_error_label); + ADD_INSN1(ret, &line.node, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE)); + ADD_INSN1(ret, &line.node, putobject, rb_eTypeError); + ADD_INSN1(ret, &line.node, putobject, rb_fstring_lit("deconstruct must return Array")); + ADD_SEND(ret, &line.node, id_core_raise, INT2FIX(2)); + ADD_INSN(ret, &line.node, pop); + + ADD_LABEL(ret, match_failed_label); + ADD_INSN(ret, &line.node, pop); + if (use_rest_size) { + ADD_INSN(ret, &line.node, pop); + } + + ADD_INSNL(ret, &line.node, jump, unmatched_label); break; - case PM_CAPTURE_PATTERN_NODE: - rb_bug("Capture pattern matching not yet supported."); + } + case PM_FIND_PATTERN_NODE: { + // Find patterns in pattern matching are triggered by using commas in + // a pattern or wrapping it in braces and using a splat on both the left + // and right side of the pattern. This looks like: + // + // foo => [*, 1, 2, 3, *] + // + // There can be any number of requireds in the middle. The splats on + // both sides can optionally have names attached. + const pm_find_pattern_node_t *cast = (const pm_find_pattern_node_t *) node; + const size_t size = cast->requireds.size; + + LABEL *match_failed_label = NEW_LABEL(line.lineno); + LABEL *type_error_label = NEW_LABEL(line.lineno); + LABEL *deconstruct_label = NEW_LABEL(line.lineno); + LABEL *deconstructed_label = NEW_LABEL(line.lineno); + + if (cast->constant) { + CHECK(pm_compile_pattern_constant(iseq, scope_node, cast->constant, ret, src, match_failed_label, in_single_pattern, base_index)); + } + + CHECK(pm_compile_pattern_deconstruct(iseq, scope_node, node, ret, src, deconstruct_label, match_failed_label, deconstructed_label, type_error_label, in_single_pattern, use_deconstructed_cache, base_index)); + + ADD_INSN(ret, &line.node, dup); + ADD_SEND(ret, &line.node, idLength, INT2FIX(0)); + ADD_INSN1(ret, &line.node, putobject, INT2FIX(size)); + ADD_SEND(ret, &line.node, idGE, INT2FIX(1)); + if (in_single_pattern) { + CHECK(pm_compile_pattern_length_error(iseq, scope_node, node, ret, rb_fstring_lit("%p length mismatch (given %p, expected %p+)"), INT2FIX(size), base_index + 1)); + } + ADD_INSNL(ret, &line.node, branchunless, match_failed_label); + + { + LABEL *while_begin_label = NEW_LABEL(line.lineno); + LABEL *next_loop_label = NEW_LABEL(line.lineno); + LABEL *find_succeeded_label = NEW_LABEL(line.lineno); + LABEL *find_failed_label = NEW_LABEL(line.lineno); + + ADD_INSN(ret, &line.node, dup); + ADD_SEND(ret, &line.node, idLength, INT2FIX(0)); + + ADD_INSN(ret, &line.node, dup); + ADD_INSN1(ret, &line.node, putobject, INT2FIX(size)); + ADD_SEND(ret, &line.node, idMINUS, INT2FIX(1)); + ADD_INSN1(ret, &line.node, putobject, INT2FIX(0)); + ADD_LABEL(ret, while_begin_label); + + ADD_INSN(ret, &line.node, dup); + ADD_INSN1(ret, &line.node, topn, INT2FIX(2)); + ADD_SEND(ret, &line.node, idLE, INT2FIX(1)); + ADD_INSNL(ret, &line.node, branchunless, find_failed_label); + + for (size_t index = 0; index < size; index++) { + ADD_INSN1(ret, &line.node, topn, INT2FIX(3)); + ADD_INSN1(ret, &line.node, topn, INT2FIX(1)); + + if (index != 0) { + ADD_INSN1(ret, &line.node, putobject, INT2FIX(index)); + ADD_SEND(ret, &line.node, idPLUS, INT2FIX(1)); + } + + ADD_SEND(ret, &line.node, idAREF, INT2FIX(1)); + CHECK(pm_compile_pattern_match(iseq, scope_node, cast->requireds.nodes[index], ret, src, next_loop_label, in_single_pattern, in_alternation_pattern, false, base_index + 4)); + } + + assert(PM_NODE_TYPE_P(cast->left, PM_SPLAT_NODE)); + const pm_splat_node_t *left = (const pm_splat_node_t *) cast->left; + + if (left->expression != NULL) { + ADD_INSN1(ret, &line.node, topn, INT2FIX(3)); + ADD_INSN1(ret, &line.node, putobject, INT2FIX(0)); + ADD_INSN1(ret, &line.node, topn, INT2FIX(2)); + ADD_SEND(ret, &line.node, idAREF, INT2FIX(2)); + CHECK(pm_compile_pattern_match(iseq, scope_node, left->expression, ret, src, find_failed_label, in_single_pattern, in_alternation_pattern, false, base_index + 4)); + } + + assert(PM_NODE_TYPE_P(cast->right, PM_SPLAT_NODE)); + const pm_splat_node_t *right = (const pm_splat_node_t *) cast->right; + + if (right->expression != NULL) { + ADD_INSN1(ret, &line.node, topn, INT2FIX(3)); + ADD_INSN1(ret, &line.node, topn, INT2FIX(1)); + ADD_INSN1(ret, &line.node, putobject, INT2FIX(size)); + ADD_SEND(ret, &line.node, idPLUS, INT2FIX(1)); + ADD_INSN1(ret, &line.node, topn, INT2FIX(3)); + ADD_SEND(ret, &line.node, idAREF, INT2FIX(2)); + pm_compile_pattern_match(iseq, scope_node, right->expression, ret, src, find_failed_label, in_single_pattern, in_alternation_pattern, false, base_index + 4); + } + + ADD_INSNL(ret, &line.node, jump, find_succeeded_label); + + ADD_LABEL(ret, next_loop_label); + ADD_INSN1(ret, &line.node, putobject, INT2FIX(1)); + ADD_SEND(ret, &line.node, idPLUS, INT2FIX(1)); + ADD_INSNL(ret, &line.node, jump, while_begin_label); + + ADD_LABEL(ret, find_failed_label); + ADD_INSN1(ret, &line.node, adjuststack, INT2FIX(3)); + if (in_single_pattern) { + ADD_INSN1(ret, &line.node, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE)); + ADD_INSN1(ret, &line.node, putobject, rb_fstring_lit("%p does not match to find pattern")); + ADD_INSN1(ret, &line.node, topn, INT2FIX(2)); + ADD_SEND(ret, &line.node, id_core_sprintf, INT2FIX(2)); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_ERROR_STRING + 1)); + + ADD_INSN1(ret, &line.node, putobject, Qfalse); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_P + 2)); + + ADD_INSN(ret, &line.node, pop); + ADD_INSN(ret, &line.node, pop); + } + ADD_INSNL(ret, &line.node, jump, match_failed_label); + ADD_INSN1(ret, &line.node, dupn, INT2FIX(3)); + + ADD_LABEL(ret, find_succeeded_label); + ADD_INSN1(ret, &line.node, adjuststack, INT2FIX(3)); + } + + ADD_INSN(ret, &line.node, pop); + ADD_INSNL(ret, &line.node, jump, matched_label); + ADD_INSN(ret, &line.node, putnil); + + ADD_LABEL(ret, type_error_label); + ADD_INSN1(ret, &line.node, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE)); + ADD_INSN1(ret, &line.node, putobject, rb_eTypeError); + ADD_INSN1(ret, &line.node, putobject, rb_fstring_lit("deconstruct must return Array")); + ADD_SEND(ret, &line.node, id_core_raise, INT2FIX(2)); + ADD_INSN(ret, &line.node, pop); + + ADD_LABEL(ret, match_failed_label); + ADD_INSN(ret, &line.node, pop); + ADD_INSNL(ret, &line.node, jump, unmatched_label); + break; - case PM_IF_NODE: { - // If guards can be placed on patterns to further limit matches based on - // a dynamic predicate. This looks like: + } + case PM_HASH_PATTERN_NODE: { + // Hash patterns in pattern matching are triggered by using labels and + // values in a pattern or by using the ** operator. They are represented + // by the HashPatternNode. This looks like: // - // case foo - // in bar if baz - // end + // foo => { a: 1, b: 2, **bar } // - pm_if_node_t *cast = (pm_if_node_t *) node; + // It can optionally have an assoc splat in the middle of it, which can + // optionally have a name. + const pm_hash_pattern_node_t *cast = (const pm_hash_pattern_node_t *) node; - pm_compile_pattern(iseq, cast->statements->body.nodes[0], ret, src, scope_node, matched_label, unmatched_label, in_alternation_pattern); - PM_COMPILE_NOT_POPPED(cast->predicate); + // We don't consider it a "rest" parameter if it's a ** that is unnamed. + bool has_rest = cast->rest != NULL && !(PM_NODE_TYPE_P(cast->rest, PM_ASSOC_SPLAT_NODE) && ((const pm_assoc_splat_node_t *) cast->rest)->value == NULL); + bool has_keys = cast->elements.size > 0 || cast->rest != NULL; + + LABEL *match_failed_label = NEW_LABEL(line.lineno); + LABEL *type_error_label = NEW_LABEL(line.lineno); + VALUE keys = Qnil; + + if (has_keys && !has_rest) { + keys = rb_ary_new_capa(cast->elements.size); + + for (size_t index = 0; index < cast->elements.size; index++) { + const pm_node_t *element = cast->elements.nodes[index]; + assert(PM_NODE_TYPE_P(element, PM_ASSOC_NODE)); + + const pm_node_t *key = ((const pm_assoc_node_t *) element)->key; + assert(PM_NODE_TYPE_P(key, PM_SYMBOL_NODE)); + + VALUE symbol = ID2SYM(parse_string_symbol(&((const pm_symbol_node_t *) key)->unescaped, scope_node->parser)); + rb_ary_push(keys, symbol); + } + } - ADD_INSNL(ret, &dummy_line_node, branchunless, unmatched_label); - ADD_INSNL(ret, &dummy_line_node, jump, matched_label); + if (cast->constant) { + CHECK(pm_compile_pattern_constant(iseq, scope_node, cast->constant, ret, src, match_failed_label, in_single_pattern, base_index)); + } + + ADD_INSN(ret, &line.node, dup); + ADD_INSN1(ret, &line.node, putobject, ID2SYM(rb_intern("deconstruct_keys"))); + ADD_SEND(ret, &line.node, idRespond_to, INT2FIX(1)); + if (in_single_pattern) { + CHECK(pm_compile_pattern_generic_error(iseq, scope_node, node, ret, rb_fstring_lit("%p does not respond to #deconstruct_keys"), base_index + 1)); + } + ADD_INSNL(ret, &line.node, branchunless, match_failed_label); + + if (NIL_P(keys)) { + ADD_INSN(ret, &line.node, putnil); + } else { + ADD_INSN1(ret, &line.node, duparray, keys); + RB_OBJ_WRITTEN(iseq, Qundef, rb_obj_hide(keys)); + } + ADD_SEND(ret, &line.node, rb_intern("deconstruct_keys"), INT2FIX(1)); + + ADD_INSN(ret, &line.node, dup); + ADD_INSN1(ret, &line.node, checktype, INT2FIX(T_HASH)); + ADD_INSNL(ret, &line.node, branchunless, type_error_label); + + if (has_rest) { + ADD_SEND(ret, &line.node, rb_intern("dup"), INT2FIX(0)); + } + + if (has_keys) { + DECL_ANCHOR(match_values); + INIT_ANCHOR(match_values); + + for (size_t index = 0; index < cast->elements.size; index++) { + const pm_node_t *element = cast->elements.nodes[index]; + assert(PM_NODE_TYPE_P(element, PM_ASSOC_NODE)); + + const pm_assoc_node_t *assoc = (const pm_assoc_node_t *) element; + const pm_node_t *key = assoc->key; + assert(PM_NODE_TYPE_P(key, PM_SYMBOL_NODE)); + + VALUE symbol = ID2SYM(parse_string_symbol(&((const pm_symbol_node_t *) key)->unescaped, scope_node->parser)); + ADD_INSN(ret, &line.node, dup); + ADD_INSN1(ret, &line.node, putobject, symbol); + ADD_SEND(ret, &line.node, rb_intern("key?"), INT2FIX(1)); + + if (in_single_pattern) { + LABEL *match_succeeded_label = NEW_LABEL(line.lineno); + + ADD_INSN(ret, &line.node, dup); + ADD_INSNL(ret, &line.node, branchif, match_succeeded_label); + + ADD_INSN1(ret, &line.node, putobject, rb_str_freeze(rb_sprintf("key not found: %+"PRIsVALUE, symbol))); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_ERROR_STRING + 2)); + ADD_INSN1(ret, &line.node, putobject, Qtrue); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_P + 3)); + ADD_INSN1(ret, &line.node, topn, INT2FIX(3)); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_MATCHEE + 4)); + ADD_INSN1(ret, &line.node, putobject, symbol); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_KEY + 5)); + + ADD_INSN1(ret, &line.node, adjuststack, INT2FIX(4)); + ADD_LABEL(ret, match_succeeded_label); + } + + ADD_INSNL(ret, &line.node, branchunless, match_failed_label); + ADD_INSN(match_values, &line.node, dup); + ADD_INSN1(match_values, &line.node, putobject, symbol); + ADD_SEND(match_values, &line.node, has_rest ? rb_intern("delete") : idAREF, INT2FIX(1)); + + CHECK(pm_compile_pattern_match(iseq, scope_node, assoc->value, match_values, src, match_failed_label, in_single_pattern, in_alternation_pattern, false, base_index + 1)); + } + + ADD_SEQ(ret, match_values); + } else { + ADD_INSN(ret, &line.node, dup); + ADD_SEND(ret, &line.node, idEmptyP, INT2FIX(0)); + if (in_single_pattern) { + CHECK(pm_compile_pattern_generic_error(iseq, scope_node, node, ret, rb_fstring_lit("%p is not empty"), base_index + 1)); + } + ADD_INSNL(ret, &line.node, branchunless, match_failed_label); + } + + if (has_rest) { + switch (PM_NODE_TYPE(cast->rest)) { + case PM_NO_KEYWORDS_PARAMETER_NODE: { + ADD_INSN(ret, &line.node, dup); + ADD_SEND(ret, &line.node, idEmptyP, INT2FIX(0)); + if (in_single_pattern) { + pm_compile_pattern_generic_error(iseq, scope_node, node, ret, rb_fstring_lit("rest of %p is not empty"), base_index + 1); + } + ADD_INSNL(ret, &line.node, branchunless, match_failed_label); + break; + } + case PM_ASSOC_SPLAT_NODE: { + const pm_assoc_splat_node_t *splat = (const pm_assoc_splat_node_t *) cast->rest; + ADD_INSN(ret, &line.node, dup); + pm_compile_pattern_match(iseq, scope_node, splat->value, ret, src, match_failed_label, in_single_pattern, in_alternation_pattern, false, base_index + 1); + break; + } + default: + rb_bug("unreachable"); + break; + } + } + + ADD_INSN(ret, &line.node, pop); + ADD_INSNL(ret, &line.node, jump, matched_label); + ADD_INSN(ret, &line.node, putnil); + + ADD_LABEL(ret, type_error_label); + ADD_INSN1(ret, &line.node, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE)); + ADD_INSN1(ret, &line.node, putobject, rb_eTypeError); + ADD_INSN1(ret, &line.node, putobject, rb_fstring_lit("deconstruct_keys must return Hash")); + ADD_SEND(ret, &line.node, id_core_raise, INT2FIX(2)); + ADD_INSN(ret, &line.node, pop); + + ADD_LABEL(ret, match_failed_label); + ADD_INSN(ret, &line.node, pop); + ADD_INSNL(ret, &line.node, jump, unmatched_label); break; } - case PM_UNLESS_NODE: { - // Unless guards can be placed on patterns to further limit matches - // based on a dynamic predicate. This looks like: + case PM_CAPTURE_PATTERN_NODE: { + // Capture patterns allow you to pattern match against an element in a + // pattern and also capture the value into a local variable. This looks + // like: // - // case foo - // in bar unless baz - // end + // [1] => [Integer => foo] // - pm_unless_node_t *cast = (pm_unless_node_t *) node; + // In this case the `Integer => foo` will be represented by a + // CapturePatternNode, which has both a value (the pattern to match + // against) and a target (the place to write the variable into). + const pm_capture_pattern_node_t *cast = (const pm_capture_pattern_node_t *) node; - pm_compile_pattern(iseq, cast->statements->body.nodes[0], ret, src, scope_node, matched_label, unmatched_label, in_alternation_pattern); - PM_COMPILE_NOT_POPPED(cast->predicate); + LABEL *match_failed_label = NEW_LABEL(line.lineno); + + ADD_INSN(ret, &line.node, dup); + CHECK(pm_compile_pattern_match(iseq, scope_node, cast->value, ret, src, match_failed_label, in_single_pattern, in_alternation_pattern, use_deconstructed_cache, base_index + 1)); + CHECK(pm_compile_pattern(iseq, scope_node, cast->target, ret, src, matched_label, match_failed_label, in_single_pattern, in_alternation_pattern, false, base_index)); + ADD_INSN(ret, &line.node, putnil); + + ADD_LABEL(ret, match_failed_label); + ADD_INSN(ret, &line.node, pop); + ADD_INSNL(ret, &line.node, jump, unmatched_label); - ADD_INSNL(ret, &dummy_line_node, branchif, unmatched_label); - ADD_INSNL(ret, &dummy_line_node, jump, matched_label); break; } case PM_LOCAL_VARIABLE_TARGET_NODE: { @@ -1272,8 +1965,8 @@ pm_compile_pattern(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const re } } - ADD_SETLOCAL(ret, &dummy_line_node, index, (int) cast->depth); - ADD_INSNL(ret, &dummy_line_node, jump, matched_label); + ADD_SETLOCAL(ret, &line.node, index, (int) cast->depth); + ADD_INSNL(ret, &line.node, jump, matched_label); break; } case PM_ALTERNATION_PATTERN_NODE: { @@ -1281,26 +1974,26 @@ pm_compile_pattern(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const re // single expression using the | operator. pm_alternation_pattern_node_t *cast = (pm_alternation_pattern_node_t *) node; - LABEL *matched_left_label = NEW_LABEL(lineno); - LABEL *unmatched_left_label = NEW_LABEL(lineno); + LABEL *matched_left_label = NEW_LABEL(line.lineno); + LABEL *unmatched_left_label = NEW_LABEL(line.lineno); // First, we're going to attempt to match against the left pattern. If // that pattern matches, then we'll skip matching the right pattern. - PM_DUP; - pm_compile_pattern(iseq, cast->left, ret, src, scope_node, matched_left_label, unmatched_left_label, true); + ADD_INSN(ret, &line.node, dup); + CHECK(pm_compile_pattern(iseq, scope_node, cast->left, ret, src, matched_left_label, unmatched_left_label, in_single_pattern, true, true, base_index + 1)); // If we get here, then we matched on the left pattern. In this case we // should pop out the duplicate value that we preemptively added to // match against the right pattern and then jump to the match label. ADD_LABEL(ret, matched_left_label); - PM_POP; - ADD_INSNL(ret, &dummy_line_node, jump, matched_label); - PM_PUTNIL; + ADD_INSN(ret, &line.node, pop); + ADD_INSNL(ret, &line.node, jump, matched_label); + ADD_INSN(ret, &line.node, putnil); // If we get here, then we didn't match on the left pattern. In this // case we attempt to match against the right pattern. ADD_LABEL(ret, unmatched_left_label); - pm_compile_pattern(iseq, cast->right, ret, src, scope_node, matched_label, unmatched_label, true); + CHECK(pm_compile_pattern(iseq, scope_node, cast->right, ret, src, matched_label, unmatched_label, in_single_pattern, true, true, base_index)); break; } case PM_ARRAY_NODE: @@ -1327,22 +2020,32 @@ pm_compile_pattern(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const re case PM_STRING_NODE: case PM_SYMBOL_NODE: case PM_TRUE_NODE: - case PM_X_STRING_NODE: + case PM_X_STRING_NODE: { // These nodes are all simple patterns, which means we'll use the // checkmatch instruction to match against them, which is effectively a // VM-level === operator. PM_COMPILE_NOT_POPPED(node); - ADD_INSN1(ret, &dummy_line_node, checkmatch, INT2FIX(VM_CHECKMATCH_TYPE_CASE)); - ADD_INSNL(ret, &dummy_line_node, branchif, matched_label); - ADD_INSNL(ret, &dummy_line_node, jump, unmatched_label); + if (in_single_pattern) { + ADD_INSN1(ret, &line.node, dupn, INT2FIX(2)); + } + + ADD_INSN1(ret, &line.node, checkmatch, INT2FIX(VM_CHECKMATCH_TYPE_CASE)); + + if (in_single_pattern) { + pm_compile_pattern_eqq_error(iseq, scope_node, node, ret, base_index + 2); + } + + ADD_INSNL(ret, &line.node, branchif, matched_label); + ADD_INSNL(ret, &line.node, jump, unmatched_label); break; + } case PM_PINNED_VARIABLE_NODE: { // Pinned variables are a way to match against the value of a variable // without it looking like you're trying to write to the variable. This // looks like: foo in ^@bar. To compile these, we compile the variable // that they hold. pm_pinned_variable_node_t *cast = (pm_pinned_variable_node_t *) node; - pm_compile_pattern(iseq, cast->variable, ret, src, scope_node, matched_label, unmatched_label, false); + CHECK(pm_compile_pattern(iseq, scope_node, cast->variable, ret, src, matched_label, unmatched_label, in_single_pattern, in_alternation_pattern, true, base_index)); break; } case PM_PINNED_EXPRESSION_NODE: { @@ -1351,7 +2054,70 @@ pm_compile_pattern(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const re // foo in ^(bar). To compile these, we compile the expression that they // hold. pm_pinned_expression_node_t *cast = (pm_pinned_expression_node_t *) node; - pm_compile_pattern(iseq, cast->expression, ret, src, scope_node, matched_label, unmatched_label, false); + CHECK(pm_compile_pattern(iseq, scope_node, cast->expression, ret, src, matched_label, unmatched_label, in_single_pattern, in_alternation_pattern, true, base_index)); + break; + } + case PM_IF_NODE: + case PM_UNLESS_NODE: { + // If and unless nodes can show up here as guards on `in` clauses. This + // looks like: + // + // case foo + // in bar if baz? + // qux + // end + // + // Because we know they're in the modifier form and they can't have any + // variation on this pattern, we compile them differently (more simply) + // here than we would in the normal compilation path. + const pm_node_t *predicate; + const pm_node_t *statement; + + if (PM_NODE_TYPE_P(node, PM_IF_NODE)) { + const pm_if_node_t *cast = (const pm_if_node_t *) node; + predicate = cast->predicate; + + assert(cast->statements != NULL && cast->statements->body.size == 1); + statement = cast->statements->body.nodes[0]; + } else { + const pm_unless_node_t *cast = (const pm_unless_node_t *) node; + predicate = cast->predicate; + + assert(cast->statements != NULL && cast->statements->body.size == 1); + statement = cast->statements->body.nodes[0]; + } + + CHECK(pm_compile_pattern_match(iseq, scope_node, statement, ret, src, unmatched_label, in_single_pattern, in_alternation_pattern, use_deconstructed_cache, base_index)); + PM_COMPILE_NOT_POPPED(predicate); + + if (in_single_pattern) { + LABEL *match_succeeded_label = NEW_LABEL(line.lineno); + + ADD_INSN(ret, &line.node, dup); + if (PM_NODE_TYPE_P(node, PM_IF_NODE)) { + ADD_INSNL(ret, &line.node, branchif, match_succeeded_label); + } else { + ADD_INSNL(ret, &line.node, branchunless, match_succeeded_label); + } + + ADD_INSN1(ret, &line.node, putobject, rb_fstring_lit("guard clause does not return true")); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_ERROR_STRING + 1)); + ADD_INSN1(ret, &line.node, putobject, Qfalse); + ADD_INSN1(ret, &line.node, setn, INT2FIX(base_index + PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_P + 2)); + + ADD_INSN(ret, &line.node, pop); + ADD_INSN(ret, &line.node, pop); + + ADD_LABEL(ret, match_succeeded_label); + } + + if (PM_NODE_TYPE_P(node, PM_IF_NODE)) { + ADD_INSNL(ret, &line.node, branchunless, unmatched_label); + } else { + ADD_INSNL(ret, &line.node, branchif, unmatched_label); + } + + ADD_INSNL(ret, &line.node, jump, matched_label); break; } default: @@ -1365,6 +2131,12 @@ pm_compile_pattern(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const re return COMPILE_OK; } +#undef PM_PATTERN_BASE_INDEX_OFFSET_DECONSTRUCTED_CACHE +#undef PM_PATTERN_BASE_INDEX_OFFSET_ERROR_STRING +#undef PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_P +#undef PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_MATCHEE +#undef PM_PATTERN_BASE_INDEX_OFFSET_KEY_ERROR_KEY + // Generate a scope node from the given node. void pm_scope_node_init(const pm_node_t *node, pm_scope_node_t *scope, pm_scope_node_t *previous, pm_parser_t *parser) @@ -2466,6 +3238,160 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, ADD_LABEL(ret, end_label); return; } + case PM_CASE_MATCH_NODE: { + // If you use the `case` keyword to create a case match node, it will + // match against all of the `in` clauses until it finds one that + // matches. If it doesn't find one, it can optionally fall back to an + // `else` clause. If none is present and a match wasn't found, it will + // raise an appropriate error. + const pm_case_match_node_t *cast = (const pm_case_match_node_t *) node; + + // This is the anchor that we will compile the bodies of the various + // `in` nodes into. We'll make sure that the patterns that are compiled + // jump into the correct spots within this anchor. + DECL_ANCHOR(body_seq); + INIT_ANCHOR(body_seq); + + // This is the anchor that we will compile the patterns of the various + // `in` nodes into. If a match is found, they will need to jump into the + // body_seq anchor to the correct spot. + DECL_ANCHOR(cond_seq); + INIT_ANCHOR(cond_seq); + + // This label is used to indicate the end of the entire node. It is + // jumped to after the entire stack is cleaned up. + LABEL *end_label = NEW_LABEL(lineno); + + // This label is used as the fallback for the case match. If no match is + // found, then we jump to this label. This is either an `else` clause or + // an error handler. + LABEL *else_label = NEW_LABEL(lineno); + + // We're going to use this to uniquely identify each branch so that we + // can track coverage information. + int branch_id = 0; + // VALUE branches = 0; + + // If there is only one pattern, then the behavior changes a bit. It + // effectively gets treated as a match required node (this is how it is + // represented in the other parser). + bool in_single_pattern = cast->consequent == NULL && cast->conditions.size == 1; + + // First, we're going to push a bunch of stuff onto the stack that is + // going to serve as our scratch space. + if (in_single_pattern) { + ADD_INSN(ret, &dummy_line_node, putnil); // key error key + ADD_INSN(ret, &dummy_line_node, putnil); // key error matchee + ADD_INSN1(ret, &dummy_line_node, putobject, Qfalse); // key error? + ADD_INSN(ret, &dummy_line_node, putnil); // error string + } + + // Now we're going to compile the value to match against. + ADD_INSN(ret, &dummy_line_node, putnil); // deconstruct cache + PM_COMPILE_NOT_POPPED(cast->predicate); + + // Next, we'll loop through every in clause and compile its body into + // the body_seq anchor and its pattern into the cond_seq anchor. We'll + // make sure the pattern knows how to jump correctly into the body if it + // finds a match. + for (size_t index = 0; index < cast->conditions.size; index++) { + const pm_node_t *condition = cast->conditions.nodes[index]; + assert(PM_NODE_TYPE_P(condition, PM_IN_NODE)); + + const pm_in_node_t *in_node = (const pm_in_node_t *) cast->conditions.nodes[index]; + + pm_line_node_t in_line; + pm_line_node(&in_line, scope_node, (const pm_node_t *) in_node); + + pm_line_node_t pattern_line; + pm_line_node(&pattern_line, scope_node, (const pm_node_t *) in_node->pattern); + + if (branch_id) { + ADD_INSN(body_seq, &in_line.node, putnil); + } + + LABEL *body_label = NEW_LABEL(in_line.lineno); + ADD_LABEL(body_seq, body_label); + ADD_INSN1(body_seq, &in_line.node, adjuststack, INT2FIX(in_single_pattern ? 6 : 2)); + + // TODO: We need to come back to this and enable trace branch + // coverage. At the moment we can't call this function because it + // accepts a NODE* and not a pm_node_t*. + // add_trace_branch_coverage(iseq, body_seq, in_node->statements || in, branch_id++, "in", branches); + + branch_id++; + if (in_node->statements != NULL) { + PM_COMPILE_INTO_ANCHOR(body_seq, (const pm_node_t *) in_node->statements); + } else if (!popped) { + ADD_INSN(body_seq, &in_line.node, putnil); + } + + ADD_INSNL(body_seq, &in_line.node, jump, end_label); + LABEL *next_pattern_label = NEW_LABEL(pattern_line.lineno); + + ADD_INSN(cond_seq, &pattern_line.node, dup); + pm_compile_pattern(iseq, scope_node, in_node->pattern, cond_seq, src, body_label, next_pattern_label, in_single_pattern, false, true, 2); + ADD_LABEL(cond_seq, next_pattern_label); + LABEL_UNREMOVABLE(next_pattern_label); + } + + if (cast->consequent != NULL) { + // If we have an `else` clause, then this becomes our fallback (and + // there is no need to compile in code to potentially raise an + // error). + const pm_else_node_t *else_node = (const pm_else_node_t *) cast->consequent; + + ADD_LABEL(cond_seq, else_label); + ADD_INSN(cond_seq, &dummy_line_node, pop); + ADD_INSN(cond_seq, &dummy_line_node, pop); + + // TODO: trace branch coverage + // add_trace_branch_coverage(iseq, cond_seq, cast->consequent, branch_id, "else", branches); + + if (else_node->statements != NULL) { + PM_COMPILE_INTO_ANCHOR(cond_seq, (const pm_node_t *) else_node->statements); + } else if (!popped) { + ADD_INSN(cond_seq, &dummy_line_node, putnil); + } + + ADD_INSNL(cond_seq, &dummy_line_node, jump, end_label); + ADD_INSN(cond_seq, &dummy_line_node, putnil); + if (popped) { + ADD_INSN(cond_seq, &dummy_line_node, putnil); + } + } else { + // Otherwise, if we do not have an `else` clause, we will compile in + // the code to handle raising an appropriate error. + ADD_LABEL(cond_seq, else_label); + + // TODO: trace branch coverage + // add_trace_branch_coverage(iseq, cond_seq, orig_node, branch_id, "else", branches); + + if (in_single_pattern) { + pm_compile_pattern_error_handler(iseq, scope_node, node, cond_seq, src, end_label, popped); + } else { + ADD_INSN1(cond_seq, &dummy_line_node, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE)); + ADD_INSN1(cond_seq, &dummy_line_node, putobject, rb_eNoMatchingPatternError); + ADD_INSN1(cond_seq, &dummy_line_node, topn, INT2FIX(2)); + ADD_SEND(cond_seq, &dummy_line_node, id_core_raise, INT2FIX(2)); + + ADD_INSN1(cond_seq, &dummy_line_node, adjuststack, INT2FIX(3)); + if (!popped) ADD_INSN(cond_seq, &dummy_line_node, putnil); + ADD_INSNL(cond_seq, &dummy_line_node, jump, end_label); + ADD_INSN1(cond_seq, &dummy_line_node, dupn, INT2FIX(1)); + if (popped) ADD_INSN(cond_seq, &dummy_line_node, putnil); + } + } + + // At the end of all of this compilation, we will add the code for the + // conditions first, then the various bodies, then mark the end of the + // entire sequence with the end label. + ADD_SEQ(ret, cond_seq); + ADD_SEQ(ret, body_seq); + ADD_LABEL(ret, end_label); + + return; + } case PM_CLASS_NODE: { pm_class_node_t *class_node = (pm_class_node_t *)node; pm_scope_node_t next_scope_node; @@ -3335,6 +4261,12 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, PM_COMPILE(cast->value); return; } + case PM_IN_NODE: { + // In nodes are handled by the case match node directly, so we should + // never end up hitting them through this path. + rb_bug("Should not ever enter an in node directly"); + return; + } case PM_INDEX_AND_WRITE_NODE: { pm_index_and_write_node_t *index_and_write_node = (pm_index_and_write_node_t *)node; @@ -3744,7 +4676,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, LABEL *matched_label = NEW_LABEL(lineno); LABEL *unmatched_label = NEW_LABEL(lineno); LABEL *done_label = NEW_LABEL(lineno); - pm_compile_pattern(iseq, cast->pattern, ret, src, scope_node, matched_label, unmatched_label, false); + pm_compile_pattern(iseq, scope_node, cast->pattern, ret, src, matched_label, unmatched_label, false, false, true, 2); // If the pattern did not match, then compile the necessary instructions // to handle pushing false onto the stack, then jump to the end. @@ -3766,6 +4698,60 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, ADD_LABEL(ret, done_label); return; } + case PM_MATCH_REQUIRED_NODE: { + // A match required node represents pattern matching against a single + // pattern using the => operator. For example, + // + // foo => bar + // + // This is somewhat analogous to compiling a case match statement with a + // single pattern. In both cases, if the pattern fails it should + // immediately raise an error. + const pm_match_required_node_t *cast = (const pm_match_required_node_t *) node; + + LABEL *matched_label = NEW_LABEL(lineno); + LABEL *unmatched_label = NEW_LABEL(lineno); + LABEL *done_label = NEW_LABEL(lineno); + + // First, we're going to push a bunch of stuff onto the stack that is + // going to serve as our scratch space. + ADD_INSN(ret, &dummy_line_node, putnil); // key error key + ADD_INSN(ret, &dummy_line_node, putnil); // key error matchee + ADD_INSN1(ret, &dummy_line_node, putobject, Qfalse); // key error? + ADD_INSN(ret, &dummy_line_node, putnil); // error string + ADD_INSN(ret, &dummy_line_node, putnil); // deconstruct cache + + // Next we're going to compile the value expression such that it's on + // the stack. + PM_COMPILE_NOT_POPPED(cast->value); + + // Here we'll dup it so that it can be used for comparison, but also be + // used for error handling. + ADD_INSN(ret, &dummy_line_node, dup); + + // Next we'll compile the pattern. We indicate to the pm_compile_pattern + // function that this is the only pattern that will be matched against + // through the in_single_pattern parameter. We also indicate that the + // value to compare against is 2 slots from the top of the stack (the + // base_index parameter). + pm_compile_pattern(iseq, scope_node, cast->pattern, ret, src, matched_label, unmatched_label, true, false, true, 2); + + // If the pattern did not match the value, then we're going to compile + // in our error handler code. This will determine which error to raise + // and raise it. + ADD_LABEL(ret, unmatched_label); + pm_compile_pattern_error_handler(iseq, scope_node, node, ret, src, done_label, popped); + + // If the pattern did match, we'll clean up the values we've pushed onto + // the stack and then push nil onto the stack if it's not popped. + ADD_LABEL(ret, matched_label); + ADD_INSN1(ret, &dummy_line_node, adjuststack, INT2FIX(6)); + if (!popped) ADD_INSN(ret, &dummy_line_node, putnil); + ADD_INSNL(ret, &dummy_line_node, jump, done_label); + + ADD_LABEL(ret, done_label); + return; + } case PM_MATCH_WRITE_NODE: { // Match write nodes are specialized call nodes that have a regular // expression with valid named capture groups on the left, the =~ @@ -5361,7 +6347,6 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, } case PM_WHEN_NODE: { rb_bug("Should not ever enter a when node directly"); - return; } case PM_WHILE_NODE: { diff --git a/test/ruby/test_compile_prism.rb b/test/ruby/test_compile_prism.rb index 192094bafe1975..e01344ccc5eb1e 100644 --- a/test/ruby/test_compile_prism.rb +++ b/test/ruby/test_compile_prism.rb @@ -1715,6 +1715,164 @@ def test_AlternationPatternNode assert_prism_eval("1 in 2 | 3") end + def test_ArrayPatternNode + assert_prism_eval("[] => []") + + ["in", "=>"].each do |operator| + ["", "Array"].each do |constant| + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[1, 2, 3]") + + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[*]") + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[1, *]") + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[1, 2, *]") + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[1, 2, 3, *]") + + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[*foo]") + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[1, *foo]") + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[1, 2, *foo]") + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[1, 2, 3, *foo]") + + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[*, 3]") + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[*, 2, 3]") + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[*, 1, 2, 3]") + + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[*foo, 3]") + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[*foo, 2, 3]") + assert_prism_eval("[1, 2, 3] #{operator} #{constant}[*foo, 1, 2, 3]") + end + end + + assert_prism_eval("begin; Object.new => [1, 2, 3]; rescue NoMatchingPatternError; true; end") + assert_prism_eval("begin; [1, 2, 3] => Object[1, 2, 3]; rescue NoMatchingPatternError; true; end") + end + + def test_CapturePatternNode + assert_prism_eval("[1] => [Integer => foo]") + end + + def test_CaseMatchNode + assert_prism_eval(<<~RUBY) + case [1, 2, 3] + in [1, 2, 3] + 4 + end + RUBY + + assert_prism_eval(<<~RUBY) + case { a: 5, b: 6 } + in [1, 2, 3] + 4 + in { a: 5, b: 6 } + 7 + end + RUBY + + assert_prism_eval(<<~RUBY) + case [1, 2, 3, 4] + in [1, 2, 3] + 4 + in { a: 5, b: 6 } + 7 + else + end + RUBY + + assert_prism_eval(<<~RUBY) + case [1, 2, 3, 4] + in [1, 2, 3] + 4 + in { a: 5, b: 6 } + 7 + else + 8 + end + RUBY + + assert_prism_eval(<<~RUBY) + case [1, 2, 3] + in [1, 2, 3] unless to_s + in [1, 2, 3] if to_s.nil? + in [1, 2, 3] + true + end + RUBY + end + + def test_FindPatternNode + ["in", "=>"].each do |operator| + ["", "Array"].each do |constant| + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 1, 2, 3, 4, 5, *]") + + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 1, *]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 3, *]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 5, *]") + + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 1, 2, *]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 2, 3, *]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 3, 4, *]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 4, 5, *]") + + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 1, 2, 3, *]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 2, 3, 4, *]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 3, 4, 5, *]") + + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 1, 2, 3, 4, *]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 2, 3, 4, 5, *]") + + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*foo, 3, *]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*foo, 3, 4, *]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*foo, 3, 4, 5, *]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*foo, 1, 2, 3, 4, *]") + + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 3, *foo]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 3, 4, *foo]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 3, 4, 5, *foo]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*, 1, 2, 3, 4, *foo]") + + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*foo, 3, *bar]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*foo, 3, 4, *bar]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*foo, 3, 4, 5, *bar]") + assert_prism_eval("[1, 2, 3, 4, 5] #{operator} #{constant}[*foo, 1, 2, 3, 4, *bar]") + end + end + + assert_prism_eval("[1, [2, [3, [4, [5]]]]] => [*, [*, [*, [*, [*]]]]]") + assert_prism_eval("[1, [2, [3, [4, [5]]]]] => [1, [2, [3, [4, [5]]]]]") + + assert_prism_eval("begin; Object.new => [*, 2, *]; rescue NoMatchingPatternError; true; end") + assert_prism_eval("begin; [1, 2, 3] => Object[*, 2, *]; rescue NoMatchingPatternError; true; end") + end + + def test_HashPatternNode + assert_prism_eval("{} => {}") + + [["{ ", " }"], ["Hash[", "]"]].each do |(prefix, suffix)| + assert_prism_eval("{} => #{prefix} **nil #{suffix}") + + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} a: 1 #{suffix}") + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} a: 1, b: 2 #{suffix}") + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} b: 2, c: 3 #{suffix}") + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} a: 1, b: 2, c: 3 #{suffix}") + + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} ** #{suffix}") + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} a: 1, ** #{suffix}") + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} a: 1, b: 2, ** #{suffix}") + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} b: 2, c: 3, ** #{suffix}") + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} a: 1, b: 2, c: 3, ** #{suffix}") + + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} **foo #{suffix}") + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} a: 1, **foo #{suffix}") + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} a: 1, b: 2, **foo #{suffix}") + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} b: 2, c: 3, **foo #{suffix}") + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} a: 1, b: 2, c: 3, **foo #{suffix}") + + assert_prism_eval("{ a: 1 } => #{prefix} a: 1, **nil #{suffix}") + assert_prism_eval("{ a: 1, b: 2, c: 3 } => #{prefix} a: 1, b: 2, c: 3, **nil #{suffix}") + end + + assert_prism_eval("{ a: { b: { c: 1 } } } => { a: { b: { c: 1 } } }") + end + def test_MatchPredicateNode assert_prism_eval("1 in 1") assert_prism_eval("1.0 in 1.0") @@ -1748,6 +1906,33 @@ def test_MatchPredicateNode assert_prism_eval("1 in 2") end + def test_MatchRequiredNode + assert_prism_eval("1 => 1") + assert_prism_eval("1.0 => 1.0") + assert_prism_eval("1i => 1i") + assert_prism_eval("1r => 1r") + + assert_prism_eval("\"foo\" => \"foo\"") + assert_prism_eval("\"foo \#{1}\" => \"foo \#{1}\"") + + assert_prism_eval("false => false") + assert_prism_eval("nil => nil") + assert_prism_eval("true => true") + + assert_prism_eval("5 => 0..10") + assert_prism_eval("5 => 0...10") + + assert_prism_eval("[\"5\"] => %w[5]") + + assert_prism_eval(":prism => :prism") + assert_prism_eval("%s[prism\#{1}] => %s[prism\#{1}]") + assert_prism_eval("\"foo\" => /.../") + assert_prism_eval("\"foo1\" => /...\#{1}/") + assert_prism_eval("4 => ->(v) { v.even? }") + + assert_prism_eval("5 => foo") + end + def test_PinnedExpressionNode assert_prism_eval("4 in ^(4)") end