diff --git a/chat_test.go b/chat_test.go index 520bf5ca..4b831ff6 100644 --- a/chat_test.go +++ b/chat_test.go @@ -97,6 +97,24 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestGroqChatCompletions Tests the completions endpoint of the API using the mocked server. +func TestGroqChatCompletions(t *testing.T) { + client, server, teardown := setupGroqTestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.Groq.Mixtral8x7b, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() diff --git a/completion.go b/completion.go index ab1dbd6c..37ac8d67 100644 --- a/completion.go +++ b/completion.go @@ -109,6 +109,19 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ }, } +// Groq contains models that work with Groq's OpenAI Compatibility API. +// +// Usage Examples: openai.Groq.Mixtral8x7b, openai.Groq.LLaMA270b, openai.Groq.Gemma7bIT. +var Groq = struct { + Mixtral8x7b string + LLaMA270b string + Gemma7bIT string +}{ + Mixtral8x7b: "mixtral-8x7b-32768", + LLaMA270b: "llama2-70b-4096", + Gemma7bIT: "gemma-7b-it", +} + func checkEndpointSupportsModel(endpoint, model string) bool { return !disabledModelsForEndpoints[endpoint][model] } diff --git a/config.go b/config.go index c58b71ec..5b383c3e 100644 --- a/config.go +++ b/config.go @@ -6,7 +6,9 @@ import ( ) const ( - openaiAPIURLv1 = "https://api.openai.com/v1" + openaiAPIURLv1 = "https://api.openai.com/v1" + groqAPIURLv1 = "https://api.groq.com/openai/v1" + defaultEmptyMessagesLimit uint = 300 azureAPIPrefix = "openai" @@ -19,6 +21,7 @@ const ( APITypeOpenAI APIType = "OPEN_AI" APITypeAzure APIType = "AZURE" APITypeAzureAD APIType = "AZURE_AD" + APITypeGroq APIType = "GROQ" ) const AzureAPIKeyHeader = "api-key" @@ -67,6 +70,20 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { } } +// DefaultGroqConfig takes a Groq auth token and returns a ClientConfig that works with Groq's OpenAI Compatibility API. +func DefaultGroqConfig(authToken string) ClientConfig { + return ClientConfig{ + authToken: authToken, + BaseURL: groqAPIURLv1, + APIType: APITypeGroq, + OrgID: "", + + HTTPClient: &http.Client{}, + + EmptyMessagesLimit: defaultEmptyMessagesLimit, + } +} + func (ClientConfig) String() string { return "" } diff --git a/example_test.go b/example_test.go index de67c57c..797a4b2d 100644 --- a/example_test.go +++ b/example_test.go @@ -345,3 +345,38 @@ func ExampleAPIError() { } } } + +// ExampleDefaultGroqConfig demonstrates how to create a new DefaultGroqConfig, create a Groq client, +// use a hosted Groq model, and create a chat completion. +func ExampleDefaultGroqConfig() { + config := openai.DefaultGroqConfig(os.Getenv("GROQ_API_KEY")) + client := openai.NewClientWithConfig(config) + + req := openai.ChatCompletionRequest{ + Model: openai.Groq.Mixtral8x7b, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "you are a helpful chatbot", + }, + }, + } + fmt.Println("Conversation") + fmt.Println("---------------------") + fmt.Print("> ") + s := bufio.NewScanner(os.Stdin) + for s.Scan() { + req.Messages = append(req.Messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: s.Text(), + }) + resp, err := client.CreateChatCompletion(context.Background(), req) + if err != nil { + fmt.Printf("ChatCompletion error: %v\n", err) + continue + } + fmt.Printf("%s\n\n", resp.Choices[0].Message.Content) + req.Messages = append(req.Messages, resp.Choices[0].Message) + fmt.Print("> ") + } +} diff --git a/openai_test.go b/openai_test.go index 729d8880..fa008e32 100644 --- a/openai_test.go +++ b/openai_test.go @@ -27,6 +27,17 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea return } +func setupGroqTestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { + server = test.NewTestServer() + ts := server.OpenAITestServer() + ts.Start() + teardown = ts.Close + config := openai.DefaultGroqConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client = openai.NewClientWithConfig(config) + return +} + // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: // https://beta.openai.com/tokenizer