Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral text2vec module #4880

Draft
wants to merge 7 commits into
base: stable/v1.25
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions adapters/handlers/rest/configure_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ import (
modgpt4all "github.com/weaviate/weaviate/modules/text2vec-gpt4all"
modhuggingface "github.com/weaviate/weaviate/modules/text2vec-huggingface"
modjinaai "github.com/weaviate/weaviate/modules/text2vec-jinaai"
modmistral "github.com/weaviate/weaviate/modules/text2vec-mistral"
modtext2vecoctoai "github.com/weaviate/weaviate/modules/text2vec-octoai"
modollama "github.com/weaviate/weaviate/modules/text2vec-ollama"
modopenai "github.com/weaviate/weaviate/modules/text2vec-openai"
Expand Down Expand Up @@ -960,6 +961,14 @@ func registerModules(appState *state.State) error {
Debug("enabled module")
}

if _, ok := enabledModules[modmistral.Name]; ok {
appState.Modules.Register(modmistral.New())
appState.Logger.
WithField("action", "startup").
WithField("module", modmistral.Name).
Debug("enabled module")
}

if _, ok := enabledModules[modbind.Name]; ok {
appState.Modules.Register(modbind.New())
appState.Logger.
Expand Down
54 changes: 54 additions & 0 deletions modules/text2vec-mistral/clients/fakes_for_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: hello@weaviate.io
//

package clients

type fakeClassConfig struct {
classConfig map[string]interface{}
vectorizePropertyName bool
skippedProperty string
excludedProperty string
}

func (f fakeClassConfig) Class() map[string]interface{} {
return f.classConfig
}

func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} {
return f.classConfig
}

func (f fakeClassConfig) Property(propName string) map[string]interface{} {
if propName == f.skippedProperty {
return map[string]interface{}{
"skip": true,
}
}
if propName == f.excludedProperty {
return map[string]interface{}{
"vectorizePropertyName": false,
}
}
if f.vectorizePropertyName {
return map[string]interface{}{
"vectorizePropertyName": true,
}
}
return nil
}

func (f fakeClassConfig) Tenant() string {
return ""
}

func (f fakeClassConfig) TargetVector() string {
return ""
}
19 changes: 19 additions & 0 deletions modules/text2vec-mistral/clients/meta.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: hello@weaviate.io
//

package clients

func (v *vectorizer) MetaInfo() (map[string]interface{}, error) {
return map[string]interface{}{
"name": "Mistral Module",
"documentationHref": "https://docs.mistral.ai/api/#operation/createEmbedding",
}, nil
}
208 changes: 208 additions & 0 deletions modules/text2vec-mistral/clients/mistral.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
// _ _
// __ _____ __ ___ ___ __ _| |_ ___
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
// \ V V / __/ (_| |\ V /| | (_| | || __/
// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
//
// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
//
// CONTACT: hello@weaviate.io
//

package clients

import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"time"

"github.com/weaviate/weaviate/entities/moduletools"

"github.com/weaviate/weaviate/usecases/modulecomponents"

"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/weaviate/weaviate/modules/text2vec-mistral/ent"
)

type embeddingsRequest struct {
Input []string `json:"input"`
Model string `json:"model"`
EncodingFormat string `json:"encoding_format"`
}

type embeddingsDataResponse struct {
Embeddings []float32 `json:"embedding"`
}

type embeddingsResponse struct {
Data []embeddingsDataResponse `json:"data,omitempty"`
Model string `json:"model,omitempty"`
Message string `json:"message,omitempty"`
}

type vectorizer struct {
apiKey string
httpClient *http.Client
logger logrus.FieldLogger
}

// info from mistral devs
const (
defaultRPM = 300 // 5 req per second
defaultTPM = 2000000
)

func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer {
return &vectorizer{
apiKey: apiKey,
httpClient: &http.Client{
Timeout: timeout,
},
logger: logger,
}
}

func (v *vectorizer) Vectorize(ctx context.Context, input []string, cfg moduletools.ClassConfig,
) (*modulecomponents.VectorizationResult, *modulecomponents.RateLimits, error) {
config := v.getVectorizationConfig(cfg)
res, err := v.vectorize(ctx, input, config.Model, config.BaseURL)
return res, nil, err
}

func (v *vectorizer) VectorizeQuery(ctx context.Context, input []string,
cfg moduletools.ClassConfig,
) (*modulecomponents.VectorizationResult, error) {
config := v.getVectorizationConfig(cfg)
return v.vectorize(ctx, input, config.Model, config.BaseURL)
}

