Skip to content

Commit

Permalink
refactor: improve ergonomics of fn replace (#250)
Browse files Browse the repository at this point in the history
After trialing `replace_(imported|exported)_func` in WASI-Virt, it's clear
that the ergonomics around the builder function need to be
improved. `FunctionBuilder` (particularly `FunctionBuilder::new()` is
difficult to use without a mutable borrow of the module itself.

This commit refactors `replace_(imported|exported)_func` in order to
pass through the mutable borrow which makes it easier to use
`FunctionBuilder`s.

Signed-off-by: Victor Adossi <vadossi@cosmonic.com>
  • Loading branch information
vados-cosmonic authored Oct 18, 2023
1 parent 632d220 commit 45ca488
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 98 deletions.
4 changes: 2 additions & 2 deletions crates/tests/tests/spec-tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,13 @@ fn run(wast: &Path) -> Result<(), anyhow::Error> {
let wasm = fs::read(&path)?;
let mut wasm = config
.parse(&wasm)
.context(format!("error parsing wasm (line {})", line))?;
.with_context(|| format!("error parsing wasm (line {})", line))?;
let wasm1 = wasm.emit_wasm();
fs::write(&path, &wasm1)?;
let wasm2 = config
.parse(&wasm1)
.map(|mut m| m.emit_wasm())
.context(format!("error re-parsing wasm (line {})", line))?;
.with_context(|| format!("error re-parsing wasm (line {})", line))?;
if wasm1 != wasm2 {
panic!("wasm module at line {} isn't deterministic", line);
}
Expand Down
14 changes: 14 additions & 0 deletions src/module/exports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ impl ModuleExports {
_ => false,
})
}

/// Delete an exported function by name from this module.
pub fn delete_func_by_name(&mut self, name: impl AsRef<str>) -> Result<()> {
let fid = self.get_func_by_name(name.as_ref()).context(format!(
"failed to find exported func with name [{}]",
name.as_ref()
))?;
self.delete(
self.get_exported_func(fid)
.with_context(|| format!("failed to find exported func with ID [{fid:?}]"))?
.id(),
);
Ok(())
}
}

