diff --git a/client.go b/client.go index 5779a8e1c..ae82aefd6 100644 --- a/client.go +++ b/client.go @@ -21,12 +21,19 @@ type Client struct { } // NewClient creates new OpenAI API client. -func NewClient(authToken string) *Client { +func NewClient(authToken string, options ...ConfigOption) *Client { config := DefaultConfig(authToken) + + for _, opt := range options { + opt(&config) + } + return NewClientWithConfig(config) } // NewClientWithConfig creates new OpenAI API client for specified config. +// +// Deprecated: Please use NewClient with options. func NewClientWithConfig(config ClientConfig) *Client { return &Client{ config: config, @@ -39,11 +46,9 @@ func NewClientWithConfig(config ClientConfig) *Client { // NewOrgClient creates new OpenAI API client for specified Organization ID. // -// Deprecated: Please use NewClientWithConfig. +// Deprecated: Please use NewClient with options. func NewOrgClient(authToken, org string) *Client { - config := DefaultConfig(authToken) - config.OrgID = org - return NewClientWithConfig(config) + return NewClient(authToken, WithOrgID(org)) } type requestOptions struct { diff --git a/client_test.go b/client_test.go index 29d84edfa..09287be39 100644 --- a/client_test.go +++ b/client_test.go @@ -37,6 +37,41 @@ func TestClient(t *testing.T) { } } +func TestClientWithOptions(t *testing.T) { + const mockToken = "mock token" + const baseURL = "https://example.com" + const orgID = "myorg" + const apiType = APITypeAzure + const apiVersion = "2023-03-01" + const emptyMsgLimit = uint(10) + cli := NewClient(mockToken, + WithBaseURL(baseURL), + WithOrgID(orgID), + WithAPIType(apiType), + WithAPIVersion(apiVersion), + WithEmptyMessagesLimit(emptyMsgLimit), + ) + + if cli.config.authToken != mockToken { + t.Errorf("Client does not contain proper token") + } + if cli.config.BaseURL != baseURL { + t.Errorf("Client does not contain proper baseURL") + } + if cli.config.OrgID != orgID { + t.Errorf("Client does not contain proper orgID") + } + if cli.config.APIType != apiType { + t.Errorf("Client does not contain proper apiType") + } + if cli.config.APIVersion != apiVersion { + t.Errorf("Client does not contain proper apiVersion") + } + if cli.config.EmptyMessagesLimit != emptyMsgLimit { + t.Errorf("Client does not contain proper emptyMessagesLimit") + } +} + func TestDecodeResponse(t *testing.T) { stringInput := "" diff --git a/options.go b/options.go new file mode 100644 index 000000000..25046446e --- /dev/null +++ b/options.go @@ -0,0 +1,34 @@ +package openai + +type ConfigOption func(config *ClientConfig) + +// WithBaseURL configures base url which should start with "https", e.g. https://exmample.com +func WithBaseURL(baseURL string) ConfigOption { + return func(config *ClientConfig) { + config.BaseURL = baseURL + } +} + +func WithAPIType(apiType APIType) ConfigOption { + return func(config *ClientConfig) { + config.APIType = apiType + } +} + +func WithAPIVersion(apiVersion string) ConfigOption { + return func(config *ClientConfig) { + config.APIVersion = apiVersion + } +} + +func WithOrgID(orgID string) ConfigOption { + return func(config *ClientConfig) { + config.OrgID = orgID + } +} + +func WithEmptyMessagesLimit(limit uint) ConfigOption { + return func(config *ClientConfig) { + config.EmptyMessagesLimit = limit + } +}