Skip to content

Commit

Permalink
feat: support gemini embeddings (text-embedding-004,embedding-001) (s…
Browse files Browse the repository at this point in the history
…ongquanpeng#1475)

* Refactor Gemini Adaptor to Support Embeddings

* Add new models to ModelList

(cherry picked from commit 9321427)
  • Loading branch information
mxdlzg authored and sunls24 committed Jun 22, 2024
1 parent 72cda7f commit 3218250
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 4 deletions.
26 changes: 23 additions & 3 deletions relay/adaptor/gemini/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)

type Adaptor struct {
Expand All @@ -24,7 +25,14 @@ func (a *Adaptor) Init(meta *meta.Meta) {

func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
action := "generateContent"
action := ""
switch meta.Mode {
case relaymode.Embeddings:
action = "batchEmbedContents"
default:
action = "generateContent"
}

if meta.IsStream {
action = "streamGenerateContent?alt=sse"
}
Expand All @@ -41,7 +49,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
switch relayMode {
case relaymode.Embeddings:
geminiEmbeddingRequest := ConvertEmbeddingRequest(*request)
return geminiEmbeddingRequest, nil
default:
geminiRequest := ConvertRequest(*request)
return geminiRequest, nil
}
}

func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
Expand All @@ -61,7 +76,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
switch meta.Mode {
case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
}
return
}
Expand Down
2 changes: 1 addition & 1 deletion relay/adaptor/gemini/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ package gemini

var ModelList = []string{
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
"gemini-pro-vision", "gemini-1.0-pro-vision-001",
"gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004",
}
76 changes: 76 additions & 0 deletions relay/adaptor/gemini/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,29 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
return &geminiRequest
}

func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest {
inputs := request.ParseInput()
requests := make([]EmbeddingRequest, len(inputs))
model := fmt.Sprintf("models/%s", request.Model)

for i, input := range inputs {
requests[i] = EmbeddingRequest{
Model: model,
Content: ChatContent{
Parts: []Part{
{
Text: input,
},
},
},
}
}

return &BatchEmbeddingRequest{
Requests: requests,
}
}

type ChatResponse struct {
Candidates []ChatCandidate `json:"candidates"`
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
Expand Down Expand Up @@ -230,6 +253,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
return &response
}

func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
Model: "gemini-embedding",
Usage: model.Usage{TotalTokens: 0},
}
for _, item := range response.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: 0,
Embedding: item.Values,
})
}
return &openAIEmbeddingResponse
}

func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
Expand Down Expand Up @@ -337,3 +377,39 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var geminiEmbeddingResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &geminiEmbeddingResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if geminiEmbeddingResponse.Error != nil {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: geminiEmbeddingResponse.Error.Message,
Type: "gemini_error",
Param: "",
Code: geminiEmbeddingResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseGemini2OpenAI(&geminiEmbeddingResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
27 changes: 27 additions & 0 deletions relay/adaptor/gemini/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,33 @@ type ChatRequest struct {
Tools []ChatTools `json:"tools,omitempty"`
}

type EmbeddingRequest struct {
Model string `json:"model"`
Content ChatContent `json:"content"`
TaskType string `json:"taskType,omitempty"`
Title string `json:"title,omitempty"`
OutputDimensionality int `json:"outputDimensionality,omitempty"`
}

type BatchEmbeddingRequest struct {
Requests []EmbeddingRequest `json:"requests"`
}

type EmbeddingData struct {
Values []float64 `json:"values"`
}

type EmbeddingResponse struct {
Embeddings []EmbeddingData `json:"embeddings"`
Error *Error `json:"error,omitempty"`
}

type Error struct {
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Status string `json:"status,omitempty"`
}

type InlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
Expand Down

0 comments on commit 3218250

Please sign in to comment.