Skip to content

Commit

Permalink
Mark the first string element of a regexp as binary if US-ASCII
Browse files Browse the repository at this point in the history
  • Loading branch information
kddnewton committed May 3, 2024
1 parent b5cefa7 commit 5409661
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 24 deletions.
8 changes: 8 additions & 0 deletions prism/prism.c
Original file line number Diff line number Diff line change
Expand Up @@ -19173,6 +19173,14 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b
pm_token_t opening = not_provided(parser);
pm_token_t closing = not_provided(parser);
pm_node_t *part = (pm_node_t *) pm_string_node_create_unescaped(parser, &opening, &parser->previous, &closing, &unescaped);

if (parser->encoding == PM_ENCODING_US_ASCII_ENTRY) {
// This is extremely strange, but the first string part of a
// regular expression will always be tagged as binary if we
// are in a US-ASCII file, no matter its contents.
pm_node_flag_set(part, PM_STRING_FLAGS_FORCED_BINARY_ENCODING);
}

pm_interpolated_regular_expression_node_append(interpolated, part);
} else {
// If the first part of the body of the regular expression is not a
Expand Down
84 changes: 61 additions & 23 deletions prism_compile.c
Original file line number Diff line number Diff line change
Expand Up @@ -363,18 +363,34 @@ parse_regexp_error(rb_iseq_t *iseq, int32_t line_number, const char *fmt, ...)
}

static VALUE
parse_regexp_string_part(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped, rb_encoding *regexp_encoding)
parse_regexp_string_part(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped, rb_encoding *implicit_regexp_encoding, rb_encoding *explicit_regexp_encoding)
{
// If we were passed an explicit regexp encoding, then we need to double
// check that it's okay here for this fragment of the string.
VALUE string = rb_enc_str_new((const char *) pm_string_source(unescaped), pm_string_length(unescaped), regexp_encoding);
rb_encoding *encoding;

if (explicit_regexp_encoding != NULL) {
encoding = explicit_regexp_encoding;
}
else if (node->flags & PM_STRING_FLAGS_FORCED_BINARY_ENCODING) {
encoding = rb_ascii8bit_encoding();
}
else if (node->flags & PM_STRING_FLAGS_FORCED_UTF8_ENCODING) {
encoding = rb_utf8_encoding();
}
else {
encoding = implicit_regexp_encoding;
}

VALUE string = rb_enc_str_new((const char *) pm_string_source(unescaped), pm_string_length(unescaped), encoding);
VALUE error = rb_reg_check_preprocess(string);

if (error != Qnil) parse_regexp_error(iseq, pm_node_line_number(scope_node->parser, node), "%" PRIsVALUE, rb_obj_as_string(error));
return string;
}

