diff --git a/lib/ruby_vm/rjit/insn_compiler.rb b/lib/ruby_vm/rjit/insn_compiler.rb index eea9315a119fd7..be497bc9a116ac 100644 --- a/lib/ruby_vm/rjit/insn_compiler.rb +++ b/lib/ruby_vm/rjit/insn_compiler.rb @@ -1479,30 +1479,120 @@ def opt_newarray_min(jit, ctx, asm) # @param ctx [RubyVM::RJIT::Context] # @param asm [RubyVM::RJIT::Assembler] def invokesuper(jit, ctx, asm) - # Specialize on a compile-time receiver, and split a block for chain guards + cd = C.rb_call_data.new(jit.operand(0)) + block = jit.operand(1) + + # Defer compilation so we can specialize on class of receiver unless jit.at_current_insn? defer_compilation(jit, ctx, asm) return EndBlock end - cd = C.rb_call_data.new(jit.operand(0)) - blockiseq = jit.operand(1) + me = C.rb_vm_frame_method_entry(jit.cfp) + if me.nil? + return CantCompile + end - block_handler = jit_caller_setup_arg_block(jit, ctx, asm, cd.ci, blockiseq, true) - if block_handler == CantCompile + # FIXME: We should track and invalidate this block when this cme is invalidated + current_defined_class = me.defined_class + mid = me.def.original_id + + if me.to_i != C.rb_callable_method_entry(current_defined_class, me.called_id).to_i + # Though we likely could generate this call, as we are only concerned + # with the method entry remaining valid, assume_method_lookup_stable + # below requires that the method lookup matches as well return CantCompile end - # calling->ci - mid = C.vm_ci_mid(cd.ci) - calling = build_calling(ci: cd.ci, block_handler:) + # vm_search_normal_superclass + rbasic_klass = C.to_ruby(C.RBasic.new(C.to_value(current_defined_class)).klass) + if C::BUILTIN_TYPE(current_defined_class) == C::RUBY_T_ICLASS && C::BUILTIN_TYPE(rbasic_klass) == C::RUBY_T_MODULE && \ + C::FL_TEST_RAW(rbasic_klass, C::RMODULE_IS_REFINEMENT) != 0 + return CantCompile + end + comptime_superclass = C.rb_class_get_superclass(C.RCLASS_ORIGIN(current_defined_class)) - # vm_sendish - cme = jit_search_super_method(jit, ctx, asm, mid, calling) - if cme == CantCompile + ci = cd.ci + argc = C.vm_ci_argc(ci) + + ci_flags = C.vm_ci_flag(ci) + + # Don't JIT calls that aren't simple + # Note, not using VM_CALL_ARGS_SIMPLE because sometimes we pass a block. + + if ci_flags & C::VM_CALL_KWARG != 0 + asm.incr_counter(:send_keywords) return CantCompile end - jit_call_general(jit, ctx, asm, mid, calling, cme, nil) + if ci_flags & C::VM_CALL_KW_SPLAT != 0 + asm.incr_counter(:send_kw_splat) + return CantCompile + end + if ci_flags & C::VM_CALL_ARGS_BLOCKARG != 0 + asm.incr_counter(:send_block_arg) + return CantCompile + end + + # Ensure we haven't rebound this method onto an incompatible class. + # In the interpreter we try to avoid making this check by performing some + # cheaper calculations first, but since we specialize on the method entry + # and so only have to do this once at compile time this is fine to always + # check and side exit. + comptime_recv = jit.peek_at_stack(argc) + unless C.obj_is_kind_of(comptime_recv, current_defined_class) + return CantCompile + end + + # Do method lookup + cme = C.rb_callable_method_entry(comptime_superclass, mid) + + if cme.nil? + return CantCompile + end + + # Check that we'll be able to write this method dispatch before generating checks + cme_def_type = cme.def.type + if cme_def_type != C::VM_METHOD_TYPE_ISEQ && cme_def_type != C::VM_METHOD_TYPE_CFUNC + # others unimplemented + return CantCompile + end + + asm.comment('guard known me') + lep_opnd = :rax + jit_get_lep(jit, asm, reg: lep_opnd) + ep_me_opnd = [lep_opnd, C.VALUE.size * C::VM_ENV_DATA_INDEX_ME_CREF] + + asm.mov(:rcx, me.to_i) + asm.cmp(ep_me_opnd, :rcx) + asm.jne(counted_exit(side_exit(jit, ctx), :invokesuper_me_changed)) + + if block == C::VM_BLOCK_HANDLER_NONE + # Guard no block passed + # rb_vm_frame_block_handler(GET_EC()->cfp) == VM_BLOCK_HANDLER_NONE + # note, we assume VM_ASSERT(VM_ENV_LOCAL_P(ep)) + # + # TODO: this could properly forward the current block handler, but + # would require changes to gen_send_* + asm.comment('guard no block given') + ep_specval_opnd = [lep_opnd, C.VALUE.size * C::VM_ENV_DATA_INDEX_SPECVAL] + asm.cmp(ep_specval_opnd, C::VM_BLOCK_HANDLER_NONE) + asm.jne(counted_exit(side_exit(jit, ctx), :invokesuper_block)) + end + + # We need to assume that both our current method entry and the super + # method entry we invoke remain stable + Invariants.assume_method_lookup_stable(jit, me) + Invariants.assume_method_lookup_stable(jit, cme) + + calling = build_calling(ci:, block_handler: block) + case cme_def_type + in C::VM_METHOD_TYPE_ISEQ + iseq = def_iseq_ptr(cme.def) + frame_type = C::VM_FRAME_MAGIC_METHOD | C::VM_ENV_FLAG_LOCAL + jit_call_iseq(jit, ctx, asm, cme, calling, iseq, frame_type:) + in C::VM_METHOD_TYPE_CFUNC + jit_call_cfunc(jit, ctx, asm, cme, calling, nil) + end end # @param jit [RubyVM::RJIT::JITState] @@ -3906,87 +3996,6 @@ def jit_search_method(jit, ctx, asm, mid, calling) return cme, comptime_recv_klass end - def jit_search_super_method(jit, ctx, asm, mid, calling) - assert_equal(true, jit.at_current_insn?) - - me = C.rb_vm_frame_method_entry(jit.cfp) - if me.nil? - return CantCompile - end - - # FIXME: We should track and invalidate this block when this cme is invalidated - current_defined_class = me.defined_class - mid = me.def.original_id - - if me.to_i != C.rb_callable_method_entry(current_defined_class, me.called_id).to_i - # Though we likely could generate this call, as we are only concerned - # with the method entry remaining valid, assume_method_lookup_stable - # below requires that the method lookup matches as well - return CantCompile - end - - # vm_search_normal_superclass - rbasic_klass = C.to_ruby(C.RBasic.new(C.to_value(current_defined_class)).klass) - if C::BUILTIN_TYPE(current_defined_class) == C::RUBY_T_ICLASS && C::BUILTIN_TYPE(rbasic_klass) == C::RUBY_T_MODULE && \ - C::FL_TEST_RAW(rbasic_klass, C::RMODULE_IS_REFINEMENT) != 0 - return CantCompile - end - comptime_superclass = C.rb_class_get_superclass(C.RCLASS_ORIGIN(current_defined_class)) - - # Don't JIT calls that aren't simple - # Note, not using VM_CALL_ARGS_SIMPLE because sometimes we pass a block. - - if calling.flags & C::VM_CALL_KWARG != 0 - asm.incr_counter(:send_kwarg) - return CantCompile - end - if calling.flags & C::VM_CALL_KW_SPLAT != 0 - asm.incr_counter(:send_kw_splat) - return CantCompile - end - - # Ensure we haven't rebound this method onto an incompatible class. - # In the interpreter we try to avoid making this check by performing some - # cheaper calculations first, but since we specialize on the method entry - # and so only have to do this once at compile time this is fine to always - # check and side exit. - comptime_recv = jit.peek_at_stack(calling.argc) - unless C.obj_is_kind_of(comptime_recv, current_defined_class) - return CantCompile - end - - # Do method lookup - cme = C.rb_callable_method_entry(comptime_superclass, mid) - - if cme.nil? - return CantCompile - end - - # Check that we'll be able to write this method dispatch before generating checks - cme_def_type = cme.def.type - if cme_def_type != C::VM_METHOD_TYPE_ISEQ && cme_def_type != C::VM_METHOD_TYPE_CFUNC - # others unimplemented - return CantCompile - end - - # Guard that the receiver has the same class as the one from compile time - side_exit = side_exit(jit, ctx) - - asm.comment('guard known me') - jit_get_lep(jit, asm, reg: :rax) - - asm.mov(:rcx, me.to_i) - asm.cmp([:rax, C.VALUE.size * C::VM_ENV_DATA_INDEX_ME_CREF], :rcx) - asm.jne(counted_exit(side_exit, :invokesuper_me_changed)) - - # We need to assume that both our current method entry and the super - # method entry we invoke remain stable - Invariants.assume_method_lookup_stable(jit, me) - Invariants.assume_method_lookup_stable(jit, cme) - - return cme - end - # vm_call_general # @param jit [RubyVM::RJIT::JITState] # @param ctx [RubyVM::RJIT::Context] diff --git a/rjit_c.h b/rjit_c.h index 1e809a10b44a31..6aca73556c6876 100644 --- a/rjit_c.h +++ b/rjit_c.h @@ -46,6 +46,7 @@ RJIT_RUNTIME_COUNTERS( send_c_tracing, send_is_a_class_mismatch, send_instance_of_class_mismatch, + send_keywords, send_blockiseq, send_block_handler, @@ -105,6 +106,7 @@ RJIT_RUNTIME_COUNTERS( send_bmethod_blockarg, invokesuper_me_changed, + invokesuper_block, invokeblock_none, invokeblock_symbol, diff --git a/rjit_c.rb b/rjit_c.rb index 13c3c6ad87985f..2ba8f2d378703f 100644 --- a/rjit_c.rb +++ b/rjit_c.rb @@ -1334,6 +1334,7 @@ def C.rb_rjit_runtime_counters send_c_tracing: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), send_c_tracing)")], send_is_a_class_mismatch: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), send_is_a_class_mismatch)")], send_instance_of_class_mismatch: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), send_instance_of_class_mismatch)")], + send_keywords: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), send_keywords)")], send_blockiseq: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), send_blockiseq)")], send_block_handler: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), send_block_handler)")], send_block_setup: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), send_block_setup)")], @@ -1384,6 +1385,7 @@ def C.rb_rjit_runtime_counters send_bmethod_not_iseq: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), send_bmethod_not_iseq)")], send_bmethod_blockarg: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), send_bmethod_blockarg)")], invokesuper_me_changed: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), invokesuper_me_changed)")], + invokesuper_block: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), invokesuper_block)")], invokeblock_none: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), invokeblock_none)")], invokeblock_symbol: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), invokeblock_symbol)")], invokeblock_proc: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), invokeblock_proc)")],