diff --git a/pkg/aiusechat/aiutil/aiutil.go b/pkg/aiusechat/aiutil/aiutil.go
new file mode 100644
index 0000000000..fb9f8bb517
--- /dev/null
+++ b/pkg/aiusechat/aiutil/aiutil.go
@@ -0,0 +1,182 @@
+// Copyright 2025, Command Line Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package aiutil
+
+import (
+ "bytes"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "strconv"
+ "strings"
+
+ "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
+ "github.com/wavetermdev/waveterm/pkg/util/utilfn"
+)
+
+// ExtractXmlAttribute extracts an attribute value from an XML-like tag.
+// Expects double-quoted strings where internal quotes are encoded as ".
+// Returns the unquoted value and true if found, or empty string and false if not found or invalid.
+func ExtractXmlAttribute(tag, attrName string) (string, bool) {
+ attrStart := strings.Index(tag, attrName+"=")
+ if attrStart == -1 {
+ return "", false
+ }
+
+ pos := attrStart + len(attrName+"=")
+ start := strings.Index(tag[pos:], `"`)
+ if start == -1 {
+ return "", false
+ }
+ start += pos
+
+ end := strings.Index(tag[start+1:], `"`)
+ if end == -1 {
+ return "", false
+ }
+ end += start + 1
+
+ quotedValue := tag[start : end+1]
+ value, err := strconv.Unquote(quotedValue)
+ if err != nil {
+ return "", false
+ }
+
+ value = strings.ReplaceAll(value, """, `"`)
+ return value, true
+}
+
+// GenerateDeterministicSuffix creates an 8-character hash from input strings
+func GenerateDeterministicSuffix(inputs ...string) string {
+ hasher := sha256.New()
+ for _, input := range inputs {
+ hasher.Write([]byte(input))
+ }
+ hash := hasher.Sum(nil)
+ return hex.EncodeToString(hash)[:8]
+}
+
+// ExtractImageUrl extracts an image URL from either URL field (http/https/data) or raw Data
+func ExtractImageUrl(data []byte, url, mimeType string) (string, error) {
+ if url != "" {
+ if !strings.HasPrefix(url, "data:") &&
+ !strings.HasPrefix(url, "http://") &&
+ !strings.HasPrefix(url, "https://") {
+ return "", fmt.Errorf("unsupported URL protocol in file part: %s", url)
+ }
+ return url, nil
+ }
+ if len(data) > 0 {
+ base64Data := base64.StdEncoding.EncodeToString(data)
+ return fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data), nil
+ }
+ return "", fmt.Errorf("file part missing both url and data")
+}
+
+// ExtractTextData extracts text data from either Data field or URL field (data: URLs only)
+func ExtractTextData(data []byte, url string) ([]byte, error) {
+ if len(data) > 0 {
+ return data, nil
+ }
+ if url != "" {
+ if strings.HasPrefix(url, "data:") {
+ _, decodedData, err := utilfn.DecodeDataURL(url)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode data URL for text/plain file: %w", err)
+ }
+ return decodedData, nil
+ }
+ return nil, fmt.Errorf("dropping text/plain file with URL (must be fetched and converted to data)")
+ }
+ return nil, fmt.Errorf("text/plain file part missing data")
+}
+
+// FormatAttachedTextFile formats a text file attachment with proper encoding and deterministic suffix
+func FormatAttachedTextFile(fileName string, textContent []byte) string {
+ if fileName == "" {
+ fileName = "untitled.txt"
+ }
+
+ encodedFileName := strings.ReplaceAll(fileName, `"`, """)
+ quotedFileName := strconv.Quote(encodedFileName)
+
+ textStr := string(textContent)
+ deterministicSuffix := GenerateDeterministicSuffix(textStr, fileName)
+ return fmt.Sprintf("\n%s\n", deterministicSuffix, quotedFileName, textStr, deterministicSuffix)
+}
+
+// FormatAttachedDirectoryListing formats a directory listing attachment with proper encoding and deterministic suffix
+func FormatAttachedDirectoryListing(directoryName, jsonContent string) string {
+ if directoryName == "" {
+ directoryName = "unnamed-directory"
+ }
+
+ encodedDirName := strings.ReplaceAll(directoryName, `"`, """)
+ quotedDirName := strconv.Quote(encodedDirName)
+
+ deterministicSuffix := GenerateDeterministicSuffix(jsonContent, directoryName)
+ return fmt.Sprintf("\n%s\n", deterministicSuffix, quotedDirName, jsonContent, deterministicSuffix)
+}
+
+// ConvertDataUserFile converts OpenAI attached file/directory blocks to UIMessagePart
+// Returns (found, part) where found indicates if the prefix was matched,
+// and part is the converted UIMessagePart (can be nil if parsing failed)
+func ConvertDataUserFile(blockText string) (bool, *uctypes.UIMessagePart) {
+ if strings.HasPrefix(blockText, "' {
+ return true, nil
+ }
+
+ openTag := blockText[:openTagEnd]
+ fileName, ok := ExtractXmlAttribute(openTag, "file_name")
+ if !ok {
+ return true, nil
+ }
+
+ return true, &uctypes.UIMessagePart{
+ Type: "data-userfile",
+ Data: uctypes.UIMessageDataUserFile{
+ FileName: fileName,
+ MimeType: "text/plain",
+ },
+ }
+ }
+
+ if strings.HasPrefix(blockText, "' {
+ return true, nil
+ }
+
+ openTag := blockText[:openTagEnd]
+ directoryName, ok := ExtractXmlAttribute(openTag, "directory_name")
+ if !ok {
+ return true, nil
+ }
+
+ return true, &uctypes.UIMessagePart{
+ Type: "data-userfile",
+ Data: uctypes.UIMessageDataUserFile{
+ FileName: directoryName,
+ MimeType: "directory",
+ },
+ }
+ }
+
+ return false, nil
+}
+
+func JsonEncodeRequestBody(reqBody any) (bytes.Buffer, error) {
+ var buf bytes.Buffer
+ encoder := json.NewEncoder(&buf)
+ encoder.SetEscapeHTML(false)
+ err := encoder.Encode(reqBody)
+ if err != nil {
+ return buf, err
+ }
+ return buf, nil
+}
diff --git a/pkg/aiusechat/anthropic/anthropic-backend.go b/pkg/aiusechat/anthropic/anthropic-backend.go
index 3717e0cc0d..c2eb3a519d 100644
--- a/pkg/aiusechat/anthropic/anthropic-backend.go
+++ b/pkg/aiusechat/anthropic/anthropic-backend.go
@@ -480,16 +480,14 @@ func RunAnthropicChatStep(
if rateLimitInfo.PReq == 0 && rateLimitInfo.Req > 0 {
// Premium requests exhausted, but regular requests available
stopReason := &uctypes.WaveStopReason{
- Kind: uctypes.StopKindPremiumRateLimit,
- RateLimitInfo: rateLimitInfo,
+ Kind: uctypes.StopKindPremiumRateLimit,
}
return stopReason, nil, rateLimitInfo, nil
}
if rateLimitInfo.Req == 0 {
// All requests exhausted
stopReason := &uctypes.WaveStopReason{
- Kind: uctypes.StopKindRateLimit,
- RateLimitInfo: rateLimitInfo,
+ Kind: uctypes.StopKindRateLimit,
}
return stopReason, nil, rateLimitInfo, nil
}
@@ -590,8 +588,6 @@ func handleAnthropicStreamingResp(
rtnStopReason = &uctypes.WaveStopReason{
Kind: uctypes.StopKindDone,
RawReason: state.stopFromDelta,
- MessageID: state.msgID,
- Model: state.model,
}
return rtnStopReason, state.rtnMessage
}
@@ -849,41 +845,30 @@ func handleAnthropicEvent(
switch reason {
case "tool_use":
return nil, &uctypes.WaveStopReason{
- Kind: uctypes.StopKindToolUse,
- RawReason: reason,
- MessageID: state.msgID,
- Model: state.model,
- ToolCalls: state.toolCalls,
- FinishStep: true,
+ Kind: uctypes.StopKindToolUse,
+ RawReason: reason,
+ ToolCalls: state.toolCalls,
}
case "max_tokens":
return nil, &uctypes.WaveStopReason{
Kind: uctypes.StopKindMaxTokens,
RawReason: reason,
- MessageID: state.msgID,
- Model: state.model,
}
case "refusal":
return nil, &uctypes.WaveStopReason{
Kind: uctypes.StopKindContent,
RawReason: reason,
- MessageID: state.msgID,
- Model: state.model,
}
case "pause_turn":
return nil, &uctypes.WaveStopReason{
Kind: uctypes.StopKindPauseTurn,
RawReason: reason,
- MessageID: state.msgID,
- Model: state.model,
}
default:
// end_turn, stop_sequence (treat as end of this call)
return nil, &uctypes.WaveStopReason{
Kind: uctypes.StopKindDone,
RawReason: reason,
- MessageID: state.msgID,
- Model: state.model,
}
}
diff --git a/pkg/aiusechat/openai/openai-backend.go b/pkg/aiusechat/openai/openai-backend.go
index c05745ee6a..f551a76845 100644
--- a/pkg/aiusechat/openai/openai-backend.go
+++ b/pkg/aiusechat/openai/openai-backend.go
@@ -93,7 +93,7 @@ type OpenAIMessageContent struct {
Name string `json:"name,omitempty"`
}
-func (c *OpenAIMessageContent) Clean() *OpenAIMessageContent {
+func (c *OpenAIMessageContent) clean() *OpenAIMessageContent {
if c.PreviewUrl == "" {
return c
}
@@ -102,17 +102,17 @@ func (c *OpenAIMessageContent) Clean() *OpenAIMessageContent {
return &rtn
}
-func (m *OpenAIMessage) CleanAndCopy() *OpenAIMessage {
+func (m *OpenAIMessage) cleanAndCopy() *OpenAIMessage {
rtn := &OpenAIMessage{Role: m.Role}
rtn.Content = make([]OpenAIMessageContent, len(m.Content))
for idx, block := range m.Content {
- cleaned := block.Clean()
+ cleaned := block.clean()
rtn.Content[idx] = *cleaned
}
return rtn
}
-func (f *OpenAIFunctionCallInput) Clean() *OpenAIFunctionCallInput {
+func (f *OpenAIFunctionCallInput) clean() *OpenAIFunctionCallInput {
if f.ToolUseData == nil {
return f
}
@@ -481,10 +481,10 @@ func RunOpenAIChatStep(
// Convert to appropriate input type based on what's populated
if chatMsg.Message != nil {
// Clean message to remove preview URLs
- cleanedMsg := chatMsg.Message.CleanAndCopy()
+ cleanedMsg := chatMsg.Message.cleanAndCopy()
inputs = append(inputs, *cleanedMsg)
} else if chatMsg.FunctionCall != nil {
- cleanedFunctionCall := chatMsg.FunctionCall.Clean()
+ cleanedFunctionCall := chatMsg.FunctionCall.clean()
inputs = append(inputs, *cleanedFunctionCall)
} else if chatMsg.FunctionCallOutput != nil {
inputs = append(inputs, *chatMsg.FunctionCallOutput)
@@ -526,16 +526,14 @@ func RunOpenAIChatStep(
if rateLimitInfo.PReq == 0 && rateLimitInfo.Req > 0 {
// Premium requests exhausted, but regular requests available
stopReason := &uctypes.WaveStopReason{
- Kind: uctypes.StopKindPremiumRateLimit,
- RateLimitInfo: rateLimitInfo,
+ Kind: uctypes.StopKindPremiumRateLimit,
}
return stopReason, nil, rateLimitInfo, nil
}
if rateLimitInfo.Req == 0 {
// All requests exhausted
stopReason := &uctypes.WaveStopReason{
- Kind: uctypes.StopKindRateLimit,
- RateLimitInfo: rateLimitInfo,
+ Kind: uctypes.StopKindRateLimit,
}
return stopReason, nil, rateLimitInfo, nil
}
@@ -797,8 +795,6 @@ func handleOpenAIEvent(
Kind: uctypes.StopKindError,
ErrorType: "api",
ErrorText: errorMsg,
- MessageID: state.msgID,
- Model: state.model,
}, nil
}
@@ -831,8 +827,6 @@ func handleOpenAIEvent(
Kind: stopKind,
RawReason: reason,
ErrorText: errorMsg,
- MessageID: state.msgID,
- Model: state.model,
}, finalMessages
}
@@ -847,8 +841,6 @@ func handleOpenAIEvent(
return &uctypes.WaveStopReason{
Kind: stopKind,
RawReason: ev.Response.Status,
- MessageID: state.msgID,
- Model: state.model,
ToolCalls: toolCalls,
}, finalMessages
@@ -860,22 +852,8 @@ func handleOpenAIEvent(
}
if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockToolUse {
st.partialJSON = append(st.partialJSON, []byte(ev.Delta)...)
-
toolDef := state.chatOpts.GetToolDefinition(st.toolName)
- if toolDef != nil && toolDef.ToolProgressDesc != nil {
- parsedJSON, err := utilfn.ParsePartialJson(st.partialJSON)
- if err == nil {
- statusLines, err := toolDef.ToolProgressDesc(parsedJSON)
- if err == nil {
- progressData := &uctypes.UIMessageDataToolProgress{
- ToolCallId: st.toolCallID,
- ToolName: st.toolName,
- StatusLines: statusLines,
- }
- _ = sse.AiMsgData("data-toolprogress", "progress-"+st.toolCallID, progressData)
- }
- }
- }
+ sendToolProgress(st, toolDef, sse, st.partialJSON, true)
}
return nil, nil
@@ -888,28 +866,10 @@ func handleOpenAIEvent(
// Get the function call info from the block state
if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockToolUse {
- // raw := json.RawMessage(ev.Arguments)
- // no longer send tool inputs to fe
- // _ = sse.AiMsgToolInputAvailable(st.toolCallID, st.toolName, raw)
-
toolDef := state.chatOpts.GetToolDefinition(st.toolName)
toolUseData := createToolUseData(st.toolCallID, st.toolName, toolDef, ev.Arguments, state.chatOpts)
state.toolUseData[st.toolCallID] = toolUseData
-
- if toolDef != nil && toolDef.ToolProgressDesc != nil {
- var parsedJSON any
- if err := json.Unmarshal([]byte(ev.Arguments), &parsedJSON); err == nil {
- statusLines, err := toolDef.ToolProgressDesc(parsedJSON)
- if err == nil {
- progressData := &uctypes.UIMessageDataToolProgress{
- ToolCallId: st.toolCallID,
- ToolName: st.toolName,
- StatusLines: statusLines,
- }
- _ = sse.AiMsgData("data-toolprogress", "progress-"+st.toolCallID, progressData)
- }
- }
- }
+ sendToolProgress(st, toolDef, sse, []byte(ev.Arguments), false)
}
return nil, nil
@@ -966,6 +926,32 @@ func handleOpenAIEvent(
}
}
+func sendToolProgress(st *openaiBlockState, toolDef *uctypes.ToolDefinition, sse *sse.SSEHandlerCh, jsonData []byte, usePartialParse bool) {
+ if toolDef == nil || toolDef.ToolProgressDesc == nil {
+ return
+ }
+ var parsedJSON any
+ var err error
+ if usePartialParse {
+ parsedJSON, err = utilfn.ParsePartialJson(jsonData)
+ } else {
+ err = json.Unmarshal(jsonData, &parsedJSON)
+ }
+ if err != nil {
+ return
+ }
+ statusLines, err := toolDef.ToolProgressDesc(parsedJSON)
+ if err != nil {
+ return
+ }
+ progressData := &uctypes.UIMessageDataToolProgress{
+ ToolCallId: st.toolCallID,
+ ToolName: st.toolName,
+ StatusLines: statusLines,
+ }
+ _ = sse.AiMsgData("data-toolprogress", "progress-"+st.toolCallID, progressData)
+}
+
func createToolUseData(toolCallID, toolName string, toolDef *uctypes.ToolDefinition, arguments string, chatOpts uctypes.WaveChatOpts) *uctypes.UIMessageDataToolUse {
toolUseData := &uctypes.UIMessageDataToolUse{
ToolCallId: toolCallID,
diff --git a/pkg/aiusechat/openai/openai-convertmessage.go b/pkg/aiusechat/openai/openai-convertmessage.go
index 604284bbb8..d2dc594a21 100644
--- a/pkg/aiusechat/openai/openai-convertmessage.go
+++ b/pkg/aiusechat/openai/openai-convertmessage.go
@@ -4,22 +4,18 @@
package openai
import (
- "bytes"
"context"
- "crypto/sha256"
"encoding/base64"
- "encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
- "strconv"
"strings"
"github.com/google/uuid"
+ "github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil"
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
- "github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wavebase"
)
@@ -28,46 +24,46 @@ const (
OpenAIDefaultMaxTokens = 4096
)
-// extractXmlAttribute extracts an attribute value from an XML-like tag.
-// Expects double-quoted strings where internal quotes are encoded as ".
-// Returns the unquoted value and true if found, or empty string and false if not found or invalid.
-func extractXmlAttribute(tag, attrName string) (string, bool) {
- attrStart := strings.Index(tag, attrName+"=")
- if attrStart == -1 {
- return "", false
- }
-
- pos := attrStart + len(attrName+"=")
- start := strings.Index(tag[pos:], `"`)
- if start == -1 {
- return "", false
- }
- start += pos
-
- end := strings.Index(tag[start+1:], `"`)
- if end == -1 {
- return "", false
- }
- end += start + 1
+// convertContentBlockToParts converts a single content block to UIMessageParts
+func convertContentBlockToParts(block OpenAIMessageContent, role string) []uctypes.UIMessagePart {
+ var parts []uctypes.UIMessagePart
- quotedValue := tag[start : end+1]
- value, err := strconv.Unquote(quotedValue)
- if err != nil {
- return "", false
+ switch block.Type {
+ case "input_text", "output_text":
+ if found, part := aiutil.ConvertDataUserFile(block.Text); found {
+ if part != nil {
+ parts = append(parts, *part)
+ }
+ } else {
+ parts = append(parts, uctypes.UIMessagePart{
+ Type: "text",
+ Text: block.Text,
+ })
+ }
+ case "input_image":
+ if role == "user" {
+ parts = append(parts, uctypes.UIMessagePart{
+ Type: "data-userfile",
+ Data: uctypes.UIMessageDataUserFile{
+ MimeType: "image/*",
+ PreviewUrl: block.PreviewUrl,
+ },
+ })
+ }
+ case "input_file":
+ if role == "user" {
+ parts = append(parts, uctypes.UIMessagePart{
+ Type: "data-userfile",
+ Data: uctypes.UIMessageDataUserFile{
+ FileName: block.Filename,
+ MimeType: "application/pdf",
+ PreviewUrl: block.PreviewUrl,
+ },
+ })
+ }
}
- value = strings.ReplaceAll(value, """, `"`)
- return value, true
-}
-
-// generateDeterministicSuffix creates an 8-character hash from input strings
-func generateDeterministicSuffix(inputs ...string) string {
- hasher := sha256.New()
- for _, input := range inputs {
- hasher.Write([]byte(input))
- }
- hash := hasher.Sum(nil)
- return hex.EncodeToString(hash)[:8]
+ return parts
}
// appendToLastUserMessage appends a text block to the last user message in the inputs slice
@@ -146,12 +142,11 @@ type OpenAIRequestTool struct {
// ConvertToolDefinitionToOpenAI converts a generic ToolDefinition to OpenAI format
func ConvertToolDefinitionToOpenAI(tool uctypes.ToolDefinition) OpenAIRequestTool {
- cleanedTool := tool.Clean()
return OpenAIRequestTool{
- Name: cleanedTool.Name,
- Description: cleanedTool.Description,
- Parameters: cleanedTool.InputSchema,
- Strict: cleanedTool.Strict,
+ Name: tool.Name,
+ Description: tool.Description,
+ Parameters: tool.InputSchema,
+ Strict: tool.Strict,
Type: "function",
}
}
@@ -218,14 +213,13 @@ func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes.
maxTokens = OpenAIDefaultMaxTokens
}
+ // injected data
if chatOpts.TabState != "" {
appendToLastUserMessage(inputs, chatOpts.TabState)
}
-
if chatOpts.AppStaticFiles != "" {
appendToLastUserMessage(inputs, "\n"+chatOpts.AppStaticFiles+"\n")
}
-
if chatOpts.AppGoFile != "" {
appendToLastUserMessage(inputs, "\n"+chatOpts.AppGoFile+"\n")
}
@@ -276,29 +270,18 @@ func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes.
}
}
- // Set temperature if provided
- if opts.APIVersion != "" && opts.APIVersion != OpenAIDefaultAPIVersion {
- // Temperature and other parameters could be set here based on config
- // For now, using defaults
- }
-
debugPrintReq(reqBody, endpoint)
// Encode request body
- var buf bytes.Buffer
- encoder := json.NewEncoder(&buf)
- encoder.SetEscapeHTML(false)
- err := encoder.Encode(reqBody)
+ buf, err := aiutil.JsonEncodeRequestBody(reqBody)
if err != nil {
return nil, err
}
-
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, &buf)
if err != nil {
return nil, err
}
-
// Set headers
req.Header.Set("Content-Type", "application/json")
if opts.APIToken != "" {
@@ -309,11 +292,7 @@ func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes.
req.Header.Set("X-Wave-ClientId", chatOpts.ClientId)
}
req.Header.Set("X-Wave-APIType", "openai")
- if chatOpts.BuilderId != "" {
- req.Header.Set("X-Wave-RequestType", "waveapps-builder")
- } else {
- req.Header.Set("X-Wave-RequestType", "waveai")
- }
+ req.Header.Set("X-Wave-RequestType", chatOpts.GetWaveRequestType())
return req, nil
}
@@ -330,23 +309,9 @@ func convertFileAIMessagePart(part uctypes.AIMessagePart) (*OpenAIMessageContent
// Handle different file types
switch {
case strings.HasPrefix(part.MimeType, "image/"):
- // Handle images
- var imageUrl string
-
- if part.URL != "" {
- // Validate URL protocol - only allow data:, http:, https:
- if !strings.HasPrefix(part.URL, "data:") &&
- !strings.HasPrefix(part.URL, "http://") &&
- !strings.HasPrefix(part.URL, "https://") {
- return nil, fmt.Errorf("unsupported URL protocol in file part: %s", part.URL)
- }
- imageUrl = part.URL
- } else if len(part.Data) > 0 {
- // Convert raw data to base64 data URL
- base64Data := base64.StdEncoding.EncodeToString(part.Data)
- imageUrl = fmt.Sprintf("data:%s;base64,%s", part.MimeType, base64Data)
- } else {
- return nil, fmt.Errorf("file part missing both url and data")
+ imageUrl, err := aiutil.ExtractImageUrl(part.Data, part.URL, part.MimeType)
+ if err != nil {
+ return nil, err
}
return &OpenAIMessageContent{
@@ -375,35 +340,11 @@ func convertFileAIMessagePart(part uctypes.AIMessagePart) (*OpenAIMessageContent
}, nil
case part.MimeType == "text/plain":
- var textContent string
-
- if len(part.Data) > 0 {
- textContent = string(part.Data)
- } else if part.URL != "" {
- if strings.HasPrefix(part.URL, "data:") {
- _, decodedData, err := utilfn.DecodeDataURL(part.URL)
- if err != nil {
- return nil, fmt.Errorf("failed to decode data URL for text/plain file: %w", err)
- }
- textContent = string(decodedData)
- } else {
- return nil, fmt.Errorf("dropping text/plain file with URL (must be fetched and converted to data)")
- }
- } else {
- return nil, fmt.Errorf("text/plain file part missing data")
- }
-
- fileName := part.FileName
- if fileName == "" {
- fileName = "untitled.txt"
+ textData, err := aiutil.ExtractTextData(part.Data, part.URL)
+ if err != nil {
+ return nil, err
}
-
- encodedFileName := strings.ReplaceAll(fileName, `"`, """)
- quotedFileName := strconv.Quote(encodedFileName)
-
- deterministicSuffix := generateDeterministicSuffix(textContent, fileName)
- formattedText := fmt.Sprintf("\n%s\n", deterministicSuffix, quotedFileName, textContent, deterministicSuffix)
-
+ formattedText := aiutil.FormatAttachedTextFile(part.FileName, textData)
return &OpenAIMessageContent{
Type: "input_text",
Text: formattedText,
@@ -417,16 +358,7 @@ func convertFileAIMessagePart(part uctypes.AIMessagePart) (*OpenAIMessageContent
return nil, fmt.Errorf("directory listing part missing data")
}
- directoryName := part.FileName
- if directoryName == "" {
- directoryName = "unnamed-directory"
- }
-
- encodedDirName := strings.ReplaceAll(directoryName, `"`, """)
- quotedDirName := strconv.Quote(encodedDirName)
-
- deterministicSuffix := generateDeterministicSuffix(jsonContent, directoryName)
- formattedText := fmt.Sprintf("\n%s\n", deterministicSuffix, quotedDirName, jsonContent, deterministicSuffix)
+ formattedText := aiutil.FormatAttachedDirectoryListing(part.FileName, jsonContent)
return &OpenAIMessageContent{
Type: "input_text",
@@ -540,89 +472,17 @@ func ConvertToolResultsToOpenAIChatMessage(toolResults []uctypes.AIToolResult) (
return messages, nil
}
-// ConvertToUIMessage converts an OpenAIChatMessage to a UIMessage
-func (m *OpenAIChatMessage) ConvertToUIMessage() *uctypes.UIMessage {
+// convertToUIMessage converts an OpenAIChatMessage to a UIMessage
+func (m *OpenAIChatMessage) convertToUIMessage() *uctypes.UIMessage {
var parts []uctypes.UIMessagePart
var role string
// Handle different message types
if m.Message != nil {
role = m.Message.Role
- // Iterate over all content blocks
for _, block := range m.Message.Content {
- switch block.Type {
- case "input_text", "output_text":
- if strings.HasPrefix(block.Text, "' {
- continue
- }
-
- openTag := block.Text[:openTagEnd]
- fileName, ok := extractXmlAttribute(openTag, "file_name")
- if !ok {
- continue
- }
-
- parts = append(parts, uctypes.UIMessagePart{
- Type: "data-userfile",
- Data: uctypes.UIMessageDataUserFile{
- FileName: fileName,
- MimeType: "text/plain",
- },
- })
- } else if strings.HasPrefix(block.Text, "' {
- continue
- }
-
- openTag := block.Text[:openTagEnd]
- directoryName, ok := extractXmlAttribute(openTag, "directory_name")
- if !ok {
- continue
- }
-
- parts = append(parts, uctypes.UIMessagePart{
- Type: "data-userfile",
- Data: uctypes.UIMessageDataUserFile{
- FileName: directoryName,
- MimeType: "directory",
- },
- })
- } else {
- parts = append(parts, uctypes.UIMessagePart{
- Type: "text",
- Text: block.Text,
- })
- }
- case "input_image":
- // Convert image blocks to data-userfile UIMessagePart (only for user role)
- if role == "user" {
- parts = append(parts, uctypes.UIMessagePart{
- Type: "data-userfile",
- Data: uctypes.UIMessageDataUserFile{
- MimeType: "image/*",
- PreviewUrl: block.PreviewUrl,
- },
- })
- }
- case "input_file":
- // Convert file blocks to data-userfile UIMessagePart (only for user role)
- if role == "user" {
- parts = append(parts, uctypes.UIMessagePart{
- Type: "data-userfile",
- Data: uctypes.UIMessageDataUserFile{
- FileName: block.Filename,
- MimeType: "application/pdf",
- PreviewUrl: block.PreviewUrl,
- },
- })
- }
- default:
- // Skip unknown types
- continue
- }
+ blockParts := convertContentBlockToParts(block, role)
+ parts = append(parts, blockParts...)
}
} else if m.FunctionCall != nil {
// Handle function call input
@@ -638,11 +498,9 @@ func (m *OpenAIChatMessage) ConvertToUIMessage() *uctypes.UIMessage {
// FunctionCallOutput messages are not converted to UIMessage
return nil
}
-
if len(parts) == 0 {
return nil
}
-
return &uctypes.UIMessage{
ID: m.MessageId,
Role: role,
@@ -655,21 +513,17 @@ func ConvertAIChatToUIChat(aiChat uctypes.AIChat) (*uctypes.UIChat, error) {
if aiChat.APIType != "openai" {
return nil, fmt.Errorf("APIType must be 'openai', got '%s'", aiChat.APIType)
}
-
uiMessages := make([]uctypes.UIMessage, 0, len(aiChat.NativeMessages))
-
for i, nativeMsg := range aiChat.NativeMessages {
openaiMsg, ok := nativeMsg.(*OpenAIChatMessage)
if !ok {
return nil, fmt.Errorf("message %d: expected *OpenAIChatMessage, got %T", i, nativeMsg)
}
-
- uiMsg := openaiMsg.ConvertToUIMessage()
+ uiMsg := openaiMsg.convertToUIMessage()
if uiMsg != nil {
uiMessages = append(uiMessages, *uiMsg)
}
}
-
return &uctypes.UIChat{
ChatId: aiChat.ChatId,
APIType: aiChat.APIType,
diff --git a/pkg/aiusechat/uctypes/usechat-types.go b/pkg/aiusechat/uctypes/usechat-types.go
index 47ada4c7b8..859b7218ce 100644
--- a/pkg/aiusechat/uctypes/usechat-types.go
+++ b/pkg/aiusechat/uctypes/usechat-types.go
@@ -193,25 +193,15 @@ type WaveToolCall struct {
type WaveStopReason struct {
Kind StopReasonKind `json:"kind"`
RawReason string `json:"raw_reason,omitempty"`
- MessageID string `json:"message_id,omitempty"`
- Model string `json:"model,omitempty"`
-
ToolCalls []WaveToolCall `json:"tool_calls,omitempty"`
-
- ErrorType string `json:"error_type,omitempty"`
- ErrorText string `json:"error_text,omitempty"`
-
- RateLimitInfo *RateLimitInfo `json:"ratelimitinfo,omitempty"` // set when Kind is StopKindPremiumRateLimit or StopKindRateLimit
-
- FinishStep bool `json:"finish_step,omitempty"`
+ ErrorType string `json:"error_type,omitempty"`
+ ErrorText string `json:"error_text,omitempty"`
}
// Wave Specific parameter used to signal to our step function that this is a continuation step, not an initial step
type WaveContinueResponse struct {
- MessageID string `json:"message_id,omitempty"`
- Model string `json:"model,omitempty"`
- ContinueFromKind StopReasonKind `json:"continue_from_kind"`
- ContinueFromRawReason string `json:"continue_from_raw_reason,omitempty"`
+ Model string `json:"model,omitempty"`
+ ContinueFromKind StopReasonKind `json:"continue_from_kind"`
}
// Wave Specific AI opts for configuration
@@ -273,6 +263,13 @@ type AIMetrics struct {
ThinkingMode string `json:"thinkingmode,omitempty"`
}
+type AIFunctionCallInput struct {
+ CallId string `json:"call_id"`
+ Name string `json:"name"`
+ Arguments string `json:"arguments"`
+ ToolUseData *UIMessageDataToolUse `json:"toolusedata,omitempty"`
+}
+
// GenAIMessage interface for messages stored in conversations
// All messages must have a unique identifier for idempotency checks
type GenAIMessage interface {
@@ -478,6 +475,14 @@ func (opts *WaveChatOpts) GetToolDefinition(toolName string) *ToolDefinition {
return nil
}
+func (opts *WaveChatOpts) GetWaveRequestType() string {
+ if opts.BuilderId != "" {
+ return "waveapps-builder"
+ } else {
+ return "waveai"
+ }
+}
+
type ProxyErrorResponse struct {
Success bool `json:"success"`
Error string `json:"error"`
diff --git a/pkg/aiusechat/usechat-backend.go b/pkg/aiusechat/usechat-backend.go
new file mode 100644
index 0000000000..adebb11282
--- /dev/null
+++ b/pkg/aiusechat/usechat-backend.go
@@ -0,0 +1,157 @@
+// Copyright 2025, Command Line Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package aiusechat
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/wavetermdev/waveterm/pkg/aiusechat/anthropic"
+ "github.com/wavetermdev/waveterm/pkg/aiusechat/openai"
+ "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
+ "github.com/wavetermdev/waveterm/pkg/web/sse"
+)
+
+// UseChatBackend defines the interface for AI chat backend providers (OpenAI, Anthropic, etc.)
+// This interface abstracts the provider-specific API calls needed by the usechat system.
+type UseChatBackend interface {
+ // RunChatStep executes a single step in the chat conversation with the AI backend.
+ // Returns the stop reason, native messages from the response, rate limit info, and any error.
+ // The cont parameter allows continuing from a previous response (e.g., after rate limiting).
+ RunChatStep(
+ ctx context.Context,
+ sseHandler *sse.SSEHandlerCh,
+ chatOpts uctypes.WaveChatOpts,
+ cont *uctypes.WaveContinueResponse,
+ ) (*uctypes.WaveStopReason, []uctypes.GenAIMessage, *uctypes.RateLimitInfo, error)
+
+ // UpdateToolUseData updates the tool use data for a specific tool call in the chat.
+ // This is used to update the UI state for tool execution (approval status, results, etc.)
+ UpdateToolUseData(chatId string, toolCallId string, toolUseData *uctypes.UIMessageDataToolUse) error
+
+ // ConvertToolResultsToNativeChatMessage converts tool execution results into native chat messages
+ // that can be sent back to the AI backend. Returns a slice of messages (some backends may
+ // require multiple messages per tool result).
+ ConvertToolResultsToNativeChatMessage(toolResults []uctypes.AIToolResult) ([]uctypes.GenAIMessage, error)
+
+ // ConvertAIMessageToNativeChatMessage converts a generic AIMessage (from the user)
+ // into the backend's native message format for sending to the API.
+ ConvertAIMessageToNativeChatMessage(message uctypes.AIMessage) (uctypes.GenAIMessage, error)
+
+ // GetFunctionCallInputByToolCallId retrieves the function call input data for a specific
+ // tool call ID from the chat history. Returns the function call structure
+ // or nil if not found.
+ GetFunctionCallInputByToolCallId(aiChat uctypes.AIChat, toolCallId string) *uctypes.AIFunctionCallInput
+
+ // ConvertAIChatToUIChat converts a stored AIChat (with native backend messages) into
+ // a UI-friendly UIChat format that can be displayed in the frontend.
+ ConvertAIChatToUIChat(aiChat uctypes.AIChat) (*uctypes.UIChat, error)
+}
+
+// Compile-time interface checks
+var _ UseChatBackend = (*openaiResponsesBackend)(nil)
+var _ UseChatBackend = (*anthropicBackend)(nil)
+
+// GetBackendByAPIType returns the appropriate UseChatBackend implementation for the given API type
+func GetBackendByAPIType(apiType string) (UseChatBackend, error) {
+ switch apiType {
+ case APIType_OpenAI:
+ return &openaiResponsesBackend{}, nil
+ case APIType_Anthropic:
+ return &anthropicBackend{}, nil
+ default:
+ return nil, fmt.Errorf("unsupported API type: %s", apiType)
+ }
+}
+
+// openaiResponsesBackend implements UseChatBackend for OpenAI API
+type openaiResponsesBackend struct{}
+
+func (b *openaiResponsesBackend) RunChatStep(
+ ctx context.Context,
+ sseHandler *sse.SSEHandlerCh,
+ chatOpts uctypes.WaveChatOpts,
+ cont *uctypes.WaveContinueResponse,
+) (*uctypes.WaveStopReason, []uctypes.GenAIMessage, *uctypes.RateLimitInfo, error) {
+ stopReason, msgs, rateLimitInfo, err := openai.RunOpenAIChatStep(ctx, sseHandler, chatOpts, cont)
+ var genMsgs []uctypes.GenAIMessage
+ for _, msg := range msgs {
+ genMsgs = append(genMsgs, msg)
+ }
+ return stopReason, genMsgs, rateLimitInfo, err
+}
+
+func (b *openaiResponsesBackend) UpdateToolUseData(chatId string, toolCallId string, toolUseData *uctypes.UIMessageDataToolUse) error {
+ return openai.UpdateToolUseData(chatId, toolCallId, toolUseData)
+}
+
+func (b *openaiResponsesBackend) ConvertToolResultsToNativeChatMessage(toolResults []uctypes.AIToolResult) ([]uctypes.GenAIMessage, error) {
+ msgs, err := openai.ConvertToolResultsToOpenAIChatMessage(toolResults)
+ if err != nil {
+ return nil, err
+ }
+ var genMsgs []uctypes.GenAIMessage
+ for _, msg := range msgs {
+ genMsgs = append(genMsgs, msg)
+ }
+ return genMsgs, nil
+}
+
+func (b *openaiResponsesBackend) ConvertAIMessageToNativeChatMessage(message uctypes.AIMessage) (uctypes.GenAIMessage, error) {
+ return openai.ConvertAIMessageToOpenAIChatMessage(message)
+}
+
+func (b *openaiResponsesBackend) GetFunctionCallInputByToolCallId(aiChat uctypes.AIChat, toolCallId string) *uctypes.AIFunctionCallInput {
+ openaiInput := openai.GetFunctionCallInputByToolCallId(aiChat, toolCallId)
+ if openaiInput == nil {
+ return nil
+ }
+ return &uctypes.AIFunctionCallInput{
+ CallId: openaiInput.CallId,
+ Name: openaiInput.Name,
+ Arguments: openaiInput.Arguments,
+ ToolUseData: openaiInput.ToolUseData,
+ }
+}
+
+func (b *openaiResponsesBackend) ConvertAIChatToUIChat(aiChat uctypes.AIChat) (*uctypes.UIChat, error) {
+ return openai.ConvertAIChatToUIChat(aiChat)
+}
+
+// anthropicBackend implements UseChatBackend for Anthropic API
+type anthropicBackend struct{}
+
+func (b *anthropicBackend) RunChatStep(
+ ctx context.Context,
+ sseHandler *sse.SSEHandlerCh,
+ chatOpts uctypes.WaveChatOpts,
+ cont *uctypes.WaveContinueResponse,
+) (*uctypes.WaveStopReason, []uctypes.GenAIMessage, *uctypes.RateLimitInfo, error) {
+ stopReason, msg, rateLimitInfo, err := anthropic.RunAnthropicChatStep(ctx, sseHandler, chatOpts, cont)
+ return stopReason, []uctypes.GenAIMessage{msg}, rateLimitInfo, err
+}
+
+func (b *anthropicBackend) UpdateToolUseData(chatId string, toolCallId string, toolUseData *uctypes.UIMessageDataToolUse) error {
+ return fmt.Errorf("UpdateToolUseData not implemented for anthropic backend")
+}
+
+func (b *anthropicBackend) ConvertToolResultsToNativeChatMessage(toolResults []uctypes.AIToolResult) ([]uctypes.GenAIMessage, error) {
+ msg, err := anthropic.ConvertToolResultsToAnthropicChatMessage(toolResults)
+ if err != nil {
+ return nil, err
+ }
+ return []uctypes.GenAIMessage{msg}, nil
+}
+
+func (b *anthropicBackend) ConvertAIMessageToNativeChatMessage(message uctypes.AIMessage) (uctypes.GenAIMessage, error) {
+ return anthropic.ConvertAIMessageToAnthropicChatMessage(message)
+}
+
+func (b *anthropicBackend) GetFunctionCallInputByToolCallId(aiChat uctypes.AIChat, toolCallId string) *uctypes.AIFunctionCallInput {
+ return nil
+}
+
+func (b *anthropicBackend) ConvertAIChatToUIChat(aiChat uctypes.AIChat) (*uctypes.UIChat, error) {
+ return anthropic.ConvertAIChatToUIChat(aiChat)
+}
diff --git a/pkg/aiusechat/usechat-utils.go b/pkg/aiusechat/usechat-utils.go
index 0b71fb1a56..72a5948d2b 100644
--- a/pkg/aiusechat/usechat-utils.go
+++ b/pkg/aiusechat/usechat-utils.go
@@ -4,10 +4,6 @@
package aiusechat
import (
- "fmt"
-
- "github.com/wavetermdev/waveterm/pkg/aiusechat/anthropic"
- "github.com/wavetermdev/waveterm/pkg/aiusechat/openai"
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
)
@@ -75,18 +71,12 @@ func ConvertAIChatToUIChat(aiChat *uctypes.AIChat) (*uctypes.UIChat, error) {
return nil, nil
}
- var uiChat *uctypes.UIChat
- var err error
-
- switch aiChat.APIType {
- case "openai":
- uiChat, err = openai.ConvertAIChatToUIChat(*aiChat)
- case "anthropic":
- uiChat, err = anthropic.ConvertAIChatToUIChat(*aiChat)
- default:
- return nil, fmt.Errorf("unsupported APIType: %s", aiChat.APIType)
+ backend, err := GetBackendByAPIType(aiChat.APIType)
+ if err != nil {
+ return nil, err
}
+ uiChat, err := backend.ConvertAIChatToUIChat(*aiChat)
if err != nil {
return nil, err
}
diff --git a/pkg/aiusechat/usechat.go b/pkg/aiusechat/usechat.go
index aa7137e72a..d15872c66c 100644
--- a/pkg/aiusechat/usechat.go
+++ b/pkg/aiusechat/usechat.go
@@ -16,9 +16,7 @@ import (
"time"
"github.com/google/uuid"
- "github.com/wavetermdev/waveterm/pkg/aiusechat/anthropic"
"github.com/wavetermdev/waveterm/pkg/aiusechat/chatstore"
- "github.com/wavetermdev/waveterm/pkg/aiusechat/openai"
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
"github.com/wavetermdev/waveterm/pkg/telemetry"
"github.com/wavetermdev/waveterm/pkg/telemetry/telemetrydata"
@@ -203,25 +201,13 @@ func GetGlobalRateLimit() *uctypes.RateLimitInfo {
return globalRateLimitInfo
}
-func runAIChatStep(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctypes.WaveChatOpts, cont *uctypes.WaveContinueResponse) (*uctypes.WaveStopReason, []uctypes.GenAIMessage, error) {
- if chatOpts.Config.APIType == APIType_Anthropic {
- stopReason, msg, rateLimitInfo, err := anthropic.RunAnthropicChatStep(ctx, sseHandler, chatOpts, cont)
- updateRateLimit(rateLimitInfo)
- return stopReason, []uctypes.GenAIMessage{msg}, err
+func runAIChatStep(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseChatBackend, chatOpts uctypes.WaveChatOpts, cont *uctypes.WaveContinueResponse) (*uctypes.WaveStopReason, []uctypes.GenAIMessage, error) {
+ if chatOpts.Config.APIType == APIType_OpenAI && shouldUseChatCompletionsAPI(chatOpts.Config.Model) {
+ return nil, nil, fmt.Errorf("Chat completions API not available (must use newer OpenAI models)")
}
- if chatOpts.Config.APIType == APIType_OpenAI {
- if shouldUseChatCompletionsAPI(chatOpts.Config.Model) {
- return nil, nil, fmt.Errorf("Chat completions API not available (must use newer OpenAI models)")
- }
- stopReason, msgs, rateLimitInfo, err := openai.RunOpenAIChatStep(ctx, sseHandler, chatOpts, cont)
- updateRateLimit(rateLimitInfo)
- var messages []uctypes.GenAIMessage
- for _, msg := range msgs {
- messages = append(messages, msg)
- }
- return stopReason, messages, err
- }
- return nil, nil, fmt.Errorf("Invalid APIType %q", chatOpts.Config.APIType)
+ stopReason, messages, rateLimitInfo, err := backend.RunChatStep(ctx, sseHandler, chatOpts, cont)
+ updateRateLimit(rateLimitInfo)
+ return stopReason, messages, err
}
func getUsage(msgs []uctypes.GenAIMessage) uctypes.AIUsage {
@@ -249,17 +235,13 @@ func GetChatUsage(chat *uctypes.AIChat) uctypes.AIUsage {
return usage
}
-func updateToolUseDataInChat(chatOpts uctypes.WaveChatOpts, toolCallID string, toolUseData *uctypes.UIMessageDataToolUse) {
- if chatOpts.Config.APIType == APIType_OpenAI {
- if err := openai.UpdateToolUseData(chatOpts.ChatId, toolCallID, toolUseData); err != nil {
- log.Printf("failed to update tool use data in chat: %v\n", err)
- }
- } else if chatOpts.Config.APIType == APIType_Anthropic {
- log.Printf("warning: UpdateToolUseData not implemented for anthropic\n")
+func updateToolUseDataInChat(backend UseChatBackend, chatOpts uctypes.WaveChatOpts, toolCallID string, toolUseData *uctypes.UIMessageDataToolUse) {
+ if err := backend.UpdateToolUseData(chatOpts.ChatId, toolCallID, toolUseData); err != nil {
+ log.Printf("failed to update tool use data in chat: %v\n", err)
}
}
-func processToolCallInternal(toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts, toolDef *uctypes.ToolDefinition, sseHandler *sse.SSEHandlerCh) uctypes.AIToolResult {
+func processToolCallInternal(backend UseChatBackend, toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts, toolDef *uctypes.ToolDefinition, sseHandler *sse.SSEHandlerCh) uctypes.AIToolResult {
if toolCall.ToolUseData == nil {
return uctypes.AIToolResult{
ToolName: toolCall.Name,
@@ -293,7 +275,7 @@ func processToolCallInternal(toolCall uctypes.WaveToolCall, chatOpts uctypes.Wav
}
// ToolVerifyInput can modify the toolusedata. re-send it here.
_ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData)
- updateToolUseDataInChat(chatOpts, toolCall.ID, toolCall.ToolUseData)
+ updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolCall.ToolUseData)
}
if toolCall.ToolUseData.Approval == uctypes.ApprovalNeedsApproval {
@@ -322,7 +304,7 @@ func processToolCallInternal(toolCall uctypes.WaveToolCall, chatOpts uctypes.Wav
// this still happens here because we need to update the FE to say the tool call was approved
_ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData)
- updateToolUseDataInChat(chatOpts, toolCall.ID, toolCall.ToolUseData)
+ updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolCall.ToolUseData)
}
toolCall.ToolUseData.RunTs = time.Now().UnixMilli()
@@ -338,12 +320,12 @@ func processToolCallInternal(toolCall uctypes.WaveToolCall, chatOpts uctypes.Wav
return result
}
-func processToolCall(toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) uctypes.AIToolResult {
+func processToolCall(backend UseChatBackend, toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) uctypes.AIToolResult {
inputJSON, _ := json.Marshal(toolCall.Input)
logutil.DevPrintf("TOOLUSE name=%s id=%s input=%s approval=%q\n", toolCall.Name, toolCall.ID, utilfn.TruncateString(string(inputJSON), 40), toolCall.ToolUseData.Approval)
toolDef := chatOpts.GetToolDefinition(toolCall.Name)
- result := processToolCallInternal(toolCall, chatOpts, toolDef, sseHandler)
+ result := processToolCallInternal(backend, toolCall, chatOpts, toolDef, sseHandler)
if result.ErrorText != "" {
log.Printf(" error=%s\n", result.ErrorText)
@@ -358,13 +340,13 @@ func processToolCall(toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpt
if toolCall.ToolUseData != nil {
_ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData)
- updateToolUseDataInChat(chatOpts, toolCall.ID, toolCall.ToolUseData)
+ updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolCall.ToolUseData)
}
return result
}
-func processToolCalls(stopReason *uctypes.WaveStopReason, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) {
+func processToolCalls(backend UseChatBackend, stopReason *uctypes.WaveStopReason, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) {
for _, toolCall := range stopReason.ToolCalls {
activeToolMap.Set(toolCall.ID, true)
defer activeToolMap.Delete(toolCall.ID)
@@ -375,7 +357,7 @@ func processToolCalls(stopReason *uctypes.WaveStopReason, chatOpts uctypes.WaveC
if toolCall.ToolUseData != nil {
log.Printf("AI data-tooluse %s\n", toolCall.ID)
_ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData)
- updateToolUseDataInChat(chatOpts, toolCall.ID, toolCall.ToolUseData)
+ updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolCall.ToolUseData)
if toolCall.ToolUseData.Approval == uctypes.ApprovalNeedsApproval && chatOpts.RegisterToolApproval != nil {
chatOpts.RegisterToolApproval(toolCall.ID)
}
@@ -384,30 +366,21 @@ func processToolCalls(stopReason *uctypes.WaveStopReason, chatOpts uctypes.WaveC
var toolResults []uctypes.AIToolResult
for _, toolCall := range stopReason.ToolCalls {
- result := processToolCall(toolCall, chatOpts, sseHandler, metrics)
+ result := processToolCall(backend, toolCall, chatOpts, sseHandler, metrics)
toolResults = append(toolResults, result)
}
- if chatOpts.Config.APIType == APIType_OpenAI {
- toolResultMsgs, err := openai.ConvertToolResultsToOpenAIChatMessage(toolResults)
- if err != nil {
- log.Printf("Failed to convert tool results to OpenAI messages: %v", err)
- } else {
- for _, msg := range toolResultMsgs {
- chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, msg)
- }
- }
+ toolResultMsgs, err := backend.ConvertToolResultsToNativeChatMessage(toolResults)
+ if err != nil {
+ log.Printf("Failed to convert tool results to native chat messages: %v", err)
} else {
- toolResultMsg, err := anthropic.ConvertToolResultsToAnthropicChatMessage(toolResults)
- if err != nil {
- log.Printf("Failed to convert tool results to Anthropic message: %v", err)
- } else {
- chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, toolResultMsg)
+ for _, msg := range toolResultMsgs {
+ chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, msg)
}
}
}
-func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctypes.WaveChatOpts) (*uctypes.AIMetrics, error) {
+func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseChatBackend, chatOpts uctypes.WaveChatOpts) (*uctypes.AIMetrics, error) {
if !activeChats.SetUnless(chatOpts.ChatId, true) {
return nil, fmt.Errorf("chat %s is already running", chatOpts.ChatId)
}
@@ -441,7 +414,7 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp
chatOpts.AppStaticFiles = appStaticFiles
}
}
- stopReason, rtnMessage, err := runAIChatStep(ctx, sseHandler, chatOpts, cont)
+ stopReason, rtnMessage, err := runAIChatStep(ctx, sseHandler, backend, chatOpts, cont)
metrics.RequestCount++
if chatOpts.Config.IsPremiumModel() {
metrics.PremiumReqCount++
@@ -474,30 +447,21 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp
chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, msg)
}
}
+ firstStep = false
if stopReason != nil && stopReason.Kind == uctypes.StopKindPremiumRateLimit && chatOpts.Config.APIType == APIType_OpenAI && chatOpts.Config.Model == uctypes.PremiumOpenAIModel {
log.Printf("Premium rate limit hit with gpt-5.1, switching to gpt-5-mini\n")
cont = &uctypes.WaveContinueResponse{
- MessageID: "",
- Model: uctypes.DefaultOpenAIModel,
- ContinueFromKind: uctypes.StopKindPremiumRateLimit,
- ContinueFromRawReason: stopReason.RawReason,
+ Model: uctypes.DefaultOpenAIModel,
+ ContinueFromKind: uctypes.StopKindPremiumRateLimit,
}
- firstStep = false
continue
}
if stopReason != nil && stopReason.Kind == uctypes.StopKindToolUse {
metrics.ToolUseCount += len(stopReason.ToolCalls)
- processToolCalls(stopReason, chatOpts, sseHandler, metrics)
-
- var messageID string
- if len(rtnMessage) > 0 && rtnMessage[0] != nil {
- messageID = rtnMessage[0].GetMessageId()
- }
+ processToolCalls(backend, stopReason, chatOpts, sseHandler, metrics)
cont = &uctypes.WaveContinueResponse{
- MessageID: messageID,
- Model: chatOpts.Config.Model,
- ContinueFromKind: uctypes.StopKindToolUse,
- ContinueFromRawReason: stopReason.RawReason,
+ Model: chatOpts.Config.Model,
+ ContinueFromKind: uctypes.StopKindToolUse,
}
continue
}
@@ -563,22 +527,14 @@ func ResolveToolCall(toolDef *uctypes.ToolDefinition, toolCall uctypes.WaveToolC
func WaveAIPostMessageWrap(ctx context.Context, sseHandler *sse.SSEHandlerCh, message *uctypes.AIMessage, chatOpts uctypes.WaveChatOpts) error {
startTime := time.Now()
- // Convert AIMessage to Anthropic chat message
- var convertedMessage uctypes.GenAIMessage
- if chatOpts.Config.APIType == APIType_Anthropic {
- var err error
- convertedMessage, err = anthropic.ConvertAIMessageToAnthropicChatMessage(*message)
- if err != nil {
- return fmt.Errorf("message conversion failed: %w", err)
- }
- } else if chatOpts.Config.APIType == APIType_OpenAI {
- var err error
- convertedMessage, err = openai.ConvertAIMessageToOpenAIChatMessage(*message)
- if err != nil {
- return fmt.Errorf("message conversion failed: %w", err)
- }
- } else {
- return fmt.Errorf("unsupported APIType %q", chatOpts.Config.APIType)
+ // Convert AIMessage to native chat message using backend
+ backend, err := GetBackendByAPIType(chatOpts.Config.APIType)
+ if err != nil {
+ return err
+ }
+ convertedMessage, err := backend.ConvertAIMessageToNativeChatMessage(*message)
+ if err != nil {
+ return fmt.Errorf("message conversion failed: %w", err)
}
// Post message to chat store
@@ -586,7 +542,7 @@ func WaveAIPostMessageWrap(ctx context.Context, sseHandler *sse.SSEHandlerCh, me
return fmt.Errorf("failed to store message: %w", err)
}
- metrics, err := RunAIChat(ctx, sseHandler, chatOpts)
+ metrics, err := RunAIChat(ctx, sseHandler, backend, chatOpts)
if metrics != nil {
metrics.RequestDuration = int(time.Since(startTime).Milliseconds())
for _, part := range message.Parts {
@@ -803,15 +759,12 @@ func CreateWriteTextFileDiff(ctx context.Context, chatId string, toolCallId stri
return nil, nil, fmt.Errorf("chat not found: %s", chatId)
}
- if aiChat.APIType == APIType_Anthropic {
- return nil, nil, fmt.Errorf("CreateWriteTextFileDiff is not implemented for Anthropic")
- }
-
- if aiChat.APIType != APIType_OpenAI {
- return nil, nil, fmt.Errorf("unsupported API type: %s", aiChat.APIType)
+ backend, err := GetBackendByAPIType(aiChat.APIType)
+ if err != nil {
+ return nil, nil, err
}
- funcCallInput := openai.GetFunctionCallInputByToolCallId(*aiChat, toolCallId)
+ funcCallInput := backend.GetFunctionCallInputByToolCallId(*aiChat, toolCallId)
if funcCallInput == nil {
return nil, nil, fmt.Errorf("tool call not found: %s", toolCallId)
}
@@ -865,7 +818,6 @@ func CreateWriteTextFileDiff(ctx context.Context, chatId string, toolCallId stri
return originalContent, modifiedContent, nil
}
-
type StaticFileInfo struct {
Name string `json:"name"`
Size int64 `json:"size"`
@@ -879,7 +831,7 @@ func generateBuilderAppData(appId string) (string, string, error) {
if err == nil {
appGoFile = string(fileData.Contents)
}
-
+
staticFilesJSON := ""
allFiles, err := waveappstore.ListAllAppFiles(appId)
if err == nil {
@@ -894,7 +846,7 @@ func generateBuilderAppData(appId string) (string, string, error) {
})
}
}
-
+
if len(staticFiles) > 0 {
staticFilesBytes, marshalErr := json.Marshal(staticFiles)
if marshalErr == nil {
@@ -902,6 +854,6 @@ func generateBuilderAppData(appId string) (string, string, error) {
}
}
}
-
+
return appGoFile, staticFilesJSON, nil
}