diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index f5363aebe261bb..02e3759cd69d9e 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -4497,8 +4497,6 @@ impl Function { self.push_insn_id(block, insn_id); continue; } if self.policy.no_side_exits { - // TODO: Support polymorphic DefinedIvar shape-specialized paths. - // https://github.com/Shopify/ruby/issues/980 // On the final version, keep the DefinedIvar fallback instead of another shape guard. self.push_insn_id(block, insn_id); continue; } @@ -7155,7 +7153,68 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { // (ID id, IVC ic, VALUE pushval) let id = ID(get_arg(pc, 0).as_u64()); let pushval = get_arg(pc, 2); - state.stack_push(fun.push_insn(block, Insn::DefinedIvar { self_val: self_param, id, pushval, state: exit_id })); + if let Some(summary) = fun.polymorphic_summary(&profiles, self_param, exit_state.insn_idx) { + self_param = fun.push_insn(block, Insn::GuardType { val: self_param, guard_type: types::HeapBasicObject, state: exit_id }); + let rbasic_flags = fun.load_rbasic_flags(block, self_param); + let join_block = insn_idx_to_block.get(&insn_idx).copied().unwrap_or_else(|| fun.new_block(insn_idx)); + let join_param = fun.push_insn(join_block, Insn::Param); + // Dedup by expected shape and type so objects with different classes + // but the same shape can share code. + let mut seen_shape_and_flags = Vec::with_capacity(summary.buckets().len()); + for &profiled_type in summary.buckets() { + // End of the buckets + if profiled_type.is_empty() { break; } + // Runtime immediates cannot pass the HeapBasicObject guard, so don't + // generate unreachable shape branches for profiled immediate buckets. + if profiled_type.flags().is_immediate() { continue; } + // Class/module/T_DATA ivars use different storage rules. + // Let the fallthrough DefinedIvar handle these. + if !profiled_type.flags().is_t_object() { continue; } + let expected_shape = profiled_type.shape(); + let (expected_rbasic_flags, rbasic_flags_mask) = profiled_type.rbasic_flags_and_mask(); + assert!(expected_shape.is_valid()); + // Too-complex shapes use hash tables for ivars; + // rb_shape_get_iv_index doesn't work for them. + // Let the fallthrough DefinedIvar handle these. + if expected_shape.is_complex() { continue; } + if seen_shape_and_flags.contains(&expected_rbasic_flags) { continue; } + seen_shape_and_flags.push(expected_rbasic_flags); + let rbasic_flags_mask = fun.push_insn(block, Insn::Const { val: Const::CUInt64(rbasic_flags_mask) }); + // The expected shape can change over run, so we put it + // as a pointer to keep it stable in snapshot tests. + let expected_rbasic_flags = fun.push_insn(block, Insn::Const { val: Const::CPtr(ptr::without_provenance(expected_rbasic_flags.to_usize())) }); + let expected_rbasic_flags = fun.push_insn(block, Insn::RefineType { val: expected_rbasic_flags, new_type: types::CUInt64 }); + let masked = fun.push_insn(block, Insn::IntAnd { left: rbasic_flags, right: rbasic_flags_mask}); + let has_shape_and_type = fun.push_insn(block, Insn::IsBitEqual { left: masked, right: expected_rbasic_flags }); + let iftrue_block = fun.new_block(insn_idx); + let target = BranchEdge { target: iftrue_block, args: vec![] }; + let fall_through = fun.new_block(insn_idx); + + fun.push_insn(block, Insn::CondBranch { val: has_shape_and_type, + if_true: target, + if_false: BranchEdge { target: fall_through, args: vec![] } + }); + + block = fall_through; + let mut ivar_index: attr_index_t = 0; + let result = if unsafe { rb_shape_get_iv_index(expected_shape.0, id, &mut ivar_index) } { + fun.push_insn(iftrue_block, Insn::Const { val: Const::Value(pushval) }) + } else { + fun.push_insn(iftrue_block, Insn::Const { val: Const::Value(Qnil) }) + }; + fun.push_insn(iftrue_block, Insn::Jump(BranchEdge { target: join_block, args: vec![result] })); + } + // In the fallthrough case, do a generic interpreter definedivar and then join. + let result = fun.push_insn(block, Insn::DefinedIvar { self_val: self_param, id, pushval, state: exit_id }); + fun.push_insn(block, Insn::Jump(BranchEdge { target: join_block, args: vec![result] })); + state.stack_push(join_param); + block = join_block; + } else { + // TODO: Handle monomorphic definedivar specialization here too, including the + // no_side_exits policy, so optimize_getivar doesn't need a separate DefinedIvar + // path. Unlike GetIvar, DefinedIvar isn't emitted by later lowering passes. + state.stack_push(fun.push_insn(block, Insn::DefinedIvar { self_val: self_param, id, pushval, state: exit_id })); + } } YARVINSN_checkkeyword => { // When a keyword is unspecified past index 32, a hash will be used instead. diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index 3212390ecf948f..74b951d2e59410 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -5638,7 +5638,40 @@ mod hir_opt_tests { } #[test] - fn test_dont_specialize_definedivar_with_t_data() { + fn test_dont_specialize_definedivar_with_immediate() { + eval(" + module M + def test = defined?(@a) + end + + class Integer + include M + end + + 1.test + 2.test + TEST = M.instance_method(:test) + "); + assert_snapshot!(hir_string_proc("TEST"), @" + fn test@:3: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb3(v1) + bb2(): + EntryPoint JIT(0) + v4:BasicObject = LoadArg :self@0 + Jump bb3(v4) + bb3(v6:BasicObject): + v10:StringExact|NilClass = DefinedIvar v6, :@a + CheckInterrupts + Return v10 + "); + } + + #[test] + fn test_dont_specialize_definedivar_with_t_struct() { + // Range is T_STRUCT (not T_OBJECT): falls back to DefinedIvar. eval(" class C < Range def test = defined?(@a) @@ -5666,7 +5699,7 @@ mod hir_opt_tests { } #[test] - fn test_dont_specialize_polymorphic_definedivar() { + fn test_optimize_definedivar_polymorphic() { set_call_threshold(3); eval(" class C @@ -5691,9 +5724,206 @@ mod hir_opt_tests { v4:BasicObject = LoadArg :self@0 Jump bb3(v4) bb3(v6:BasicObject): - v10:StringExact|NilClass = DefinedIvar v6, :@a + v10:HeapBasicObject = GuardType v6, HeapBasicObject + v11:CUInt64 = LoadField v10, :RBASIC_FLAGS@0x1000 + v13:CUInt64[0xffffffff0000001f] = Const CUInt64(0xffffffff0000001f) + v14:CPtr[CPtr(0x1001)] = Const CPtr(0x1001) + v15 = RefineType v14, CUInt64 + v16:CInt64 = IntAnd v11, v13 + v17:CBool = IsBitEqual v16, v15 + CondBranch v17, bb5(), bb6() + bb5(): + v19:NilClass = Const Value(nil) + Jump bb4(v19) + bb6(): + v21:CUInt64[0xffffffff0000001f] = Const CUInt64(0xffffffff0000001f) + v22:CPtr[CPtr(0x1002)] = Const CPtr(0x1002) + v23 = RefineType v22, CUInt64 + v24:CInt64 = IntAnd v11, v21 + v25:CBool = IsBitEqual v24, v23 + CondBranch v25, bb7(), bb8() + bb7(): + v27:StringExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + Jump bb4(v27) + bb8(): + v29:StringExact|NilClass = DefinedIvar v10, :@a + Jump bb4(v29) + bb4(v12:StringExact|NilClass): CheckInterrupts - Return v10 + Return v12 + "); + } + + #[test] + fn test_optimize_definedivar_polymorphic_with_immediate() { + set_call_threshold(3); + eval(r#" + module M + def test = defined?(@a) + end + + class C + include M + end + + class Integer + include M + end + + obj = C.new + obj.instance_variable_set(:@a, 1) + + obj.test + 1.test + TEST = M.instance_method(:test) + "#); + assert_snapshot!(hir_string_proc("TEST"), @" + fn test@:3: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb3(v1) + bb2(): + EntryPoint JIT(0) + v4:BasicObject = LoadArg :self@0 + Jump bb3(v4) + bb3(v6:BasicObject): + v10:HeapBasicObject = GuardType v6, HeapBasicObject + v11:CUInt64 = LoadField v10, :RBASIC_FLAGS@0x1000 + v13:CUInt64[0xffffffff0000001f] = Const CUInt64(0xffffffff0000001f) + v14:CPtr[CPtr(0x1001)] = Const CPtr(0x1001) + v15 = RefineType v14, CUInt64 + v16:CInt64 = IntAnd v11, v13 + v17:CBool = IsBitEqual v16, v15 + CondBranch v17, bb5(), bb6() + bb5(): + v19:StringExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + Jump bb4(v19) + bb6(): + v21:StringExact|NilClass = DefinedIvar v10, :@a + Jump bb4(v21) + bb4(v12:StringExact|NilClass): + CheckInterrupts + Return v12 + "); + } + + #[test] + fn test_optimize_definedivar_polymorphic_with_t_struct() { + set_call_threshold(3); + eval(r#" + module M + def test = defined?(@a) + end + + class C + include M + end + + class D < Range + include M + end + + obj = C.new + obj.instance_variable_set(:@a, 1) + + range = D.new 0, 1 + range.instance_variable_set(:@a, 1) + + obj.test + range.test + TEST = M.instance_method(:test) + "#); + assert_snapshot!(hir_string_proc("TEST"), @" + fn test@:3: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb3(v1) + bb2(): + EntryPoint JIT(0) + v4:BasicObject = LoadArg :self@0 + Jump bb3(v4) + bb3(v6:BasicObject): + v10:HeapBasicObject = GuardType v6, HeapBasicObject + v11:CUInt64 = LoadField v10, :RBASIC_FLAGS@0x1000 + v13:CUInt64[0xffffffff0000001f] = Const CUInt64(0xffffffff0000001f) + v14:CPtr[CPtr(0x1001)] = Const CPtr(0x1001) + v15 = RefineType v14, CUInt64 + v16:CInt64 = IntAnd v11, v13 + v17:CBool = IsBitEqual v16, v15 + CondBranch v17, bb5(), bb6() + bb5(): + v19:StringExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + Jump bb4(v19) + bb6(): + v21:StringExact|NilClass = DefinedIvar v10, :@a + Jump bb4(v21) + bb4(v12:StringExact|NilClass): + CheckInterrupts + Return v12 + "); + } + + #[test] + fn test_optimize_definedivar_polymorphic_with_complex_shape() { + set_call_threshold(3); + eval(r#" + module M + def test = defined?(@a) + end + + class C + include M + end + + class D + include M + end + + obj = C.new + obj.instance_variable_set(:@a, 1) + + complex = D.new + (0..1000).each do |i| + complex.instance_variable_set(:"@v#{i}", i) + end + (0..1000).each do |i| + complex.remove_instance_variable(:"@v#{i}") + end + + obj.test + complex.test + TEST = M.instance_method(:test) + "#); + assert_snapshot!(hir_string_proc("TEST"), @" + fn test@:3: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb3(v1) + bb2(): + EntryPoint JIT(0) + v4:BasicObject = LoadArg :self@0 + Jump bb3(v4) + bb3(v6:BasicObject): + v10:HeapBasicObject = GuardType v6, HeapBasicObject + v11:CUInt64 = LoadField v10, :RBASIC_FLAGS@0x1000 + v13:CUInt64[0xffffffff0000001f] = Const CUInt64(0xffffffff0000001f) + v14:CPtr[CPtr(0x1001)] = Const CPtr(0x1001) + v15 = RefineType v14, CUInt64 + v16:CInt64 = IntAnd v11, v13 + v17:CBool = IsBitEqual v16, v15 + CondBranch v17, bb5(), bb6() + bb5(): + v19:StringExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + Jump bb4(v19) + bb6(): + v21:StringExact|NilClass = DefinedIvar v10, :@a + Jump bb4(v21) + bb4(v12:StringExact|NilClass): + CheckInterrupts + Return v12 "); } @@ -8010,7 +8240,7 @@ mod hir_opt_tests { fn test_definedivar_shape_guard_recompile() { // Call with one shape to compile, then call with a different shape to // trigger shape guard exits and recompilation. On the recompiled version, - // DefinedIvar stays as a C call fallback to avoid more shape guard exits. + // DefinedIvar uses polymorphic fast paths plus a C call fallback. eval(" class C def initialize(extra = false) @@ -8038,9 +8268,32 @@ mod hir_opt_tests { v4:HeapBasicObject = LoadArg :self@0 Jump bb3(v4) bb3(v6:HeapBasicObject): - v10:StringExact|NilClass = DefinedIvar v6, :@foo + v11:CUInt64 = LoadField v6, :RBASIC_FLAGS@0x1000 + v13:CUInt64[0xffffffff0000001f] = Const CUInt64(0xffffffff0000001f) + v14:CPtr[CPtr(0x1001)] = Const CPtr(0x1001) + v15 = RefineType v14, CUInt64 + v16:CInt64 = IntAnd v11, v13 + v17:CBool = IsBitEqual v16, v15 + CondBranch v17, bb5(), bb6() + bb5(): + v19:StringExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + Jump bb4(v19) + bb6(): + v21:CUInt64[0xffffffff0000001f] = Const CUInt64(0xffffffff0000001f) + v22:CPtr[CPtr(0x1010)] = Const CPtr(0x1010) + v23 = RefineType v22, CUInt64 + v24:CInt64 = IntAnd v11, v21 + v25:CBool = IsBitEqual v24, v23 + CondBranch v25, bb7(), bb8() + bb7(): + v27:StringExact[VALUE(0x1008)] = Const Value(VALUE(0x1008)) + Jump bb4(v27) + bb8(): + v29:StringExact|NilClass = DefinedIvar v6, :@foo + Jump bb4(v29) + bb4(v12:StringExact|NilClass): CheckInterrupts - Return v10 + Return v12 "); }