Skip to content

Commit

Permalink
Embeddings: simplify batching (#51014)
Browse files Browse the repository at this point in the history
This is just some simplification of the batching code so the control
flow is more clear. I'm going to be adding stats collection to this
function, which will likey muddy things even more, so I wanted to clean
this up first.
  • Loading branch information
camdencheek committed Apr 24, 2023
1 parent 7da3dd0 commit afc1b03
Showing 1 changed file with 40 additions and 50 deletions.
90 changes: 40 additions & 50 deletions enterprise/internal/embeddings/embed/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,6 @@ func EmbedRepo(
return &embeddings.RepoEmbeddingIndex{RepoName: repoName, Revision: revision, CodeIndex: codeIndex, TextIndex: textIndex}, nil
}

func createEmptyEmbeddingIndex(columnDimension int) embeddings.EmbeddingIndex {
return embeddings.EmbeddingIndex{
Embeddings: []int8{},
RowMetadata: []embeddings.RepoEmbeddingRowMetadata{},
ColumnDimension: columnDimension,
}
}

// embedFiles embeds file contents from the given file names. Since embedding models can only handle a certain amount of text (tokens) we cannot embed
// entire files. So we split the file contents into chunks and get embeddings for the chunks in batches. Functions returns an EmbeddingIndex containing
// the embeddings and metadata about the chunks the embeddings correspond to.
Expand All @@ -89,11 +81,7 @@ func embedFiles(
) (embeddings.EmbeddingIndex, error) {
dimensions, err := client.GetDimensions()
if err != nil {
return createEmptyEmbeddingIndex(dimensions), err
}

if len(fileNames) == 0 {
return createEmptyEmbeddingIndex(dimensions), nil
return embeddings.EmbeddingIndex{}, err
}

index := embeddings.EmbeddingIndex{
Expand All @@ -103,34 +91,42 @@ func embedFiles(
Ranks: make([]float32, 0, len(fileNames)),
}

// addEmbeddableChunks batches embeddable chunks, gets embeddings for the batches, and appends them to the index above.
addEmbeddableChunks := func(embeddableChunks []split.EmbeddableChunk, batchSize int) error {
// The embeddings API operates with batches up to a certain size, so we can't send all embeddable chunks for embedding at once.
// We batch them according to `batchSize`, and embed one by one.
for i := 0; i < len(embeddableChunks); i += batchSize {
end := min(len(embeddableChunks), i+batchSize)
batch := embeddableChunks[i:end]
batchChunks := make([]string, len(batch))
for idx, chunk := range batch {
batchChunks[idx] = chunk.Content
index.RowMetadata = append(index.RowMetadata, embeddings.RepoEmbeddingRowMetadata{FileName: chunk.FileName, StartLine: chunk.StartLine, EndLine: chunk.EndLine})

// Unknown documents have rank 0. Zoekt is a bit smarter about this, assigning 0
// to "unimportant" files and the average for unknown files. We should probably
// add this here, too.
index.Ranks = append(index.Ranks, float32(repoPathRanks.Paths[chunk.FileName]))
}
var batch []split.EmbeddableChunk

batchEmbeddings, err := client.GetEmbeddingsWithRetries(ctx, batchChunks, GET_EMBEDDINGS_MAX_RETRIES)
if err != nil {
return errors.Wrap(err, "error while getting embeddings")
}
index.Embeddings = append(index.Embeddings, embeddings.Quantize(batchEmbeddings)...)
flush := func() error {
if len(batch) == 0 {
return nil
}

batchChunks := make([]string, len(batch))
for idx, chunk := range batch {
batchChunks[idx] = chunk.Content
index.RowMetadata = append(index.RowMetadata, embeddings.RepoEmbeddingRowMetadata{FileName: chunk.FileName, StartLine: chunk.StartLine, EndLine: chunk.EndLine})

// Unknown documents have rank 0. Zoekt is a bit smarter about this, assigning 0
// to "unimportant" files and the average for unknown files. We should probably
// add this here, too.
index.Ranks = append(index.Ranks, float32(repoPathRanks.Paths[chunk.FileName]))
}

batchEmbeddings, err := client.GetEmbeddingsWithRetries(ctx, batchChunks, GET_EMBEDDINGS_MAX_RETRIES)
if err != nil {
return errors.Wrap(err, "error while getting embeddings")
}
index.Embeddings = append(index.Embeddings, embeddings.Quantize(batchEmbeddings)...)
batch = batch[:0] // reset batch
return nil
}

addToBatch := func(chunk split.EmbeddableChunk) error {
batch = append(batch, chunk)
if len(batch) >= EMBEDDING_BATCH_SIZE {
// Flush if we've hit batch size
return flush()
}
return nil
}

embeddableChunks := []split.EmbeddableChunk{}
for _, fileName := range fileNames {
// This is a fail-safe measure to prevent producing an extremely large index for large repositories.
if len(index.RowMetadata) > maxEmbeddingVectors {
Expand All @@ -139,29 +135,23 @@ func embedFiles(

contentBytes, err := readFile(ctx, fileName)
if err != nil {
return createEmptyEmbeddingIndex(dimensions), errors.Wrap(err, "error while reading a file")
return embeddings.EmbeddingIndex{}, errors.Wrap(err, "error while reading a file")
}

if embeddable, _ := isEmbeddableFileContent(contentBytes); !embeddable {
continue
}
content := string(contentBytes)

embeddableChunks = append(embeddableChunks, split.SplitIntoEmbeddableChunks(content, fileName, splitOptions)...)

if len(embeddableChunks) > EMBEDDING_BATCHES*EMBEDDING_BATCH_SIZE {
err := addEmbeddableChunks(embeddableChunks, EMBEDDING_BATCH_SIZE)
if err != nil {
return createEmptyEmbeddingIndex(dimensions), err
for _, chunk := range split.SplitIntoEmbeddableChunks(string(contentBytes), fileName, splitOptions) {
if err := addToBatch(chunk); err != nil {
return embeddings.EmbeddingIndex{}, err
}
embeddableChunks = []split.EmbeddableChunk{}
}
}

if len(embeddableChunks) > 0 {
err := addEmbeddableChunks(embeddableChunks, EMBEDDING_BATCH_SIZE)
if err != nil {
return createEmptyEmbeddingIndex(dimensions), err
}
// Always do a final flush
if err := flush(); err != nil {
return embeddings.EmbeddingIndex{}, err
}

return index, nil
Expand Down

0 comments on commit afc1b03

Please sign in to comment.