Skip to content

Commit

Permalink
Implement Wasm tail-call proposal (#683)
Browse files Browse the repository at this point in the history
* add tail call support to Config

* implement return_call[_indirect] for func translator

* add tail-call tests from Wasm spec testsuite

* expect failure for tail call tests since not yet implemented

* generalize CallOutcome to later support tail calls

* make StoreIdx base on NonZeroU32

* apply rustfmt

* partially implement return_call[_indirect]

* revert

* add DropKeep to return call bytecodes

* create unique drop_keep for return calls

* fix bug in drop_keep_return_call method

* wasm_return_call works now

* apply clippy suggestion

* properly charge for drop_keep in return_call_indirect

* fix return_call_indirect

* enable return_call_indirect Wasm spec test case

* fix performance regressions

* add fib_tail_recursive to benchmarks

* add TODO comment for resumable calls

* add more TODO comments for resumable calls

* comment out fib_tail_recursive

This is just temporary so that the CI produces benchmarks again.

* make use of return_call in benchmark .wat

* refactor engine executor calls

* apply rustfmt

* add doc comments

* revert changes as it slowed down Wasm targets

* add tests for resumable calls + tail calls

* fix resumable tail calls edge case
  • Loading branch information
Robbepop committed Mar 1, 2023
1 parent 01423af commit aecf008
Show file tree
Hide file tree
Showing 11 changed files with 375 additions and 97 deletions.
1 change: 1 addition & 0 deletions crates/wasmi/benches/bench/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub fn load_wasm_from_file(file_name: &str) -> Vec<u8> {
/// Returns a [`Config`] useful for benchmarking.
fn bench_config() -> Config {
let mut config = Config::default();
config.wasm_tail_call(true);
config.set_stack_limits(StackLimits::new(1024, 1024 * 1024, 64 * 1024).unwrap());
config
}
Expand Down
4 changes: 2 additions & 2 deletions crates/wasmi/benches/wat/fibonacci.wat
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@
(return (local.get $b))
)
)
(call $fib_tail_recursive
(return_call $fib_tail_recursive
(i64.sub (local.get $N) (i64.const 1))
(local.get $b)
(i64.add (local.get $a) (local.get $b))
)
)

(func (export "fibonacci_tail") (param $N i64) (result i64)
(call $fib_tail_recursive (local.get $N) (i64.const 0) (i64.const 1))
(return_call $fib_tail_recursive (local.get $N) (i64.const 0) (i64.const 1))
)

(func $fib_iterative (export "fibonacci_iter") (param $N i64) (result i64)
Expand Down
9 changes: 9 additions & 0 deletions crates/wasmi/src/engine/bytecode/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ pub enum Instruction {
},
Return(DropKeep),
ReturnIfNez(DropKeep),
ReturnCall {
drop_keep: DropKeep,
func: FuncIdx,
},
ReturnCallIndirect {
drop_keep: DropKeep,
table: TableIdx,
func_type: SignatureIdx,
},
Call(FuncIdx),
CallIndirect {
table: TableIdx,
Expand Down
17 changes: 16 additions & 1 deletion crates/wasmi/src/engine/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ pub struct Config {
bulk_memory: bool,
/// Is `true` if the [`reference-types`] Wasm proposal is enabled.
reference_types: bool,
/// Is `true` if the [`tail-call`] Wasm proposal is enabled.
tail_call: bool,
/// Is `true` if Wasm instructions on `f32` and `f64` types are allowed.
floats: bool,
/// Is `true` if `wasmi` executions shall consume fuel.
Expand Down Expand Up @@ -94,6 +96,7 @@ impl Default for Config {
multi_value: true,
bulk_memory: true,
reference_types: true,
tail_call: false,
floats: true,
consume_fuel: false,
fuel_costs: FuelCosts::default(),
Expand Down Expand Up @@ -201,6 +204,18 @@ impl Config {
self
}

/// Enable or disable the [`tail-call`] Wasm proposal for the [`Config`].
///
/// # Note
///
/// Disabled by default.
///
/// [`tail-call`]: https://github.com/WebAssembly/tail-calls
pub fn wasm_tail_call(&mut self, enable: bool) -> &mut Self {
self.tail_call = enable;
self
}

/// Enable or disable Wasm floating point (`f32` and `f64`) instructions and types.
///
/// Enabled by default.
Expand Down Expand Up @@ -252,12 +267,12 @@ impl Config {
sign_extension: self.sign_extension,
bulk_memory: self.bulk_memory,
reference_types: self.reference_types,
tail_call: self.tail_call,
floats: self.floats,
component_model: false,
simd: false,
relaxed_simd: false,
threads: false,
tail_call: false,
multi_memory: false,
exceptions: false,
memory64: false,
Expand Down
183 changes: 142 additions & 41 deletions crates/wasmi/src/engine/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::{
table::TableEntity,
Func,
FuncRef,
Instance,
StoreInner,
Table,
};
Expand All @@ -42,7 +43,7 @@ pub enum WasmOutcome {
/// The Wasm execution has ended and returns to the host side.
Return,
/// The Wasm execution calls a host function.
Call(Func),
Call { host_func: Func, instance: Instance },
}

/// The outcome of a Wasm execution.
Expand All @@ -56,7 +57,16 @@ pub enum CallOutcome {
/// The Wasm execution continues in Wasm.
Continue,
/// The Wasm execution calls a host function.
Call(Func),
Call { host_func: Func, instance: Instance },
}

/// The kind of a function call.
#[derive(Debug, Copy, Clone)]
pub enum CallKind {
/// A nested function call.
Nested,
/// A tailing function call.
Tail,
}

/// The outcome of a Wasm return statement.
Expand Down Expand Up @@ -203,16 +213,56 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> {
return Ok(WasmOutcome::Return);
}
}
Instr::ReturnCall { drop_keep, func } => {
if let CallOutcome::Call {
host_func,
instance,
} = self.visit_return_call(drop_keep, func)?
{
return Ok(WasmOutcome::Call {
host_func,
instance,
});
}
}
Instr::ReturnCallIndirect {
drop_keep,
table,
func_type,
} => {
if let CallOutcome::Call {
host_func,
instance,
} = self.visit_return_call_indirect(drop_keep, table, func_type)?
{
return Ok(WasmOutcome::Call {
host_func,
instance,
});
}
}
Instr::Call(func) => {
if let CallOutcome::Call(host_func) = self.visit_call(func)? {
return Ok(WasmOutcome::Call(host_func));
if let CallOutcome::Call {
host_func,
instance,
} = self.visit_call(func)?
{
return Ok(WasmOutcome::Call {
host_func,
instance,
});
}
}
Instr::CallIndirect { table, func_type } => {
if let CallOutcome::Call(host_func) =
self.visit_call_indirect(table, func_type)?
if let CallOutcome::Call {
host_func,
instance,
} = self.visit_call_indirect(table, func_type)?
{
return Ok(WasmOutcome::Call(host_func));
return Ok(WasmOutcome::Call {
host_func,
instance,
});
}
}
Instr::Drop => self.visit_drop(),
Expand Down Expand Up @@ -519,24 +569,30 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> {
/// the function call so that the stack and execution state is synchronized
/// with the outer structures.
#[inline(always)]
fn call_func(&mut self, func: &Func) -> Result<CallOutcome, TrapCode> {
fn call_func(&mut self, func: &Func, kind: CallKind) -> Result<CallOutcome, TrapCode> {
self.next_instr();
self.sync_stack_ptr();
self.call_stack
.push(FuncFrame::new(self.ip, self.cache.instance()))?;
let wasm_func = match self.ctx.resolve_func(func) {
FuncEntity::Wasm(wasm_func) => wasm_func,
if matches!(kind, CallKind::Nested) {
self.call_stack
.push(FuncFrame::new(self.ip, self.cache.instance()))?;
}
match self.ctx.resolve_func(func) {
FuncEntity::Wasm(wasm_func) => {
let header = self.code_map.header(wasm_func.func_body());
self.value_stack.prepare_wasm_call(header)?;
self.sp = self.value_stack.stack_ptr();
self.cache.update_instance(wasm_func.instance());
self.ip = self.code_map.instr_ptr(header.iref());
Ok(CallOutcome::Continue)
}
FuncEntity::Host(_host_func) => {
self.cache.reset();
return Ok(CallOutcome::Call(*func));
Ok(CallOutcome::Call {
host_func: *func,
instance: *self.cache.instance(),
})
}
};
let header = self.code_map.header(wasm_func.func_body());
self.value_stack.prepare_wasm_call(header)?;
self.sp = self.value_stack.stack_ptr();
self.cache.update_instance(wasm_func.instance());
self.ip = self.code_map.instr_ptr(header.iref());
Ok(CallOutcome::Continue)
}
}

/// Returns to the caller.
Expand Down Expand Up @@ -608,6 +664,48 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> {
fn fuel_costs(&self) -> &FuelCosts {
self.ctx.engine().config().fuel_costs()
}

/// Executes a `call` or `return_call` instruction.
#[inline(always)]
fn execute_call(
&mut self,
func_index: FuncIdx,
kind: CallKind,
) -> Result<CallOutcome, TrapCode> {
let callee = self.cache.get_func(self.ctx, func_index);
self.call_func(&callee, kind)
}

/// Executes a `call_indirect` or `return_call_indirect` instruction.
#[inline(always)]
fn execute_call_indirect(
&mut self,
table: TableIdx,
func_index: u32,
func_type: SignatureIdx,
kind: CallKind,
) -> Result<CallOutcome, TrapCode> {
let table = self.cache.get_table(self.ctx, table);
let funcref = self
.ctx
.resolve_table(&table)
.get_untyped(func_index)
.map(FuncRef::from)
.ok_or(TrapCode::TableOutOfBounds)?;
let func = funcref.func().ok_or(TrapCode::IndirectCallToNull)?;
let actual_signature = self.ctx.resolve_func(func).ty_dedup();
let expected_signature = self
.ctx
.resolve_instance(self.cache.instance())
.get_signature(func_type.into_inner())
.unwrap_or_else(|| {
panic!("missing signature for call_indirect at index: {func_type:?}")
});
if actual_signature != expected_signature {
return Err(TrapCode::BadSignature).map_err(Into::into);
}
self.call_func(func, kind)
}
}

impl<'ctx, 'engine> Executor<'ctx, 'engine> {
Expand Down Expand Up @@ -712,10 +810,32 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> {
self.next_instr()
}

#[inline(always)]
fn visit_return_call(
&mut self,
drop_keep: DropKeep,
func_index: FuncIdx,
) -> Result<CallOutcome, TrapCode> {
self.sp.drop_keep(drop_keep);
self.execute_call(func_index, CallKind::Tail)
}

#[inline(always)]
fn visit_return_call_indirect(
&mut self,
drop_keep: DropKeep,
table: TableIdx,
func_type: SignatureIdx,
) -> Result<CallOutcome, TrapCode> {
let func_index: u32 = self.sp.pop_as();
self.sp.drop_keep(drop_keep);
self.execute_call_indirect(table, func_index, func_type, CallKind::Tail)
}

#[inline(always)]
fn visit_call(&mut self, func_index: FuncIdx) -> Result<CallOutcome, TrapCode> {
let callee = self.cache.get_func(self.ctx, func_index);
self.call_func(&callee)
self.call_func(&callee, CallKind::Nested)
}

#[inline(always)]
Expand All @@ -725,26 +845,7 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> {
func_type: SignatureIdx,
) -> Result<CallOutcome, TrapCode> {
let func_index: u32 = self.sp.pop_as();
let table = self.cache.get_table(self.ctx, table);
let funcref = self
.ctx
.resolve_table(&table)
.get_untyped(func_index)
.map(FuncRef::from)
.ok_or(TrapCode::TableOutOfBounds)?;
let func = funcref.func().ok_or(TrapCode::IndirectCallToNull)?;
let actual_signature = self.ctx.resolve_func(func).ty_dedup();
let expected_signature = self
.ctx
.resolve_instance(self.cache.instance())
.get_signature(func_type.into_inner())
.unwrap_or_else(|| {
panic!("missing signature for call_indirect at index: {func_type:?}")
});
if actual_signature != expected_signature {
return Err(TrapCode::BadSignature).map_err(Into::into);
}
self.call_func(func)
self.execute_call_indirect(table, func_index, func_type, CallKind::Nested)
}

#[inline(always)]
Expand Down
3 changes: 3 additions & 0 deletions crates/wasmi/src/engine/func_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ macro_rules! impl_visit_operator {
( @reference_types $($rest:tt)* ) => {
impl_visit_operator!(@@supported $($rest)*);
};
( @tail_call $($rest:tt)* ) => {
impl_visit_operator!(@@supported $($rest)*);
};
( @@supported $op:ident $({ $($arg:ident: $argty:ty),* })? => $visit:ident $($rest:tt)* ) => {
fn $visit(&mut self $($(,$arg: $argty)*)?) -> Self::Output {
let offset = self.current_pos();
Expand Down

0 comments on commit aecf008

Please sign in to comment.