diff --git a/loom-core/src/lib.rs b/loom-core/src/lib.rs index 7c0e039..1edc2b0 100644 --- a/loom-core/src/lib.rs +++ b/loom-core/src/lib.rs @@ -11578,6 +11578,17 @@ pub mod optimize { enum CSEAction { SaveToLocal(u32), // Save result to local using local.tee LoadFromLocal(u32), // Replace with local.get + // PR-K2: span replacement for pure Call exprs. The first + // occurrence is annotated with `SaveToLocal` at the call + // position (so `local.tee` is appended after the call). A + // duplicate occurrence is annotated with `ReplaceSpanWithLoad` + // at the leftmost-arg position with `end_pos = call_pos`, + // emitting one `local.get` and skipping all instructions in + // `[arg_pos..=call_pos]`. + ReplaceSpanWithLoad { + local_idx: u32, + end_pos: usize, + }, } // Expression representation for CSE @@ -12158,32 +12169,222 @@ pub mod optimize { // Insert local.tee after first occurrence, replace duplicates with local.get // // Strategy: Conservative transformation for simple expressions - // - Only transform single-instruction expressions (constants, local.get) - // - For binary operations, skip for now (requires tracking instruction spans) + // - Single-instruction Const exprs: tee at the call site, load + // at duplicate sites (existing path). + // - PR-K2: pure-Call exprs whose args are all single-instruction + // pure pushers (Const, LocalGet) get span-based substitution. + // The first occurrence is unchanged except for a trailing + // `local.tee` after the Call instruction; each subsequent + // occurrence collapses its entire `[arg_start..=call_pos]` + // span into a single `local.get` of the cache local. + // - Binary/Unary/nested-Call args are NOT folded in this PR + // (deferred to a follow-up): they would require a more + // careful span-overlap analysis. The verifier would still + // catch a mistake, but conservative-over-fast: skip. // Build transformation plan: first occurrence -> save, others -> load let mut position_action: HashMap = HashMap::new(); + // PR-K2: occupied[i] = true if position i is already covered by + // some other Expr's transform span (first-occurrence's call site + // or a duplicate's [arg..=call] span). Used to reject overlaps. + let mut occupied: Vec = vec![false; func.instructions.len()]; + + // PR-K2: helper — is this Expr arg a single-instruction pure + // pusher whose value at runtime equals the pusher's local + // identity in the IR? Const is always safe (value is encoded + // in the instruction). LocalGet is safe ONLY if the local is + // not mutated between the cached site and the use site — we + // enforce that below via a separate scan. + fn arg_is_simple_pusher(e: &Expr) -> bool { + matches!( + e, + Expr::Const32(_) | Expr::Const64(_) | Expr::LocalGet(_) + ) + } + // Collect the set of LocalGet indices referenced by a Call's + // arg list (used to verify no intervening local.set/local.tee + // could invalidate the cached value). + fn arg_local_indices(args: &[Expr]) -> Vec { + args.iter() + .filter_map(|a| match a { + Expr::LocalGet(idx) => Some(*idx), + _ => None, + }) + .collect() + } for (hash, local_idx) in &hash_to_local { if let Some(positions) = hash_to_positions.get(hash) { if positions.len() > 1 { - // Check if this is a simple expression we can safely transform - // SAFETY: Only CSE constants - LocalGet is UNSAFE because the - // local's value might change between uses (via local.set) if let Some((expr, _)) = expr_at_position.get(&positions[0]) { - let is_safe_to_cse = - matches!(expr, Expr::Const32(_) | Expr::Const64(_)); - - if is_safe_to_cse { - // First occurrence: add local.tee after it - position_action - .insert(positions[0], CSEAction::SaveToLocal(*local_idx)); - - // Subsequent occurrences: replace with local.get - for &pos in &positions[1..] { - position_action - .insert(pos, CSEAction::LoadFromLocal(*local_idx)); + match expr { + // SAFETY: constants are referentially + // transparent — same bits in any context. + Expr::Const32(_) | Expr::Const64(_) => { + if occupied[positions[0]] { + continue; + } + let mut blocked = false; + for &pos in &positions[1..] { + if occupied[pos] { + blocked = true; + break; + } + } + if blocked { + continue; + } + occupied[positions[0]] = true; + position_action.insert( + positions[0], + CSEAction::SaveToLocal(*local_idx), + ); + for &pos in &positions[1..] { + occupied[pos] = true; + position_action.insert( + pos, + CSEAction::LoadFromLocal(*local_idx), + ); + } + } + + // PR-K2: pure+no-trap single-result Call + // with simple-pusher args → span dedup. + Expr::Call { args, .. } => { + // DEFENSE-IN-DEPTH: PR-K constructed + // the Call expr only for pure+no-trap + // callees with all-pure args. We add + // a stronger check here: every arg + // must be a single-instruction pure + // pusher (Const or LocalGet). Binary, + // Unary, and nested-Call args are + // deferred to a follow-up PR. + if !args.iter().all(arg_is_simple_pusher) { + continue; + } + + // Compute spans for every occurrence. + // For a Call with N simple-pusher args, + // the span is exactly N+1 instructions: + // [start_at_position[pos] ..= pos]. + let mut spans: Vec<(usize, usize)> = + Vec::with_capacity(positions.len()); + let mut span_ok = true; + for &pos in positions { + let start = match start_at_position.get(&pos) { + Some(s) => *s, + None => { + span_ok = false; + break; + } + }; + // For N=args.len() simple pushers, + // expect span length N+1. If the + // measured span is smaller (would + // happen if the leftmost-arg-start + // tracking missed an instruction) + // we abandon: cannot prove the + // span is exactly the call subtree. + let expected_len = args.len() + 1; + if pos < start || (pos - start + 1) != expected_len { + span_ok = false; + break; + } + spans.push((start, pos)); + } + if !span_ok { + continue; + } + + // Overlap check: every position in + // every span must be free. + let mut overlap = false; + for &(start, end) in &spans { + for p in start..=end { + if occupied[p] { + overlap = true; + break; + } + } + if overlap { + break; + } + } + if overlap { + continue; + } + + // local.set/local.tee scan: a LocalGet + // arg is only safe to cache if the + // local is not mutated anywhere from + // the FIRST occurrence's call position + // through the LAST duplicate position. + // (Conservative bound — we could narrow + // to per-duplicate, but the bigger + // window is simpler and gives the + // verifier the strongest invariant.) + let arg_locals = arg_local_indices(args); + let first_call = spans[0].1; + let last_dup_start = spans.last().unwrap().0; + let mut local_mutated = false; + if !arg_locals.is_empty() { + for ins in func + .instructions + .iter() + .skip(first_call) + .take(last_dup_start.saturating_sub(first_call) + 1) + { + match ins { + Instruction::LocalSet(idx) + | Instruction::LocalTee(idx) => { + if arg_locals.contains(idx) { + local_mutated = true; + break; + } + } + _ => {} + } + } + } + if local_mutated { + continue; + } + + // Plan the transform. + let (_first_start, first_call_pos) = spans[0]; + // Mark every instruction in every span + // as occupied so later iterations + // don't double-plan. + for &(start, end) in &spans { + for p in start..=end { + occupied[p] = true; + } + } + // First occurrence: keep the whole + // [start..=call_pos] sequence and tee + // after the call. + position_action.insert( + first_call_pos, + CSEAction::SaveToLocal(*local_idx), + ); + // Each later occurrence: collapse the + // span to a single local.get at the + // arg_start; skip up through call_pos. + for &(start, end) in &spans[1..] { + position_action.insert( + start, + CSEAction::ReplaceSpanWithLoad { + local_idx: *local_idx, + end_pos: end, + }, + ); + } } + + // SAFETY: LocalGet alone is unsafe under + // the same local.set issue as before; + // Binary/Unary span dedup is a future PR. + _ => {} } } } @@ -12193,8 +12394,17 @@ pub mod optimize { // Apply transformations: rebuild instruction list if !position_action.is_empty() { let mut new_instructions = Vec::new(); + let mut skip_until: Option = None; for (pos, instr) in func.instructions.iter().enumerate() { + // Honor an outstanding span-skip request. + if let Some(end) = skip_until { + if pos <= end { + continue; + } else { + skip_until = None; + } + } match position_action.get(&pos) { Some(CSEAction::SaveToLocal(local_idx)) => { // Keep the original instruction, add local.tee @@ -12205,6 +12415,14 @@ pub mod optimize { // Replace with local.get new_instructions.push(Instruction::LocalGet(*local_idx)); } + Some(CSEAction::ReplaceSpanWithLoad { + local_idx, + end_pos, + }) => { + // Replace `[pos ..= end_pos]` with one local.get. + new_instructions.push(Instruction::LocalGet(*local_idx)); + skip_until = Some(*end_pos); + } None => { // Keep instruction as-is new_instructions.push(instr.clone()); @@ -15016,16 +15234,22 @@ mod tests { // calls aren't deduped together. #[test] - #[ignore = "PR-K is INFRASTRUCTURE-ONLY for v0.8.0. The Call \ - expression is recognized, hashed, cost-gated, and \ - tracked in hash_to_positions. But the existing CSE \ - replacement loop at lib.rs:12109 only substitutes \ - single-instruction Const exprs (`is_safe_to_cse = \ - matches!(expr, Expr::Const32(_) | Expr::Const64(_))`). \ - Call replacement requires span-based substitution \ - (remove args + call, insert local.get) and is deferred \ - to a follow-up PR-K2. This test pins the EXPECTED \ - end state for that PR."] + #[ignore = "PR-K2 lands the SPAN REPLACEMENT mechanism (Phase 4 now \ + handles Expr::Call via [arg_start..=call_pos] → \ + local.get N substitution). However the existing Z3 \ + translation validator models every Instruction::Call as \ + a FRESH symbolic constant (BV::new_const at verify.rs:4035 \ + falls through to `call_unknown_result` for the empty \ + sig-context used by CSE). Two identical pure+no-trap \ + calls therefore translate to two INDEPENDENT BVs, and \ + the optimization `R1+R2 → 2*R3` is treated as unproven \ + — verify_or_revert reverts every dedup. Making this \ + test pass requires a verifier-side change (model Call \ + results as `f(args)` uninterpreted-function applications \ + when the IPA proves pure+no-trap) and is the deliverable \ + of PR-K3. Gate kept conservative until then; cost gate \ + also rejects N=2 cost=4 patterns to stop futile \ + attempt+revert churn at runtime."] fn test_cse_dedupes_repeated_pure_calls() { // $pure_helper is pure+no-trap (no Store/Load/Div, no GlobalSet). // Two adjacent identical `call $pure_helper(x)` should dedupe: @@ -15144,6 +15368,91 @@ mod tests { ); } + #[test] + #[ignore = "Same blocker as test_cse_dedupes_repeated_pure_calls: \ + the SPAN mechanism is in place (PR-K2) but the Z3 \ + validator can't yet prove Call CSE equivalent, so the \ + transform reverts. Re-enable in PR-K3 once the verifier \ + models pure+no-trap calls as uninterpreted functions \ + of their args."] + fn test_cse_dedupes_pure_clamp_calls_via_span_replacement() { + // PR-K2: span-replacement regression test. A realistic shape — + // two `call $clamp(local.get $x)` invocations with the same + // single arg — should fold the second call's `[arg, call]` span + // into a single `local.get` of the cache local. $clamp is pure + // and no-trap (only arithmetic and constants). + let wat = r#"(module + (func $clamp (param i32) (result i32) + local.get 0 + i32.const 0 + i32.const 100 + ;; min(max(x, 0), 100) ≈ clamp via two arithmetic ops. + ;; Encoded as i32.add chains for soundness within the + ;; supported-instruction subset; what matters is that + ;; $clamp is pure+no-trap. + i32.add + i32.add + ) + (func $caller (export "test") (param i32) (result i32) + local.get 0 + call $clamp + local.get 0 + call $clamp + i32.add + ) + )"#; + + let mut module = parse::parse_wat(wat).expect("parse"); + let instrs_before = module.functions[1].instructions.len(); + let calls_before = module.functions[1] + .instructions + .iter() + .filter(|i| matches!(i, Instruction::Call(_))) + .count(); + assert_eq!(calls_before, 2, "sanity: two calls before CSE"); + + optimize::eliminate_common_subexpressions_enhanced(&mut module).expect("cse"); + + let calls_after = module.functions[1] + .instructions + .iter() + .filter(|i| matches!(i, Instruction::Call(_))) + .count(); + assert_eq!( + calls_after, 1, + "span-CSE must dedupe two identical `call $clamp(local.get 0)` \ + invocations into a single call + cached local.get" + ); + + // The duplicate's span (LocalGet + Call = 2 instructions) collapses + // into one local.get; the first occurrence keeps both and adds a + // local.tee. Net instruction-count change is zero, but the second + // dynamic call is gone — that's the observable win. + let instrs_after = module.functions[1].instructions.len(); + assert_eq!( + instrs_after, instrs_before, + "span replacement should remove the duplicate Call but add a \ + local.tee at the first occurrence — net zero instruction count" + ); + // Confirm the second call's args are gone: there should be exactly + // one LocalGet of the original param-local 0 in the caller now + // (the duplicate's LocalGet was absorbed into the span swap). + let local0_gets_after = module.functions[1] + .instructions + .iter() + .filter(|i| matches!(i, Instruction::LocalGet(0))) + .count(); + assert_eq!( + local0_gets_after, 1, + "the duplicate `local.get 0` arg of the second call must be \ + removed by span replacement" + ); + + // Output must validate. + let bytes = encode::encode_wasm(&module).expect("encode"); + wasmparser::validate(&bytes).expect("output validates"); + } + #[test] fn test_cse_dedupes_calls_with_different_args_separately() { // Two pure calls with DIFFERENT args. Neither call's result