diff --git a/src/transform-sdk/go/transform/ai/model.go b/src/transform-sdk/go/transform/ai/model.go index 7686053cc3960..8e1187fe3e978 100644 --- a/src/transform-sdk/go/transform/ai/model.go +++ b/src/transform-sdk/go/transform/ai/model.go @@ -49,7 +49,7 @@ func LoadLargeLanguageModel(m HuggingFaceModel) (*LargeLanguageModel, error) { } // LoadEmbeddingsModel from hugging face -func LoadEmbeddingsModel(m HuggingFaceModel) (*LargeLanguageModel, error) { +func LoadEmbeddingsModel(m HuggingFaceModel) (*EmbeddingsModel, error) { r := loadHfEmbeddings( unsafe.Pointer(unsafe.StringData(m.Repo)), int32(len(m.Repo)), @@ -59,7 +59,7 @@ func LoadEmbeddingsModel(m HuggingFaceModel) (*LargeLanguageModel, error) { if r < 0 { return nil, errors.New("unable to load model, errorcode: " + strconv.Itoa(int(r))) } - return &LargeLanguageModel{modelHandle(r)}, nil + return &EmbeddingsModel{modelHandle(r)}, nil } // Wrapper around LLM diff --git a/src/transform-sdk/go/transform/internal/testdata/build_all.sh b/src/transform-sdk/go/transform/internal/testdata/build_all.sh index f5202d4e99bce..f0718322088ac 100755 --- a/src/transform-sdk/go/transform/internal/testdata/build_all.sh +++ b/src/transform-sdk/go/transform/internal/testdata/build_all.sh @@ -9,6 +9,7 @@ PKGS=( wasi dynamic llm-prompt-forward + make-embeddings ) for PKG in "${PKGS[@]}"; do diff --git a/src/transform-sdk/go/transform/internal/testdata/make-embeddings/transform.go b/src/transform-sdk/go/transform/internal/testdata/make-embeddings/transform.go new file mode 100644 index 0000000000000..0b330d5d412cb --- /dev/null +++ b/src/transform-sdk/go/transform/internal/testdata/make-embeddings/transform.go @@ -0,0 +1,42 @@ +package main + +import ( + "encoding/json" + "unsafe" + + "github.com/redpanda-data/redpanda/src/transform-sdk/go/transform" + "github.com/redpanda-data/redpanda/src/transform-sdk/go/transform/ai" +) + +var model *ai.EmbeddingsModel + +func main() { + m, err := ai.LoadEmbeddingsModel(ai.HuggingFaceModel{ + Repo: "leliuga/all-MiniLM-L12-v2-GGUF", + File: "all-MiniLM-L12-v2.Q8_0.gguf", + }) + if err != nil { + panic(err) + } + model = m + transform.OnRecordWritten(myTransform) +} + +func myTransform(e transform.WriteEvent, w transform.RecordWriter) error { + // zero copy compute embeddings + val := e.Record().Value + str := unsafe.String(&val[0], len(val)) + resp, err := model.ComputeEmbeddings(str) + if err != nil { + return err + } + // Turn the embeddings into JSON + b, err := json.Marshal(resp) + if err != nil { + return err + } + return w.Write(transform.Record{ + Key: e.Record().Key, + Value: b, + }) +}