Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions docs/api/language_model_clients/Groq.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
---
sidebar_position: 9
---

# dspy.GROQ

### Usage

```python
lm = dspy.GROQ(model='mixtral-8x7b-32768', api_key ="gsk_***" )
```

### Constructor

The constructor initializes the base class `LM` and verifies the provided arguments like the `api_key` for GROQ api retriver. The `kwargs` attribute is initialized with default values for relevant text generation parameters needed for communicating with the GPT API, such as `temperature`, `max_tokens`, `top_p`, `frequency_penalty`, `presence_penalty`, and `n`.

```python
class GroqLM(LM):
def __init__(
self,
api_key: str,
model: str = "mixtral-8x7b-32768",
**kwargs,
):
```



**Parameters:**
- `api_key` str: API provider authentication token. Defaults to None.
- `model` str: model name. Defaults to "mixtral-8x7b-32768' options: ['llama2-70b-4096', 'gemma-7b-it']
- `**kwargs`: Additional language model arguments to pass to the API provider.

### Methods

#### `def __call__(self, prompt: str, only_completed: bool = True, return_sorted: bool = False, **kwargs, ) -> list[dict[str, Any]]:`

Retrieves completions from GROQ by calling `request`.

Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response.

After generation, the generated content look like `choice["message"]["content"]`.

**Parameters:**
- `prompt` (_str_): Prompt to send to OpenAI.
- `only_completed` (_bool_, _optional_): Flag to return only completed responses and ignore completion due to length. Defaults to True.
- `return_sorted` (_bool_, _optional_): Flag to sort the completion choices using the returned averaged log-probabilities. Defaults to False.
- `**kwargs`: Additional keyword arguments for completion request.

**Returns:**
- `List[Dict[str, Any]]`: List of completion choices.
2 changes: 2 additions & 0 deletions dsp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from .databricks import *
from .google import *
from .gpt3 import *
from .groq_client import *
from .hf import HFModel
from .hf_client import Anyscale, HFClientTGI, Together
from .mistral import *
from .ollama import *
from .pyserini import *
from .sbert import *
from .sentence_vectorizer import *

169 changes: 169 additions & 0 deletions dsp/modules/groq_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import logging
from typing import Any

import backoff

try:
import groq
from groq import Groq
groq_api_error = (groq.APIError, groq.RateLimitError)
except ImportError:
groq_api_error = (Exception)


import dsp
from dsp.modules.lm import LM

# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
handlers=[logging.FileHandler("groq_usage.log")],
)



def backoff_hdlr(details):
"""Handler from https://pypi.org/project/backoff/"""
print(
"Backing off {wait:0.1f} seconds after {tries} tries "
"calling function {target} with kwargs "
"{kwargs}".format(**details),
)


class GroqLM(LM):
"""Wrapper around groq's API.

Args:
model (str, optional): groq supported LLM model to use. Defaults to "mixtral-8x7b-32768".
api_key (Optional[str], optional): API provider Authentication token. use Defaults to None.
**kwargs: Additional arguments to pass to the API provider.
"""

def __init__(
self,
api_key: str,
model: str = "mixtral-8x7b-32768",
**kwargs,
):
super().__init__(model)
self.provider = "groq"
if api_key:
self.api_key = api_key
self.client = Groq(api_key = api_key)
else:
raise ValueError("api_key is required for groq")


self.kwargs = {
"temperature": 0.0,
"max_tokens": 150,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"n": 1,
**kwargs,
}
models = self.client.models.list().data
if models is not None:
if model in [m.id for m in models]:
self.kwargs["model"] = model
self.history: list[dict[str, Any]] = []


def log_usage(self, response):
"""Log the total tokens from the Groq API response."""
usage_data = response.get("usage")
if usage_data:
total_tokens = usage_data.get("total_tokens")
logging.info(f"{total_tokens}")

def basic_request(self, prompt: str, **kwargs):
raw_kwargs = kwargs

