Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dsp/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .anthropic import Claude
from .azure_openai import AzureOpenAI
from .bedrock import *
from .cache_utils import *
Expand Down
140 changes: 140 additions & 0 deletions dsp/modules/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import logging
import os
from typing import Any, Optional

import backoff

from dsp.modules.lm import LM

try:
import anthropic
anthropic_rate_limit = anthropic.RateLimitError
except ImportError:
anthropic_rate_limit = Exception


logger = logging.getLogger(__name__)

BASE_URL = "https://api.anthropic.com/v1/messages"


def backoff_hdlr(details):
"""Handler from https://pypi.org/project/backoff/"""
print(
"Backing off {wait:0.1f} seconds after {tries} tries "
"calling function {target} with kwargs "
"{kwargs}".format(**details),
)


def giveup_hdlr(details):
"""wrapper function that decides when to give up on retry"""
if "rate limits" in details.message:
return False
return True


class Claude(LM):
"""Wrapper around anthropic's API. Supports both the Anthropic and Azure APIs."""
def __init__(
self,
model: str = "claude-instant-1.2",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
**kwargs,
):
super().__init__(model)

try:
from anthropic import Anthropic, RateLimitError
except ImportError as err:
raise ImportError("Claude requires `pip install anthropic`.") from err

self.provider = "anthropic"
self.api_key = api_key = os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key
self.api_base = BASE_URL if api_base is None else api_base

self.kwargs = {
"temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"],
"max_tokens": min(kwargs.get("max_tokens", 4096), 4096),
"top_p": 1.0 if "top_p" not in kwargs else kwargs["top_p"],
"top_k": 1 if "top_k" not in kwargs else kwargs["top_k"],
"n": kwargs.pop("n", kwargs.pop("num_generations", 1)),
**kwargs,
}
self.kwargs["model"] = model
self.history: list[dict[str, Any]] = []
self.client = Anthropic(api_key=api_key)

def log_usage(self, response):
"""Log the total tokens from the Anthropic API response."""
usage_data = response.usage
if usage_data:
total_tokens = usage_data.input_tokens + usage_data.output_tokens
logger.info(f'{total_tokens}')

def basic_request(self, prompt: str, **kwargs):
raw_kwargs = kwargs

kwargs = {**self.kwargs, **kwargs}
# caching mechanism requires hashable kwargs
kwargs["messages"] = [{"role": "user", "content": prompt}]
kwargs.pop("n")
print(kwargs)
response = self.client.messages.create(**kwargs)

history = {
"prompt": prompt,
"response": response,
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
self.history.append(history)

return response

@backoff.on_exception(
backoff.expo,
(anthropic_rate_limit),
max_time=1000,
max_tries=8,
on_backoff=backoff_hdlr,
giveup=giveup_hdlr,
)
def request(self, prompt: str, **kwargs):
"""Handles retrieval of completions from Anthropic whilst handling API errors"""
return self.basic_request(prompt, **kwargs)

def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
"""Retrieves completions from Anthropic.

Args:
prompt (str): prompt to send to Anthropic
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.

Returns:
list[str]: list of completion choices
"""

assert only_completed, "for now"
assert return_sorted is False, "for now"


# per eg here: https://docs.anthropic.com/claude/reference/messages-examples
# max tokens can be used as a proxy to return smaller responses
# so this cannot be a proper indicator for incomplete response unless it isnt the user-intent.
# if only_completed and response.stop_reason != "end_turn":
# choices = []

n = kwargs.pop("n", 1)
completions = []
for i in range(n):
response = self.request(prompt, **kwargs)
# TODO: Log llm usage instead of hardcoded openai usage
# if dsp.settings.log_openai_usage:
# self.log_usage(response)
if only_completed and response.stop_reason == "max_tokens":
continue
completions = [c.text for c in response.content]
return completions
37 changes: 36 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
]

[project.optional-dependencies]
anthropic = ["anthropic~=0.18.0"]
chromadb = ["chromadb~=0.4.14"]
qdrant = ["qdrant-client~=1.6.2", "fastembed~=0.1.0"]
marqo = ["marqo"]
Expand Down Expand Up @@ -84,6 +85,7 @@ tqdm = "^4.66.1"
datasets = "^2.14.6"
requests = "^2.31.0"
optuna = "^3.4.0"
anthropic = { version = "^0.18.0", optional = true }
chromadb = { version = "^0.4.14", optional = true }
fastembed = { version = "^0.1.0", optional = true }
marqo = { version = "*", optional = true }
Expand Down