Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat Add headers to openai responses #506

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@ type AudioResponse struct {
Transient bool `json:"transient"`
} `json:"segments"`
Text string `json:"text"`

httpHeader
}

type audioTextResponse struct {
Text string `json:"text"`

httpHeader
}

func (r *audioTextResponse) ToAudioResponse() AudioResponse {
return AudioResponse{
Text: r.Text,
httpHeader: r.httpHeader,
}
}

// CreateTranscription — API call to create a transcription. Returns transcribed text.
Expand Down Expand Up @@ -104,7 +119,9 @@ func (c *Client) callAudioAPI(
if request.HasJSONResponse() {
err = c.sendRequest(req, &response)
} else {
err = c.sendRequest(req, &response.Text)
var textResponse audioTextResponse
err = c.sendRequest(req, &textResponse)
response = textResponse.ToAudioResponse()
}
if err != nil {
return AudioResponse{}, err
Expand Down
2 changes: 2 additions & 0 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ type ChatCompletionResponse struct {
Model string `json:"model"`
Choices []ChatCompletionChoice `json:"choices"`
Usage Usage `json:"usage"`

httpHeader
}

// CreateChatCompletion — API call to Create a completion for the chat message.
Expand Down
30 changes: 30 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ import (
"github.com/sashabaranov/go-openai/jsonschema"
)

const (
xCustomHeader = "X-CUSTOM-HEADER"
xCustomHeaderValue = "test"
)

func TestChatCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
Expand Down Expand Up @@ -68,6 +73,30 @@ func TestChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateChatCompletion error")
}

// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletionsWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")

a := resp.Header().Get(xCustomHeader)
_ = a
if resp.Header().Get(xCustomHeader) != xCustomHeaderValue {
t.Errorf("expected header %s to be %s", xCustomHeader, xCustomHeaderValue)
}
}

// TestChatCompletionsFunctions tests including a function call.
func TestChatCompletionsFunctions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
Expand Down Expand Up @@ -281,6 +310,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
w.Header().Set(xCustomHeader, xCustomHeaderValue)
fmt.Fprintln(w, string(resBytes))
}

Expand Down
20 changes: 19 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@ type Client struct {
createFormBuilder func(io.Writer) utils.FormBuilder
}

type Response interface {
SetHeader(http.Header)
}

type httpHeader http.Header

func (h *httpHeader) SetHeader(header http.Header) {
*h = httpHeader(header)
}

func (h httpHeader) Header() http.Header {
return http.Header(h)
}

// NewClient creates new OpenAI API client.
func NewClient(authToken string) *Client {
config := DefaultConfig(authToken)
Expand Down Expand Up @@ -82,7 +96,7 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ...
return req, nil
}

func (c *Client) sendRequest(req *http.Request, v any) error {
func (c *Client) sendRequest(req *http.Request, v Response) error {
req.Header.Set("Accept", "application/json; charset=utf-8")

// Check whether Content-Type is already set, Upload Files API requires
Expand All @@ -103,6 +117,10 @@ func (c *Client) sendRequest(req *http.Request, v any) error {
return c.handleErrorResp(res)
}

if v != nil {
v.SetHeader(res.Header)
}

return decodeResponse(res.Body, v)
}

Expand Down
2 changes: 2 additions & 0 deletions completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ type CompletionResponse struct {
Model string `json:"model"`
Choices []CompletionChoice `json:"choices"`
Usage Usage `json:"usage"`

httpHeader
}

// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well
Expand Down
2 changes: 2 additions & 0 deletions edits.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type EditsResponse struct {
Created int64 `json:"created"`
Usage Usage `json:"usage"`
Choices []EditsChoice `json:"choices"`

httpHeader
}

// Edits Perform an API call to the Edits endpoint.
Expand Down
4 changes: 4 additions & 0 deletions embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ type EmbeddingResponse struct {
Data []Embedding `json:"data"`
Model EmbeddingModel `json:"model"`
Usage Usage `json:"usage"`

httpHeader
}

type base64String string
Expand Down Expand Up @@ -182,6 +184,8 @@ type EmbeddingResponseBase64 struct {
Data []Base64Embedding `json:"data"`
Model EmbeddingModel `json:"model"`
Usage Usage `json:"usage"`

httpHeader
}

// ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse.
Expand Down
4 changes: 4 additions & 0 deletions engines.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ type Engine struct {
Object string `json:"object"`
Owner string `json:"owner"`
Ready bool `json:"ready"`

httpHeader
}

// EnginesList is a list of engines.
type EnginesList struct {
Engines []Engine `json:"data"`

httpHeader
}

// ListEngines Lists the currently available engines, and provides basic
Expand Down
4 changes: 4 additions & 0 deletions files.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@ type File struct {
Status string `json:"status"`
Purpose string `json:"purpose"`
StatusDetails string `json:"status_details"`

httpHeader
}

// FilesList is a list of files that belong to the user or organization.
type FilesList struct {
Files []File `json:"data"`

httpHeader
}

// CreateFile uploads a jsonl file to GPT3
Expand Down
8 changes: 8 additions & 0 deletions fine_tunes.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type FineTune struct {
ValidationFiles []File `json:"validation_files"`
TrainingFiles []File `json:"training_files"`
UpdatedAt int64 `json:"updated_at"`

httpHeader
}

// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API.
Expand Down Expand Up @@ -69,6 +71,8 @@ type FineTuneHyperParams struct {
type FineTuneList struct {
Object string `json:"object"`
Data []FineTune `json:"data"`

httpHeader
}

// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API.
Expand All @@ -77,6 +81,8 @@ type FineTuneList struct {
type FineTuneEventList struct {
Object string `json:"object"`
Data []FineTuneEvent `json:"data"`

httpHeader
}

// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API.
Expand All @@ -86,6 +92,8 @@ type FineTuneDeleteResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Deleted bool `json:"deleted"`

httpHeader
}

// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API.
Expand Down
4 changes: 4 additions & 0 deletions fine_tuning_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type FineTuningJob struct {
ValidationFile string `json:"validation_file,omitempty"`
ResultFiles []string `json:"result_files"`
TrainedTokens int `json:"trained_tokens"`

httpHeader
}

type Hyperparameters struct {
Expand All @@ -39,6 +41,8 @@ type FineTuningJobEventList struct {
Object string `json:"object"`
Data []FineTuneEvent `json:"data"`
HasMore bool `json:"has_more"`

httpHeader
}

type FineTuningJobEvent struct {
Expand Down
2 changes: 2 additions & 0 deletions image.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ type ImageRequest struct {
type ImageResponse struct {
Created int64 `json:"created,omitempty"`
Data []ImageResponseDataInner `json:"data,omitempty"`

httpHeader
}

// ImageResponseDataInner represents a response data structure for image API.
Expand Down
6 changes: 6 additions & 0 deletions models.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ type Model struct {
Permission []Permission `json:"permission"`
Root string `json:"root"`
Parent string `json:"parent"`

httpHeader
}

// Permission struct represents an OpenAPI permission.
Expand All @@ -38,11 +40,15 @@ type FineTuneModelDeleteResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Deleted bool `json:"deleted"`

httpHeader
}

// ModelsList is a list of models, including those that belong to the user or organization.
type ModelsList struct {
Models []Model `json:"data"`

httpHeader
}

// ListModels Lists the currently available models,
Expand Down
2 changes: 2 additions & 0 deletions moderation.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ type ModerationResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Results []Result `json:"results"`

httpHeader
}

// Moderations — perform a moderation api call over a string.
Expand Down
Loading