Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
hayeah committed May 19, 2024
1 parent 67fe23e commit f8d19ae
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 57 deletions.
4 changes: 2 additions & 2 deletions run.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,12 +495,12 @@ func (c *Client) CreateThreadAndStream(
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")

resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
resp, err := c.config.HTTPClient.Do(req)
if err != nil {
return
}

if resp.StatusCode != 200 {
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
Expand Down
24 changes: 12 additions & 12 deletions sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"strings"
)

// NewEOLSplitterFunc returns a bufio.SplitFunc tied to a new EOLSplitter instance
// NewEOLSplitterFunc returns a bufio.SplitFunc tied to a new EOLSplitter instance.
func NewEOLSplitterFunc() bufio.SplitFunc {
splitter := NewEOLSplitter()
return splitter.Split
Expand All @@ -23,6 +23,8 @@ func NewEOLSplitter() *EOLSplitter {
return &EOLSplitter{prevCR: false}
}

const crlfLen = 2

// Split function to handle CR LF, CR, and LF as end-of-line.
func (s *EOLSplitter) Split(data []byte, atEOF bool) (advance int, token []byte, err error) {
// Check if the previous data ended with a CR
Expand All @@ -38,7 +40,7 @@ func (s *EOLSplitter) Split(data []byte, atEOF bool) (advance int, token []byte,
if data[i] == '\r' {
if i+1 < len(data) && data[i+1] == '\n' {
// Found CR LF
return i + 2, data[:i], nil
return i + crlfLen, data[:i], nil
}
// Found CR
if !atEOF && i == len(data)-1 {
Expand Down Expand Up @@ -119,29 +121,27 @@ func (s *SSEScanner) Next() bool {
}

seenNonEmptyLine = true

if strings.HasPrefix(line, "id: ") {
switch {
case strings.HasPrefix(line, "id: "):
event.ID = strings.TrimPrefix(line, "id: ")
} else if strings.HasPrefix(line, "data: ") {
case strings.HasPrefix(line, "data: "):
dataLines = append(dataLines, strings.TrimPrefix(line, "data: "))
} else if strings.HasPrefix(line, "event: ") {
case strings.HasPrefix(line, "event: "):
event.Event = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "retry: ") {
case strings.HasPrefix(line, "retry: "):
retry, err := strconv.Atoi(strings.TrimPrefix(line, "retry: "))
if err == nil {
event.Retry = retry
}

// ignore invalid retry values
} else if strings.HasPrefix(line, ":") {
case strings.HasPrefix(line, ":"):
if s.readComment {
event.Comment = strings.TrimPrefix(line, ":")
}

// ignore comment line
default:
// ignore unknown lines
}

// ignore unknown lines
}

s.err = s.scanner.Err()
Expand Down
32 changes: 17 additions & 15 deletions sse_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package openai
package openai_test

import (
"bufio"
"io"
"reflect"
"strings"
"testing"

"github.com/sashabaranov/go-openai"
)

// ChunksReader simulates a reader that splits the input across multiple reads.
Expand Down Expand Up @@ -55,7 +57,7 @@ func TestEolSplitter(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
reader := strings.NewReader(test.input)
scanner := bufio.NewScanner(reader)
scanner.Split(NewEOLSplitterFunc())
scanner.Split(openai.NewEOLSplitterFunc())

var lines []string
for scanner.Scan() {
Expand Down Expand Up @@ -97,7 +99,7 @@ func TestEolSplitterBoundaryCondition(t *testing.T) {
// Custom reader to simulate the boundary condition
reader := NewChunksReader(c.input)
scanner := bufio.NewScanner(reader)
scanner.Split(NewEOLSplitterFunc())
scanner.Split(openai.NewEOLSplitterFunc())

var lines []string
for scanner.Scan() {
Expand All @@ -124,11 +126,11 @@ func TestEolSplitterBoundaryCondition(t *testing.T) {
func TestSSEScanner(t *testing.T) {
tests := []struct {
raw string
want []ServerSentEvent
want []openai.ServerSentEvent
}{
{
raw: `data: hello world`,
want: []ServerSentEvent{
want: []openai.ServerSentEvent{
{
Data: "hello world",
},
Expand All @@ -137,7 +139,7 @@ func TestSSEScanner(t *testing.T) {
{
raw: `event: hello
data: hello world`,
want: []ServerSentEvent{
want: []openai.ServerSentEvent{
{
Event: "hello",
Data: "hello world",
Expand All @@ -150,7 +152,7 @@ data: {
data: "msg": "hello world",
data: "id": 12345
data: }`,
want: []ServerSentEvent{
want: []openai.ServerSentEvent{
{
Event: "hello-json",
Data: "{\n\"msg\": \"hello world\",\n\"id\": 12345\n}",
Expand All @@ -161,7 +163,7 @@ data: }`,
raw: `data: hello world
data: hello again`,
want: []ServerSentEvent{
want: []openai.ServerSentEvent{
{
Data: "hello world",
},
Expand All @@ -173,7 +175,7 @@ data: hello again`,
{
raw: `retry: 10000
data: hello world`,
want: []ServerSentEvent{
want: []openai.ServerSentEvent{
{
Retry: 10000,
Data: "hello world",
Expand All @@ -184,7 +186,7 @@ data: hello again`,
raw: `retry: 10000
retry: 20000`,
want: []ServerSentEvent{
want: []openai.ServerSentEvent{
{
Retry: 10000,
},
Expand All @@ -200,7 +202,7 @@ id: message-id
retry: 20000
event: hello-event
data: hello`,
want: []ServerSentEvent{
want: []openai.ServerSentEvent{
{
ID: "message-id",
Retry: 20000,
Expand All @@ -222,7 +224,7 @@ id: message 2
retry: 20000
event: hello-event 2
`,
want: []ServerSentEvent{
want: []openai.ServerSentEvent{
{
ID: "message 1",
Retry: 10000,
Expand Down Expand Up @@ -254,10 +256,10 @@ event: hello-event 2
}
}

func runSSEScanTest(t *testing.T, raw string, want []ServerSentEvent) {
sseScanner := NewSSEScanner(strings.NewReader(raw), false)
func runSSEScanTest(t *testing.T, raw string, want []openai.ServerSentEvent) {
sseScanner := openai.NewSSEScanner(strings.NewReader(raw), false)

var got []ServerSentEvent
var got []openai.ServerSentEvent
for sseScanner.Next() {
got = append(got, sseScanner.Scan())
}
Expand Down
29 changes: 14 additions & 15 deletions stream_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ type StreamRawEvent struct {
type StreamDone struct {
}

// Define StreamThreadMessageDelta
type StreamThreadMessageDelta struct {
ID string `json:"id"`
Object string `json:"object"`
Expand Down Expand Up @@ -75,7 +74,7 @@ type StreamerV2 struct {
buffer []byte
}

// Close closes the underlying io.ReadCloser
// Close closes the underlying io.ReadCloser.
func (s *StreamerV2) Close() error {
return s.r.Close()
}
Expand Down Expand Up @@ -106,30 +105,30 @@ func (s *StreamerV2) Next() bool {
return true
}

// Read implements io.Reader of the text deltas of thread.message.delta events
func (r *StreamerV2) Read(p []byte) (int, error) {
// Read implements io.Reader of the text deltas of thread.message.delta events.
func (s *StreamerV2) Read(p []byte) (int, error) {
// If we have data in the buffer, copy it to p first.
if len(r.buffer) > 0 {
n := copy(p, r.buffer)
r.buffer = r.buffer[n:]
if len(s.buffer) > 0 {
n := copy(p, s.buffer)
s.buffer = s.buffer[n:]
return n, nil
}

for r.Next() {
for s.Next() {
// Read only text deltas
text, ok := r.MessageDeltaText()
text, ok := s.MessageDeltaText()
if !ok {
continue
}

r.buffer = []byte(text)
n := copy(p, r.buffer)
r.buffer = r.buffer[n:]
s.buffer = []byte(text)
n := copy(p, s.buffer)
s.buffer = s.buffer[n:]
return n, nil
}

// Check for streamer error
if err := r.Err(); err != nil {
if err := s.Err(); err != nil {
return 0, err
}

Expand All @@ -145,7 +144,7 @@ func (s *StreamerV2) Text() (string, bool) {
return s.MessageDeltaText()
}

// MessageDeltaText returns text delta if the current event is a "thread.message.delta"
// MessageDeltaText returns text delta if the current event is a "thread.message.delta".
func (s *StreamerV2) MessageDeltaText() (string, bool) {
event, ok := s.next.(StreamThreadMessageDelta)
if !ok {
Expand All @@ -157,7 +156,7 @@ func (s *StreamerV2) MessageDeltaText() (string, bool) {
if content.Text != nil {
// Can we return the first text we find? Does OpenAI stream ever
// return multiple text contents in a delta?
text = text + content.Text.Value
text += content.Text.Value
}
}

Expand Down
29 changes: 16 additions & 13 deletions stream_v2_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package openai
//nolint:lll
package openai_test

import (
"encoding/json"
"io"
"reflect"
"strings"
"testing"

"github.com/sashabaranov/go-openai"
)

func TestNewStreamTextReader(t *testing.T) {
Expand All @@ -19,7 +22,7 @@ data: {"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delt
event: done
data: [DONE]
`
reader := NewStreamerV2(strings.NewReader(raw))
reader := openai.NewStreamerV2(strings.NewReader(raw))

expected := "helloworld"
buffer := make([]byte, len(expected))
Expand Down Expand Up @@ -65,7 +68,7 @@ event: done
data: [DONE]
`

scanner := NewStreamerV2(strings.NewReader(raw))
scanner := openai.NewStreamerV2(strings.NewReader(raw))
var events []any

for scanner.Next() {
Expand All @@ -74,26 +77,26 @@ data: [DONE]
}

expectedValues := []any{
StreamRawEvent{
openai.StreamRawEvent{
Type: "thread.created",
Data: json.RawMessage(`{"id":"thread_vMWb8sJ14upXpPO2VbRpGTYD","object":"thread","created_at":1715864046,"metadata":{},"tool_resources":{"code_interpreter":{"file_ids":[]}}}`),
},
StreamThreadMessageDelta{
openai.StreamThreadMessageDelta{
ID: "msg_KFiZxHhXYQo6cGFnGjRDHSee",
Object: "thread.message.delta",
Delta: Delta{
Content: []DeltaContent{
Delta: openai.Delta{
Content: []openai.DeltaContent{
{
Index: 0,
Type: "text",
Text: &DeltaText{
Text: &openai.DeltaText{
Value: "hello",
},
},
},
},
},
StreamDone{},
openai.StreamDone{},
}

if len(events) != len(expectedValues) {
Expand All @@ -119,25 +122,25 @@ func TestStreamThreadMessageDeltaJSON(t *testing.T) {
name: "DeltaContent with Text",
jsonData: `{"index":0,"type":"text","text":{"value":"hello"}}`,
expectType: "text",
expectValue: &DeltaText{Value: "hello"},
expectValue: &openai.DeltaText{Value: "hello"},
},
{
name: "DeltaContent with ImageFile",
jsonData: `{"index":1,"type":"image_file","image_file":{"file_id":"file123","detail":"An image"}}`,
expectType: "image_file",
expectValue: &DeltaImageFile{FileID: "file123", Detail: "An image"},
expectValue: &openai.DeltaImageFile{FileID: "file123", Detail: "An image"},
},
{
name: "DeltaContent with ImageURL",
jsonData: `{"index":2,"type":"image_url","image_url":{"url":"https://example.com/image.jpg","detail":"low"}}`,
expectType: "image_url",
expectValue: &DeltaImageURL{URL: "https://example.com/image.jpg", Detail: "low"},
expectValue: &openai.DeltaImageURL{URL: "https://example.com/image.jpg", Detail: "low"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var content DeltaContent
var content openai.DeltaContent
err := json.Unmarshal([]byte(tt.jsonData), &content)
if err != nil {
t.Fatalf("Error unmarshalling JSON: %v", err)
Expand Down

0 comments on commit f8d19ae

Please sign in to comment.