Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YJIT: expandarray for non-arrays #9495

Merged
merged 3 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions bootstraptest/test_yjit.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2494,6 +2494,23 @@ def obj.to_ary
expandarray_not_array(obj)
}

assert_equal '[1, 2]', %q{
class NilClass
private
def to_ary
[1, 2]
end
end

def expandarray_redefined_nilclass
a, b = nil
[a, b]
end

expandarray_redefined_nilclass
expandarray_redefined_nilclass
} unless rjit_enabled?

assert_equal '[1, 2, nil]', %q{
def expandarray_rhs_too_small
a, b, c = [1, 2]
Expand Down
62 changes: 46 additions & 16 deletions yjit/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1637,18 +1637,6 @@ fn gen_expandarray(

let array_opnd = asm.stack_opnd(0);

// If the array operand is nil, just push on nils
if asm.ctx.get_opnd_type(array_opnd.into()) == Type::Nil {
asm.stack_pop(1); // pop after using the type info
// special case for a, b = nil pattern
// push N nils onto the stack
for _ in 0..num {
let push_opnd = asm.stack_push(Type::Nil);
asm.mov(push_opnd, Qnil.into());
}
return Some(KeepCompiling);
}

// Defer compilation so we can specialize on a runtime `self`
if !jit.at_current_insn() {
defer_compilation(jit, asm, ocb);
Expand All @@ -1657,10 +1645,52 @@ fn gen_expandarray(

let comptime_recv = jit.peek_at_stack(&asm.ctx, 0);

// If the comptime receiver is not an array, bail
if comptime_recv.class_of() != unsafe { rb_cArray } {
gen_counter_incr(asm, Counter::expandarray_comptime_not_array);
return None;
// If the comptime receiver is not an array
if !unsafe { RB_TYPE_P(comptime_recv, RUBY_T_ARRAY) } {
// at compile time, ensure to_ary is not defined
let target_cme = unsafe { rb_callable_method_entry_or_negative(comptime_recv.class_of(), ID!(to_ary)) };
let cme_def_type = unsafe { get_cme_def_type(target_cme) };

// if to_ary is defined, return can't compile so to_ary can be called
if cme_def_type != VM_METHOD_TYPE_UNDEF {
gen_counter_incr(asm, Counter::expandarray_to_ary);
return None;
maximecb marked this conversation as resolved.
Show resolved Hide resolved
}

// invalidate compile block if to_ary is later defined
jit.assume_method_lookup_stable(asm, ocb, target_cme);

jit_guard_known_klass(
jit,
asm,
ocb,
comptime_recv.class_of(),
array_opnd,
array_opnd.into(),
comptime_recv,
SEND_MAX_DEPTH,
Counter::expandarray_not_array,
);

let opnd = asm.stack_pop(1); // pop after using the type info

// If we don't actually want any values, then just keep going
if num == 0 {
return Some(KeepCompiling);
}

// load opnd to avoid a race because we are also pushing onto the stack
let opnd = asm.load(opnd);

for _ in 1..num {
let push_opnd = asm.stack_push(Type::Nil);
asm.mov(push_opnd, Qnil.into());
}

let push_opnd = asm.stack_push(Type::Unknown);
asm.mov(push_opnd, opnd);

return Some(KeepCompiling);
}

// Get the compile-time array length
Expand Down
1 change: 1 addition & 0 deletions yjit/src/cruby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ pub(crate) mod ids {
name: max content: b"max"
name: hash content: b"hash"
name: respond_to_missing content: b"respond_to_missing?"
name: to_ary content: b"to_ary"
}
}

Expand Down
2 changes: 1 addition & 1 deletion yjit/src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ make_counters! {
expandarray_splat,
expandarray_postarg,
expandarray_not_array,
expandarray_comptime_not_array,
expandarray_to_ary,
expandarray_chain_max_depth,

// getblockparam
Expand Down