Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embeddings: search with SIMD #51372

Merged
merged 10 commits into from May 2, 2023
Merged

Embeddings: search with SIMD #51372

merged 10 commits into from May 2, 2023

Conversation

camdencheek
Copy link
Member

@camdencheek camdencheek commented May 2, 2023

This implements a hand-written assembly version of the int8 dot product that takes advantage of AVX2 SIMD instructions. This speeds up our embeddings searches by roughly 10x on modern x86_64 machines.

Test plan

Added quickchecks and fuzz tests to compare output with the go version.

The following benchmark is for a n2-standard-4 GCE instance, which is a very standard machine type. For the single core benchmark, we can search about 6 million embeddings per second, which is the equivalent of a 6GB monorepo.

This is low-risk to merge because it is disabled by default.

goos: linux
goarch: amd64
pkg: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings
cpu: Intel(R) Xeon(R) CPU @ 2.80GHz
                                 │ /tmp/noasm.txt │            /tmp/asm.txt             │
                                 │     sec/op     │   sec/op     vs base                │
SimilaritySearch/numWorkers=1-4     1697.3m ±  0%   169.6m ± 1%  -90.01% (p=0.000 n=10)
SimilaritySearch/numWorkers=2-4      728.7m ± 94%   108.9m ± 7%  -85.05% (p=0.000 n=10)
SimilaritySearch/numWorkers=4-4     705.96m ±  0%   63.01m ± 2%  -91.07% (p=0.000 n=10)
SimilaritySearch/numWorkers=8-4     713.51m ±  1%   69.47m ± 2%  -90.26% (p=0.000 n=10)
SimilaritySearch/numWorkers=16-4    709.74m ±  0%   65.26m ± 1%  -90.80% (p=0.000 n=10)
geomean                              849.4m         88.00m       -89.64%

@cla-bot cla-bot bot added the cla-signed label May 2, 2023
Comment on lines +18 to +20
// Sign-extend 16 bytes into 16 int16s
VPMOVSXBW (AX), Y1
VPMOVSXBW (BX), Y2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do operations on int16s because there is (surprisingly) no instruction to multiply-add signed int8 vectors.

Comment on lines +35 to +44
// X0 is the low bits of Y0.
// Extract the high bits into X1, fold in half, add, repeat.
VEXTRACTI128 $1, Y0, X1
VPADDD X0, X1, X0

VPSRLDQ $8, X0, X1
VPADDD X0, X1, X0

VPSRLDQ $4, X0, X1
VPADDD X0, X1, X0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section sums the 8 32-bit ints in Y0 by repeatedly folding it in half and adding vertically. We are left with the sum in the rightmost position of X0

Comment on lines +49 to +65
// In tailloop, we add to the dot product one at a time
tailloop:
CMPQ DX, $0
JE end

// Load values from the input slices
MOVBQSX (AX), R9
MOVBQSX (BX), R10

// Multiply and accumulate
IMULQ R9, R10
ADDQ R10, R8

INCQ AX
INCQ BX
DECQ DX
JMP tailloop
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case our input is not a multiple of 16 (which it will be for OpenAI embeddings), this handles the remainder

goos: linux
goarch: amd64
pkg: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings
cpu: Intel(R) Xeon(R) CPU E5-2643 v3 @ 3.40GHz
                                  │ /tmp/before.txt │            /tmp/after.txt            │
                                  │     sec/op      │    sec/op     vs base                │
SimilaritySearch/numWorkers=1-24      1817.3m ± 48%   286.2m ± 67%  -84.25% (p=0.000 n=10)
SimilaritySearch/numWorkers=2-24       845.2m ± 31%   150.9m ± 21%  -82.15% (p=0.000 n=10)
SimilaritySearch/numWorkers=4-24       593.8m ± 21%   107.0m ± 15%  -81.99% (p=0.000 n=10)
SimilaritySearch/numWorkers=8-24      302.67m ± 19%   77.49m ± 14%  -74.40% (p=0.000 n=10)
SimilaritySearch/numWorkers=16-24     173.93m ± 12%   83.05m ±  6%  -52.25% (p=0.000 n=10)
geomean                                544.9m         124.3m        -77.18%
Comment on lines -204 to -254
func CosineSimilarity(row []int8, query []int8) int32 {
similarity := int32(0)

count := len(row)
if count > len(query) {
// Do this ahead of time so the compiler doesn't need to bounds check
// every time we index into query.
panic("mismatched vector lengths")
}

i := 0
for ; i+3 < count; i += 4 {
m0 := int32(row[i]) * int32(query[i])
m1 := int32(row[i+1]) * int32(query[i+1])
m2 := int32(row[i+2]) * int32(query[i+2])
m3 := int32(row[i+3]) * int32(query[i+3])
similarity += (m0 + m1 + m2 + m3)
}

for ; i < count; i++ {
similarity += int32(row[i]) * int32(query[i])
}

return similarity
}

