diff --git a/run.go b/run.go index 5d852e1b..0463615b 100644 --- a/run.go +++ b/run.go @@ -495,12 +495,12 @@ func (c *Client) CreateThreadAndStream( req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") - resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + resp, err := c.config.HTTPClient.Do(req) if err != nil { return } - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { resp.Body.Close() return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) } diff --git a/sse.go b/sse.go index 076650b8..fe5a5c5f 100644 --- a/sse.go +++ b/sse.go @@ -7,7 +7,7 @@ import ( "strings" ) -// NewEOLSplitterFunc returns a bufio.SplitFunc tied to a new EOLSplitter instance +// NewEOLSplitterFunc returns a bufio.SplitFunc tied to a new EOLSplitter instance. func NewEOLSplitterFunc() bufio.SplitFunc { splitter := NewEOLSplitter() return splitter.Split @@ -23,6 +23,8 @@ func NewEOLSplitter() *EOLSplitter { return &EOLSplitter{prevCR: false} } +const crlfLen = 2 + // Split function to handle CR LF, CR, and LF as end-of-line. func (s *EOLSplitter) Split(data []byte, atEOF bool) (advance int, token []byte, err error) { // Check if the previous data ended with a CR @@ -38,7 +40,7 @@ func (s *EOLSplitter) Split(data []byte, atEOF bool) (advance int, token []byte, if data[i] == '\r' { if i+1 < len(data) && data[i+1] == '\n' { // Found CR LF - return i + 2, data[:i], nil + return i + crlfLen, data[:i], nil } // Found CR if !atEOF && i == len(data)-1 { @@ -119,29 +121,27 @@ func (s *SSEScanner) Next() bool { } seenNonEmptyLine = true - - if strings.HasPrefix(line, "id: ") { + switch { + case strings.HasPrefix(line, "id: "): event.ID = strings.TrimPrefix(line, "id: ") - } else if strings.HasPrefix(line, "data: ") { + case strings.HasPrefix(line, "data: "): dataLines = append(dataLines, strings.TrimPrefix(line, "data: ")) - } else if strings.HasPrefix(line, "event: ") { + case strings.HasPrefix(line, "event: "): event.Event = strings.TrimPrefix(line, "event: ") - } else if strings.HasPrefix(line, "retry: ") { + case strings.HasPrefix(line, "retry: "): retry, err := strconv.Atoi(strings.TrimPrefix(line, "retry: ")) if err == nil { event.Retry = retry } - // ignore invalid retry values - } else if strings.HasPrefix(line, ":") { + case strings.HasPrefix(line, ":"): if s.readComment { event.Comment = strings.TrimPrefix(line, ":") } - // ignore comment line + default: + // ignore unknown lines } - - // ignore unknown lines } s.err = s.scanner.Err() diff --git a/sse_test.go b/sse_test.go index 71fcb3ce..73c458d4 100644 --- a/sse_test.go +++ b/sse_test.go @@ -1,4 +1,4 @@ -package openai +package openai_test import ( "bufio" @@ -6,6 +6,8 @@ import ( "reflect" "strings" "testing" + + "github.com/sashabaranov/go-openai" ) // ChunksReader simulates a reader that splits the input across multiple reads. @@ -55,7 +57,7 @@ func TestEolSplitter(t *testing.T) { t.Run(test.name, func(t *testing.T) { reader := strings.NewReader(test.input) scanner := bufio.NewScanner(reader) - scanner.Split(NewEOLSplitterFunc()) + scanner.Split(openai.NewEOLSplitterFunc()) var lines []string for scanner.Scan() { @@ -97,7 +99,7 @@ func TestEolSplitterBoundaryCondition(t *testing.T) { // Custom reader to simulate the boundary condition reader := NewChunksReader(c.input) scanner := bufio.NewScanner(reader) - scanner.Split(NewEOLSplitterFunc()) + scanner.Split(openai.NewEOLSplitterFunc()) var lines []string for scanner.Scan() { @@ -124,11 +126,11 @@ func TestEolSplitterBoundaryCondition(t *testing.T) { func TestSSEScanner(t *testing.T) { tests := []struct { raw string - want []ServerSentEvent + want []openai.ServerSentEvent }{ { raw: `data: hello world`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { Data: "hello world", }, @@ -137,7 +139,7 @@ func TestSSEScanner(t *testing.T) { { raw: `event: hello data: hello world`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { Event: "hello", Data: "hello world", @@ -150,7 +152,7 @@ data: { data: "msg": "hello world", data: "id": 12345 data: }`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { Event: "hello-json", Data: "{\n\"msg\": \"hello world\",\n\"id\": 12345\n}", @@ -161,7 +163,7 @@ data: }`, raw: `data: hello world data: hello again`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { Data: "hello world", }, @@ -173,7 +175,7 @@ data: hello again`, { raw: `retry: 10000 data: hello world`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { Retry: 10000, Data: "hello world", @@ -184,7 +186,7 @@ data: hello again`, raw: `retry: 10000 retry: 20000`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { Retry: 10000, }, @@ -200,7 +202,7 @@ id: message-id retry: 20000 event: hello-event data: hello`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { ID: "message-id", Retry: 20000, @@ -222,7 +224,7 @@ id: message 2 retry: 20000 event: hello-event 2 `, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { ID: "message 1", Retry: 10000, @@ -254,10 +256,10 @@ event: hello-event 2 } } -func runSSEScanTest(t *testing.T, raw string, want []ServerSentEvent) { - sseScanner := NewSSEScanner(strings.NewReader(raw), false) +func runSSEScanTest(t *testing.T, raw string, want []openai.ServerSentEvent) { + sseScanner := openai.NewSSEScanner(strings.NewReader(raw), false) - var got []ServerSentEvent + var got []openai.ServerSentEvent for sseScanner.Next() { got = append(got, sseScanner.Scan()) } diff --git a/stream_v2.go b/stream_v2.go index e15f9498..8ff7362e 100644 --- a/stream_v2.go +++ b/stream_v2.go @@ -13,7 +13,6 @@ type StreamRawEvent struct { type StreamDone struct { } -// Define StreamThreadMessageDelta type StreamThreadMessageDelta struct { ID string `json:"id"` Object string `json:"object"` @@ -75,7 +74,7 @@ type StreamerV2 struct { buffer []byte } -// Close closes the underlying io.ReadCloser +// Close closes the underlying io.ReadCloser. func (s *StreamerV2) Close() error { return s.r.Close() } @@ -106,30 +105,30 @@ func (s *StreamerV2) Next() bool { return true } -// Read implements io.Reader of the text deltas of thread.message.delta events -func (r *StreamerV2) Read(p []byte) (int, error) { +// Read implements io.Reader of the text deltas of thread.message.delta events. +func (s *StreamerV2) Read(p []byte) (int, error) { // If we have data in the buffer, copy it to p first. - if len(r.buffer) > 0 { - n := copy(p, r.buffer) - r.buffer = r.buffer[n:] + if len(s.buffer) > 0 { + n := copy(p, s.buffer) + s.buffer = s.buffer[n:] return n, nil } - for r.Next() { + for s.Next() { // Read only text deltas - text, ok := r.MessageDeltaText() + text, ok := s.MessageDeltaText() if !ok { continue } - r.buffer = []byte(text) - n := copy(p, r.buffer) - r.buffer = r.buffer[n:] + s.buffer = []byte(text) + n := copy(p, s.buffer) + s.buffer = s.buffer[n:] return n, nil } // Check for streamer error - if err := r.Err(); err != nil { + if err := s.Err(); err != nil { return 0, err } @@ -145,7 +144,7 @@ func (s *StreamerV2) Text() (string, bool) { return s.MessageDeltaText() } -// MessageDeltaText returns text delta if the current event is a "thread.message.delta" +// MessageDeltaText returns text delta if the current event is a "thread.message.delta". func (s *StreamerV2) MessageDeltaText() (string, bool) { event, ok := s.next.(StreamThreadMessageDelta) if !ok { @@ -157,7 +156,7 @@ func (s *StreamerV2) MessageDeltaText() (string, bool) { if content.Text != nil { // Can we return the first text we find? Does OpenAI stream ever // return multiple text contents in a delta? - text = text + content.Text.Value + text += content.Text.Value } } diff --git a/stream_v2_test.go b/stream_v2_test.go index a92f793b..3ca92f3a 100644 --- a/stream_v2_test.go +++ b/stream_v2_test.go @@ -1,4 +1,5 @@ -package openai +//nolint:lll +package openai_test import ( "encoding/json" @@ -6,6 +7,8 @@ import ( "reflect" "strings" "testing" + + "github.com/sashabaranov/go-openai" ) func TestNewStreamTextReader(t *testing.T) { @@ -19,7 +22,7 @@ data: {"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delt event: done data: [DONE] ` - reader := NewStreamerV2(strings.NewReader(raw)) + reader := openai.NewStreamerV2(strings.NewReader(raw)) expected := "helloworld" buffer := make([]byte, len(expected)) @@ -65,7 +68,7 @@ event: done data: [DONE] ` - scanner := NewStreamerV2(strings.NewReader(raw)) + scanner := openai.NewStreamerV2(strings.NewReader(raw)) var events []any for scanner.Next() { @@ -74,26 +77,26 @@ data: [DONE] } expectedValues := []any{ - StreamRawEvent{ + openai.StreamRawEvent{ Type: "thread.created", Data: json.RawMessage(`{"id":"thread_vMWb8sJ14upXpPO2VbRpGTYD","object":"thread","created_at":1715864046,"metadata":{},"tool_resources":{"code_interpreter":{"file_ids":[]}}}`), }, - StreamThreadMessageDelta{ + openai.StreamThreadMessageDelta{ ID: "msg_KFiZxHhXYQo6cGFnGjRDHSee", Object: "thread.message.delta", - Delta: Delta{ - Content: []DeltaContent{ + Delta: openai.Delta{ + Content: []openai.DeltaContent{ { Index: 0, Type: "text", - Text: &DeltaText{ + Text: &openai.DeltaText{ Value: "hello", }, }, }, }, }, - StreamDone{}, + openai.StreamDone{}, } if len(events) != len(expectedValues) { @@ -119,25 +122,25 @@ func TestStreamThreadMessageDeltaJSON(t *testing.T) { name: "DeltaContent with Text", jsonData: `{"index":0,"type":"text","text":{"value":"hello"}}`, expectType: "text", - expectValue: &DeltaText{Value: "hello"}, + expectValue: &openai.DeltaText{Value: "hello"}, }, { name: "DeltaContent with ImageFile", jsonData: `{"index":1,"type":"image_file","image_file":{"file_id":"file123","detail":"An image"}}`, expectType: "image_file", - expectValue: &DeltaImageFile{FileID: "file123", Detail: "An image"}, + expectValue: &openai.DeltaImageFile{FileID: "file123", Detail: "An image"}, }, { name: "DeltaContent with ImageURL", jsonData: `{"index":2,"type":"image_url","image_url":{"url":"https://example.com/image.jpg","detail":"low"}}`, expectType: "image_url", - expectValue: &DeltaImageURL{URL: "https://example.com/image.jpg", Detail: "low"}, + expectValue: &openai.DeltaImageURL{URL: "https://example.com/image.jpg", Detail: "low"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var content DeltaContent + var content openai.DeltaContent err := json.Unmarshal([]byte(tt.jsonData), &content) if err != nil { t.Fatalf("Error unmarshalling JSON: %v", err)