Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactoring tests with mock servers (#30) #356

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 5 additions & 11 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"

Expand Down Expand Up @@ -226,18 +225,13 @@ func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) {
}

func TestRequestError(t *testing.T) {
var err error

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
}))
defer ts.Close()
})

config := DefaultConfig("dummy")
config.BaseURL = ts.URL
c := NewClientWithConfig(config)
ctx := context.Background()
_, err = c.ListEngines(ctx)
_, err := client.ListEngines(context.Background())
checks.HasError(t, err, "ListEngines did not fail")

var reqErr *RequestError
Expand Down
162 changes: 162 additions & 0 deletions audio_api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package openai_test

import (
"bytes"
"context"
"errors"
"io"
"mime"
"mime/multipart"
"net/http"
"path/filepath"
"strings"
"testing"

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)

// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
func TestAudio(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)

testcases := []struct {
name string
createFn func(context.Context, AudioRequest) (AudioResponse, error)
}{
{
"transcribe",
client.CreateTranscription,
},
{
"translate",
client.CreateTranslation,
},
}

ctx := context.Background()

dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path)

req := AudioRequest{
FilePath: path,
Model: "whisper-3",
}
_, err := tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})

t.Run(tc.name+" (with reader)", func(t *testing.T) {
req := AudioRequest{
FilePath: "fake.webm",
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
Model: "whisper-3",
}
_, err := tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})
}
}

func TestAudioWithOptionalArgs(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)

testcases := []struct {
name string
createFn func(context.Context, AudioRequest) (AudioResponse, error)
}{
{
"transcribe",
client.CreateTranscription,
},
{
"translate",
client.CreateTranslation,
},
}

ctx := context.Background()

dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path)

req := AudioRequest{
FilePath: path,
Model: "whisper-3",
Prompt: "用简体中文",
Temperature: 0.5,
Language: "zh",
Format: AudioResponseFormatSRT,
}
_, err := tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})
}
}

// handleAudioEndpoint Handles the completion endpoint by the test server.
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
var err error

// audio endpoints only accept POST requests
if r.Method != "POST" {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}

mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
http.Error(w, "failed to parse media type", http.StatusBadRequest)
return
}

if !strings.HasPrefix(mediaType, "multipart") {
http.Error(w, "request is not multipart", http.StatusBadRequest)
}

boundary, ok := params["boundary"]
if !ok {
http.Error(w, "no boundary in params", http.StatusBadRequest)
return
}

fileData := &bytes.Buffer{}
mr := multipart.NewReader(r.Body, boundary)
part, err := mr.NextPart()
if err != nil && errors.Is(err, io.EOF) {
http.Error(w, "error accessing file", http.StatusBadRequest)
return
}
if _, err = io.Copy(fileData, part); err != nil {
http.Error(w, "failed to copy file", http.StatusInternalServerError)
return
}

if len(fileData.Bytes()) == 0 {
w.WriteHeader(http.StatusInternalServerError)
http.Error(w, "received empty file data", http.StatusBadRequest)
return
}

if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil {
http.Error(w, "failed to write body", http.StatusInternalServerError)
return
}
}
166 changes: 0 additions & 166 deletions audio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,182 +2,16 @@ package openai //nolint:testpackage // testing private field

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"
"testing"

"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)

// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
func TestAudio(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)

testcases := []struct {
name string
createFn func(context.Context, AudioRequest) (AudioResponse, error)
}{
{
"transcribe",
client.CreateTranscription,
},
{
"translate",
client.CreateTranslation,
},
}

ctx := context.Background()

dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path)

req := AudioRequest{
FilePath: path,
Model: "whisper-3",
}
_, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})

t.Run(tc.name+" (with reader)", func(t *testing.T) {
req := AudioRequest{
FilePath: "fake.webm",
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
Model: "whisper-3",
}
_, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})
}
}

func TestAudioWithOptionalArgs(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)

testcases := []struct {
name string
createFn func(context.Context, AudioRequest) (AudioResponse, error)
}{
{
"transcribe",
client.CreateTranscription,
},
{
"translate",
client.CreateTranslation,
},
}

ctx := context.Background()

dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path)

req := AudioRequest{
FilePath: path,
Model: "whisper-3",
Prompt: "用简体中文",
Temperature: 0.5,
Language: "zh",
Format: AudioResponseFormatSRT,
}
_, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})
}
}

// handleAudioEndpoint Handles the completion endpoint by the test server.
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
var err error

// audio endpoints only accept POST requests
if r.Method != "POST" {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}

mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
http.Error(w, "failed to parse media type", http.StatusBadRequest)
return
}

if !strings.HasPrefix(mediaType, "multipart") {
http.Error(w, "request is not multipart", http.StatusBadRequest)
}

boundary, ok := params["boundary"]
if !ok {
http.Error(w, "no boundary in params", http.StatusBadRequest)
return
}

fileData := &bytes.Buffer{}
mr := multipart.NewReader(r.Body, boundary)
part, err := mr.NextPart()
if err != nil && errors.Is(err, io.EOF) {
http.Error(w, "error accessing file", http.StatusBadRequest)
return
}
if _, err = io.Copy(fileData, part); err != nil {
http.Error(w, "failed to copy file", http.StatusInternalServerError)
return
}

if len(fileData.Bytes()) == 0 {
w.WriteHeader(http.StatusInternalServerError)
http.Error(w, "received empty file data", http.StatusBadRequest)
return
}

if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil {
http.Error(w, "failed to write body", http.StatusInternalServerError)
return
}
}

func TestAudioWithFailingFormBuilder(t *testing.T) {
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
Expand Down