diff --git a/pkg/aiusechat/aiutil/aiutil.go b/pkg/aiusechat/aiutil/aiutil.go new file mode 100644 index 0000000000..fb9f8bb517 --- /dev/null +++ b/pkg/aiusechat/aiutil/aiutil.go @@ -0,0 +1,182 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package aiutil + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" +) + +// ExtractXmlAttribute extracts an attribute value from an XML-like tag. +// Expects double-quoted strings where internal quotes are encoded as ". +// Returns the unquoted value and true if found, or empty string and false if not found or invalid. +func ExtractXmlAttribute(tag, attrName string) (string, bool) { + attrStart := strings.Index(tag, attrName+"=") + if attrStart == -1 { + return "", false + } + + pos := attrStart + len(attrName+"=") + start := strings.Index(tag[pos:], `"`) + if start == -1 { + return "", false + } + start += pos + + end := strings.Index(tag[start+1:], `"`) + if end == -1 { + return "", false + } + end += start + 1 + + quotedValue := tag[start : end+1] + value, err := strconv.Unquote(quotedValue) + if err != nil { + return "", false + } + + value = strings.ReplaceAll(value, """, `"`) + return value, true +} + +// GenerateDeterministicSuffix creates an 8-character hash from input strings +func GenerateDeterministicSuffix(inputs ...string) string { + hasher := sha256.New() + for _, input := range inputs { + hasher.Write([]byte(input)) + } + hash := hasher.Sum(nil) + return hex.EncodeToString(hash)[:8] +} + +// ExtractImageUrl extracts an image URL from either URL field (http/https/data) or raw Data +func ExtractImageUrl(data []byte, url, mimeType string) (string, error) { + if url != "" { + if !strings.HasPrefix(url, "data:") && + !strings.HasPrefix(url, "http://") && + !strings.HasPrefix(url, "https://") { + return "", fmt.Errorf("unsupported URL protocol in file part: %s", url) + } + return url, nil + } + if len(data) > 0 { + base64Data := base64.StdEncoding.EncodeToString(data) + return fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data), nil + } + return "", fmt.Errorf("file part missing both url and data") +} + +// ExtractTextData extracts text data from either Data field or URL field (data: URLs only) +func ExtractTextData(data []byte, url string) ([]byte, error) { + if len(data) > 0 { + return data, nil + } + if url != "" { + if strings.HasPrefix(url, "data:") { + _, decodedData, err := utilfn.DecodeDataURL(url) + if err != nil { + return nil, fmt.Errorf("failed to decode data URL for text/plain file: %w", err) + } + return decodedData, nil + } + return nil, fmt.Errorf("dropping text/plain file with URL (must be fetched and converted to data)") + } + return nil, fmt.Errorf("text/plain file part missing data") +} + +// FormatAttachedTextFile formats a text file attachment with proper encoding and deterministic suffix +func FormatAttachedTextFile(fileName string, textContent []byte) string { + if fileName == "" { + fileName = "untitled.txt" + } + + encodedFileName := strings.ReplaceAll(fileName, `"`, """) + quotedFileName := strconv.Quote(encodedFileName) + + textStr := string(textContent) + deterministicSuffix := GenerateDeterministicSuffix(textStr, fileName) + return fmt.Sprintf("\n%s\n", deterministicSuffix, quotedFileName, textStr, deterministicSuffix) +} + +// FormatAttachedDirectoryListing formats a directory listing attachment with proper encoding and deterministic suffix +func FormatAttachedDirectoryListing(directoryName, jsonContent string) string { + if directoryName == "" { + directoryName = "unnamed-directory" + } + + encodedDirName := strings.ReplaceAll(directoryName, `"`, """) + quotedDirName := strconv.Quote(encodedDirName) + + deterministicSuffix := GenerateDeterministicSuffix(jsonContent, directoryName) + return fmt.Sprintf("\n%s\n", deterministicSuffix, quotedDirName, jsonContent, deterministicSuffix) +} + +// ConvertDataUserFile converts OpenAI attached file/directory blocks to UIMessagePart +// Returns (found, part) where found indicates if the prefix was matched, +// and part is the converted UIMessagePart (can be nil if parsing failed) +func ConvertDataUserFile(blockText string) (bool, *uctypes.UIMessagePart) { + if strings.HasPrefix(blockText, " 0 { // Premium requests exhausted, but regular requests available stopReason := &uctypes.WaveStopReason{ - Kind: uctypes.StopKindPremiumRateLimit, - RateLimitInfo: rateLimitInfo, + Kind: uctypes.StopKindPremiumRateLimit, } return stopReason, nil, rateLimitInfo, nil } if rateLimitInfo.Req == 0 { // All requests exhausted stopReason := &uctypes.WaveStopReason{ - Kind: uctypes.StopKindRateLimit, - RateLimitInfo: rateLimitInfo, + Kind: uctypes.StopKindRateLimit, } return stopReason, nil, rateLimitInfo, nil } @@ -590,8 +588,6 @@ func handleAnthropicStreamingResp( rtnStopReason = &uctypes.WaveStopReason{ Kind: uctypes.StopKindDone, RawReason: state.stopFromDelta, - MessageID: state.msgID, - Model: state.model, } return rtnStopReason, state.rtnMessage } @@ -849,41 +845,30 @@ func handleAnthropicEvent( switch reason { case "tool_use": return nil, &uctypes.WaveStopReason{ - Kind: uctypes.StopKindToolUse, - RawReason: reason, - MessageID: state.msgID, - Model: state.model, - ToolCalls: state.toolCalls, - FinishStep: true, + Kind: uctypes.StopKindToolUse, + RawReason: reason, + ToolCalls: state.toolCalls, } case "max_tokens": return nil, &uctypes.WaveStopReason{ Kind: uctypes.StopKindMaxTokens, RawReason: reason, - MessageID: state.msgID, - Model: state.model, } case "refusal": return nil, &uctypes.WaveStopReason{ Kind: uctypes.StopKindContent, RawReason: reason, - MessageID: state.msgID, - Model: state.model, } case "pause_turn": return nil, &uctypes.WaveStopReason{ Kind: uctypes.StopKindPauseTurn, RawReason: reason, - MessageID: state.msgID, - Model: state.model, } default: // end_turn, stop_sequence (treat as end of this call) return nil, &uctypes.WaveStopReason{ Kind: uctypes.StopKindDone, RawReason: reason, - MessageID: state.msgID, - Model: state.model, } } diff --git a/pkg/aiusechat/openai/openai-backend.go b/pkg/aiusechat/openai/openai-backend.go index c05745ee6a..f551a76845 100644 --- a/pkg/aiusechat/openai/openai-backend.go +++ b/pkg/aiusechat/openai/openai-backend.go @@ -93,7 +93,7 @@ type OpenAIMessageContent struct { Name string `json:"name,omitempty"` } -func (c *OpenAIMessageContent) Clean() *OpenAIMessageContent { +func (c *OpenAIMessageContent) clean() *OpenAIMessageContent { if c.PreviewUrl == "" { return c } @@ -102,17 +102,17 @@ func (c *OpenAIMessageContent) Clean() *OpenAIMessageContent { return &rtn } -func (m *OpenAIMessage) CleanAndCopy() *OpenAIMessage { +func (m *OpenAIMessage) cleanAndCopy() *OpenAIMessage { rtn := &OpenAIMessage{Role: m.Role} rtn.Content = make([]OpenAIMessageContent, len(m.Content)) for idx, block := range m.Content { - cleaned := block.Clean() + cleaned := block.clean() rtn.Content[idx] = *cleaned } return rtn } -func (f *OpenAIFunctionCallInput) Clean() *OpenAIFunctionCallInput { +func (f *OpenAIFunctionCallInput) clean() *OpenAIFunctionCallInput { if f.ToolUseData == nil { return f } @@ -481,10 +481,10 @@ func RunOpenAIChatStep( // Convert to appropriate input type based on what's populated if chatMsg.Message != nil { // Clean message to remove preview URLs - cleanedMsg := chatMsg.Message.CleanAndCopy() + cleanedMsg := chatMsg.Message.cleanAndCopy() inputs = append(inputs, *cleanedMsg) } else if chatMsg.FunctionCall != nil { - cleanedFunctionCall := chatMsg.FunctionCall.Clean() + cleanedFunctionCall := chatMsg.FunctionCall.clean() inputs = append(inputs, *cleanedFunctionCall) } else if chatMsg.FunctionCallOutput != nil { inputs = append(inputs, *chatMsg.FunctionCallOutput) @@ -526,16 +526,14 @@ func RunOpenAIChatStep( if rateLimitInfo.PReq == 0 && rateLimitInfo.Req > 0 { // Premium requests exhausted, but regular requests available stopReason := &uctypes.WaveStopReason{ - Kind: uctypes.StopKindPremiumRateLimit, - RateLimitInfo: rateLimitInfo, + Kind: uctypes.StopKindPremiumRateLimit, } return stopReason, nil, rateLimitInfo, nil } if rateLimitInfo.Req == 0 { // All requests exhausted stopReason := &uctypes.WaveStopReason{ - Kind: uctypes.StopKindRateLimit, - RateLimitInfo: rateLimitInfo, + Kind: uctypes.StopKindRateLimit, } return stopReason, nil, rateLimitInfo, nil } @@ -797,8 +795,6 @@ func handleOpenAIEvent( Kind: uctypes.StopKindError, ErrorType: "api", ErrorText: errorMsg, - MessageID: state.msgID, - Model: state.model, }, nil } @@ -831,8 +827,6 @@ func handleOpenAIEvent( Kind: stopKind, RawReason: reason, ErrorText: errorMsg, - MessageID: state.msgID, - Model: state.model, }, finalMessages } @@ -847,8 +841,6 @@ func handleOpenAIEvent( return &uctypes.WaveStopReason{ Kind: stopKind, RawReason: ev.Response.Status, - MessageID: state.msgID, - Model: state.model, ToolCalls: toolCalls, }, finalMessages @@ -860,22 +852,8 @@ func handleOpenAIEvent( } if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockToolUse { st.partialJSON = append(st.partialJSON, []byte(ev.Delta)...) - toolDef := state.chatOpts.GetToolDefinition(st.toolName) - if toolDef != nil && toolDef.ToolProgressDesc != nil { - parsedJSON, err := utilfn.ParsePartialJson(st.partialJSON) - if err == nil { - statusLines, err := toolDef.ToolProgressDesc(parsedJSON) - if err == nil { - progressData := &uctypes.UIMessageDataToolProgress{ - ToolCallId: st.toolCallID, - ToolName: st.toolName, - StatusLines: statusLines, - } - _ = sse.AiMsgData("data-toolprogress", "progress-"+st.toolCallID, progressData) - } - } - } + sendToolProgress(st, toolDef, sse, st.partialJSON, true) } return nil, nil @@ -888,28 +866,10 @@ func handleOpenAIEvent( // Get the function call info from the block state if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockToolUse { - // raw := json.RawMessage(ev.Arguments) - // no longer send tool inputs to fe - // _ = sse.AiMsgToolInputAvailable(st.toolCallID, st.toolName, raw) - toolDef := state.chatOpts.GetToolDefinition(st.toolName) toolUseData := createToolUseData(st.toolCallID, st.toolName, toolDef, ev.Arguments, state.chatOpts) state.toolUseData[st.toolCallID] = toolUseData - - if toolDef != nil && toolDef.ToolProgressDesc != nil { - var parsedJSON any - if err := json.Unmarshal([]byte(ev.Arguments), &parsedJSON); err == nil { - statusLines, err := toolDef.ToolProgressDesc(parsedJSON) - if err == nil { - progressData := &uctypes.UIMessageDataToolProgress{ - ToolCallId: st.toolCallID, - ToolName: st.toolName, - StatusLines: statusLines, - } - _ = sse.AiMsgData("data-toolprogress", "progress-"+st.toolCallID, progressData) - } - } - } + sendToolProgress(st, toolDef, sse, []byte(ev.Arguments), false) } return nil, nil @@ -966,6 +926,32 @@ func handleOpenAIEvent( } } +func sendToolProgress(st *openaiBlockState, toolDef *uctypes.ToolDefinition, sse *sse.SSEHandlerCh, jsonData []byte, usePartialParse bool) { + if toolDef == nil || toolDef.ToolProgressDesc == nil { + return + } + var parsedJSON any + var err error + if usePartialParse { + parsedJSON, err = utilfn.ParsePartialJson(jsonData) + } else { + err = json.Unmarshal(jsonData, &parsedJSON) + } + if err != nil { + return + } + statusLines, err := toolDef.ToolProgressDesc(parsedJSON) + if err != nil { + return + } + progressData := &uctypes.UIMessageDataToolProgress{ + ToolCallId: st.toolCallID, + ToolName: st.toolName, + StatusLines: statusLines, + } + _ = sse.AiMsgData("data-toolprogress", "progress-"+st.toolCallID, progressData) +} + func createToolUseData(toolCallID, toolName string, toolDef *uctypes.ToolDefinition, arguments string, chatOpts uctypes.WaveChatOpts) *uctypes.UIMessageDataToolUse { toolUseData := &uctypes.UIMessageDataToolUse{ ToolCallId: toolCallID, diff --git a/pkg/aiusechat/openai/openai-convertmessage.go b/pkg/aiusechat/openai/openai-convertmessage.go index 604284bbb8..d2dc594a21 100644 --- a/pkg/aiusechat/openai/openai-convertmessage.go +++ b/pkg/aiusechat/openai/openai-convertmessage.go @@ -4,22 +4,18 @@ package openai import ( - "bytes" "context" - "crypto/sha256" "encoding/base64" - "encoding/hex" "encoding/json" "errors" "fmt" "log" "net/http" - "strconv" "strings" "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" - "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" ) @@ -28,46 +24,46 @@ const ( OpenAIDefaultMaxTokens = 4096 ) -// extractXmlAttribute extracts an attribute value from an XML-like tag. -// Expects double-quoted strings where internal quotes are encoded as ". -// Returns the unquoted value and true if found, or empty string and false if not found or invalid. -func extractXmlAttribute(tag, attrName string) (string, bool) { - attrStart := strings.Index(tag, attrName+"=") - if attrStart == -1 { - return "", false - } - - pos := attrStart + len(attrName+"=") - start := strings.Index(tag[pos:], `"`) - if start == -1 { - return "", false - } - start += pos - - end := strings.Index(tag[start+1:], `"`) - if end == -1 { - return "", false - } - end += start + 1 +// convertContentBlockToParts converts a single content block to UIMessageParts +func convertContentBlockToParts(block OpenAIMessageContent, role string) []uctypes.UIMessagePart { + var parts []uctypes.UIMessagePart - quotedValue := tag[start : end+1] - value, err := strconv.Unquote(quotedValue) - if err != nil { - return "", false + switch block.Type { + case "input_text", "output_text": + if found, part := aiutil.ConvertDataUserFile(block.Text); found { + if part != nil { + parts = append(parts, *part) + } + } else { + parts = append(parts, uctypes.UIMessagePart{ + Type: "text", + Text: block.Text, + }) + } + case "input_image": + if role == "user" { + parts = append(parts, uctypes.UIMessagePart{ + Type: "data-userfile", + Data: uctypes.UIMessageDataUserFile{ + MimeType: "image/*", + PreviewUrl: block.PreviewUrl, + }, + }) + } + case "input_file": + if role == "user" { + parts = append(parts, uctypes.UIMessagePart{ + Type: "data-userfile", + Data: uctypes.UIMessageDataUserFile{ + FileName: block.Filename, + MimeType: "application/pdf", + PreviewUrl: block.PreviewUrl, + }, + }) + } } - value = strings.ReplaceAll(value, """, `"`) - return value, true -} - -// generateDeterministicSuffix creates an 8-character hash from input strings -func generateDeterministicSuffix(inputs ...string) string { - hasher := sha256.New() - for _, input := range inputs { - hasher.Write([]byte(input)) - } - hash := hasher.Sum(nil) - return hex.EncodeToString(hash)[:8] + return parts } // appendToLastUserMessage appends a text block to the last user message in the inputs slice @@ -146,12 +142,11 @@ type OpenAIRequestTool struct { // ConvertToolDefinitionToOpenAI converts a generic ToolDefinition to OpenAI format func ConvertToolDefinitionToOpenAI(tool uctypes.ToolDefinition) OpenAIRequestTool { - cleanedTool := tool.Clean() return OpenAIRequestTool{ - Name: cleanedTool.Name, - Description: cleanedTool.Description, - Parameters: cleanedTool.InputSchema, - Strict: cleanedTool.Strict, + Name: tool.Name, + Description: tool.Description, + Parameters: tool.InputSchema, + Strict: tool.Strict, Type: "function", } } @@ -218,14 +213,13 @@ func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes. maxTokens = OpenAIDefaultMaxTokens } + // injected data if chatOpts.TabState != "" { appendToLastUserMessage(inputs, chatOpts.TabState) } - if chatOpts.AppStaticFiles != "" { appendToLastUserMessage(inputs, "\n"+chatOpts.AppStaticFiles+"\n") } - if chatOpts.AppGoFile != "" { appendToLastUserMessage(inputs, "\n"+chatOpts.AppGoFile+"\n") } @@ -276,29 +270,18 @@ func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes. } } - // Set temperature if provided - if opts.APIVersion != "" && opts.APIVersion != OpenAIDefaultAPIVersion { - // Temperature and other parameters could be set here based on config - // For now, using defaults - } - debugPrintReq(reqBody, endpoint) // Encode request body - var buf bytes.Buffer - encoder := json.NewEncoder(&buf) - encoder.SetEscapeHTML(false) - err := encoder.Encode(reqBody) + buf, err := aiutil.JsonEncodeRequestBody(reqBody) if err != nil { return nil, err } - // Create HTTP request req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, &buf) if err != nil { return nil, err } - // Set headers req.Header.Set("Content-Type", "application/json") if opts.APIToken != "" { @@ -309,11 +292,7 @@ func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes. req.Header.Set("X-Wave-ClientId", chatOpts.ClientId) } req.Header.Set("X-Wave-APIType", "openai") - if chatOpts.BuilderId != "" { - req.Header.Set("X-Wave-RequestType", "waveapps-builder") - } else { - req.Header.Set("X-Wave-RequestType", "waveai") - } + req.Header.Set("X-Wave-RequestType", chatOpts.GetWaveRequestType()) return req, nil } @@ -330,23 +309,9 @@ func convertFileAIMessagePart(part uctypes.AIMessagePart) (*OpenAIMessageContent // Handle different file types switch { case strings.HasPrefix(part.MimeType, "image/"): - // Handle images - var imageUrl string - - if part.URL != "" { - // Validate URL protocol - only allow data:, http:, https: - if !strings.HasPrefix(part.URL, "data:") && - !strings.HasPrefix(part.URL, "http://") && - !strings.HasPrefix(part.URL, "https://") { - return nil, fmt.Errorf("unsupported URL protocol in file part: %s", part.URL) - } - imageUrl = part.URL - } else if len(part.Data) > 0 { - // Convert raw data to base64 data URL - base64Data := base64.StdEncoding.EncodeToString(part.Data) - imageUrl = fmt.Sprintf("data:%s;base64,%s", part.MimeType, base64Data) - } else { - return nil, fmt.Errorf("file part missing both url and data") + imageUrl, err := aiutil.ExtractImageUrl(part.Data, part.URL, part.MimeType) + if err != nil { + return nil, err } return &OpenAIMessageContent{ @@ -375,35 +340,11 @@ func convertFileAIMessagePart(part uctypes.AIMessagePart) (*OpenAIMessageContent }, nil case part.MimeType == "text/plain": - var textContent string - - if len(part.Data) > 0 { - textContent = string(part.Data) - } else if part.URL != "" { - if strings.HasPrefix(part.URL, "data:") { - _, decodedData, err := utilfn.DecodeDataURL(part.URL) - if err != nil { - return nil, fmt.Errorf("failed to decode data URL for text/plain file: %w", err) - } - textContent = string(decodedData) - } else { - return nil, fmt.Errorf("dropping text/plain file with URL (must be fetched and converted to data)") - } - } else { - return nil, fmt.Errorf("text/plain file part missing data") - } - - fileName := part.FileName - if fileName == "" { - fileName = "untitled.txt" + textData, err := aiutil.ExtractTextData(part.Data, part.URL) + if err != nil { + return nil, err } - - encodedFileName := strings.ReplaceAll(fileName, `"`, """) - quotedFileName := strconv.Quote(encodedFileName) - - deterministicSuffix := generateDeterministicSuffix(textContent, fileName) - formattedText := fmt.Sprintf("\n%s\n", deterministicSuffix, quotedFileName, textContent, deterministicSuffix) - + formattedText := aiutil.FormatAttachedTextFile(part.FileName, textData) return &OpenAIMessageContent{ Type: "input_text", Text: formattedText, @@ -417,16 +358,7 @@ func convertFileAIMessagePart(part uctypes.AIMessagePart) (*OpenAIMessageContent return nil, fmt.Errorf("directory listing part missing data") } - directoryName := part.FileName - if directoryName == "" { - directoryName = "unnamed-directory" - } - - encodedDirName := strings.ReplaceAll(directoryName, `"`, """) - quotedDirName := strconv.Quote(encodedDirName) - - deterministicSuffix := generateDeterministicSuffix(jsonContent, directoryName) - formattedText := fmt.Sprintf("\n%s\n", deterministicSuffix, quotedDirName, jsonContent, deterministicSuffix) + formattedText := aiutil.FormatAttachedDirectoryListing(part.FileName, jsonContent) return &OpenAIMessageContent{ Type: "input_text", @@ -540,89 +472,17 @@ func ConvertToolResultsToOpenAIChatMessage(toolResults []uctypes.AIToolResult) ( return messages, nil } -// ConvertToUIMessage converts an OpenAIChatMessage to a UIMessage -func (m *OpenAIChatMessage) ConvertToUIMessage() *uctypes.UIMessage { +// convertToUIMessage converts an OpenAIChatMessage to a UIMessage +func (m *OpenAIChatMessage) convertToUIMessage() *uctypes.UIMessage { var parts []uctypes.UIMessagePart var role string // Handle different message types if m.Message != nil { role = m.Message.Role - // Iterate over all content blocks for _, block := range m.Message.Content { - switch block.Type { - case "input_text", "output_text": - if strings.HasPrefix(block.Text, " 0 && rtnMessage[0] != nil { - messageID = rtnMessage[0].GetMessageId() - } + processToolCalls(backend, stopReason, chatOpts, sseHandler, metrics) cont = &uctypes.WaveContinueResponse{ - MessageID: messageID, - Model: chatOpts.Config.Model, - ContinueFromKind: uctypes.StopKindToolUse, - ContinueFromRawReason: stopReason.RawReason, + Model: chatOpts.Config.Model, + ContinueFromKind: uctypes.StopKindToolUse, } continue } @@ -563,22 +527,14 @@ func ResolveToolCall(toolDef *uctypes.ToolDefinition, toolCall uctypes.WaveToolC func WaveAIPostMessageWrap(ctx context.Context, sseHandler *sse.SSEHandlerCh, message *uctypes.AIMessage, chatOpts uctypes.WaveChatOpts) error { startTime := time.Now() - // Convert AIMessage to Anthropic chat message - var convertedMessage uctypes.GenAIMessage - if chatOpts.Config.APIType == APIType_Anthropic { - var err error - convertedMessage, err = anthropic.ConvertAIMessageToAnthropicChatMessage(*message) - if err != nil { - return fmt.Errorf("message conversion failed: %w", err) - } - } else if chatOpts.Config.APIType == APIType_OpenAI { - var err error - convertedMessage, err = openai.ConvertAIMessageToOpenAIChatMessage(*message) - if err != nil { - return fmt.Errorf("message conversion failed: %w", err) - } - } else { - return fmt.Errorf("unsupported APIType %q", chatOpts.Config.APIType) + // Convert AIMessage to native chat message using backend + backend, err := GetBackendByAPIType(chatOpts.Config.APIType) + if err != nil { + return err + } + convertedMessage, err := backend.ConvertAIMessageToNativeChatMessage(*message) + if err != nil { + return fmt.Errorf("message conversion failed: %w", err) } // Post message to chat store @@ -586,7 +542,7 @@ func WaveAIPostMessageWrap(ctx context.Context, sseHandler *sse.SSEHandlerCh, me return fmt.Errorf("failed to store message: %w", err) } - metrics, err := RunAIChat(ctx, sseHandler, chatOpts) + metrics, err := RunAIChat(ctx, sseHandler, backend, chatOpts) if metrics != nil { metrics.RequestDuration = int(time.Since(startTime).Milliseconds()) for _, part := range message.Parts { @@ -803,15 +759,12 @@ func CreateWriteTextFileDiff(ctx context.Context, chatId string, toolCallId stri return nil, nil, fmt.Errorf("chat not found: %s", chatId) } - if aiChat.APIType == APIType_Anthropic { - return nil, nil, fmt.Errorf("CreateWriteTextFileDiff is not implemented for Anthropic") - } - - if aiChat.APIType != APIType_OpenAI { - return nil, nil, fmt.Errorf("unsupported API type: %s", aiChat.APIType) + backend, err := GetBackendByAPIType(aiChat.APIType) + if err != nil { + return nil, nil, err } - funcCallInput := openai.GetFunctionCallInputByToolCallId(*aiChat, toolCallId) + funcCallInput := backend.GetFunctionCallInputByToolCallId(*aiChat, toolCallId) if funcCallInput == nil { return nil, nil, fmt.Errorf("tool call not found: %s", toolCallId) } @@ -865,7 +818,6 @@ func CreateWriteTextFileDiff(ctx context.Context, chatId string, toolCallId stri return originalContent, modifiedContent, nil } - type StaticFileInfo struct { Name string `json:"name"` Size int64 `json:"size"` @@ -879,7 +831,7 @@ func generateBuilderAppData(appId string) (string, string, error) { if err == nil { appGoFile = string(fileData.Contents) } - + staticFilesJSON := "" allFiles, err := waveappstore.ListAllAppFiles(appId) if err == nil { @@ -894,7 +846,7 @@ func generateBuilderAppData(appId string) (string, string, error) { }) } } - + if len(staticFiles) > 0 { staticFilesBytes, marshalErr := json.Marshal(staticFiles) if marshalErr == nil { @@ -902,6 +854,6 @@ func generateBuilderAppData(appId string) (string, string, error) { } } } - + return appGoFile, staticFilesJSON, nil }