-
-
Notifications
You must be signed in to change notification settings - Fork 521
/
embedder.go
39 lines (32 loc) · 1.18 KB
/
embedder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
package chroma
import (
"context"
chromatypes "github.com/amikos-tech/chroma-go/types"
"github.com/tmc/langchaingo/embeddings"
)
var _ chromatypes.EmbeddingFunction = chromaGoEmbedder{} // compile-time check
// chromaGoEmbedder adapts an 'embeddings.Embedder' to a 'chroma_go.EmbeddingFunction'.
type chromaGoEmbedder struct {
embeddings.Embedder
}
func (e chromaGoEmbedder) EmbedDocuments(ctx context.Context, texts []string) ([]*chromatypes.Embedding, error) {
_embeddings, err := e.Embedder.EmbedDocuments(ctx, texts)
if err != nil {
return nil, err
}
_chrmembeddings := make([]*chromatypes.Embedding, len(_embeddings))
for i, emb := range _embeddings {
_chrmembeddings[i] = chromatypes.NewEmbeddingFromFloat32(emb)
}
return _chrmembeddings, nil
}
func (e chromaGoEmbedder) EmbedQuery(ctx context.Context, text string) (*chromatypes.Embedding, error) {
_embedding, err := e.Embedder.EmbedQuery(ctx, text)
if err != nil {
return nil, err
}
return chromatypes.NewEmbeddingFromFloat32(_embedding), nil
}
func (e chromaGoEmbedder) EmbedRecords(ctx context.Context, records []*chromatypes.Record, force bool) error {
return chromatypes.EmbedRecordsDefaultImpl(e, ctx, records, force)
}