-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathdatasets.go
88 lines (70 loc) · 2.13 KB
/
datasets.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package embedding
import (
"context"
"fmt"
"io"
"math"
"os"
"github.com/copilot-extensions/rag-extension/copilot"
)
func Create(ctx context.Context, integrationID, apiToken string, content string) ([]float32, error) {
resp, err := copilot.Embeddings(ctx, integrationID, apiToken, &copilot.EmbeddingsRequest{
Model: copilot.ModelEmbeddings,
Input: []string{content},
})
if err != nil {
return nil, fmt.Errorf("error fetching embeddings: %w", err)
}
for _, data := range resp.Data {
return data.Embedding, nil
}
return nil, fmt.Errorf("no embeddings found in response")
}
type Dataset struct {
Embedding []float32
Filename string
}
func GenerateDatasets(integrationID, apiToken string, filenames []string) ([]*Dataset, error) {
datasets := make([]*Dataset, len(filenames))
for i, filename := range filenames {
file, err := os.Open(filename)
if err != nil {
return nil, fmt.Errorf("error reading in file %s: %w", filename, err)
}
fileContent, err := io.ReadAll(file)
embedding, err := Create(context.Background(), integrationID, apiToken, string(fileContent))
if err != nil {
return nil, fmt.Errorf("error creating embedding for file %s: %w", filename, err)
}
datasets[i] = &Dataset{
Embedding: embedding,
Filename: filename,
}
}
return datasets, nil
}
func FindBestDataset(datasets []*Dataset, target []float32) (*Dataset, error) {
var bestDataset *Dataset
var bestScore float32
var targetMagnitude float32
for i := 0; i < len(target); i++ {
targetMagnitude += target[i] * target[i]
}
for _, dataset := range datasets {
// Score similarity using Cosine Similarity
if len(target) != len(dataset.Embedding) {
return nil, fmt.Errorf("embeddings are different length, cannot compare")
}
var docMagnitude, dotProduct float32
for i := 0; i < len(target); i++ {
docMagnitude += dataset.Embedding[i] * dataset.Embedding[i]
dotProduct += target[i] * dataset.Embedding[i]
}
dotProduct /= float32(math.Sqrt(float64(targetMagnitude)) * math.Sqrt(float64(docMagnitude)))
if dotProduct > bestScore {
bestDataset = dataset
bestScore = dotProduct
}
}
return bestDataset, nil
}