Skip to content

Commit

Permalink
RJIT: Eliminate known-result branches
Browse files Browse the repository at this point in the history
  • Loading branch information
k0kubun committed Apr 5, 2023
1 parent 9a5d4cc commit 4f77d1c
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 95 deletions.
217 changes: 122 additions & 95 deletions lib/ruby_vm/rjit/insn_compiler.rb
Expand Up @@ -1846,44 +1846,53 @@ def branchif(jit, ctx, asm)
jit_check_ints(jit, ctx, asm)
end

# TODO: skip check for known truthy
# Get the branch target instruction offsets
next_pc = jit.pc + C.VALUE.size * jit.insn.len
jump_pc = jit.pc + C.VALUE.size * (jit.insn.len + jump_offset)

# This `test` sets ZF only for Qnil and Qfalse, which let jz jump.
val = ctx.stack_pop
asm.test(val, ~Qnil)
val_type = ctx.get_opnd_type(StackOpnd[0])
val_opnd = ctx.stack_pop(1)

# Set stubs
branch_stub = BranchStub.new(
iseq: jit.iseq,
shape: Default,
target0: BranchTarget.new(ctx:, pc: jit.pc + C.VALUE.size * (jit.insn.len + jump_offset)), # branch target
target1: BranchTarget.new(ctx:, pc: jit.pc + C.VALUE.size * jit.insn.len), # fallthrough
)
branch_stub.target0.address = Assembler.new.then do |ocb_asm|
@exit_compiler.compile_branch_stub(ctx, ocb_asm, branch_stub, true)
@ocb.write(ocb_asm)
end
branch_stub.target1.address = Assembler.new.then do |ocb_asm|
@exit_compiler.compile_branch_stub(ctx, ocb_asm, branch_stub, false)
@ocb.write(ocb_asm)
end
if (result = val_type.known_truthy) != nil
target_pc = result ? jump_pc : next_pc
jit_direct_jump(jit.iseq, target_pc, ctx, asm)
else
# This `test` sets ZF only for Qnil and Qfalse, which let jz jump.
asm.test(val_opnd, ~Qnil)

# Jump to target0 on jnz
branch_stub.compile = proc do |branch_asm|
branch_asm.comment("branchif #{branch_stub.shape}")
branch_asm.stub(branch_stub) do
case branch_stub.shape
in Default
branch_asm.jnz(branch_stub.target0.address)
branch_asm.jmp(branch_stub.target1.address)
in Next0
branch_asm.jz(branch_stub.target1.address)
in Next1
branch_asm.jnz(branch_stub.target0.address)
# Set stubs
branch_stub = BranchStub.new(
iseq: jit.iseq,
shape: Default,
target0: BranchTarget.new(ctx:, pc: jump_pc), # branch target
target1: BranchTarget.new(ctx:, pc: next_pc), # fallthrough
)
branch_stub.target0.address = Assembler.new.then do |ocb_asm|
@exit_compiler.compile_branch_stub(ctx, ocb_asm, branch_stub, true)
@ocb.write(ocb_asm)
end
branch_stub.target1.address = Assembler.new.then do |ocb_asm|
@exit_compiler.compile_branch_stub(ctx, ocb_asm, branch_stub, false)
@ocb.write(ocb_asm)
end

# Jump to target0 on jnz
branch_stub.compile = proc do |branch_asm|
branch_asm.comment("branchif #{branch_stub.shape}")
branch_asm.stub(branch_stub) do
case branch_stub.shape
in Default
branch_asm.jnz(branch_stub.target0.address)
branch_asm.jmp(branch_stub.target1.address)
in Next0
branch_asm.jz(branch_stub.target1.address)
in Next1
branch_asm.jnz(branch_stub.target0.address)
end
end
end
branch_stub.compile.call(asm)
end
branch_stub.compile.call(asm)

EndBlock
end
Expand All @@ -1898,44 +1907,53 @@ def branchunless(jit, ctx, asm)
jit_check_ints(jit, ctx, asm)
end

# TODO: skip check for known truthy
# Get the branch target instruction offsets
next_pc = jit.pc + C.VALUE.size * jit.insn.len
jump_pc = jit.pc + C.VALUE.size * (jit.insn.len + jump_offset)

# This `test` sets ZF only for Qnil and Qfalse, which let jz jump.
val = ctx.stack_pop
asm.test(val, ~Qnil)
val_type = ctx.get_opnd_type(StackOpnd[0])
val_opnd = ctx.stack_pop(1)

# Set stubs
branch_stub = BranchStub.new(
iseq: jit.iseq,
shape: Default,
target0: BranchTarget.new(ctx:, pc: jit.pc + C.VALUE.size * (jit.insn.len + jump_offset)), # branch target
target1: BranchTarget.new(ctx:, pc: jit.pc + C.VALUE.size * jit.insn.len), # fallthrough
)
branch_stub.target0.address = Assembler.new.then do |ocb_asm|
@exit_compiler.compile_branch_stub(ctx, ocb_asm, branch_stub, true)
@ocb.write(ocb_asm)
end
branch_stub.target1.address = Assembler.new.then do |ocb_asm|
@exit_compiler.compile_branch_stub(ctx, ocb_asm, branch_stub, false)
@ocb.write(ocb_asm)
end
if (result = val_type.known_truthy) != nil
target_pc = result ? next_pc : jump_pc
jit_direct_jump(jit.iseq, target_pc, ctx, asm)
else
# This `test` sets ZF only for Qnil and Qfalse, which let jz jump.
asm.test(val_opnd, ~Qnil)