static VALUE
pm_static_literal_concat(rb_iseq_t *iseq, const pm_node_list_t *nodes, const pm_scope_node_t *scope_node, rb_encoding *regexp_encoding, bool top)
pm_static_literal_concat(rb_iseq_t *iseq, const pm_node_list_t *nodes, const pm_scope_node_t *scope_node, rb_encoding *implicit_regexp_encoding, rb_encoding *explicit_regexp_encoding, bool top)
{
VALUE current = Qnil;

Expand All @@ -384,9 +400,9 @@ pm_static_literal_concat(rb_iseq_t *iseq, const pm_node_list_t *nodes, const pm_

switch (PM_NODE_TYPE(part)) {
case PM_STRING_NODE:
if (regexp_encoding != NULL) {
if (implicit_regexp_encoding != NULL) {
if (top) {
string = parse_regexp_string_part(iseq, scope_node, part, &((const pm_string_node_t *) part)->unescaped, regexp_encoding);
string = parse_regexp_string_part(iseq, scope_node, part, &((const pm_string_node_t *) part)->unescaped, implicit_regexp_encoding, explicit_regexp_encoding);
}
else {
string = parse_string_encoded(part, &((const pm_string_node_t *) part)->unescaped, scope_node->encoding);
Expand All @@ -399,11 +415,11 @@ pm_static_literal_concat(rb_iseq_t *iseq, const pm_node_list_t *nodes, const pm_
}
break;
case PM_INTERPOLATED_STRING_NODE:
string = pm_static_literal_concat(iseq, &((const pm_interpolated_string_node_t *) part)->parts, scope_node, regexp_encoding, false);
string = pm_static_literal_concat(iseq, &((const pm_interpolated_string_node_t *) part)->parts, scope_node, implicit_regexp_encoding, explicit_regexp_encoding, false);
break;
case PM_EMBEDDED_STATEMENTS_NODE: {
const pm_embedded_statements_node_t *cast = (const pm_embedded_statements_node_t *) part;
string = pm_static_literal_concat(iseq, &cast->statements->body, scope_node, regexp_encoding, false);
string = pm_static_literal_concat(iseq, &cast->statements->body, scope_node, implicit_regexp_encoding, explicit_regexp_encoding, false);
break;
}
default:
Expand Down Expand Up @@ -499,7 +515,7 @@ parse_regexp_encoding(const pm_scope_node_t *scope_node, const pm_node_t *node)
return rb_enc_get_from_index(ENCINDEX_Windows_31J);
}
else {
return scope_node->encoding;
return NULL;
}
}

Expand Down Expand Up @@ -527,22 +543,26 @@ static inline VALUE
parse_regexp_literal(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped)
{
rb_encoding *regexp_encoding = parse_regexp_encoding(scope_node, node);
if (regexp_encoding == NULL) regexp_encoding = scope_node->encoding;

VALUE string = rb_enc_str_new((const char *) pm_string_source(unescaped), pm_string_length(unescaped), regexp_encoding);
return parse_regexp(iseq, scope_node, node, string);
}

static inline VALUE
parse_regexp_concat(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_node_list_t *parts)
{
rb_encoding *regexp_encoding = parse_regexp_encoding(scope_node, node);
VALUE string = pm_static_literal_concat(iseq, parts, scope_node, regexp_encoding, false);
rb_encoding *explicit_regexp_encoding = parse_regexp_encoding(scope_node, node);
rb_encoding *implicit_regexp_encoding = explicit_regexp_encoding != NULL ? explicit_regexp_encoding : scope_node->encoding;

VALUE string = pm_static_literal_concat(iseq, parts, scope_node, implicit_regexp_encoding, explicit_regexp_encoding, false);
return parse_regexp(iseq, scope_node, node, string);
}

static void pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, bool popped, pm_scope_node_t *scope_node);

static int
pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const pm_line_column_t *node_location, LINK_ANCHOR *const ret, bool popped, pm_scope_node_t *scope_node, rb_encoding *regexp_encoding)
pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const pm_line_column_t *node_location, LINK_ANCHOR *const ret, bool popped, pm_scope_node_t *scope_node, rb_encoding *implicit_regexp_encoding, rb_encoding *explicit_regexp_encoding)
{
int stack_size = 0;
size_t parts_size = parts->size;
Expand All @@ -558,11 +578,11 @@ pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const
const pm_string_node_t *string_node = (const pm_string_node_t *) part;
VALUE string_value;

if (regexp_encoding == NULL) {
if (implicit_regexp_encoding == NULL) {
string_value = parse_string_encoded(part, &string_node->unescaped, scope_node->encoding);
}
else {
string_value = parse_regexp_string_part(iseq, scope_node, (const pm_node_t *) string_node, &string_node->unescaped, regexp_encoding);
string_value = parse_regexp_string_part(iseq, scope_node, (const pm_node_t *) string_node, &string_node->unescaped, implicit_regexp_encoding, explicit_regexp_encoding);
}

if (RTEST(current_string)) {
Expand All @@ -584,11 +604,11 @@ pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const
const pm_string_node_t *string_node = (const pm_string_node_t *) ((const pm_embedded_statements_node_t *) part)->statements->body.nodes[0];
VALUE string_value;

if (regexp_encoding == NULL) {
if (implicit_regexp_encoding == NULL) {
string_value = parse_string_encoded(part, &string_node->unescaped, scope_node->encoding);
}
else {
string_value = parse_regexp_string_part(iseq, scope_node, (const pm_node_t *) string_node, &string_node->unescaped, regexp_encoding);
string_value = parse_regexp_string_part(iseq, scope_node, (const pm_node_t *) string_node, &string_node->unescaped, implicit_regexp_encoding, explicit_regexp_encoding);
}

if (RTEST(current_string)) {
Expand All @@ -600,7 +620,24 @@ pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const
}
else {
if (!RTEST(current_string)) {
current_string = rb_enc_str_new(NULL, 0, regexp_encoding != NULL ? regexp_encoding : scope_node->encoding);
rb_encoding *encoding;

if (implicit_regexp_encoding != NULL) {
if (explicit_regexp_encoding != NULL) {
encoding = explicit_regexp_encoding;
}
else if (scope_node->parser->encoding == PM_ENCODING_US_ASCII_ENTRY) {
encoding = rb_ascii8bit_encoding();
}
else {
encoding = implicit_regexp_encoding;
}
}
else {
encoding = scope_node->encoding;
}

current_string = rb_enc_str_new(NULL, 0, encoding);
}

PUSH_INSN1(ret, *node_location, putobject, rb_fstring(current_string));
Expand Down Expand Up @@ -639,9 +676,10 @@ pm_interpolated_node_compile(rb_iseq_t *iseq, const pm_node_list_t *parts, const
static void
pm_compile_regexp_dynamic(rb_iseq_t *iseq, const pm_node_t *node, const pm_node_list_t *parts, const pm_line_column_t *node_location, LINK_ANCHOR *const ret, bool popped, pm_scope_node_t *scope_node)
{
rb_encoding *regexp_encoding = parse_regexp_encoding(scope_node, node);
int length = pm_interpolated_node_compile(iseq, parts, node_location, ret, popped, scope_node, regexp_encoding);
rb_encoding *explicit_regexp_encoding = parse_regexp_encoding(scope_node, node);
rb_encoding *implicit_regexp_encoding = explicit_regexp_encoding != NULL ? explicit_regexp_encoding : scope_node->encoding;

int length = pm_interpolated_node_compile(iseq, parts, node_location, ret, popped, scope_node, implicit_regexp_encoding, explicit_regexp_encoding);
PUSH_INSN2(ret, *node_location, toregexp, INT2FIX(parse_regexp_flags(node) & 0xFF), INT2FIX(length));
}

Expand Down Expand Up @@ -738,13 +776,13 @@ pm_static_literal_value(rb_iseq_t *iseq, const pm_node_t *node, const pm_scope_n
return parse_regexp_concat(iseq, scope_node, (const pm_node_t *) cast, &cast->parts);
}
case PM_INTERPOLATED_STRING_NODE: {
VALUE string = pm_static_literal_concat(iseq, &((const pm_interpolated_string_node_t *) node)->parts, scope_node, NULL, false);
VALUE string = pm_static_literal_concat(iseq, &((const pm_interpolated_string_node_t *) node)->parts, scope_node, NULL, NULL, false);
int line_number = pm_node_line_number(scope_node->parser, node);
return pm_static_literal_string(iseq, string, line_number);
}
case PM_INTERPOLATED_SYMBOL_NODE: {
const pm_interpolated_symbol_node_t *cast = (const pm_interpolated_symbol_node_t *) node;
VALUE string = pm_static_literal_concat(iseq, &cast->parts, scope_node, NULL, true);
VALUE string = pm_static_literal_concat(iseq, &cast->parts, scope_node, NULL, NULL, true);

return ID2SYM(rb_intern_str(string));
}
Expand Down Expand Up @@ -6524,7 +6562,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
}
else {
const pm_interpolated_string_node_t *cast = (const pm_interpolated_string_node_t *) node;
int length = pm_interpolated_node_compile(iseq, &cast->parts, &location, ret, popped, scope_node, NULL);
int length = pm_interpolated_node_compile(iseq, &cast->parts, &location, ret, popped, scope_node, NULL, NULL);
if (length > 1) PUSH_INSN1(ret, location, concatstrings, INT2FIX(length));
if (popped) PUSH_INSN(ret, location, pop);
}
Expand All @@ -6543,7 +6581,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,
}
}
else {
int length = pm_interpolated_node_compile(iseq, &cast->parts, &location, ret, popped, scope_node, NULL);
int length = pm_interpolated_node_compile(iseq, &cast->parts, &location, ret, popped, scope_node, NULL, NULL);
if (length > 1) {
PUSH_INSN1(ret, location, concatstrings, INT2FIX(length));
}
Expand All @@ -6565,7 +6603,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret,

PUSH_INSN(ret, location, putself);

int length = pm_interpolated_node_compile(iseq, &cast->parts, &location, ret, false, scope_node, NULL);
int length = pm_interpolated_node_compile(iseq, &cast->parts, &location, ret, false, scope_node, NULL, NULL);
if (length > 1) PUSH_INSN1(ret, location, concatstrings, INT2FIX(length));

PUSH_SEND_WITH_FLAG(ret, location, idBackquote, INT2NUM(1), INT2FIX(VM_CALL_FCALL | VM_CALL_ARGS_SIMPLE));
Expand Down
1 change: 0 additions & 1 deletion test/.excludes-prism/TestM17N.rb
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
exclude(:test_regexp_ascii, "https://github.com/ruby/prism/issues/2664")
exclude(:test_regexp_usascii, "unknown")
exclude(:test_string_mixed_unicode, "unknown")

0 comments on commit 5409661

Please sign in to comment.