Skip to content

Commit

Permalink
add basic pick operation support.
Browse files Browse the repository at this point in the history
Signed-off-by: Takeshi Yoneda <takeshi@tetrate.io>
  • Loading branch information
mathetake committed Nov 30, 2021
1 parent 000901e commit 588c7bf
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 18 deletions.
55 changes: 53 additions & 2 deletions wasm/jit/jit_amd64.go
Expand Up @@ -4,6 +4,7 @@
package jit

import (
"errors"
"fmt"
"strings"

Expand Down Expand Up @@ -77,8 +78,9 @@ func (e *engine) compileWasmFunction(f *wasm.FunctionInstance) (*compiledWasmFun
case *wazeroir.OperationSelect:
return nil, fmt.Errorf("unsupported operation in JIT compiler: %v", o)
case *wazeroir.OperationPick:
// TODO:
return nil, fmt.Errorf("unsupported operation in JIT compiler: %v", o)
if err := builder.handlePick(o); err != nil {
return nil, fmt.Errorf("error handling pick operation %v: %w", o, err)
}
case *wazeroir.OperationSwap:
return nil, fmt.Errorf("unsupported operation in JIT compiler: %v", o)
case *wazeroir.OperationGlobalGet:
Expand Down Expand Up @@ -274,6 +276,55 @@ func (b *amd64Builder) handleLabel(o *wazeroir.OperationLabel) error {
return nil
}

func (b *amd64Builder) handlePick(o *wazeroir.OperationPick) error {
// TODO: if we track the type of values on the stack,
// we could optimize the instruction according to the bit size of the value.
// For now, we just move the entire register i.e. as a quad word (8 bytes).
pickTarget := b.locationStack.stack[len(b.locationStack.stack)-1-o.Depth]
if reg, err := b.locationStack.takeFreeRegister(gpTypeInt); err == nil {
if pickTarget.onRegister() {
prog := b.newProg()
prog.As = x86.AMOVQ
prog.From.Type = obj.TYPE_REG
prog.From.Reg = *pickTarget.register
prog.To.Type = obj.TYPE_REG
prog.To.Reg = reg
b.addInstruction(prog)
} else if pickTarget.onStack() {
// Place the stack pointer at first.
prog := b.newProg()
prog.As = x86.AMOVQ
prog.From.Type = obj.TYPE_CONST
prog.From.Offset = int64(*pickTarget.stackPointer)
prog.To.Type = obj.TYPE_REG
prog.To.Reg = reg
b.addInstruction(prog)

// Then Copy the value from the stack.
prog = b.newProg()
prog.As = x86.AMOVQ
prog.From.Type = obj.TYPE_MEM
prog.From.Reg = cachedStackBasePointerReg
prog.From.Index = reg
prog.From.Scale = 8
prog.To.Type = obj.TYPE_REG
prog.To.Reg = reg
b.addInstruction(prog)
} else if pickTarget.onConditionalRegister() {
panic("TODO")
}
// Now we already placed the picked value on the register,
// so push the location onto the stack.
loc := &valueLocation{register: &reg}
b.locationStack.push(loc)
return nil
} else if !errors.Is(err, errFreeRegisterNotFound) {
return fmt.Errorf("cannot take free register: %w", err)
}
// TODO: handle the case when there's not free register.
return nil
}

func (b *amd64Builder) setJITStatus(status jitStatusCodes) *obj.Prog {
prog := b.newProg()
prog.As = x86.AMOVL
Expand Down
97 changes: 95 additions & 2 deletions wasm/jit/jit_amd64_test.go
Expand Up @@ -54,6 +54,7 @@ func requireNewBuilder(t *testing.T) *amd64Builder {
b, err := asm.NewBuilder("amd64", 128)
require.NoError(t, err)
return &amd64Builder{eng: nil, builder: b,
locationStack: newValueLocationStack(),
onLabelStartCallbacks: map[string][]func(*obj.Prog){},
labelProgs: map[string]*obj.Prog{},
}
Expand Down Expand Up @@ -513,7 +514,7 @@ func TestAmd64Builder_initializeReservedRegisters(t *testing.T) {
code, err := builder.assemble()
require.NoError(t, err)

// Run codes.
// Run code.
eng := newEngine()
mem := newMemoryInst()
jitcall(
Expand All @@ -525,6 +526,7 @@ func TestAmd64Builder_initializeReservedRegisters(t *testing.T) {

func TestAmd64Builder_handleLabel(t *testing.T) {
builder := requireNewBuilder(t)
builder.initializeReservedRegisters()
label := &wazeroir.Label{FrameID: 100, Kind: wazeroir.LabelKindContinuation}

var called bool
Expand All @@ -542,7 +544,7 @@ func TestAmd64Builder_handleLabel(t *testing.T) {
code, err := builder.assemble()
require.NoError(t, err)

// Run codes.
// Run code.
eng := newEngine()
mem := newMemoryInst()
jitcall(
Expand All @@ -551,3 +553,94 @@ func TestAmd64Builder_handleLabel(t *testing.T) {
uintptr(unsafe.Pointer(&mem.Buffer[0])),
)
}

func TestAmd64Builder_handlePick(t *testing.T) {
t.Run("free register", func(t *testing.T) {
o := &wazeroir.OperationPick{Depth: 1}
builder := requireNewBuilder(t)
builder.initializeReservedRegisters()
// The case when the original value is already in register.
t.Run("on reg", func(t *testing.T) {
// Set up the pick target original value.
orignalReg := int16(x86.REG_AX)
loc := &valueLocation{register: &orignalReg}
builder.locationStack.push(loc)
builder.locationStack.push(nil)
builder.movConstToRegister(100, orignalReg)
// Now insert pick code.
err := builder.handlePick(o)
require.NoError(t, err)
// Increment the picked value.
pickedLocation := builder.locationStack.peek()
prog := builder.newProg()
prog.As = x86.AINCQ
prog.To.Type = obj.TYPE_REG
prog.To.Reg = *pickedLocation.register
builder.addInstruction(prog)
// To verify the behavior, we push the incremented picked value
// to the stack.
builder.pushRegisterToStack(*pickedLocation.register)
builder.returnFunction()

// Assemble.
code, err := builder.assemble()
require.NoError(t, err)
// Run code.
eng := newEngine()
mem := newMemoryInst()
jitcall(
uintptr(unsafe.Pointer(&code[0])),
uintptr(unsafe.Pointer(eng)),
uintptr(unsafe.Pointer(&mem.Buffer[0])),
)
// Check the stack.
require.Equal(t, uint64(1), eng.currentStackPointer)
require.Equal(t, uint64(101), eng.stack[eng.currentStackPointer-1])
})
// The case when the original value is in stack.
t.Run("on stack", func(t *testing.T) {
eng := newEngine()

// Setup the original value.
sp := uint64(1)
loc := &valueLocation{stackPointer: &sp}
builder.locationStack.push(loc)
builder.locationStack.push(nil)
eng.currentStackPointer = 5
eng.currentBaseStackPointer = 1
eng.stack[eng.currentBaseStackPointer+sp] = 100

// Now insert pick code.
err := builder.handlePick(o)
require.NoError(t, err)

// Increment the picked value.
pickedLocation := builder.locationStack.peek()
prog := builder.newProg()
prog.As = x86.AINCQ
prog.To.Type = obj.TYPE_REG
prog.To.Reg = *pickedLocation.register
builder.addInstruction(prog)

// To verify the behavior, we push the incremented picked value
// to the stack.
builder.pushRegisterToStack(*pickedLocation.register)
builder.returnFunction()

// Assemble.
code, err := builder.assemble()
require.NoError(t, err)
// Run code.
mem := newMemoryInst()
jitcall(
uintptr(unsafe.Pointer(&code[0])),
uintptr(unsafe.Pointer(eng)),
uintptr(unsafe.Pointer(&mem.Buffer[0])),
)

// Check the stack.
require.Equal(t, uint64(6), eng.currentStackPointer)
require.Equal(t, uint64(101), eng.stack[eng.currentBaseStackPointer+eng.currentStackPointer-1])
})
})
}
32 changes: 21 additions & 11 deletions wasm/jit/jit_value_location_amd64.go
Expand Up @@ -13,7 +13,7 @@ import (
type valueLocation struct {
// TODO: might not be neeeded at all!
valueType wazeroir.SignLessType
register *int
register *int16
stackPointer *uint64
// conditional registers?
}
Expand All @@ -26,8 +26,13 @@ func (v *valueLocation) onRegister() bool {
return v.register != nil
}

func (v *valueLocation) onConditionalRegister() bool {
// TODO!
return false
}

var (
gpFloatRegisters = []int{
gpFloatRegisters = []int16{
x86.REG_X0, x86.REG_X1, x86.REG_X2, x86.REG_X3,
x86.REG_X4, x86.REG_X5, x86.REG_X6, x86.REG_X7,
x86.REG_X8, x86.REG_X9, x86.REG_X10, x86.REG_X11,
Expand All @@ -37,32 +42,32 @@ var (
// so we don't need to care about the calling convension.
// TODO: we still have to take into acounts RAX,RDX register
// usages in DIV,MUL operations.
gpIntRegisters = []int{
gpIntRegisters = []int16{
x86.REG_AX, x86.REG_CX, x86.REG_DX, x86.REG_BX,
x86.REG_BP, x86.REG_SI, x86.REG_DI, x86.REG_R8,
x86.REG_R9, x86.REG_R10, x86.REG_R11,
}
errFreeRegisterNotFound = errors.New("free register not found")
)

func isIntRegister(r int) bool {
func isIntRegister(r int16) bool {
return gpIntRegisters[0] <= r && r <= gpIntRegisters[len(gpIntRegisters)-1]
}

func isFloatRegister(r int) bool {
func isFloatRegister(r int16) bool {
return gpFloatRegisters[0] <= r && r <= gpFloatRegisters[len(gpFloatRegisters)-1]
}

func newValueLocationStack() *valueLocationStack {
return &valueLocationStack{
usedRegisters: map[int]struct{}{},
usedRegisters: map[int16]struct{}{},
}
}

type valueLocationStack struct {
stack []*valueLocation
sp int
usedRegisters map[int]struct{}
usedRegisters map[int16]struct{}
}

func (s *valueLocationStack) push(loc *valueLocation) {
Expand All @@ -81,11 +86,16 @@ func (s *valueLocationStack) pop() (loc *valueLocation) {
return
}

func (s *valueLocationStack) releaseRegister(reg int) {
func (s *valueLocationStack) peek() (loc *valueLocation) {
loc = s.stack[s.sp-1]
return
}

func (s *valueLocationStack) releaseRegister(reg int16) {
delete(s.usedRegisters, reg)
}

func (s *valueLocationStack) markRegisterUsed(reg int) {
func (s *valueLocationStack) markRegisterUsed(reg int16) {
s.usedRegisters[reg] = struct{}{}
}

Expand All @@ -108,8 +118,8 @@ func gpRegisterTypeFromSignLess(in wazeroir.SignLessType) (ret generalPurposeReg

// Search for unused registers, and if found, returns the resgister
// and mark it used.
func (s *valueLocationStack) takeFreeRegister(tp generalPurposeRegisterType) (int, error) {
var targetRegs []int
func (s *valueLocationStack) takeFreeRegister(tp generalPurposeRegisterType) (int16, error) {
var targetRegs []int16
switch tp {
case gpTypeFloat:
targetRegs = gpFloatRegisters
Expand Down
6 changes: 3 additions & 3 deletions wasm/jit/jit_value_location_amd64_test.go
Expand Up @@ -30,7 +30,7 @@ func TestValueLocationStack_basic(t *testing.T) {
require.Equal(t, 1, s.sp)
require.Equal(t, depth, *s.stack[s.sp-1].stackPointer)
// markRegisterUsed.
reg := x86.REG_X1
reg := int16(x86.REG_X1)
s.markRegisterUsed(reg)
require.Contains(t, s.usedRegisters, reg)
// releaseRegister
Expand Down Expand Up @@ -66,9 +66,9 @@ func TestValueLocationStack_takeFreeRegister(t *testing.T) {

func TestValueLocationStack_takeStealTargetFromUsedRegister(t *testing.T) {
s := newValueLocationStack()
intReg := x86.REG_R10
intReg := int16(x86.REG_R10)
intLocation := &valueLocation{register: &intReg}
floatReg := x86.REG_X0
floatReg := int16(x86.REG_X0)
floatLocation := &valueLocation{register: &floatReg}
s.push(intLocation)
s.push(floatLocation)
Expand Down

0 comments on commit 588c7bf

Please sign in to comment.