Skip to content

Vec dispatch should pick a scalar per element pair, not one for all pairs #145

@timfennis

Description

@timfennis

Summary

The runtime vec dispatch currently requires a single scalar overload to accept every element pair of the broadcast axis. That's a stricter rule than the per-element cost we're paying actually justifies — we already iterate pairs and check each one against parameter types. Picking a different scalar per pair would unlock heterogeneous vec calls without changing the asymptotic cost of dispatch.

Repro

([1, 2, 3], "foo") ++ ([4, 5, 6], "bar")

Expected (intuition): ([1, 2, 3, 4, 5, 6], "foobar") — element 0 uses ++(List, List), element 1 uses ++(String, String).

Today: errors with no function called '++' found matches the arguments (Tuple<List<Int>, String>, Tuple<List<Int>, String>).

Why it fails today

find_overload in ndc_vm/src/vm.rs returns Callable::Vec(scalar_fn, axis_len) — a single scalar overload. The applicability check vec_candidate_applies accepts a candidate only when one scalar's parameter types satisfy every pair. For the example above:

  • ++(List, List): pair 0 ([1,2,3], [4,5,6]) matches; pair 1 ("foo", "bar") fails → rejected.
  • ++(String, String): pair 0 fails → rejected.
  • Every vec candidate fails the all-pairs check → find_overload returns None → "no function found".

The analyser bails earlier too: synthetic_vec_sig collapses each tuple-arg to its element LUB, giving [List<Int>, String] as the signature to look up. No single ++ overload accepts that, so no vec candidate is synthesised at compile time either — the call falls to the all_by_name fallback and ends up as a Dynamic that runtime can't dispatch.

Why the cost is mostly free

vec_candidate_applies already walks every pair × every vec candidate's parameter types. The per-call cost in PR #141's benchmarks (≈25% slowdown on tuple-heavy hot loops) was attributed to exactly this iteration. Restructuring to per-pair dispatch replaces "scan one scalar's params against all pairs" with "per pair, scan the candidate list until one accepts" — same asymptotic work, just rearranged so the result is usable.

Implementation sketch

Change the Callable shape to carry the candidate list, and resolve per pair in dispatch_vec_call:

pub(crate) enum Callable {
    Scalar(Function),
    Vec { candidates: Vec<Function>, axis_len: usize }, // was (Function, usize)
}

fn dispatch_vec_call(&mut self, candidates: &[Function], args: usize, axis_len: usize, span: Span) -> Result<(), VmError> {
    // ... existing arg_values setup ...
    for i in 0..axis_len {
        let element_args: Vec<Value> = arg_values.iter().map(|a| vec_element_at(a, i)).collect();
        let scalar = candidates
            .iter()
            .find(|f| f.matches_value_args(&element_args))
            .ok_or_else(|| VmError::new(
                format!("no overload accepts element {i}: ({})", element_args.iter().map(|v| v.static_type().to_string()).join(", ")),
                span,
            ))?;
        let result = self.call_callback(scalar.clone(), element_args)?;
        results.push(result);
    }
    // ... push result tuple ...
}

find_overload becomes: scan candidates, return Callable::Scalar on the first scalar match, otherwise collect every vec candidate into the Vec { candidates, axis_len } variant. The previous "all pairs must match this one scalar" gate goes away.

Trade-offs

  • Static typing for heterogeneous calls stays at Any. Per-position precise typing (each tuple position keeping its own element type through the result) is RFC option A in docs/design/vectorization.md and is a much bigger analyser change. The runtime behaviour fix here is independent — the analyser just continues to widen to Any when no single overload covers the whole synthetic sig.
  • Slightly more candidate scans per pair in the worst case. For homogeneous numeric tuples the first candidate matches every pair and cost is unchanged. For heterogeneous tuples we now do axis_len × candidate_list_size matches instead of erroring early — net positive because the alternative was just refusing to run.
  • Error message becomes more informative. Today: "no function called '++' found …" with the outer call's argument types. Under per-pair dispatch: "no overload accepts element N: (…)" pointing at the actual mismatched pair.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestinterpreterIssue relates to the interpreter

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions