New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Embeddings: search with SIMD #51372
Embeddings: search with SIMD #51372
Conversation
// Sign-extend 16 bytes into 16 int16s | ||
VPMOVSXBW (AX), Y1 | ||
VPMOVSXBW (BX), Y2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do operations on int16s because there is (surprisingly) no instruction to multiply-add signed int8 vectors.
// X0 is the low bits of Y0. | ||
// Extract the high bits into X1, fold in half, add, repeat. | ||
VEXTRACTI128 $1, Y0, X1 | ||
VPADDD X0, X1, X0 | ||
|
||
VPSRLDQ $8, X0, X1 | ||
VPADDD X0, X1, X0 | ||
|
||
VPSRLDQ $4, X0, X1 | ||
VPADDD X0, X1, X0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section sums the 8 32-bit ints in Y0 by repeatedly folding it in half and adding vertically. We are left with the sum in the rightmost position of X0
// In tailloop, we add to the dot product one at a time | ||
tailloop: | ||
CMPQ DX, $0 | ||
JE end | ||
|
||
// Load values from the input slices | ||
MOVBQSX (AX), R9 | ||
MOVBQSX (BX), R10 | ||
|
||
// Multiply and accumulate | ||
IMULQ R9, R10 | ||
ADDQ R10, R8 | ||
|
||
INCQ AX | ||
INCQ BX | ||
DECQ DX | ||
JMP tailloop |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case our input is not a multiple of 16 (which it will be for OpenAI embeddings), this handles the remainder
goos: linux goarch: amd64 pkg: github.com/sourcegraph/sourcegraph/enterprise/internal/embeddings cpu: Intel(R) Xeon(R) CPU E5-2643 v3 @ 3.40GHz │ /tmp/before.txt │ /tmp/after.txt │ │ sec/op │ sec/op vs base │ SimilaritySearch/numWorkers=1-24 1817.3m ± 48% 286.2m ± 67% -84.25% (p=0.000 n=10) SimilaritySearch/numWorkers=2-24 845.2m ± 31% 150.9m ± 21% -82.15% (p=0.000 n=10) SimilaritySearch/numWorkers=4-24 593.8m ± 21% 107.0m ± 15% -81.99% (p=0.000 n=10) SimilaritySearch/numWorkers=8-24 302.67m ± 19% 77.49m ± 14% -74.40% (p=0.000 n=10) SimilaritySearch/numWorkers=16-24 173.93m ± 12% 83.05m ± 6% -52.25% (p=0.000 n=10) geomean 544.9m 124.3m -77.18%
805c336
to
8bca4c1
Compare
0b9ac83
to
d667d02
Compare
func CosineSimilarity(row []int8, query []int8) int32 { | ||
similarity := int32(0) | ||
|
||
count := len(row) | ||
if count > len(query) { | ||
// Do this ahead of time so the compiler doesn't need to bounds check | ||
// every time we index into query. | ||
panic("mismatched vector lengths") | ||
} | ||
|
||
i := 0 | ||
for ; i+3 < count; i += 4 { | ||
m0 := int32(row[i]) * int32(query[i]) | ||
m1 := int32(row[i+1]) * int32(query[i+1]) | ||
m2 := int32(row[i+2]) * int32(query[i+2]) | ||
m3 := int32(row[i+3]) * int32(query[i+3]) | ||
similarity += (m0 + m1 + m2 + m3) | ||
} | ||
|
||
for ; i < count; i++ { | ||
similarity += int32(row[i]) * int32(query[i]) | ||
} | ||
|
||
return similarity | ||
} | ||
|
||
func CosineSimilarityFloat32(row []float32, query []float32) float32 { | ||
similarity := float32(0) | ||
|
||
count := len(row) | ||
if count > len(query) { | ||
// Do this ahead of time so the compiler doesn't need to bounds check | ||
// every time we index into query. | ||
panic("mismatched vector lengths") | ||
} | ||
|
||
i := 0 | ||
for ; i+3 < count; i += 4 { | ||
m0 := row[i] * query[i] | ||
m1 := row[i+1] * query[i+1] | ||
m2 := row[i+2] * query[i+2] | ||
m3 := row[i+3] * query[i+3] | ||
similarity += (m0 + m1 + m2 + m3) | ||
} | ||
|
||
for ; i < count; i++ { | ||
similarity += row[i] * query[i] | ||
} | ||
|
||
return similarity | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved these into dot.go
and renamed them to Dot*
. The dot product is only equivalent to cosine similarity if the vectors are normalized, so I think the rename is justified in the case we ever use non-normalized vectors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me, but I'm not up-to-speed on Go assembly! Maybe there's someone more knowledgeable who could give a timely review?
Also, it's interesting that when testing on the GCE instance, we get a max 2x speeduup. This is different from what we observed when testing locally, where the speedup scales with the number of workers: #51372. It means we should limit the request parallelism to something more conservative like 2 threads, rather than the number of processors as we do now. This would be for a follow-up, as it's separate from this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went through the code (on mobile 😄 ) and it LGTM. It's been two decades since I last did assembly and, of course, I have no idea of these new instructions, but it looks reasonable, and the test coverage is convincing. Overall, wow, impressive work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me from the config and search side. I will dig through the assembly at another time :)
The backport to
To backport manually, run these commands in your terminal: # Fetch latest updates from GitHub
git fetch
# Create a new working tree
git worktree add .worktrees/backport-5.0 5.0
# Navigate to the new working tree
cd .worktrees/backport-5.0
# Create a new branch
git switch --create backport-51372-to-5.0
# Cherry-pick the merged commit of this pull request and resolve the conflicts
git cherry-pick -x --mainline 1 17a8ec942c1eaca26ae62191460e7ff9bd6285aa
# Push it to GitHub
git push --set-upstream origin backport-51372-to-5.0
# Go back to the original working tree
cd ../..
# Delete the working tree
git worktree remove .worktrees/backport-5.0 Then, create a pull request where the |
This implements a hand-written assembly version of the int8 dot product that takes advantage of AVX2 SIMD instructions. This speeds up our embeddings searches by roughly 10x on modern x86_64 machines. (cherry picked from commit 17a8ec9)
@@ -0,0 +1,70 @@ | |||
#include "textflag.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is missing from this patch; maybe remove this #include
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean "missing from this patch"? The #include "textflag.h"
is needed to define the NOSPLIT symbol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you mean "you didn't commit a textflag.h file", it's a go compiler builtin
SUBQ $16, DX | ||
JMP blockloop | ||
|
||
reduce: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that reduce:
and tailloop:
are running only once (or are very small), it would make things a bit simpler to remove the assembly for them and write them in Go. (Assuming it's possible to access Y0 from Go, otherwise it would make sense to leave the reduce
code in assembly)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK, it is not possible to access Y0 from Go without assembly
MOVQ a_base+0(FP), AX | ||
MOVQ b_base+24(FP), BX | ||
MOVQ a_len+8(FP), DX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this hard-coding the stack offsets based on the ABI of a slice? If so, it'll break if the compiler starts using SROA for slices.
Instead, you could use https://pkg.go.dev/unsafe#SliceData to get the underlying pointer in a stable way. Then this function would take in two pointers and the one length as the arguments, instead of hard-coding stack offsets here.
Otherwise, at least add a comment describing where these hard-coded constants come from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding was that, unless I opt into ABIInternal
(or any future stable ABI), I can depend on the current (stack-based) ABI to be stable.
Of note, the a_base
notation is a mnemonic that is checked by go vet. So if the field offset does not line up with the FP offset I specified there, go vet will complain.
I'll add some comments describing the offsets though 👍
VPMOVSXBW (AX), Y1 | ||
VPMOVSXBW (BX), Y2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a link to the calling convention where it's described whether these registers are preserved or clobbered across a call? It seems like this code is assuming that all the registers it is using are caller-preserve (i.e. OK to clobber).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See "Clobber sets" here. Based on my read of it, I do not need to worry about callee-saved registers
JMP tailloop | ||
|
||
end: | ||
MOVQ R8, ret+48(FP) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a more "stable"/"reliable" way to get this rather than hard-coding the stack offset? IIUC this is just writing the return value, It'll also break if Go starts returning small return values in registers.
I'm surprised this is actually working, I thought Go started using a register based calling convention recently... Maybe only for parameters or for functions implemented in Go?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the examples I've seen hard-code the stack offset. I agree it's awkward and error-prone.
The register-based calling convention is only used for compiled go source code unless you opt into it with the ABIInternal
flag on the function definition. So, since I did not opt into it, I am using the old stack-based calling convention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will also be caught by go vet
though. ret
is the implicit return variable name, so go vet
will check that 48(FP)
is, in fact, the target address for the return value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually though, I thought go vet
runs in CI. It does not. Apparently, I miscalculated the frame size. PR incoming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
go vet should be running 🏃♀️ It runs as part of the nogo linters 🤔 I'll double check the BAZEL config to make sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neat.
Also, it's interesting that when testing on the GCE instance, we get a max 2x speeduup
@jtibshirani is it possible that the node you are running on has a lot more CPUs than kubernetes is configured to let you use? So we end up doing too much parallelism / something else is confusing in the measurements. I would make sure we are using automaxprocs: https://sourcegraph.com/search?q=context:global+r:%5Egithub%5C.com/sourcegraph/+maxprocs.Set&patternType=standard&sm=0&groupBy=repo
got := Dot(a, b) | ||
|
||
if want != got { | ||
t.Fatalf("a: %#v\nb: %#v\ngot: %d\nwant: %d", a, b, got, want) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t.Log otherwise you never return false.
@keegancsmith, the benchmarks were not running in kubernetes, so it's unlikely to be related to reserved CPUs. For clarity, we're seeing different behaviors on the different machines I've tested on. M1 scales linearly up to 8 cores, which is what we based our initial assumption of scaling on. My home server (2014 intel 12 core) scales linearly up to 16 cores without SIMD, but only up to 8 cores with SIMD. I expect it's starting to hit memory bandwidth and/or cache limits with SIMD implementation (M1s have stupidly good memory bandwidth). The GCE n2-standard-4 scales up to 2 cores without SIMD, and 4 cores with SIMD, but the 4 cores is only ~2x faster than 1 core. This is where the "2x" number is coming from. This problem should be very parallel-friendly, but we could be hitting cache effects. I put together a spreadsheet with the numbers I'm working from. Note that these numbers aren't super rigorous and I haven't looked into this very closely. It's more of an observation that I thought was interesting and probably deserves a little bit of looking to make sure we're not throwing more CPU at the problem than we can use. |
Something we could consider here, since we are using It might be worth it to consider building another set of binaries that unlock these gains at the runtime level by setting |
AFAICT, the runtime uses very few v3-specific features so far. Am I looking at the right thing? |
I'm comparing this hand-written assembly to what clang generates from C++ code, and shouldn't there be a |
This implements a hand-written assembly version of the int8 dot product that takes advantage of AVX2 SIMD instructions. This speeds up our embeddings searches by roughly 10x on modern x86_64 machines.
Test plan
Added quickchecks and fuzz tests to compare output with the go version.
The following benchmark is for a
n2-standard-4
GCE instance, which is a very standard machine type. For the single core benchmark, we can search about 6 million embeddings per second, which is the equivalent of a 6GB monorepo.This is low-risk to merge because it is disabled by default.