Skip to content

Commit

Permalink
Make it possible to get embeddings for prompt + check distances
Browse files Browse the repository at this point in the history
  • Loading branch information
xyproto committed May 1, 2024
1 parent 90ccc3b commit bc4234a
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 10 deletions.
113 changes: 103 additions & 10 deletions lua/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
package ollama

import (
"fmt"
"math"
"strings"
"sync"

"github.com/dustin/go-humanize"
log "github.com/sirupsen/logrus"
"github.com/xyproto/algernon/lua/convert"
lua "github.com/xyproto/gopher-lua"
"github.com/xyproto/ollamaclient/v2"

"sync"
)

const (
Expand Down Expand Up @@ -168,6 +169,46 @@ func ollamaGenerateOutput(L *lua.LState) int {
return 1 // number of results
}

// Use ollama to get a []float64 representation of a given prompt.
// This can be used to ie. find the distance between two strings.
func ollamaEmbeddings(L *lua.LState) int {
mut.Lock()
defer mut.Unlock()
oc := checkOllamaClient(L) // arg 1
prompt := defaultPrompt
top := L.GetTop()
if top > 2 {
prompt = L.ToString(2) // arg 2
oc.ModelName = L.ToString(3) // arg 3
err := oc.PullIfNeeded(true)
if err != nil {
log.Error(err)
L.Push(lua.LString(err.Error()))
return 1 // number of results
}
} else if top > 1 {
prompt = L.ToString(2) // arg 2
}
oc.SetReproducible()
var err error
var floats []float64
floats, err = oc.Embeddings(prompt)
if err != nil {
log.Error(err)
L.Push(lua.LString(err.Error()))
return 1 // number of results
}

// Push the floats as Lua numbers
tbl := L.NewTable()
for i, f := range floats {
index := fmt.Sprintf("%d", i+1) // Convert index to string for Lua table
L.SetField(tbl, index, lua.LNumber(f))
}
L.Push(tbl)
return 1 // number of results
}

func ollamaGenerateOutputCreative(L *lua.LState) int {
mut.Lock()
defer mut.Unlock()
Expand Down Expand Up @@ -217,14 +258,15 @@ func constructOllamaClient(L *lua.LState) (*lua.LUserData, error) {

// The hash map methods that are to be registered
var ollamaMethods = map[string]lua.LGFunction{
"ask": ollamaGenerateOutput,
"bytesize": ollamaSizeInBytes,
"creative": ollamaGenerateOutputCreative,
"has": ollamaHas,
"list": ollamaList,
"model": ollamaModel, // set or get the current model, but don't pull anything
"pull": ollamaPullIfNeeded,
"size": ollamaSize,
"ask": ollamaGenerateOutput,
"bytesize": ollamaSizeInBytes,
"creative": ollamaGenerateOutputCreative,
"has": ollamaHas,
"list": ollamaList,
"model": ollamaModel, // set or get the current model, but don't pull anything
"pull": ollamaPullIfNeeded,
"size": ollamaSize,
"embeddings": ollamaEmbeddings, // get a []float64 representation of a given prompt
}

func askOllama(L *lua.LState) int {
Expand Down Expand Up @@ -256,6 +298,56 @@ func askOllama(L *lua.LState) int {
return 1 // number of results
}

// distance calculates the cosine similarity between two embeddings (Lua tables of floats).
func distance(L *lua.LState) int {
// Check and get the first table argument
tbl1 := L.CheckTable(1)
// Check and get the second table argument
tbl2 := L.CheckTable(2)

// Convert Lua tables to slices of float64
slice1, err1 := tableToFloatSlice(tbl1)
if err1 != nil {
L.Push(lua.LString(err1.Error()))
return 1 // number of results (error message)
}
slice2, err2 := tableToFloatSlice(tbl2)
if err2 != nil {
L.Push(lua.LString(err2.Error()))
return 1 // number of results (error message)
}

// Calculate the cosine similarity
if len(slice1) != len(slice2) {
L.Push(lua.LString("error: embeddings must be of the same length"))
return 1 // number of results (error message)
}

dotProduct := 0.0
normA := 0.0
normB := 0.0
for i := range slice1 {
dotProduct += slice1[i] * slice2[i]
normA += slice1[i] * slice1[i]
normB += slice2[i] * slice2[i]
}
normA = math.Sqrt(normA)
normB = math.Sqrt(normB)

if normA == 0.0 || normB == 0.0 {
L.Push(lua.LString("error: one or both vectors are zero vectors"))
return 1 // number of results
}

cosineSimilarity := dotProduct / (normA * normB)

// Cosine similarity ranges from -1 to 1, higher values mean more similarity
// We convert it to a distance measure that ranges from 0 to 2
cosineDistance := 1 - cosineSimilarity
L.Push(lua.LNumber(cosineDistance))
return 1 // number of results (distance)
}

// Load makes functions related Ollama clients available to the given Lua state
func Load(L *lua.LState) {
// Register the OllamaClient class and the methods that belongs with it.
Expand All @@ -278,4 +370,5 @@ func Load(L *lua.LState) {
}))

L.SetGlobal("ollama", L.NewFunction(askOllama))
L.SetGlobal("edistance", L.NewFunction(distance))
}
27 changes: 27 additions & 0 deletions lua/ollama/table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package ollama

import (
"fmt"

lua "github.com/xyproto/gopher-lua"
)

// Helper function to convert a Lua table to a slice of float64
func tableToFloatSlice(tbl *lua.LTable) ([]float64, error) {
var floats []float64
var errFlag error
tbl.ForEach(func(_ lua.LValue, data lua.LValue) {
if errFlag != nil {
return
}
if num, ok := data.(lua.LNumber); ok {
floats = append(floats, float64(num))
} else {
errFlag = fmt.Errorf("table contains non-numeric values")
}
})
if errFlag != nil {
return []float64{}, errFlag
}
return floats, nil
}

0 comments on commit bc4234a

Please sign in to comment.