func CosineSimilarityFloat32(row []float32, query []float32) float32 {
similarity := float32(0)

count := len(row)
if count > len(query) {
// Do this ahead of time so the compiler doesn't need to bounds check
// every time we index into query.
panic("mismatched vector lengths")
}

i := 0
for ; i+3 < count; i += 4 {
m0 := row[i] * query[i]
m1 := row[i+1] * query[i+1]
m2 := row[i+2] * query[i+2]
m3 := row[i+3] * query[i+3]
similarity += (m0 + m1 + m2 + m3)
}

for ; i < count; i++ {
similarity += row[i] * query[i]
}

return similarity
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved these into dot.go and renamed them to Dot*. The dot product is only equivalent to cosine similarity if the vectors are normalized, so I think the rename is justified in the case we ever use non-normalized vectors.

@camdencheek camdencheek marked this pull request as ready for review May 2, 2023 18:41
@camdencheek camdencheek requested a review from a team May 2, 2023 18:41
Copy link
Member

@jtibshirani jtibshirani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me, but I'm not up-to-speed on Go assembly! Maybe there's someone more knowledgeable who could give a timely review?

Also, it's interesting that when testing on the GCE instance, we get a max 2x speeduup. This is different from what we observed when testing locally, where the speedup scales with the number of workers: #51372. It means we should limit the request parallelism to something more conservative like 2 threads, rather than the number of processors as we do now. This would be for a follow-up, as it's separate from this PR.

enterprise/internal/embeddings/dot_amd64.go Outdated Show resolved Hide resolved
enterprise/internal/embeddings/dot.go Show resolved Hide resolved
enterprise/internal/embeddings/dot_test.go Show resolved Hide resolved
Copy link
Contributor

@vdavid vdavid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through the code (on mobile 😄 ) and it LGTM. It's been two decades since I last did assembly and, of course, I have no idea of these new instructions, but it looks reasonable, and the test coverage is convincing. Overall, wow, impressive work!

Copy link
Member

@jtibshirani jtibshirani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me from the config and search side. I will dig through the assembly at another time :)

@camdencheek camdencheek merged commit 17a8ec9 into main May 2, 2023
17 checks passed
@camdencheek camdencheek deleted the cc/simd-embeddings branch May 2, 2023 20:58
@github-actions
Copy link
Contributor

github-actions bot commented May 2, 2023

The backport to 5.0 failed:

The process '/usr/bin/git' failed with exit code 1

To backport manually, run these commands in your terminal:

# Fetch latest updates from GitHub
git fetch
# Create a new working tree
git worktree add .worktrees/backport-5.0 5.0
# Navigate to the new working tree
cd .worktrees/backport-5.0
# Create a new branch
git switch --create backport-51372-to-5.0
# Cherry-pick the merged commit of this pull request and resolve the conflicts
git cherry-pick -x --mainline 1 17a8ec942c1eaca26ae62191460e7ff9bd6285aa
# Push it to GitHub
git push --set-upstream origin backport-51372-to-5.0
# Go back to the original working tree
cd ../..
# Delete the working tree
git worktree remove .worktrees/backport-5.0

Then, create a pull request where the base branch is 5.0 and the compare/head branch is backport-51372-to-5.0.

@github-actions github-actions bot added backports failed-backport-to-5.0 release-blocker Prevents us from releasing: https://about.sourcegraph.com/handbook/engineering/releases labels May 2, 2023
camdencheek added a commit that referenced this pull request May 2, 2023
This implements a hand-written assembly version of the int8 dot product
that takes advantage of AVX2 SIMD instructions. This speeds up our
embeddings searches by roughly 10x on modern x86_64 machines.

(cherry picked from commit 17a8ec9)
@@ -0,0 +1,70 @@
#include "textflag.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is missing from this patch; maybe remove this #include?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean "missing from this patch"? The #include "textflag.h" is needed to define the NOSPLIT symbol

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean "you didn't commit a textflag.h file", it's a go compiler builtin

SUBQ $16, DX
JMP blockloop