func (v *vectorizer) vectorize(ctx context.Context, input []string,
model string, url string,
) (*modulecomponents.VectorizationResult, error) {
body, err := json.Marshal(embeddingsRequest{
Input: input,
Model: model,
EncodingFormat: "float",
})
if err != nil {
return nil, errors.Wrapf(err, "marshal body")
}

req, err := http.NewRequestWithContext(ctx, "POST", url,
bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "create POST request")
}
apiKey, err := v.getApiKey(ctx)
if err != nil {
return nil, errors.Wrapf(err, "Mistral API Key")
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
req.Header.Add("Content-Type", "application/json")

res, err := v.httpClient.Do(req)
if err != nil {
return nil, errors.Wrap(err, "send POST request")
}
defer res.Body.Close()
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return nil, errors.Wrap(err, "read response body")
}
var resBody embeddingsResponse
if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
return nil, errors.Wrap(err, "unmarshal response body")
}

if res.StatusCode != 200 {
if resBody.Message != "" {
errorMessage := getErrorMessage(res.StatusCode, resBody.Message, "connection to Mistral failed with status: %d error: %v")
return nil, errors.Errorf(errorMessage)
}
errorMessage := getErrorMessage(res.StatusCode, "", "connection to Mistral failed with status: %d")
return nil, errors.Errorf(errorMessage)
}

if len(resBody.Data) == 0 || len(resBody.Data[0].Embeddings) == 0 {
return nil, errors.Errorf("empty embeddings response")
}

vectors := make([][]float32, len(resBody.Data))
for i, data := range resBody.Data {
vectors[i] = data.Embeddings
}

return &modulecomponents.VectorizationResult{
Text: input,
Dimensions: len(resBody.Data[0].Embeddings),
Vector: vectors,
}, nil
}

func getErrorMessage(statusCode int, resBodyError string, errorTemplate string) string {
if resBodyError != "" {
return fmt.Sprintf(errorTemplate, statusCode, resBodyError)
}
return fmt.Sprintf(errorTemplate, statusCode)
}

func (v *vectorizer) getApiKey(ctx context.Context) (string, error) {
if apiKey := modulecomponents.GetValueFromContext(ctx, "X-Mistral-Api-Key"); apiKey != "" {
return apiKey, nil
}
if v.apiKey != "" {
return v.apiKey, nil
}
return "", errors.New("no api key found " +
"neither in request header: X-Mistral-Api-Key " +
"nor in environment variable under MISTRAL_APIKEY")
}

func (v *vectorizer) GetApiKeyHash(ctx context.Context, cfg moduletools.ClassConfig) [32]byte {
key, err := v.getApiKey(ctx)
if err != nil {
return [32]byte{}
}
return sha256.Sum256([]byte(key))
}

func (v *vectorizer) GetVectorizerRateLimit(ctx context.Context) *modulecomponents.RateLimits {
rpm, tpm := modulecomponents.GetRateLimitFromContext(ctx, "Mistral", defaultRPM, defaultTPM)
execAfterRequestFunction := func(limits *modulecomponents.RateLimits, tokensUsed int, deductRequest bool) {
// refresh is after 60 seconds but leave a bit of room for errors. Otherwise, we only deduct the request that just happened
if limits.LastOverwrite.Add(61 * time.Second).After(time.Now()) {
if deductRequest {
limits.RemainingRequests -= 1
}
limits.RemainingTokens -= tokensUsed
return
}

limits.RemainingRequests = rpm
limits.ResetRequests = time.Now().Add(time.Duration(61) * time.Second)
limits.LimitRequests = rpm
limits.LastOverwrite = time.Now()

limits.RemainingTokens = tpm
limits.LimitTokens = tpm
limits.ResetTokens = time.Now().Add(time.Duration(61) * time.Second)
}

initialRL := &modulecomponents.RateLimits{AfterRequestFunction: execAfterRequestFunction, LastOverwrite: time.Now().Add(-61 * time.Minute)}
initialRL.ResetAfterRequestFunction(0) // set initial values

return initialRL
}

func (v *vectorizer) getVectorizationConfig(cfg moduletools.ClassConfig) ent.VectorizationConfig {
settings := ent.NewClassSettings(cfg)
return ent.VectorizationConfig{
Model: settings.Model(), BaseURL: settings.BaseURL(),
}
}
Loading
Loading