Skip to content

Commit

Permalink
YJIT: Expand codegen for TrueClass#=== to FalseClass and `NilClas…
Browse files Browse the repository at this point in the history
…s` (#10679)
  • Loading branch information
rwstauner committed Apr 29, 2024
1 parent 845f2db commit adae813
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 9 deletions.
70 changes: 70 additions & 0 deletions bootstraptest/test_yjit.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4894,3 +4894,73 @@ class TrueClass
results << test
} unless rjit_enabled? # Not yet working on RJIT

# test FalseClass#=== before and after redefining FalseClass#==
assert_equal '[[true, false, false], [true, false, true], [true, :error, :error]]', %q{
def case_equal(x, y)
x === y
rescue NoMethodError
:error
end
def test
[
# first one is always true because rb_equal does object comparison before calling #==
case_equal(false, false),
# these will use #==
case_equal(false, true),
case_equal(false, nil),
]
end
results = [test]
class FalseClass
def ==(x)
!x
end
end
results << test
class FalseClass
undef_method :==
end
results << test
} unless rjit_enabled? # Not yet working on RJIT

# test NilClass#=== before and after redefining NilClass#==
assert_equal '[[true, false, false], [true, false, true], [true, :error, :error]]', %q{
def case_equal(x, y)
x === y
rescue NoMethodError
:error
end
def test
[
# first one is always true because rb_equal does object comparison before calling #==
case_equal(nil, nil),
# these will use #==
case_equal(nil, true),
case_equal(nil, false),
]
end
results = [test]
class NilClass
def ==(x)
!x
end
end
results << test
class NilClass
undef_method :==
end
results << test
} unless rjit_enabled? # Not yet working on RJIT
25 changes: 16 additions & 9 deletions yjit/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6191,23 +6191,23 @@ fn jit_rb_class_superclass(
true
}

// Codegen for rb_trueclass_case_equal()
fn jit_rb_trueclass_case_equal(
fn jit_rb_case_equal(
jit: &mut JITState,
asm: &mut Assembler,
ocb: &mut OutlinedCb,
_ci: *const rb_callinfo,
_cme: *const rb_callable_method_entry_t,
_block: Option<BlockHandler>,
_argc: i32,
_known_recv_class: Option<VALUE>,
known_recv_class: Option<VALUE>,
) -> bool {
if !jit.assume_expected_cfunc( asm, ocb, unsafe { rb_cTrueClass }, ID!(eq), rb_obj_equal as _) {
if !jit.assume_expected_cfunc( asm, ocb, known_recv_class.unwrap(), ID!(eq), rb_obj_equal as _) {
return false;
}

asm_comment!(asm, "case_equal: {}#===", get_class_name(known_recv_class));

// Compare the arguments
asm_comment!(asm, "TrueClass#===");
let arg1 = asm.stack_pop(1);
let arg0 = asm.stack_pop(1);
asm.cmp(arg0, arg1);
Expand Down Expand Up @@ -8882,11 +8882,16 @@ fn gen_send_general(
}
}

/// Get class name from a class pointer.
fn get_class_name(class: Option<VALUE>) -> String {
class.and_then(|class| unsafe {
cstr_to_rust_string(rb_class2name(class))
}).unwrap_or_else(|| "Unknown".to_string())
}

/// Assemble "{class_name}#{method_name}" from a class pointer and a method ID
fn get_method_name(class: Option<VALUE>, mid: u64) -> String {
let class_name = class.and_then(|class| unsafe {
cstr_to_rust_string(rb_class2name(class))
}).unwrap_or_else(|| "Unknown".to_string());
let class_name = get_class_name(class);
let method_name = if mid != 0 {
unsafe { cstr_to_rust_string(rb_id2name(mid)) }
} else {
Expand Down Expand Up @@ -10221,7 +10226,9 @@ pub fn yjit_reg_method_codegen_fns() {
yjit_reg_method(rb_cString, "<<", jit_rb_str_concat);
yjit_reg_method(rb_cString, "+@", jit_rb_str_uplus);

yjit_reg_method(rb_cTrueClass, "===", jit_rb_trueclass_case_equal);
yjit_reg_method(rb_cNilClass, "===", jit_rb_case_equal);
yjit_reg_method(rb_cTrueClass, "===", jit_rb_case_equal);
yjit_reg_method(rb_cFalseClass, "===", jit_rb_case_equal);

yjit_reg_method(rb_cArray, "empty?", jit_rb_ary_empty_p);
yjit_reg_method(rb_cArray, "length", jit_rb_ary_length);
Expand Down

0 comments on commit adae813

Please sign in to comment.