Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DotProduct Method and README Example for Embedding Similarity Search #492

Merged
merged 2 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,62 @@ func main() {
```
</details>

<detail>
<summary>Embedding Semantic Similarity</summary>

```go
package main

import (
"context"
"log"
openai "github.com/sashabaranov/go-openai"

)

func main() {
client := openai.NewClient("your-token")

// Create an EmbeddingRequest for the user query
queryReq := openai.EmbeddingRequest{
Input: []string{"How many chucks would a woodchuck chuck"},
Model: openai.AdaEmbeddingv2,
}

// Create an embedding for the user query
queryResponse, err := client.CreateEmbeddings(context.Background(), queryReq)
if err != nil {
log.Fatal("Error creating query embedding:", err)
}

// Create an EmbeddingRequest for the target text
targetReq := openai.EmbeddingRequest{
Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"},
Model: openai.AdaEmbeddingv2,
}

// Create an embedding for the target text
targetResponse, err := client.CreateEmbeddings(context.Background(), targetReq)
if err != nil {
log.Fatal("Error creating target embedding:", err)
}

// Now that we have the embeddings for the user query and the target text, we
// can calculate their similarity.
queryEmbedding := queryResponse.Data[0]
targetEmbedding := targetResponse.Data[0]

similarity, err := queryEmbedding.DotProduct(&targetEmbedding)
if err != nil {
log.Fatal("Error calculating dot product:", err)
}

log.Printf("The similarity score between the query and the target is %f", similarity)
}

```
</detail>

<details>
<summary>Azure OpenAI Embeddings</summary>

Expand Down
20 changes: 20 additions & 0 deletions embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import (
"context"
"encoding/base64"
"encoding/binary"
"errors"
"math"
"net/http"
)

var ErrVectorLengthMismatch = errors.New("vector length mismatch")

// EmbeddingModel enumerates the models which can be used
// to generate Embedding vectors.
type EmbeddingModel int
Expand Down Expand Up @@ -124,6 +127,23 @@ type Embedding struct {
Index int `json:"index"`
}

// DotProduct calculates the dot product of the embedding vector with another
// embedding vector. Both vectors must have the same length; otherwise, an
// ErrVectorLengthMismatch is returned. The method returns the calculated dot
// product as a float32 value.
func (e *Embedding) DotProduct(other *Embedding) (float32, error) {
if len(e.Embedding) != len(other.Embedding) {
return 0, ErrVectorLengthMismatch
}

var dotProduct float32
for i := range e.Embedding {
dotProduct += e.Embedding[i] * other.Embedding[i]
}

return dotProduct, nil
}

// EmbeddingResponse is the response from a Create embeddings request.
type EmbeddingResponse struct {
Object string `json:"object"`
Expand Down
38 changes: 38 additions & 0 deletions embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"reflect"
"testing"
Expand Down Expand Up @@ -233,3 +235,39 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
})
}
}

func TestDotProduct(t *testing.T) {
v1 := &Embedding{Embedding: []float32{1, 2, 3}}
v2 := &Embedding{Embedding: []float32{2, 4, 6}}
expected := float32(28.0)

result, err := v1.DotProduct(v2)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

if math.Abs(float64(result-expected)) > 1e-12 {
t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
}

v1 = &Embedding{Embedding: []float32{1, 0, 0}}
v2 = &Embedding{Embedding: []float32{0, 1, 0}}
expected = float32(0.0)

result, err = v1.DotProduct(v2)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

if math.Abs(float64(result-expected)) > 1e-12 {
t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
}

// Test for VectorLengthMismatchError
v1 = &Embedding{Embedding: []float32{1, 0, 0}}
v2 = &Embedding{Embedding: []float32{0, 1}}
_, err = v1.DotProduct(v2)
if !errors.Is(err, ErrVectorLengthMismatch) {
t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err)
}
}
Loading