diff --git a/docs/api/language_model_clients/Groq.md b/docs/api/language_model_clients/Groq.md new file mode 100644 index 0000000000..5f48accf4b --- /dev/null +++ b/docs/api/language_model_clients/Groq.md @@ -0,0 +1,51 @@ +--- +sidebar_position: 9 +--- + +# dspy.GROQ + +### Usage + +```python +lm = dspy.GROQ(model='mixtral-8x7b-32768', api_key ="gsk_***" ) +``` + +### Constructor + +The constructor initializes the base class `LM` and verifies the provided arguments like the `api_key` for GROQ api retriver. The `kwargs` attribute is initialized with default values for relevant text generation parameters needed for communicating with the GPT API, such as `temperature`, `max_tokens`, `top_p`, `frequency_penalty`, `presence_penalty`, and `n`. + +```python +class GroqLM(LM): + def __init__( + self, + api_key: str, + model: str = "mixtral-8x7b-32768", + **kwargs, + ): +``` + + + +**Parameters:** +- `api_key` str: API provider authentication token. Defaults to None. +- `model` str: model name. Defaults to "mixtral-8x7b-32768' options: ['llama2-70b-4096', 'gemma-7b-it'] +- `**kwargs`: Additional language model arguments to pass to the API provider. + +### Methods + +#### `def __call__(self, prompt: str, only_completed: bool = True, return_sorted: bool = False, **kwargs, ) -> list[dict[str, Any]]:` + +Retrieves completions from GROQ by calling `request`. + +Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response. + +After generation, the generated content look like `choice["message"]["content"]`. + +**Parameters:** +- `prompt` (_str_): Prompt to send to OpenAI. +- `only_completed` (_bool_, _optional_): Flag to return only completed responses and ignore completion due to length. Defaults to True. +- `return_sorted` (_bool_, _optional_): Flag to sort the completion choices using the returned averaged log-probabilities. Defaults to False. +- `**kwargs`: Additional keyword arguments for completion request. + +**Returns:** +- `List[Dict[str, Any]]`: List of completion choices. \ No newline at end of file diff --git a/dsp/modules/__init__.py b/dsp/modules/__init__.py index 79e51402d5..b52d663cd5 100644 --- a/dsp/modules/__init__.py +++ b/dsp/modules/__init__.py @@ -8,6 +8,7 @@ from .databricks import * from .google import * from .gpt3 import * +from .groq_client import * from .hf import HFModel from .hf_client import Anyscale, HFClientTGI, Together from .mistral import * @@ -15,3 +16,4 @@ from .pyserini import * from .sbert import * from .sentence_vectorizer import * + diff --git a/dsp/modules/groq_client.py b/dsp/modules/groq_client.py new file mode 100644 index 0000000000..0f1d2ffe40 --- /dev/null +++ b/dsp/modules/groq_client.py @@ -0,0 +1,169 @@ +import logging +from typing import Any + +import backoff + +try: + import groq + from groq import Groq + groq_api_error = (groq.APIError, groq.RateLimitError) +except ImportError: + groq_api_error = (Exception) + + +import dsp +from dsp.modules.lm import LM + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(message)s", + handlers=[logging.FileHandler("groq_usage.log")], +) + + + +def backoff_hdlr(details): + """Handler from https://pypi.org/project/backoff/""" + print( + "Backing off {wait:0.1f} seconds after {tries} tries " + "calling function {target} with kwargs " + "{kwargs}".format(**details), + ) + + +class GroqLM(LM): + """Wrapper around groq's API. + + Args: + model (str, optional): groq supported LLM model to use. Defaults to "mixtral-8x7b-32768". + api_key (Optional[str], optional): API provider Authentication token. use Defaults to None. + **kwargs: Additional arguments to pass to the API provider. + """ + + def __init__( + self, + api_key: str, + model: str = "mixtral-8x7b-32768", + **kwargs, + ): + super().__init__(model) + self.provider = "groq" + if api_key: + self.api_key = api_key + self.client = Groq(api_key = api_key) + else: + raise ValueError("api_key is required for groq") + + + self.kwargs = { + "temperature": 0.0, + "max_tokens": 150, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + "n": 1, + **kwargs, + } + models = self.client.models.list().data + if models is not None: + if model in [m.id for m in models]: + self.kwargs["model"] = model + self.history: list[dict[str, Any]] = [] + + + def log_usage(self, response): + """Log the total tokens from the Groq API response.""" + usage_data = response.get("usage") + if usage_data: + total_tokens = usage_data.get("total_tokens") + logging.info(f"{total_tokens}") + + def basic_request(self, prompt: str, **kwargs): + raw_kwargs = kwargs + + kwargs = {**self.kwargs, **kwargs} + + kwargs["messages"] = [{"role": "user", "content": prompt}] + response = self.chat_request(**kwargs) + + history = { + "prompt": prompt, + "response": response.choices[0].message.content, + "kwargs": kwargs, + "raw_kwargs": raw_kwargs, + } + + self.history.append(history) + + return response + + @backoff.on_exception( + backoff.expo, + groq_api_error, + max_time=1000, + on_backoff=backoff_hdlr, + ) + def request(self, prompt: str, **kwargs): + """Handles retreival of model completions whilst handling rate limiting and caching.""" + if "model_type" in kwargs: + del kwargs["model_type"] + + return self.basic_request(prompt, **kwargs) + + def _get_choice_text(self, choice) -> str: + return choice.message.content + + def chat_request(self, **kwargs): + """Handles retreival of model completions whilst handling rate limiting and caching.""" + response = self.client.chat.completions.create(**kwargs) + return response + + def __call__( + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, + ) -> list[dict[str, Any]]: + """Retrieves completions from model. + + Args: + prompt (str): prompt to send to model + only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True. + return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False. + + Returns: + list[dict[str, Any]]: list of completion choices + """ + + assert only_completed, "for now" + assert return_sorted is False, "for now" + response = self.request(prompt, **kwargs) + + if dsp.settings.log_openai_usage: + self.log_usage(response) + + choices = response.choices + + completions = [self._get_choice_text(c) for c in choices] + if return_sorted and kwargs.get("n", 1) > 1: + scored_completions = [] + + for c in choices: + tokens, logprobs = ( + c["logprobs"]["tokens"], + c["logprobs"]["token_logprobs"], + ) + + if "<|endoftext|>" in tokens: + index = tokens.index("<|endoftext|>") + 1 + tokens, logprobs = tokens[:index], logprobs[:index] + + avglog = sum(logprobs) / len(logprobs) + scored_completions.append((avglog, self._get_choice_text(c))) + + scored_completions = sorted(scored_completions, reverse=True) + completions = [c for _, c in scored_completions] + + return completions \ No newline at end of file diff --git a/dsp/modules/hf_client.py b/dsp/modules/hf_client.py index 56bbbb2ab9..4e9ed472e8 100644 --- a/dsp/modules/hf_client.py +++ b/dsp/modules/hf_client.py @@ -435,4 +435,4 @@ def _generate(self, prompt, **kwargs): @CacheMemory.cache def send_hfsglang_request_v00(arg, **kwargs): - return requests.post(arg, **kwargs) + return requests.post(arg, **kwargs) \ No newline at end of file diff --git a/dsp/modules/lm.py b/dsp/modules/lm.py index 796d8b11bf..a3347557ac 100644 --- a/dsp/modules/lm.py +++ b/dsp/modules/lm.py @@ -46,7 +46,7 @@ def inspect_history(self, n: int = 1, skip: int = 0): prompt = x["prompt"] if prompt != last_prompt: - if provider == "clarifai" or provider == "google": + if provider == "clarifai" or provider == "google" or provider == "groq": printed.append((prompt, x["response"])) elif provider == "anthropic": blocks = [{"text": block.text} for block in x["response"].content if block.type == "text"] @@ -80,6 +80,8 @@ def inspect_history(self, n: int = 1, skip: int = 0): text = " " + self._get_choice_text(choices[0]).strip() elif provider == "clarifai": text = choices + elif provider == "groq": + text = ' ' + choices elif provider == "google": text = choices[0].parts[0].text elif provider == "mistral": diff --git a/dspy/__init__.py b/dspy/__init__.py index f2edcbfce6..75d2332815 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -21,6 +21,7 @@ Pyserini = dsp.PyseriniRetriever Clarifai = dsp.ClarifaiLLM Google = dsp.Google +GROQ = dsp.GroqLM HFClientTGI = dsp.HFClientTGI HFClientVLLM = HFClientVLLM diff --git a/pyproject.toml b/pyproject.toml index 85ea9a1691..146b4cfa0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ docs = [ "autodoc_pydantic", "sphinx-reredirects>=0.1.2", "sphinx-automodapi==0.16.0", + ] dev = ["pytest>=6.2.5"] @@ -108,6 +109,7 @@ sphinx_rtd_theme = { version = "*", optional = true } autodoc_pydantic = { version = "*", optional = true } sphinx-reredirects = { version = "^0.1.2", optional = true } sphinx-automodapi = { version = "0.16.0", optional = true } +groq = {version = "^0.4.2", optional = true } rich = "^13.7.1" psycopg2 = {version = "^2.9.9", optional = true} pgvector = {version = "^0.2.5", optional = true}