diff --git a/dsp/modules/__init__.py b/dsp/modules/__init__.py index 241a349c58..fdf59eabc8 100644 --- a/dsp/modules/__init__.py +++ b/dsp/modules/__init__.py @@ -1,3 +1,4 @@ +from .anthropic import Claude from .azure_openai import AzureOpenAI from .bedrock import * from .cache_utils import * diff --git a/dsp/modules/anthropic.py b/dsp/modules/anthropic.py new file mode 100644 index 0000000000..68f30ce7a1 --- /dev/null +++ b/dsp/modules/anthropic.py @@ -0,0 +1,140 @@ +import logging +import os +from typing import Any, Optional + +import backoff + +from dsp.modules.lm import LM + +try: + import anthropic + anthropic_rate_limit = anthropic.RateLimitError +except ImportError: + anthropic_rate_limit = Exception + + +logger = logging.getLogger(__name__) + +BASE_URL = "https://api.anthropic.com/v1/messages" + + +def backoff_hdlr(details): + """Handler from https://pypi.org/project/backoff/""" + print( + "Backing off {wait:0.1f} seconds after {tries} tries " + "calling function {target} with kwargs " + "{kwargs}".format(**details), + ) + + +def giveup_hdlr(details): + """wrapper function that decides when to give up on retry""" + if "rate limits" in details.message: + return False + return True + + +class Claude(LM): + """Wrapper around anthropic's API. Supports both the Anthropic and Azure APIs.""" + def __init__( + self, + model: str = "claude-instant-1.2", + api_key: Optional[str] = None, + api_base: Optional[str] = None, + **kwargs, + ): + super().__init__(model) + + try: + from anthropic import Anthropic, RateLimitError + except ImportError as err: + raise ImportError("Claude requires `pip install anthropic`.") from err + + self.provider = "anthropic" + self.api_key = api_key = os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key + self.api_base = BASE_URL if api_base is None else api_base + + self.kwargs = { + "temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"], + "max_tokens": min(kwargs.get("max_tokens", 4096), 4096), + "top_p": 1.0 if "top_p" not in kwargs else kwargs["top_p"], + "top_k": 1 if "top_k" not in kwargs else kwargs["top_k"], + "n": kwargs.pop("n", kwargs.pop("num_generations", 1)), + **kwargs, + } + self.kwargs["model"] = model + self.history: list[dict[str, Any]] = [] + self.client = Anthropic(api_key=api_key) + + def log_usage(self, response): + """Log the total tokens from the Anthropic API response.""" + usage_data = response.usage + if usage_data: + total_tokens = usage_data.input_tokens + usage_data.output_tokens + logger.info(f'{total_tokens}') + + def basic_request(self, prompt: str, **kwargs): + raw_kwargs = kwargs + + kwargs = {**self.kwargs, **kwargs} + # caching mechanism requires hashable kwargs + kwargs["messages"] = [{"role": "user", "content": prompt}] + kwargs.pop("n") + print(kwargs) + response = self.client.messages.create(**kwargs) + + history = { + "prompt": prompt, + "response": response, + "kwargs": kwargs, + "raw_kwargs": raw_kwargs, + } + self.history.append(history) + + return response + + @backoff.on_exception( + backoff.expo, + (anthropic_rate_limit), + max_time=1000, + max_tries=8, + on_backoff=backoff_hdlr, + giveup=giveup_hdlr, + ) + def request(self, prompt: str, **kwargs): + """Handles retrieval of completions from Anthropic whilst handling API errors""" + return self.basic_request(prompt, **kwargs) + + def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs): + """Retrieves completions from Anthropic. + + Args: + prompt (str): prompt to send to Anthropic + only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True. + return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False. + + Returns: + list[str]: list of completion choices + """ + + assert only_completed, "for now" + assert return_sorted is False, "for now" + + + # per eg here: https://docs.anthropic.com/claude/reference/messages-examples + # max tokens can be used as a proxy to return smaller responses + # so this cannot be a proper indicator for incomplete response unless it isnt the user-intent. + # if only_completed and response.stop_reason != "end_turn": + # choices = [] + + n = kwargs.pop("n", 1) + completions = [] + for i in range(n): + response = self.request(prompt, **kwargs) + # TODO: Log llm usage instead of hardcoded openai usage + # if dsp.settings.log_openai_usage: + # self.log_usage(response) + if only_completed and response.stop_reason == "max_tokens": + continue + completions = [c.text for c in response.content] + return completions \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 236ee58290..7747a524e5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -151,6 +151,30 @@ files = [ {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, ] +[[package]] +name = "anthropic" +version = "0.18.1" +description = "The official Python library for the anthropic API" +optional = true +python-versions = ">=3.7" +files = [ + {file = "anthropic-0.18.1-py3-none-any.whl", hash = "sha256:b85aee64f619ce1b1964ba733a09adc4053e7bc4e6d4186001229ec191099dcf"}, + {file = "anthropic-0.18.1.tar.gz", hash = "sha256:f5d1caafd43f6cc933a79753a93531605095f040a384f6a900c3de9c3fb6694e"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tokenizers = ">=0.13.0" +typing-extensions = ">=4.7,<5" + +[package.extras] +bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"] +vertex = ["google-auth (>=2,<3)"] + [[package]] name = "anyio" version = "4.3.0" @@ -917,6 +941,17 @@ files = [ graph = ["objgraph (>=1.7.2)"] profile = ["gprof2dot (>=2022.7.29)"] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = true +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "dnspython" version = "2.6.1" @@ -6193,4 +6228,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "3d584715c2e1e090a222116cc007e4c6689c475f6ccc281796a221be59b615e2" +content-hash = "f7a5ab7c85e79920d41e45e9bbd17f0dbc1180c52d027235a656c270d9e79346" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cb8462de12..04204686d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ ] [project.optional-dependencies] +anthropic = ["anthropic~=0.18.0"] chromadb = ["chromadb~=0.4.14"] qdrant = ["qdrant-client~=1.6.2", "fastembed~=0.1.0"] marqo = ["marqo"] @@ -84,6 +85,7 @@ tqdm = "^4.66.1" datasets = "^2.14.6" requests = "^2.31.0" optuna = "^3.4.0" +anthropic = { version = "^0.18.0", optional = true } chromadb = { version = "^0.4.14", optional = true } fastembed = { version = "^0.1.0", optional = true } marqo = { version = "*", optional = true }