From afc1b036690662bf2c013ddc2fbd478b252e0d6d Mon Sep 17 00:00:00 2001 From: Camden Cheek Date: Mon, 24 Apr 2023 14:24:51 -0600 Subject: [PATCH] Embeddings: simplify batching (#51014) This is just some simplification of the batching code so the control flow is more clear. I'm going to be adding stats collection to this function, which will likey muddy things even more, so I wanted to clean this up first. --- enterprise/internal/embeddings/embed/embed.go | 90 +++++++++---------- 1 file changed, 40 insertions(+), 50 deletions(-) diff --git a/enterprise/internal/embeddings/embed/embed.go b/enterprise/internal/embeddings/embed/embed.go index bc2454a5c18f..3b149317a911 100644 --- a/enterprise/internal/embeddings/embed/embed.go +++ b/enterprise/internal/embeddings/embed/embed.go @@ -67,14 +67,6 @@ func EmbedRepo( return &embeddings.RepoEmbeddingIndex{RepoName: repoName, Revision: revision, CodeIndex: codeIndex, TextIndex: textIndex}, nil } -func createEmptyEmbeddingIndex(columnDimension int) embeddings.EmbeddingIndex { - return embeddings.EmbeddingIndex{ - Embeddings: []int8{}, - RowMetadata: []embeddings.RepoEmbeddingRowMetadata{}, - ColumnDimension: columnDimension, - } -} - // embedFiles embeds file contents from the given file names. Since embedding models can only handle a certain amount of text (tokens) we cannot embed // entire files. So we split the file contents into chunks and get embeddings for the chunks in batches. Functions returns an EmbeddingIndex containing // the embeddings and metadata about the chunks the embeddings correspond to. @@ -89,11 +81,7 @@ func embedFiles( ) (embeddings.EmbeddingIndex, error) { dimensions, err := client.GetDimensions() if err != nil { - return createEmptyEmbeddingIndex(dimensions), err - } - - if len(fileNames) == 0 { - return createEmptyEmbeddingIndex(dimensions), nil + return embeddings.EmbeddingIndex{}, err } index := embeddings.EmbeddingIndex{ @@ -103,34 +91,42 @@ func embedFiles( Ranks: make([]float32, 0, len(fileNames)), } - // addEmbeddableChunks batches embeddable chunks, gets embeddings for the batches, and appends them to the index above. - addEmbeddableChunks := func(embeddableChunks []split.EmbeddableChunk, batchSize int) error { - // The embeddings API operates with batches up to a certain size, so we can't send all embeddable chunks for embedding at once. - // We batch them according to `batchSize`, and embed one by one. - for i := 0; i < len(embeddableChunks); i += batchSize { - end := min(len(embeddableChunks), i+batchSize) - batch := embeddableChunks[i:end] - batchChunks := make([]string, len(batch)) - for idx, chunk := range batch { - batchChunks[idx] = chunk.Content - index.RowMetadata = append(index.RowMetadata, embeddings.RepoEmbeddingRowMetadata{FileName: chunk.FileName, StartLine: chunk.StartLine, EndLine: chunk.EndLine}) - - // Unknown documents have rank 0. Zoekt is a bit smarter about this, assigning 0 - // to "unimportant" files and the average for unknown files. We should probably - // add this here, too. - index.Ranks = append(index.Ranks, float32(repoPathRanks.Paths[chunk.FileName])) - } + var batch []split.EmbeddableChunk - batchEmbeddings, err := client.GetEmbeddingsWithRetries(ctx, batchChunks, GET_EMBEDDINGS_MAX_RETRIES) - if err != nil { - return errors.Wrap(err, "error while getting embeddings") - } - index.Embeddings = append(index.Embeddings, embeddings.Quantize(batchEmbeddings)...) + flush := func() error { + if len(batch) == 0 { + return nil + } + + batchChunks := make([]string, len(batch)) + for idx, chunk := range batch { + batchChunks[idx] = chunk.Content + index.RowMetadata = append(index.RowMetadata, embeddings.RepoEmbeddingRowMetadata{FileName: chunk.FileName, StartLine: chunk.StartLine, EndLine: chunk.EndLine}) + + // Unknown documents have rank 0. Zoekt is a bit smarter about this, assigning 0 + // to "unimportant" files and the average for unknown files. We should probably + // add this here, too. + index.Ranks = append(index.Ranks, float32(repoPathRanks.Paths[chunk.FileName])) + } + + batchEmbeddings, err := client.GetEmbeddingsWithRetries(ctx, batchChunks, GET_EMBEDDINGS_MAX_RETRIES) + if err != nil { + return errors.Wrap(err, "error while getting embeddings") + } + index.Embeddings = append(index.Embeddings, embeddings.Quantize(batchEmbeddings)...) + batch = batch[:0] // reset batch + return nil + } + + addToBatch := func(chunk split.EmbeddableChunk) error { + batch = append(batch, chunk) + if len(batch) >= EMBEDDING_BATCH_SIZE { + // Flush if we've hit batch size + return flush() } return nil } - embeddableChunks := []split.EmbeddableChunk{} for _, fileName := range fileNames { // This is a fail-safe measure to prevent producing an extremely large index for large repositories. if len(index.RowMetadata) > maxEmbeddingVectors { @@ -139,29 +135,23 @@ func embedFiles( contentBytes, err := readFile(ctx, fileName) if err != nil { - return createEmptyEmbeddingIndex(dimensions), errors.Wrap(err, "error while reading a file") + return embeddings.EmbeddingIndex{}, errors.Wrap(err, "error while reading a file") } + if embeddable, _ := isEmbeddableFileContent(contentBytes); !embeddable { continue } - content := string(contentBytes) - - embeddableChunks = append(embeddableChunks, split.SplitIntoEmbeddableChunks(content, fileName, splitOptions)...) - if len(embeddableChunks) > EMBEDDING_BATCHES*EMBEDDING_BATCH_SIZE { - err := addEmbeddableChunks(embeddableChunks, EMBEDDING_BATCH_SIZE) - if err != nil { - return createEmptyEmbeddingIndex(dimensions), err + for _, chunk := range split.SplitIntoEmbeddableChunks(string(contentBytes), fileName, splitOptions) { + if err := addToBatch(chunk); err != nil { + return embeddings.EmbeddingIndex{}, err } - embeddableChunks = []split.EmbeddableChunk{} } } - if len(embeddableChunks) > 0 { - err := addEmbeddableChunks(embeddableChunks, EMBEDDING_BATCH_SIZE) - if err != nil { - return createEmptyEmbeddingIndex(dimensions), err - } + // Always do a final flush + if err := flush(); err != nil { + return embeddings.EmbeddingIndex{}, err } return index, nil