kwargs = {**self.kwargs, **kwargs}

kwargs["messages"] = [{"role": "user", "content": prompt}]
response = self.chat_request(**kwargs)

history = {
"prompt": prompt,
"response": response.choices[0].message.content,
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}

self.history.append(history)

return response

@backoff.on_exception(
backoff.expo,
groq_api_error,
max_time=1000,
on_backoff=backoff_hdlr,
)
def request(self, prompt: str, **kwargs):
"""Handles retreival of model completions whilst handling rate limiting and caching."""
if "model_type" in kwargs:
del kwargs["model_type"]

return self.basic_request(prompt, **kwargs)

def _get_choice_text(self, choice) -> str:
return choice.message.content

def chat_request(self, **kwargs):
"""Handles retreival of model completions whilst handling rate limiting and caching."""
response = self.client.chat.completions.create(**kwargs)
return response

def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
) -> list[dict[str, Any]]:
"""Retrieves completions from model.

Args:
prompt (str): prompt to send to model
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.

Returns:
list[dict[str, Any]]: list of completion choices
"""

assert only_completed, "for now"
assert return_sorted is False, "for now"
response = self.request(prompt, **kwargs)

if dsp.settings.log_openai_usage:
self.log_usage(response)

choices = response.choices

completions = [self._get_choice_text(c) for c in choices]
if return_sorted and kwargs.get("n", 1) > 1:
scored_completions = []

for c in choices:
tokens, logprobs = (
c["logprobs"]["tokens"],
c["logprobs"]["token_logprobs"],
)

if "<|endoftext|>" in tokens:
index = tokens.index("<|endoftext|>") + 1
tokens, logprobs = tokens[:index], logprobs[:index]

avglog = sum(logprobs) / len(logprobs)
scored_completions.append((avglog, self._get_choice_text(c)))

scored_completions = sorted(scored_completions, reverse=True)
completions = [c for _, c in scored_completions]

return completions
2 changes: 1 addition & 1 deletion dsp/modules/hf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,4 +435,4 @@ def _generate(self, prompt, **kwargs):

@CacheMemory.cache
def send_hfsglang_request_v00(arg, **kwargs):
return requests.post(arg, **kwargs)
return requests.post(arg, **kwargs)
8 changes: 5 additions & 3 deletions dsp/modules/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ def inspect_history(self, n: int = 1, skip: int = 0):
if provider == "cohere":
text = choices
elif provider == "openai" or provider == "ollama":
text = " " + self._get_choice_text(choices[0]).strip()
elif provider == "clarifai":
text = choices
text = ' ' + self._get_choice_text(choices[0]).strip()
elif provider == "clarifai" or provider == "claude" :
text=choices
elif provider == "groq":
text = ' ' + choices
elif provider == "google":
text = choices[0].parts[0].text
elif provider == "mistral":
Expand Down
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Pyserini = dsp.PyseriniRetriever
Clarifai = dsp.ClarifaiLLM
Google = dsp.Google
GROQ = dsp.GroqLM

HFClientTGI = dsp.HFClientTGI
HFClientVLLM = HFClientVLLM
Expand Down
23 changes: 21 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ docs = [
"autodoc_pydantic",
"sphinx-reredirects>=0.1.2",
"sphinx-automodapi==0.16.0",

]
dev = ["pytest>=6.2.5"]

Expand Down Expand Up @@ -108,6 +109,7 @@ sphinx_rtd_theme = { version = "*", optional = true }
autodoc_pydantic = { version = "*", optional = true }
sphinx-reredirects = { version = "^0.1.2", optional = true }
sphinx-automodapi = { version = "0.16.0", optional = true }
groq = {version = "^0.4.2", optional = true }
rich = "^13.7.1"
psycopg2 = {version = "^2.9.9", optional = true}
pgvector = {version = "^0.2.5", optional = true}
Expand Down