reduce:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that reduce: and tailloop: are running only once (or are very small), it would make things a bit simpler to remove the assembly for them and write them in Go. (Assuming it's possible to access Y0 from Go, otherwise it would make sense to leave the reduce code in assembly)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, it is not possible to access Y0 from Go without assembly

Comment on lines +4 to +6
MOVQ a_base+0(FP), AX
MOVQ b_base+24(FP), BX
MOVQ a_len+8(FP), DX
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this hard-coding the stack offsets based on the ABI of a slice? If so, it'll break if the compiler starts using SROA for slices.

Instead, you could use https://pkg.go.dev/unsafe#SliceData to get the underlying pointer in a stable way. Then this function would take in two pointers and the one length as the arguments, instead of hard-coding stack offsets here.

Otherwise, at least add a comment describing where these hard-coded constants come from.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding was that, unless I opt into ABIInternal (or any future stable ABI), I can depend on the current (stack-based) ABI to be stable.

Of note, the a_base notation is a mnemonic that is checked by go vet. So if the field offset does not line up with the FP offset I specified there, go vet will complain.

I'll add some comments describing the offsets though 👍

Comment on lines +19 to +20
VPMOVSXBW (AX), Y1
VPMOVSXBW (BX), Y2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a link to the calling convention where it's described whether these registers are preserved or clobbered across a call? It seems like this code is assuming that all the registers it is using are caller-preserve (i.e. OK to clobber).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See "Clobber sets" here. Based on my read of it, I do not need to worry about callee-saved registers

JMP tailloop

end:
MOVQ R8, ret+48(FP)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a more "stable"/"reliable" way to get this rather than hard-coding the stack offset? IIUC this is just writing the return value, It'll also break if Go starts returning small return values in registers.

I'm surprised this is actually working, I thought Go started using a register based calling convention recently... Maybe only for parameters or for functions implemented in Go?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the examples I've seen hard-code the stack offset. I agree it's awkward and error-prone.

The register-based calling convention is only used for compiled go source code unless you opt into it with the ABIInternal flag on the function definition. So, since I did not opt into it, I am using the old stack-based calling convention

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also be caught by go vet though. ret is the implicit return variable name, so go vet will check that 48(FP) is, in fact, the target address for the return value

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually though, I thought go vet runs in CI. It does not. Apparently, I miscalculated the frame size. PR incoming.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

go vet should be running 🏃‍♀️ It runs as part of the nogo linters 🤔 I'll double check the BAZEL config to make sure

Copy link
Member

@keegancsmith keegancsmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat.

Also, it's interesting that when testing on the GCE instance, we get a max 2x speeduup

@jtibshirani is it possible that the node you are running on has a lot more CPUs than kubernetes is configured to let you use? So we end up doing too much parallelism / something else is confusing in the measurements. I would make sure we are using automaxprocs: https://sourcegraph.com/search?q=context:global+r:%5Egithub%5C.com/sourcegraph/+maxprocs.Set&patternType=standard&sm=0&groupBy=repo

got := Dot(a, b)

if want != got {
t.Fatalf("a: %#v\nb: %#v\ngot: %d\nwant: %d", a, b, got, want)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t.Log otherwise you never return false.

@camdencheek
Copy link
Member Author

camdencheek commented May 3, 2023

@keegancsmith, the benchmarks were not running in kubernetes, so it's unlikely to be related to reserved CPUs.

For clarity, we're seeing different behaviors on the different machines I've tested on.

M1 scales linearly up to 8 cores, which is what we based our initial assumption of scaling on.

My home server (2014 intel 12 core) scales linearly up to 16 cores without SIMD, but only up to 8 cores with SIMD. I expect it's starting to hit memory bandwidth and/or cache limits with SIMD implementation (M1s have stupidly good memory bandwidth).

The GCE n2-standard-4 scales up to 2 cores without SIMD, and 4 cores with SIMD, but the 4 cores is only ~2x faster than 1 core. This is where the "2x" number is coming from. This problem should be very parallel-friendly, but we could be hitting cache effects.

I put together a spreadsheet with the numbers I'm working from. Note that these numbers aren't super rigorous and I haven't looked into this very closely. It's more of an observation that I thought was interesting and probably deserves a little bit of looking to make sure we're not throwing more CPU at the problem than we can use.

@daxmc99
Copy link
Member

daxmc99 commented May 4, 2023

Something we could consider here, since we are using VPMOVSXBW we are leveraging the GOAMD64=v3 class of instructions with AVX.

It might be worth it to consider building another set of binaries that unlock these gains at the runtime level by setting GOAMD64=v3 at compile time.
Maybe something we can do easier now that we have Bazel?

@camdencheek
Copy link
Member Author

AFAICT, the runtime uses very few v3-specific features so far. Am I looking at the right thing?

@kiroma
Copy link

kiroma commented May 10, 2023

I'm comparing this hand-written assembly to what clang generates from C++ code, and shouldn't there be a vzeroupper as the last instruction of the function?

@camdencheek
Copy link
Member Author

@kiroma good catch! This was an oversight on my part. Fixed here

@camdencheek camdencheek added backported-to-5.0 and removed release-blocker Prevents us from releasing: https://about.sourcegraph.com/handbook/engineering/releases failed-backport-to-5.0 labels May 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants