-
Notifications
You must be signed in to change notification settings - Fork 4
/
embeddingcache.go
115 lines (93 loc) · 3.15 KB
/
embeddingcache.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package embed
import (
"encoding/json"
"github.com/tzapio/tzap/internal/logging/tl"
"github.com/tzapio/tzap/pkg/types"
"github.com/tzapio/tzap/pkg/tzap"
"github.com/tzapio/tzap/pkg/util/reflectutil"
)
type EmbeddingCache struct {
embeddingCacheDB types.DBCollectionInterface[string]
}
func NewEmbeddingCache(embeddingCacheDB types.DBCollectionInterface[string]) *EmbeddingCache {
return &EmbeddingCache{embeddingCacheDB}
}
func (ec *EmbeddingCache) GetCachedEmbeddings(files []types.FileReader, embeddings *types.Embeddings) (*types.Embeddings, error) {
tl.Logger.Println("Getting cached embeddings", len(embeddings.Vectors))
var cachedEmbeddings []*types.Vector
for _, vector := range embeddings.Vectors {
splitPart := vector.Metadata.SplitPart
kv, exists := ec.embeddingCacheDB.ScanGet(splitPart)
if exists {
if !reflectutil.IsZero(kv.Value) {
var float32Vector [1536]float32
err := json.Unmarshal([]byte(kv.Value), &float32Vector)
if err != nil {
return nil, err
}
if len(float32Vector) == 1536 {
cachedVector := &types.Vector{
ID: vector.ID,
TimeStamp: 0,
Metadata: vector.Metadata,
Values: float32Vector,
}
cachedEmbeddings = append(cachedEmbeddings, cachedVector)
continue
} else {
println("invalid vector length", splitPart)
return &types.Embeddings{}, nil
}
}
}
println("Warning: %s is uncached.", vector.ID)
}
return &types.Embeddings{Vectors: cachedEmbeddings}, nil
}
func (ec *EmbeddingCache) GetUncachedEmbeddings(embeddings *types.Embeddings) *types.Embeddings {
var uncachedEmbeddings []*types.Vector
for _, vector := range embeddings.Vectors {
splitPart := vector.Metadata.SplitPart
kv, exists := ec.embeddingCacheDB.ScanGet(splitPart)
if !exists || reflectutil.IsZero(kv.Value) {
uncachedEmbeddings = append(uncachedEmbeddings, vector)
}
}
return &types.Embeddings{Vectors: uncachedEmbeddings}
}
func (ec *EmbeddingCache) FetchThenCacheNewEmbeddings(t *tzap.Tzap, files []types.FileReader, uncachedEmbeddings *types.Embeddings) error {
storedFiles := map[string]struct{}{}
if len(uncachedEmbeddings.Vectors) > 0 {
batchSize := 200
for i := 0; i < len(uncachedEmbeddings.Vectors); i += batchSize {
end := i + batchSize
if end > len(uncachedEmbeddings.Vectors) {
end = len(uncachedEmbeddings.Vectors)
}
batch := uncachedEmbeddings.Vectors[i:end]
var inputStrings []string
for _, vector := range batch {
storedFiles[vector.Metadata.Filename] = struct{}{}
inputStrings = append(inputStrings, vector.Metadata.SplitPart)
}
embeddingsResult, err := t.TG.FetchEmbedding(t.C, inputStrings...)
if err != nil {
return err
}
cacheKeyVal := make([]types.KeyValue[string], len(embeddingsResult))
for i, embedding := range embeddingsResult {
embBytes, err := json.Marshal(embedding)
if err != nil {
return err
}
cacheKeyVal[i] = types.KeyValue[string]{Key: inputStrings[i], Value: string(embBytes)}
}
added, err := ec.embeddingCacheDB.BatchSet(cacheKeyVal)
if err != nil {
return err
}
tl.UILogger.Println("Added", added, "embeddings to cache")
}
}
return nil
}