# Jump to target0 on jz
branch_stub.compile = proc do |branch_asm|
branch_asm.comment("branchunless #{branch_stub.shape}")
branch_asm.stub(branch_stub) do
case branch_stub.shape
in Default
branch_asm.jz(branch_stub.target0.address)
branch_asm.jmp(branch_stub.target1.address)
in Next0
branch_asm.jnz(branch_stub.target1.address)
in Next1
branch_asm.jz(branch_stub.target0.address)
# Set stubs
branch_stub = BranchStub.new(
iseq: jit.iseq,
shape: Default,
target0: BranchTarget.new(ctx:, pc: jump_pc), # branch target
target1: BranchTarget.new(ctx:, pc: next_pc), # fallthrough
)
branch_stub.target0.address = Assembler.new.then do |ocb_asm|
@exit_compiler.compile_branch_stub(ctx, ocb_asm, branch_stub, true)
@ocb.write(ocb_asm)
end
branch_stub.target1.address = Assembler.new.then do |ocb_asm|
@exit_compiler.compile_branch_stub(ctx, ocb_asm, branch_stub, false)
@ocb.write(ocb_asm)
end

# Jump to target0 on jz
branch_stub.compile = proc do |branch_asm|
branch_asm.comment("branchunless #{branch_stub.shape}")
branch_asm.stub(branch_stub) do
case branch_stub.shape
in Default
branch_asm.jz(branch_stub.target0.address)
branch_asm.jmp(branch_stub.target1.address)
in Next0
branch_asm.jnz(branch_stub.target1.address)
in Next1
branch_asm.jz(branch_stub.target0.address)
end
end
end
branch_stub.compile.call(asm)
end
branch_stub.compile.call(asm)

EndBlock
end
Expand All @@ -1950,43 +1968,52 @@ def branchnil(jit, ctx, asm)
jit_check_ints(jit, ctx, asm)
end

# TODO: skip check for known truthy
# Get the branch target instruction offsets
next_pc = jit.pc + C.VALUE.size * jit.insn.len
jump_pc = jit.pc + C.VALUE.size * (jit.insn.len + jump_offset)

val = ctx.stack_pop
asm.cmp(val, Qnil)
val_type = ctx.get_opnd_type(StackOpnd[0])
val_opnd = ctx.stack_pop(1)

# Set stubs
branch_stub = BranchStub.new(
iseq: jit.iseq,
shape: Default,
target0: BranchTarget.new(ctx:, pc: jit.pc + C.VALUE.size * (jit.insn.len + jump_offset)), # branch target
target1: BranchTarget.new(ctx:, pc: jit.pc + C.VALUE.size * jit.insn.len), # fallthrough
)
branch_stub.target0.address = Assembler.new.then do |ocb_asm|
@exit_compiler.compile_branch_stub(ctx, ocb_asm, branch_stub, true)
@ocb.write(ocb_asm)
end
branch_stub.target1.address = Assembler.new.then do |ocb_asm|
@exit_compiler.compile_branch_stub(ctx, ocb_asm, branch_stub, false)
@ocb.write(ocb_asm)
end
if (result = val_type.known_nil) != nil
target_pc = result ? jump_pc : next_pc
jit_direct_jump(jit.iseq, target_pc, ctx, asm)
else
asm.cmp(val_opnd, Qnil)

# Jump to target0 on je
branch_stub.compile = proc do |branch_asm|
branch_asm.comment("branchnil #{branch_stub.shape}")
branch_asm.stub(branch_stub) do
case branch_stub.shape
in Default
branch_asm.je(branch_stub.target0.address)
branch_asm.jmp(branch_stub.target1.address)
in Next0
branch_asm.jne(branch_stub.target1.address)
in Next1
branch_asm.je(branch_stub.target0.address)
# Set stubs
branch_stub = BranchStub.new(
iseq: jit.iseq,
shape: Default,
target0: BranchTarget.new(ctx:, pc: jump_pc), # branch target
target1: BranchTarget.new(ctx:, pc: next_pc), # fallthrough
)
branch_stub.target0.address = Assembler.new.then do |ocb_asm|
@exit_compiler.compile_branch_stub(ctx, ocb_asm, branch_stub, true)
@ocb.write(ocb_asm)
end
branch_stub.target1.address = Assembler.new.then do |ocb_asm|
@exit_compiler.compile_branch_stub(ctx, ocb_asm, branch_stub, false)
@ocb.write(ocb_asm)
end

# Jump to target0 on je
branch_stub.compile = proc do |branch_asm|
branch_asm.comment("branchnil #{branch_stub.shape}")
branch_asm.stub(branch_stub) do
case branch_stub.shape
in Default
branch_asm.je(branch_stub.target0.address)
branch_asm.jmp(branch_stub.target1.address)
in Next0
branch_asm.jne(branch_stub.target1.address)
in Next1
branch_asm.je(branch_stub.target0.address)
end
end
end
branch_stub.compile.call(asm)
end
branch_stub.compile.call(asm)

EndBlock
end
Expand Down
10 changes: 10 additions & 0 deletions lib/ruby_vm/rjit/type.rb
Expand Up @@ -89,6 +89,16 @@ def known_truthy
end
end

# Returns a boolean representing whether the value is equal to nil if known, otherwise nil
def known_nil
case [self, self.known_truthy]
in Type::Nil, _ then true
in Type::False, _ then false # Qfalse is not nil
in _, true then false # if truthy, can't be nil
in _, _ then nil # otherwise unknown
end
end

def diff(dst)
# Perfect match, difference is zero
if self == dst
Expand Down

0 comments on commit 4f77d1c

Please sign in to comment.