Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing auto-vectorization for take+sum #115160

Closed
numas13 opened this issue Aug 24, 2023 · 9 comments
Closed

Missing auto-vectorization for take+sum #115160

numas13 opened this issue Aug 24, 2023 · 9 comments
Assignees
Labels
A-iterators Area: Iterators I-slow Issue: Problems and improvements with respect to performance of generated code. T-libs Relevant to the library team, which will review and decide on the PR/issue.

Comments

@numas13
Copy link

numas13 commented Aug 24, 2023

LLVM cannot auto-vectorize the following code:

pub fn slice_sum(s: &[u64], l: usize) -> u64 {
    s.iter().take(l).sum()
}
ASM
example::slice_sum:
        test    rdx, rdx
        je      .LBB0_1
        shl     rsi, 3
        xor     ecx, ecx
        xor     eax, eax
.LBB0_3:
        cmp     rsi, rcx
        je      .LBB0_5
        add     rax, qword ptr [rdi + rcx]
        add     rcx, 8
        dec     rdx
        jne     .LBB0_3
.LBB0_5:
        ret
.LBB0_1:
        xor     eax, eax
        ret

But can auto-vectorize when using loop instead of fold or sum:

pub fn slice_sum_loop(s: &[u64], l: usize) -> u64 {
    let mut acc = 0;
    for i in s.iter().take(l) {
        acc += *i;
    }
    acc
}
ASM
example::slice_sum_loop:
        xor     eax, eax
        test    rdx, rdx
        je      .LBB1_10
        test    rsi, rsi
        je      .LBB1_10
        lea     rax, [rsi - 1]
        mov     cl, 61
        bzhi    rax, rax, rcx
        lea     r8, [rdx - 1]
        cmp     rax, r8
        cmovb   r8, rax
        cmp     r8, 15
        jae     .LBB1_4
        xor     eax, eax
        mov     rcx, rdi
        jmp     .LBB1_7
.LBB1_4:
        inc     r8
        mov     r9, r8
        and     r9, -16
        lea     rcx, [rdi + 8*r9]
        sub     rdx, r9
        vpxor   xmm0, xmm0, xmm0
        xor     eax, eax
        vpxor   xmm1, xmm1, xmm1
        vpxor   xmm2, xmm2, xmm2
        vpxor   xmm3, xmm3, xmm3
.LBB1_5:
        vpaddq  ymm0, ymm0, ymmword ptr [rdi + 8*rax]
        vpaddq  ymm1, ymm1, ymmword ptr [rdi + 8*rax + 32]
        vpaddq  ymm2, ymm2, ymmword ptr [rdi + 8*rax + 64]
        vpaddq  ymm3, ymm3, ymmword ptr [rdi + 8*rax + 96]
        add     rax, 16
        cmp     r9, rax
        jne     .LBB1_5
        vpaddq  ymm0, ymm1, ymm0
        vpaddq  ymm1, ymm3, ymm2
        vpaddq  ymm0, ymm1, ymm0
        vextracti128    xmm1, ymm0, 1
        vpaddq  xmm0, xmm0, xmm1
        vpshufd xmm1, xmm0, 238
        vpaddq  xmm0, xmm0, xmm1
        vmovq   rax, xmm0
        cmp     r8, r9
        je      .LBB1_10
.LBB1_7:
        lea     rsi, [rdi + 8*rsi]
        add     rcx, 8
        dec     rdx
.LBB1_8:
        add     rax, qword ptr [rcx - 8]
        sub     rdx, 1
        jb      .LBB1_10
        lea     rdi, [rcx + 8]
        cmp     rcx, rsi
        mov     rcx, rdi
        jne     .LBB1_8
.LBB1_10:
        vzeroupper
        ret
$ rustc -C opt-level=3 -C target-cpu=x86-64-v3 ...
$ rustc --version --verbose
rustc 1.71.0 (8ede3aae2 2023-07-12)
binary: rustc
commit-hash: 8ede3aae28fe6e4d52b38157d7bfe0d3bceef225
commit-date: 2023-07-12
host: x86_64-unknown-linux-gnu
release: 1.71.0
LLVM version: 16.0.5
Compiler returned: 0

https://rust.godbolt.org/z/o1hcvczTW

@rustbot rustbot added the needs-triage This issue may need triage. Remove it if it has been sufficiently triaged. label Aug 24, 2023
@the8472
Copy link
Member

the8472 commented Aug 24, 2023

sum() calls Take::fold() which is implemented via Take::try_fold which in turn calls slice::Iter::try_fold which uses the iterator default impl which finally calls slice::Iter::next. And of course try_fold does lots of unnecessary work.

The manual loop just nests all the next impls directly which happens to be better in this case. So for once external iteration is faster than internal iteration 😮‍💨

I can think of several ways to fix this.

@rustbot claim

@the8472 the8472 added I-slow Issue: Problems and improvements with respect to performance of generated code. A-iterators Area: Iterators T-libs Relevant to the library team, which will review and decide on the PR/issue. and removed needs-triage This issue may need triage. Remove it if it has been sufficiently triaged. labels Aug 24, 2023
@numas13
Copy link
Author

numas13 commented Aug 24, 2023

@the8472

Same problem with MapWhile.

https://rust.godbolt.org/z/76KacvjPr

@numas13
Copy link
Author

numas13 commented Aug 24, 2023

Same problem with TakeWhile.

https://rust.godbolt.org/z/q8335TGhW

@the8472
Copy link
Member

the8472 commented Aug 24, 2023

On the library level those are a lot harder to optimize than Take because those closures can contain arbitrary code while we know what Take does (restrict length) and substitute alternative implementations.
MapWhile and TakeWhile can only be optimized by the backend when the concrete closure is inlined.

@the8472
Copy link
Member

the8472 commented Aug 25, 2023

Is this meant to highlight take + sum in general or specifically for slice iters? Getting vectorization for other iters is more complicated compared to slices.

@numas13
Copy link
Author

numas13 commented Aug 27, 2023

Ideally, there should be no difference between iterators and explicit loops, since this contradicts the idea of zero-cost abstractions.

The point is this: iterators, although a high-level abstraction, get compiled down to roughly the same code as if you’d written the lower-level code yourself. Iterators are one of Rust’s zero-cost abstractions, by which we mean using the abstraction imposes no additional runtime overhead.
...
The implementations of closures and iterators are such that runtime performance is not affected. This is part of Rust’s goal to strive to provide zero-cost abstractions.

https://doc.rust-lang.org/book/ch13-04-performance.html

But if it's hard to implement, it's fine for slices and other simple cases, and there should probably be a note that this isn't always the case :)

@the8472
Copy link
Member

the8472 commented Aug 27, 2023

My question was more about what you're actually reporting. You started with slice.take.sum. Then you followed up with slice.take_while.sum and slice.map_while.sum. The more you generalize the less specific and more open-ended the issue becomes.

@numas13
Copy link
Author

numas13 commented Aug 27, 2023

You pointed out that Take has a specialized fold, but I didn't pay attention to that at first, and it led me to think that TakeWhile and MapWhile have the same problem.

So the problem is the specialization of fold/try_fold for Take, TakeWhile and MapWhile, which prevents LLVM from doing auto-vectorization.

@the8472
Copy link
Member

the8472 commented Sep 3, 2023

#115273 should fix it for slice.take.sum

@numas13 numas13 closed this as completed Sep 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
A-iterators Area: Iterators I-slow Issue: Problems and improvements with respect to performance of generated code. T-libs Relevant to the library team, which will review and decide on the PR/issue.
Projects
None yet
Development

No branches or pull requests

3 participants