impl Module {
Expand Down
196 changes: 100 additions & 96 deletions src/module/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::parse::IndicesToIds;
use crate::tombstone_arena::{Id, Tombstone, TombstoneArena};
use crate::ty::TypeId;
use crate::ty::ValType;
use crate::{ExportItem, Memory, MemoryId};
use crate::{ExportItem, FunctionBuilder, InstrSeqBuilder, LocalId, Memory, MemoryId};

pub use self::local_function::LocalFunction;

Expand Down Expand Up @@ -447,98 +447,118 @@ impl Module {

/// Replace a single exported function with the result of the provided builder function.
///
/// The builder function is provided a mutable reference to an [`InstrSeqBuilder`] which can be
/// used to build the function as necessary.
///
/// For example, if you wanted to replace an exported function with a no-op,
///
/// ```ignore
/// // Since `FunctionBuilder` requires a mutable pointer to the module's types,
/// // we must build it *outside* the closure and `move` it in
/// let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
///
/// module.replace_exported_func(fid, move || {
/// module.replace_exported_func(fid, |(body, arg_locals)| {
/// builder.func_body().unreachable();
/// builder.local_func(vec![])
/// });
/// ```
///
/// The arguments passed to the original function will be passed to the
/// new exported function that was built in your closure.
///
/// This function returns the function ID of the *new* function,
/// after it has been inserted into the module as an export.
pub fn replace_exported_func<F>(&mut self, fid: FunctionId, fn_builder: F) -> Result<FunctionId>
where
F: FnOnce((&FuncParams, &FuncResults)) -> Result<LocalFunction>,
{
match (self.exports.get_exported_func(fid), self.funcs.get(fid)) {
(
Some(exported_fn),
Function {
kind: FunctionKind::Local(lf),
..
},
) => {
// Retrieve the params & result types for the exported (local) function
let ty = self.types.get(lf.ty());
let (params, results) = (ty.params().to_vec(), ty.results().to_vec());

// Add the function produced by `fn_builder` as a local function,
let new_fid = self.funcs.add_local(
fn_builder((&params, &results)).context("export fn builder failed")?,
);

// Mutate the existing export to use the new local function
let export = self.exports.get_mut(exported_fn.id());
export.item = ExportItem::Function(new_fid);

Ok(new_fid)
}
// The export didn't exist, or the function isn't the kind we expect
_ => bail!("cannot replace function [{fid:?}], it is not an exported function"),
pub fn replace_exported_func(
&mut self,
fid: FunctionId,
builder_fn: impl FnOnce((&mut InstrSeqBuilder, &Vec<LocalId>)),
) -> Result<FunctionId> {
let original_export_id = self
.exports
.get_exported_func(fid)
.map(|e| e.id())
.with_context(|| format!("no exported function with ID [{fid:?}]"))?;

if let Function {
kind: FunctionKind::Local(lf),
..
} = self.funcs.get(fid)
{
// Retrieve the params & result types for the exported (local) function
let ty = self.types.get(lf.ty());
let (params, results) = (ty.params().to_vec(), ty.results().to_vec());

// Add the function produced by `fn_builder` as a local function
let mut builder = FunctionBuilder::new(&mut self.types, &params, &results);
let mut new_fn_body = builder.func_body();
builder_fn((&mut new_fn_body, &lf.args));
let func = builder.local_func(lf.args.clone());
let new_fn_id = self.funcs.add_local(func);

// Mutate the existing export to use the new local function
let export = self.exports.get_mut(original_export_id);
export.item = ExportItem::Function(new_fn_id);
Ok(new_fn_id)
} else {
bail!("cannot replace function [{fid:?}], it is not an exported function");
}
}

/// Replace a single imported function with the result of the provided builder function.
///
/// ```ignore
/// // Since `FunctionBuilder` requires a mutable pointer to the module's types,
/// // we must build it *outside* the closure and `move` it in
/// let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);
/// The builder function is provided a mutable reference to an [`InstrSeqBuilder`] which can be
/// used to build the function as necessary.
///
/// module.replace_imported_func(fid, move || {
/// For example, if you wanted to replace an imported function with a no-op,
///
/// ```ignore
/// module.replace_imported_func(fid, |(body, arg_locals)| {
/// builder.func_body().unreachable();
/// builder.local_func(vec![])
/// });
/// ```
///
/// The arguments passed to the original function will be passed to the
/// new exported function that was built in your closure.
///
/// This function returns the function ID of the *new* function, and
/// removes the existing import that has been replaced (the function will become local).
pub fn replace_imported_func<F>(&mut self, fid: FunctionId, fn_builder: F) -> Result<FunctionId>
where
F: FnOnce((&FuncParams, &FuncResults)) -> Result<LocalFunction>,
{
// If the function is in the imports, replace it
match (self.imports.get_imported_func(fid), self.funcs.get(fid)) {
(
Some(original_imported_fn),
Function {
kind: FunctionKind::Import(ImportedFunction { ty: tid, .. }),
..
},
) => {
// Retrieve the params & result types for the imported function
let ty = self.types.get(*tid);
let (params, results) = (ty.params().to_vec(), ty.results().to_vec());

// Mutate the existing function, changing it from a FunctionKind::ImportedFunction
// to the local function produced by running the provided `fn_builder`
let func = self.funcs.get_mut(fid);
func.kind = FunctionKind::Local(
fn_builder((&params, &results)).context("import fn builder failed")?,
);

self.imports.delete(original_imported_fn.id());

Ok(fid)
}
// The export didn't exist, or the function isn't the kind we expect
_ => bail!("cannot replace function [{fid:?}], it is not an imported function"),
pub fn replace_imported_func(
&mut self,
fid: FunctionId,
builder_fn: impl FnOnce((&mut InstrSeqBuilder, &Vec<LocalId>)),
) -> Result<FunctionId> {
let original_import_id = self
.imports
.get_imported_func(fid)
.map(|import| import.id())
.with_context(|| format!("no exported function with ID [{fid:?}]"))?;

if let Function {
kind: FunctionKind::Import(ImportedFunction { ty: tid, .. }),
..
} = self.funcs.get(fid)
{
// Retrieve the params & result types for the imported function
let ty = self.types.get(*tid);
let (params, results) = (ty.params().to_vec(), ty.results().to_vec());

// Build the list LocalIds used by args to match the original function
let args = params
.iter()
.map(|ty| self.locals.add(*ty))
.collect::<Vec<_>>();

// Build the new function
let mut builder = FunctionBuilder::new(&mut self.types, &params, &results);
let mut new_fn_body = builder.func_body();
builder_fn((&mut new_fn_body, &args));
let new_func_kind = FunctionKind::Local(builder.local_func(args));

// Mutate the existing function, changing it from a FunctionKind::ImportedFunction
// to the local function produced by running the provided `fn_builder`
let func = self.funcs.get_mut(fid);
func.kind = new_func_kind;

self.imports.delete(original_import_id);

Ok(fid)
} else {
bail!("cannot replace function [{fid:?}], it is not an imported function");
}
}
}
Expand Down Expand Up @@ -683,14 +703,10 @@ mod tests {
let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs);
let original_export_id = module.exports.add("dummy", original_fn_id);

// Create builder to use inside closure
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);

// Replace the existing function with a new one with a reversed const value
let new_fn_id = module
.replace_exported_func(original_fn_id, move |_| {
builder.func_body().i32_const(4321).drop();
Ok(builder.local_func(vec![]))
.replace_exported_func(original_fn_id, |(body, _)| {
body.i32_const(4321).drop();
})
.expect("function replacement worked");

Expand Down Expand Up @@ -728,14 +744,10 @@ mod tests {
let original_fn_id: FunctionId = builder.finish(vec![], &mut module.funcs);
let original_export_id = module.exports.add("dummy", original_fn_id);

// Create builder to use inside closure
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);

// Replace the existing function with a new one with a reversed const value
let new_fn_id = module
.replace_exported_func(original_fn_id, move |_| {
builder.func_body().unreachable();
Ok(builder.local_func(vec![]))
.replace_exported_func(original_fn_id, |(body, _arg_locals)| {
body.unreachable();
})
.expect("export function replacement worked");

Expand Down Expand Up @@ -773,14 +785,10 @@ mod tests {
let types = module.types.add(&[], &[]);
let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types);

// Create builder to use inside closure
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);

// Replace the existing function with a new one with a reversed const value
let new_fn_id = module
.replace_imported_func(original_fn_id, |_| {
builder.func_body().i32_const(4321).drop();
Ok(builder.local_func(vec![]))
.replace_imported_func(original_fn_id, |(body, _)| {
body.i32_const(4321).drop();
})
.expect("import fn replacement worked");

Expand Down Expand Up @@ -815,14 +823,10 @@ mod tests {
let types = module.types.add(&[], &[]);
let (original_fn_id, original_import_id) = module.add_import_func("mod", "dummy", types);

// Create builder to use inside closure
let mut builder = FunctionBuilder::new(&mut module.types, &[], &[]);

// Replace the existing function with a new one with a reversed const value
let new_fn_id = module
.replace_imported_func(original_fn_id, |_| {
builder.func_body().unreachable();
Ok(builder.local_func(vec![]))
.replace_imported_func(original_fn_id, |(body, _arg_locals)| {
body.unreachable();
})
.expect("import fn replacement worked");

Expand Down
19 changes: 19 additions & 0 deletions src/module/imports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,25 @@ impl ModuleImports {
_ => None,
})
}

/// Delete an imported function by name from this module.
pub fn delete_func_by_name(
&mut self,
module: impl AsRef<str>,
name: impl AsRef<str>,
) -> Result<()> {
let fid = self
.get_func_by_name(module, name.as_ref())
.with_context(|| {
format!("failed to find imported func with name [{}]", name.as_ref())
})?;
self.delete(
self.get_imported_func(fid)
.with_context(|| format!("failed to find imported func with ID [{fid:?}]"))?
.id(),
);
Ok(())
}
}

impl Module {
Expand Down
1 change: 1 addition & 0 deletions src/module/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub use crate::module::debug::ModuleDebugData;
pub use crate::module::elements::ElementKind;
pub use crate::module::elements::{Element, ElementId, ModuleElements};
pub use crate::module::exports::{Export, ExportId, ExportItem, ModuleExports};
pub use crate::module::functions::{FuncParams, FuncResults};
pub use crate::module::functions::{Function, FunctionId, ModuleFunctions};
pub use crate::module::functions::{FunctionKind, ImportedFunction, LocalFunction};
pub use crate::module::globals::{Global, GlobalId, GlobalKind, ModuleGlobals};
Expand Down

0 comments on commit 45ca488

Please sign in to comment.