diff --git a/dsp/modules/bedrock.py b/dsp/modules/bedrock.py index 252c87fe86..737dd61c0a 100644 --- a/dsp/modules/bedrock.py +++ b/dsp/modules/bedrock.py @@ -1,19 +1,26 @@ from __future__ import annotations import json -from typing import Any, Optional +from dataclasses import dataclass +from typing import Any from dsp.modules.aws_lm import AWSLM +@dataclass +class ChatMessage: + role: str + content: str + + class Bedrock(AWSLM): def __init__( - self, - region_name: str, - model: str, - profile_name: Optional[str] = None, - input_output_ratio: int = 3, - max_new_tokens: int = 1500, + self, + region_name: str, + model: str, + profile_name: str | None = None, + input_output_ratio: int = 3, + max_new_tokens: int = 1500, ) -> None: """Use an AWS Bedrock language model. NOTE: You must first configure your AWS credentials with the AWS CLI before using this model! @@ -37,26 +44,36 @@ def __init__( ) self._validate_model(model) self.provider = "claude" if "claude" in model.lower() else "bedrock" + self.use_messages = "claude-3" in model.lower() def _validate_model(self, model: str) -> None: if "claude" not in model.lower(): raise NotImplementedError("Only claude models are supported as of now") - def _create_body(self, prompt: str, **kwargs) -> dict[str, str | float]: + def _create_body(self, prompt: str, system_prompt: str | None = None, **kwargs) -> dict[str, Any]: base_args: dict[str, Any] = { - "max_tokens_to_sample": self._max_new_tokens, + "anthropic_version": "bedrock-2023-05-31", } for k, v in kwargs.items(): base_args[k] = v + query_args: dict[str, Any] = self._sanitize_kwargs(base_args) - query_args["prompt"] = prompt - # AWS Bedrock forbids these keys - if "max_tokens" in query_args: - max_tokens: int = query_args["max_tokens"] - input_tokens: int = self._estimate_tokens(prompt) - max_tokens_to_sample: int = max_tokens - input_tokens - del query_args["max_tokens"] - query_args["max_tokens_to_sample"] = max_tokens_to_sample + + if self.use_messages: + messages = [ChatMessage(role="user", content=prompt)] + if system_prompt: + messages.insert(0, ChatMessage(role="system", content=system_prompt)) + else: + messages.insert(0, ChatMessage(role="system", content="You are a helpful AI assistant.")) + serialized_messages = [vars(m) for m in messages if m.role != "system"] + system_message = next(m["content"] for m in [vars(m) for m in messages if m.role == "system"]) + query_args["messages"] = serialized_messages + query_args["system"] = system_message + query_args["max_tokens"] = self._max_new_tokens + else: + query_args["prompt"] = self._format_prompt(prompt) + query_args["max_tokens_to_sample"] = self._max_new_tokens + return query_args def _call_model(self, body: str) -> str: @@ -67,13 +84,28 @@ def _call_model(self, body: str) -> str: contentType="application/json", ) response_body = json.loads(response["body"].read()) - completion = response_body["completion"] + + if self.use_messages: # Claude-3 model + try: + completion = response_body['content'][0]['text'] + except (KeyError, IndexError): + raise ValueError("Unexpected response format from the Claude-3 model.") + else: # Other models + expected_keys = ["completion", "text"] + found_key = next((key for key in expected_keys if key in response_body), None) + + if found_key: + completion = response_body[found_key] + else: + raise ValueError( + f"Unexpected response format from the model. Expected one of {', '.join(expected_keys)} keys.") + return completion def _extract_input_parameters( - self, body: dict[Any, Any], + self, body: dict[Any, Any], ) -> dict[str, str | float | int]: return body def _format_prompt(self, raw_prompt: str) -> str: - return "\n\nHuman: " + raw_prompt + "\n\nAssistant:" + return "\n\nHuman: " + raw_prompt + "\n\nAssistant:" \ No newline at end of file