diff --git a/enterprise/internal/embeddings/embed/embed.go b/enterprise/internal/embeddings/embed/embed.go index bc2454a5c18f..3b149317a911 100644 --- a/enterprise/internal/embeddings/embed/embed.go +++ b/enterprise/internal/embeddings/embed/embed.go @@ -67,14 +67,6 @@ func EmbedRepo( return &embeddings.RepoEmbeddingIndex{RepoName: repoName, Revision: revision, CodeIndex: codeIndex, TextIndex: textIndex}, nil } -func createEmptyEmbeddingIndex(columnDimension int) embeddings.EmbeddingIndex { - return embeddings.EmbeddingIndex{ - Embeddings: []int8{}, - RowMetadata: []embeddings.RepoEmbeddingRowMetadata{}, - ColumnDimension: columnDimension, - } -} - // embedFiles embeds file contents from the given file names. Since embedding models can only handle a certain amount of text (tokens) we cannot embed // entire files. So we split the file contents into chunks and get embeddings for the chunks in batches. Functions returns an EmbeddingIndex containing // the embeddings and metadata about the chunks the embeddings correspond to. @@ -89,11 +81,7 @@ func embedFiles( ) (embeddings.EmbeddingIndex, error) { dimensions, err := client.GetDimensions() if err != nil { - return createEmptyEmbeddingIndex(dimensions), err - } - - if len(fileNames) == 0 { - return createEmptyEmbeddingIndex(dimensions), nil + return embeddings.EmbeddingIndex{}, err } index := embeddings.EmbeddingIndex{ @@ -103,34 +91,42 @@ func embedFiles( Ranks: make([]float32, 0, len(fileNames)), } - // addEmbeddableChunks batches embeddable chunks, gets embeddings for the batches, and appends them to the index above. - addEmbeddableChunks := func(embeddableChunks []split.EmbeddableChunk, batchSize int) error { - // The embeddings API operates with batches up to a certain size, so we can't send all embeddable chunks for embedding at once. - // We batch them according to `batchSize`, and embed one by one. - for i := 0; i < len(embeddableChunks); i += batchSize { - end := min(len(embeddableChunks), i+batchSize) - batch := embeddableChunks[i:end] - batchChunks := make([]string, len(batch)) - for idx, chunk := range batch { - batchChunks[idx] = chunk.Content - index.RowMetadata = append(index.RowMetadata, embeddings.RepoEmbeddingRowMetadata{FileName: chunk.FileName, StartLine: chunk.StartLine, EndLine: chunk.EndLine}) - - // Unknown documents have rank 0. Zoekt is a bit smarter about this, assigning 0 - // to "unimportant" files and the average for unknown files. We should probably - // add this here, too. - index.Ranks = append(index.Ranks, float32(repoPathRanks.Paths[chunk.FileName])) - } + var batch []split.EmbeddableChunk - batchEmbeddings, err := client.GetEmbeddingsWithRetries(ctx, batchChunks, GET_EMBEDDINGS_MAX_RETRIES) - if err != nil { - return errors.Wrap(err, "error while getting embeddings") - } - index.Embeddings = append(index.Embeddings, embeddings.Quantize(batchEmbeddings)...) + flush := func() error { + if len(batch) == 0 { + return nil + } + + batchChunks := make([]string, len(batch)) + for idx, chunk := range batch { + batchChunks[idx] = chunk.Content + index.RowMetadata = append(index.RowMetadata, embeddings.RepoEmbeddingRowMetadata{FileName: chunk.FileName, StartLine: chunk.StartLine, EndLine: chunk.EndLine}) + + // Unknown documents have rank 0. Zoekt is a bit smarter about this, assigning 0 + // to "unimportant" files and the average for unknown files. We should probably + // add this here, too. + index.Ranks = append(index.Ranks, float32(repoPathRanks.Paths[chunk.FileName])) + } + + batchEmbeddings, err := client.GetEmbeddingsWithRetries(ctx, batchChunks, GET_EMBEDDINGS_MAX_RETRIES) + if err != nil { + return errors.Wrap(err, "error while getting embeddings") + } + index.Embeddings = append(index.Embeddings, embeddings.Quantize(batchEmbeddings)...) + batch = batch[:0] // reset batch + return nil + } + + addToBatch := func(chunk split.EmbeddableChunk) error { + batch = append(batch, chunk) + if len(batch) >= EMBEDDING_BATCH_SIZE { + // Flush if we've hit batch size + return flush() } return nil } - embeddableChunks := []split.EmbeddableChunk{} for _, fileName := range fileNames { // This is a fail-safe measure to prevent producing an extremely large index for large repositories. if len(index.RowMetadata) > maxEmbeddingVectors { @@ -139,29 +135,23 @@ func embedFiles( contentBytes, err := readFile(ctx, fileName) if err != nil { - return createEmptyEmbeddingIndex(dimensions), errors.Wrap(err, "error while reading a file") + return embeddings.EmbeddingIndex{}, errors.Wrap(err, "error while reading a file") } + if embeddable, _ := isEmbeddableFileContent(contentBytes); !embeddable { continue } - content := string(contentBytes) - - embeddableChunks = append(embeddableChunks, split.SplitIntoEmbeddableChunks(content, fileName, splitOptions)...) - if len(embeddableChunks) > EMBEDDING_BATCHES*EMBEDDING_BATCH_SIZE { - err := addEmbeddableChunks(embeddableChunks, EMBEDDING_BATCH_SIZE) - if err != nil { - return createEmptyEmbeddingIndex(dimensions), err + for _, chunk := range split.SplitIntoEmbeddableChunks(string(contentBytes), fileName, splitOptions) { + if err := addToBatch(chunk); err != nil { + return embeddings.EmbeddingIndex{}, err } - embeddableChunks = []split.EmbeddableChunk{} } } - if len(embeddableChunks) > 0 { - err := addEmbeddableChunks(embeddableChunks, EMBEDDING_BATCH_SIZE) - if err != nil { - return createEmptyEmbeddingIndex(dimensions), err - } + // Always do a final flush + if err := flush(); err != nil { + return embeddings.EmbeddingIndex{}, err } return index, nil