Skip to content

Commit

Permalink
feat: Support attaching images to generate and chat requests
Browse files Browse the repository at this point in the history
  • Loading branch information
prantlf committed Jun 9, 2024
1 parent d042aec commit a3d7f15
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 25 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ Generates a text using the specified prompt. See the available [bison text model
❯ curl localhost:22434/api/generate -d '{
"model": "gemini-1.5-pro-preview-0409",
"prompt": "Describe guilds from Dungeons and Dragons.",
"images": [],
"stream": false
}'
Expand Down Expand Up @@ -167,7 +168,8 @@ Replies to a chat with the specified message history. See the available [bison c
},
{
"role": "user",
"content": "What race is the best for a barbarian?"
"content": "What race is the best for a barbarian?",
"images": []
}
],
"stream": false
Expand Down
24 changes: 13 additions & 11 deletions internal/routes/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ import (
)

type message struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content"`
Images []string `json:"images"`
}

type chatInput struct {
Expand Down Expand Up @@ -48,15 +49,16 @@ func convertGeminiMessages(messages []message) ([]geminiContent, error) {
} else if msg.Role == "assistant" {
role = "model"
} else {
return []geminiContent{}, fmt.Errorf("invalid chat message role: %q", msg.Role)
return nil, fmt.Errorf("invalid chat message role: %q", msg.Role)
}
parts, err := createGeminiParts(msg.Content, msg.Images)
if err != nil {
return nil, err
}
parts[0].Text = msg.Content
chatMessages = append(chatMessages, geminiContent{
Role: role,
Parts: []geminiPart{
{
Text: msg.Content,
},
},
Role: role,
Parts: parts,
})
}
}
Expand Down Expand Up @@ -133,7 +135,7 @@ func convertBisonMessages(messages []message) (string, []bisonMessage, error) {
} else if msg.Role == "assistant" {
role = "bot"
} else {
return "", []bisonMessage{}, fmt.Errorf("invalid chat message role: %q", msg.Role)
return "", nil, fmt.Errorf("invalid chat message role: %q", msg.Role)
}
chatMessages = append(chatMessages, bisonMessage{
Author: role,
Expand All @@ -142,7 +144,7 @@ func convertBisonMessages(messages []message) (string, []bisonMessage, error) {
}
}
if len(chatMessages) == 0 {
return "", []bisonMessage{}, errors.New("no user message found")
return "", nil, errors.New("no user message found")
}
var context string
if len(systemMessages) > 0 {
Expand Down
63 changes: 50 additions & 13 deletions internal/routes/generate.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routes

import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
Expand All @@ -23,19 +24,26 @@ type modelParameters struct {
type generateInput struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Images []string `json:"images"`
Stream bool `json:"stream"`
Options modelParameters `json:"options"`
}

type generateModelHandler interface {
prepareBody(input *generateInput) (string, interface{}, interface{})
prepareBody(input *generateInput) (string, interface{}, interface{}, error)
extractResponse(data interface{}) (string, int, int)
}

type generateGeminiHandler struct{}

type inlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}

type geminiPart struct {
Text string `json:"text"`
Text string `json:"text,omitempty"`
InlineData *inlineData `json:"inlineData,omitempty"`
}

type geminiContent struct {
Expand Down Expand Up @@ -78,25 +86,51 @@ func mergeParameters(target *cfg.GenerationConfig, source *modelParameters) {
}
}

func (h *generateGeminiHandler) prepareBody(input *generateInput) (string, interface{}, interface{}) {
func createGeminiParts(content string, images []string) ([]geminiPart, error) {
parts := []geminiPart{
{
Text: content,
},
}
for _, image := range images {
bytes, err := base64.StdEncoding.DecodeString(image)
if err != nil {
return nil, fmt.Errorf("invalid image encoding: %s", err.Error())
}
mimeType := http.DetectContentType(bytes)
if !strings.HasPrefix(mimeType, "image/") {
return nil, fmt.Errorf("invalid image type: %s", mimeType)
}
part := geminiPart{
InlineData: &inlineData{
MimeType: mimeType,
Data: image,
},
}
parts = append(parts, part)
}
return parts, nil
}

func (h *generateGeminiHandler) prepareBody(input *generateInput) (string, interface{}, interface{}, error) {
urlPrefix := input.Model + ":generateContent"
generationConfig := cfg.Defaults.GeminiDefaults.GenerationConfig
mergeParameters(&generationConfig, &input.Options)
parts, err := createGeminiParts(input.Prompt, input.Images)
if err != nil {
return "", nil, nil, err
}
body := &geminiBody{
Contents: []geminiContent{
{
Role: "user",
Parts: []geminiPart{
{
Text: input.Prompt,
},
},
Role: "user",
Parts: parts,
},
},
GenerationConfig: generationConfig,
SafetySettings: cfg.Defaults.GeminiDefaults.SafetySettings,
}
return urlPrefix, body, &geminiOutput{}
return urlPrefix, body, &geminiOutput{}, nil
}

func extractGeminiResponse(data interface{}) (string, int, int) {
Expand Down Expand Up @@ -152,7 +186,7 @@ type generateBisonOutput struct {
Metadata bisonMetadata `json:"metadata"`
}

func (h *generateBisonHandler) prepareBody(input *generateInput) (string, interface{}, interface{}) {
func (h *generateBisonHandler) prepareBody(input *generateInput) (string, interface{}, interface{}, error) {
urlPrefix := input.Model + ":predict"
parameters := cfg.Defaults.BisonDefaults.Parameters
mergeParameters(&parameters, &input.Options)
Expand All @@ -164,7 +198,7 @@ func (h *generateBisonHandler) prepareBody(input *generateInput) (string, interf
},
Parameters: parameters,
}
return urlPrefix, body, &generateBisonOutput{}
return urlPrefix, body, &generateBisonOutput{}, nil
}

func (h *generateBisonHandler) extractResponse(data interface{}) (string, int, int) {
Expand Down Expand Up @@ -237,7 +271,10 @@ func HandleGenerate(w http.ResponseWriter, r *http.Request) int {
if handler == nil {
return proxyRequest("generate", reqPayload, w, "result", input.Model)
}
urlSuffix, reqBody, output := handler.prepareBody(&input)
urlSuffix, reqBody, output, err := handler.prepareBody(&input)
if err != nil {
return wrongInput(w, err.Error())
}
status, duration, err := forwardRequest(urlSuffix, reqBody, output)
if err != nil {
return failRequest(w, status, err.Error())
Expand Down

0 comments on commit a3d7f15

Please sign in to comment.