Skip to content

Commit 2a1210f

Browse files
ZJIT: Implement getspecial (#13642)
ZJIT: Implement getspecial in ZJIT Adds support for the getspecial instruction in zjit. We split getspecial into two instructions, one for special symbols (`$&`, $'`, etc) and one for special backrefs (`$1`, `$2`, etc). Co-authored-by: Aaron Patterson <tenderlove@ruby-lang.org>
1 parent 1d7ed95 commit 2a1210f

File tree

3 files changed

+191
-1
lines changed

3 files changed

+191
-1
lines changed

test/ruby/test_zjit.rb

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,106 @@ def test = 1.nil?
12211221
}, insns: [:opt_nil_p]
12221222
end
12231223

1224+
def test_getspecial_last_match
1225+
assert_compiles '"hello"', %q{
1226+
def test(str)
1227+
str =~ /hello/
1228+
$&
1229+
end
1230+
test("hello world")
1231+
}, insns: [:getspecial]
1232+
end
1233+
1234+
def test_getspecial_match_pre
1235+
assert_compiles '"hello "', %q{
1236+
def test(str)
1237+
str =~ /world/
1238+
$`
1239+
end
1240+
test("hello world")
1241+
}, insns: [:getspecial]
1242+
end
1243+
1244+
def test_getspecial_match_post
1245+
assert_compiles '" world"', %q{
1246+
def test(str)
1247+
str =~ /hello/
1248+
$'
1249+
end
1250+
test("hello world")
1251+
}, insns: [:getspecial]
1252+
end
1253+
1254+
def test_getspecial_match_last_group
1255+
assert_compiles '"world"', %q{
1256+
def test(str)
1257+
str =~ /(hello) (world)/
1258+
$+
1259+
end
1260+
test("hello world")
1261+
}, insns: [:getspecial]
1262+
end
1263+
1264+
def test_getspecial_numbered_match_1
1265+
assert_compiles '"hello"', %q{
1266+
def test(str)
1267+
str =~ /(hello) (world)/
1268+
$1
1269+
end
1270+
test("hello world")
1271+
}, insns: [:getspecial]
1272+
end
1273+
1274+
def test_getspecial_numbered_match_2
1275+
assert_compiles '"world"', %q{
1276+
def test(str)
1277+
str =~ /(hello) (world)/
1278+
$2
1279+
end
1280+
test("hello world")
1281+
}, insns: [:getspecial]
1282+
end
1283+
1284+
def test_getspecial_numbered_match_nonexistent
1285+
assert_compiles 'nil', %q{
1286+
def test(str)
1287+
str =~ /(hello)/
1288+
$2
1289+
end
1290+
test("hello world")
1291+
}, insns: [:getspecial]
1292+
end
1293+
1294+
def test_getspecial_no_match
1295+
assert_compiles 'nil', %q{
1296+
def test(str)
1297+
str =~ /xyz/
1298+
$&
1299+
end
1300+
test("hello world")
1301+
}, insns: [:getspecial]
1302+
end
1303+
1304+
def test_getspecial_complex_pattern
1305+
assert_compiles '"123"', %q{
1306+
def test(str)
1307+
str =~ /(\d+)/
1308+
$1
1309+
end
1310+
test("abc123def")
1311+
}, insns: [:getspecial]
1312+
end
1313+
1314+
def test_getspecial_multiple_groups
1315+
assert_compiles '"456"', %q{
1316+
def test(str)
1317+
str =~ /(\d+)-(\d+)/
1318+
$2
1319+
end
1320+
test("123-456")
1321+
}, insns: [:getspecial]
1322+
end
1323+
12241324
# tool/ruby_vm/views/*.erb relies on the zjit instructions a) being contiguous and
12251325
# b) being reliably ordered after all the other instructions.
12261326
def test_instruction_order

zjit/src/codegen.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::state::ZJITState;
1010
use crate::stats::{counter_ptr, with_time_stat, Counter, Counter::compile_time_ns};
1111
use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr};
1212
use crate::backend::lir::{self, asm_comment, asm_ccall, Assembler, Opnd, Target, CFP, C_ARG_OPNDS, C_RET_OPND, EC, NATIVE_STACK_PTR, NATIVE_BASE_PTR, SP};
13-
use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, Invariant, RangeType, SideExitReason, SideExitReason::*, SpecialObjectType, SELF_PARAM_IDX};
13+
use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, Invariant, RangeType, SideExitReason, SideExitReason::*, SpecialObjectType, SpecialBackrefSymbol, SELF_PARAM_IDX};
1414
use crate::hir::{Const, FrameState, Function, Insn, InsnId};
1515
use crate::hir_type::{types, Type};
1616
use crate::options::get_option;
@@ -378,6 +378,8 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
378378
Insn::PutSpecialObject { value_type } => gen_putspecialobject(asm, *value_type),
379379
Insn::AnyToString { val, str, state } => gen_anytostring(asm, opnd!(val), opnd!(str), &function.frame_state(*state))?,
380380
Insn::Defined { op_type, obj, pushval, v, state } => gen_defined(jit, asm, *op_type, *obj, *pushval, opnd!(v), &function.frame_state(*state))?,
381+
Insn::GetSpecialSymbol { symbol_type, state: _ } => gen_getspecial_symbol(asm, *symbol_type),
382+
Insn::GetSpecialNumber { nth, state } => gen_getspecial_number(asm, *nth, &function.frame_state(*state)),
381383
&Insn::IncrCounter(counter) => return Some(gen_incr_counter(asm, counter)),
382384
Insn::ObjToString { val, cd, state, .. } => gen_objtostring(jit, asm, opnd!(val), *cd, &function.frame_state(*state))?,
383385
Insn::ArrayExtend { .. }
@@ -640,6 +642,37 @@ fn gen_putspecialobject(asm: &mut Assembler, value_type: SpecialObjectType) -> O
640642
asm_ccall!(asm, rb_vm_get_special_object, ep_reg, Opnd::UImm(u64::from(value_type)))
641643
}
642644

645+
fn gen_getspecial_symbol(asm: &mut Assembler, symbol_type: SpecialBackrefSymbol) -> Opnd {
646+
// Fetch a "special" backref based on the symbol type
647+
648+
let backref = asm_ccall!(asm, rb_backref_get,);
649+
650+
match symbol_type {
651+
SpecialBackrefSymbol::LastMatch => {
652+
asm_ccall!(asm, rb_reg_last_match, backref)
653+
}
654+
SpecialBackrefSymbol::PreMatch => {
655+
asm_ccall!(asm, rb_reg_match_pre, backref)
656+
}
657+
SpecialBackrefSymbol::PostMatch => {
658+
asm_ccall!(asm, rb_reg_match_post, backref)
659+
}
660+
SpecialBackrefSymbol::LastGroup => {
661+
asm_ccall!(asm, rb_reg_match_last, backref)
662+
}
663+
}
664+
}
665+
666+
fn gen_getspecial_number(asm: &mut Assembler, nth: u64, state: &FrameState) -> Opnd {
667+
// Fetch the N-th match from the last backref based on type shifted by 1
668+
669+
let backref = asm_ccall!(asm, rb_backref_get,);
670+
671+
gen_prepare_call_with_gc(asm, state);
672+
673+
asm_ccall!(asm, rb_reg_nth_match, Opnd::Imm((nth >> 1).try_into().unwrap()), backref)
674+
}
675+
643676
/// Compile an interpreter entry block to be inserted into an ISEQ
644677
fn gen_entry_prologue(asm: &mut Assembler, iseq: IseqPtr) {
645678
asm_comment!(asm, "ZJIT entry point: {}", iseq_get_location(iseq, 0));

zjit/src/hir.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,29 @@ impl From<RangeType> for u32 {
321321
}
322322
}
323323

324+
/// Special regex backref symbol types
325+
#[derive(Debug, Clone, Copy, PartialEq)]
326+
pub enum SpecialBackrefSymbol {
327+
LastMatch, // $&
328+
PreMatch, // $`
329+
PostMatch, // $'
330+
LastGroup, // $+
331+
}
332+
333+
impl TryFrom<u8> for SpecialBackrefSymbol {
334+
type Error = String;
335+
336+
fn try_from(value: u8) -> Result<Self, Self::Error> {
337+
match value as char {
338+
'&' => Ok(SpecialBackrefSymbol::LastMatch),
339+
'`' => Ok(SpecialBackrefSymbol::PreMatch),
340+
'\'' => Ok(SpecialBackrefSymbol::PostMatch),
341+
'+' => Ok(SpecialBackrefSymbol::LastGroup),
342+
c => Err(format!("invalid backref symbol: '{}'", c)),
343+
}
344+
}
345+
}
346+
324347
/// Print adaptor for [`Const`]. See [`PtrPrintMap`].
325348
struct ConstPrinter<'a> {
326349
inner: &'a Const,
@@ -415,6 +438,7 @@ pub enum SideExitReason {
415438
PatchPoint(Invariant),
416439
CalleeSideExit,
417440
ObjToStringFallback,
441+
UnknownSpecialVariable(u64),
418442
}
419443

420444
impl std::fmt::Display for SideExitReason {
@@ -494,6 +518,8 @@ pub enum Insn {
494518
GetLocal { level: u32, ep_offset: u32 },
495519
/// Set a local variable in a higher scope or the heap
496520
SetLocal { level: u32, ep_offset: u32, val: InsnId },
521+
GetSpecialSymbol { symbol_type: SpecialBackrefSymbol, state: InsnId },
522+
GetSpecialNumber { nth: u64, state: InsnId },
497523

498524
/// Own a FrameState so that instructions can look up their dominating FrameState when
499525
/// generating deopt side-exits and frame reconstruction metadata. Does not directly generate
@@ -774,6 +800,8 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
774800
Insn::SetGlobal { id, val, .. } => write!(f, "SetGlobal :{}, {val}", id.contents_lossy()),
775801
Insn::GetLocal { level, ep_offset } => write!(f, "GetLocal l{level}, EP@{ep_offset}"),
776802
Insn::SetLocal { val, level, ep_offset } => write!(f, "SetLocal l{level}, EP@{ep_offset}, {val}"),
803+
Insn::GetSpecialSymbol { symbol_type, .. } => write!(f, "GetSpecialSymbol {symbol_type:?}"),
804+
Insn::GetSpecialNumber { nth, .. } => write!(f, "GetSpecialNumber {nth}"),
777805
Insn::ToArray { val, .. } => write!(f, "ToArray {val}"),
778806
Insn::ToNewArray { val, .. } => write!(f, "ToNewArray {val}"),
779807
Insn::ArrayExtend { left, right, .. } => write!(f, "ArrayExtend {left}, {right}"),
@@ -1221,6 +1249,8 @@ impl Function {
12211249
&GetIvar { self_val, id, state } => GetIvar { self_val: find!(self_val), id, state },
12221250
&SetIvar { self_val, id, val, state } => SetIvar { self_val: find!(self_val), id, val: find!(val), state },
12231251
&SetLocal { val, ep_offset, level } => SetLocal { val: find!(val), ep_offset, level },
1252+
&GetSpecialSymbol { symbol_type, state } => GetSpecialSymbol { symbol_type, state },
1253+
&GetSpecialNumber { nth, state } => GetSpecialNumber { nth, state },
12241254
&ToArray { val, state } => ToArray { val: find!(val), state },
12251255
&ToNewArray { val, state } => ToNewArray { val: find!(val), state },
12261256
&ArrayExtend { left, right, state } => ArrayExtend { left: find!(left), right: find!(right), state },
@@ -1306,6 +1336,8 @@ impl Function {
13061336
Insn::ArrayMax { .. } => types::BasicObject,
13071337
Insn::GetGlobal { .. } => types::BasicObject,
13081338
Insn::GetIvar { .. } => types::BasicObject,
1339+
Insn::GetSpecialSymbol { .. } => types::BasicObject,
1340+
Insn::GetSpecialNumber { .. } => types::BasicObject,
13091341
Insn::ToNewArray { .. } => types::ArrayExact,
13101342
Insn::ToArray { .. } => types::ArrayExact,
13111343
Insn::ObjToString { .. } => types::BasicObject,
@@ -1995,6 +2027,8 @@ impl Function {
19952027
worklist.push_back(state);
19962028
}
19972029
&Insn::GetGlobal { state, .. } |
2030+
&Insn::GetSpecialSymbol { state, .. } |
2031+
&Insn::GetSpecialNumber { state, .. } |
19982032
&Insn::SideExit { state, .. } => worklist.push_back(state),
19992033
}
20002034
}
@@ -3325,6 +3359,29 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
33253359
let anytostring = fun.push_insn(block, Insn::AnyToString { val, str, state: exit_id });
33263360
state.stack_push(anytostring);
33273361
}
3362+
YARVINSN_getspecial => {
3363+
let key = get_arg(pc, 0).as_u64();
3364+
let svar = get_arg(pc, 1).as_u64();
3365+
3366+
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
3367+
3368+
if svar == 0 {
3369+
// TODO: Handle non-backref
3370+
fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnknownSpecialVariable(key) });
3371+
// End the block
3372+
break;
3373+
} else if svar & 0x01 != 0 {
3374+
// Handle symbol backrefs like $&, $`, $', $+
3375+
let shifted_svar: u8 = (svar >> 1).try_into().unwrap();
3376+
let symbol_type = SpecialBackrefSymbol::try_from(shifted_svar).expect("invalid backref symbol");
3377+
let result = fun.push_insn(block, Insn::GetSpecialSymbol { symbol_type, state: exit_id });
3378+
state.stack_push(result);
3379+
} else {
3380+
// Handle number backrefs like $1, $2, $3
3381+
let result = fun.push_insn(block, Insn::GetSpecialNumber { nth: svar, state: exit_id });
3382+
state.stack_push(result);
3383+
}
3384+
}
33283385
_ => {
33293386
// Unknown opcode; side-exit into the interpreter
33303387
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });

0 commit comments

Comments
 (0)