From 77d49227647fa8d429dd0f90214156b10b641696 Mon Sep 17 00:00:00 2001 From: cta2106 <32028497+cta2106@users.noreply.github.com> Date: Mon, 25 Mar 2024 21:19:21 -0400 Subject: [PATCH 1/4] feat(bedrock/claude-3): Add support for chat format --- dsp/modules/bedrock.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/dsp/modules/bedrock.py b/dsp/modules/bedrock.py index 252c87fe86..77565cf6b0 100644 --- a/dsp/modules/bedrock.py +++ b/dsp/modules/bedrock.py @@ -1,19 +1,24 @@ from __future__ import annotations - import json -from typing import Any, Optional - +from typing import Any, Optional, List from dsp.modules.aws_lm import AWSLM +from dataclasses import dataclass + + +@dataclass +class ChatMessage: + role: str + content: str class Bedrock(AWSLM): def __init__( - self, - region_name: str, - model: str, - profile_name: Optional[str] = None, - input_output_ratio: int = 3, - max_new_tokens: int = 1500, + self, + region_name: str, + model: str, + profile_name: Optional[str] = None, + input_output_ratio: int = 3, + max_new_tokens: int = 1500, ) -> None: """Use an AWS Bedrock language model. NOTE: You must first configure your AWS credentials with the AWS CLI before using this model! @@ -37,19 +42,24 @@ def __init__( ) self._validate_model(model) self.provider = "claude" if "claude" in model.lower() else "bedrock" + self.use_messages = "claude-3" in model.lower() def _validate_model(self, model: str) -> None: if "claude" not in model.lower(): raise NotImplementedError("Only claude models are supported as of now") - def _create_body(self, prompt: str, **kwargs) -> dict[str, str | float]: + def _create_body(self, prompt: str, messages: Optional[List[ChatMessage]] = None, **kwargs) -> dict[ + str, str | float | list]: base_args: dict[str, Any] = { "max_tokens_to_sample": self._max_new_tokens, } for k, v in kwargs.items(): base_args[k] = v query_args: dict[str, Any] = self._sanitize_kwargs(base_args) - query_args["prompt"] = prompt + if self.use_messages: + query_args["messages"] = [vars(m) for m in messages] + else: + query_args["prompt"] = self._format_prompt(prompt) # AWS Bedrock forbids these keys if "max_tokens" in query_args: max_tokens: int = query_args["max_tokens"] @@ -71,7 +81,7 @@ def _call_model(self, body: str) -> str: return completion def _extract_input_parameters( - self, body: dict[Any, Any], + self, body: dict[Any, Any], ) -> dict[str, str | float | int]: return body From e7b50ddcc7097023d6a1ec754c36ba718ae17977 Mon Sep 17 00:00:00 2001 From: cta2106 <32028497+cta2106@users.noreply.github.com> Date: Mon, 25 Mar 2024 22:13:48 -0400 Subject: [PATCH 2/4] feat(bedrock/claude-3): Add support for chat format --- dsp/modules/bedrock.py | 53 +++++++++++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/dsp/modules/bedrock.py b/dsp/modules/bedrock.py index 77565cf6b0..dff4afcdac 100644 --- a/dsp/modules/bedrock.py +++ b/dsp/modules/bedrock.py @@ -1,6 +1,6 @@ from __future__ import annotations import json -from typing import Any, Optional, List +from typing import Any, Optional from dsp.modules.aws_lm import AWSLM from dataclasses import dataclass @@ -48,25 +48,35 @@ def _validate_model(self, model: str) -> None: if "claude" not in model.lower(): raise NotImplementedError("Only claude models are supported as of now") - def _create_body(self, prompt: str, messages: Optional[List[ChatMessage]] = None, **kwargs) -> dict[ - str, str | float | list]: + def _create_body(self, prompt: str, system_prompt: Optional[str] = None, **kwargs) -> dict[str, Any]: base_args: dict[str, Any] = { + "anthropic_version": "bedrock-2023-05-31", "max_tokens_to_sample": self._max_new_tokens, } for k, v in kwargs.items(): base_args[k] = v - query_args: dict[str, Any] = self._sanitize_kwargs(base_args) + if self.use_messages: - query_args["messages"] = [vars(m) for m in messages] + messages = [ChatMessage(role="user", content=prompt)] + if system_prompt: + messages.insert(0, ChatMessage(role="system", content=system_prompt)) + else: + messages.insert(0, ChatMessage(role="system", content="You are a helpful AI assistant.")) + serialized_messages = [vars(m) for m in messages if m.role != "system"] + system_message = next(m["content"] for m in [vars(m) for m in messages if m.role == "system"]) + query_args = { + "messages": serialized_messages, + "system": system_message, + "anthropic_version": base_args["anthropic_version"], + "max_tokens": base_args["max_tokens_to_sample"], + } else: - query_args["prompt"] = self._format_prompt(prompt) - # AWS Bedrock forbids these keys - if "max_tokens" in query_args: - max_tokens: int = query_args["max_tokens"] - input_tokens: int = self._estimate_tokens(prompt) - max_tokens_to_sample: int = max_tokens - input_tokens - del query_args["max_tokens"] - query_args["max_tokens_to_sample"] = max_tokens_to_sample + query_args = { + "prompt": self._format_prompt(prompt), + "anthropic_version": base_args["anthropic_version"], + "max_tokens_to_sample": base_args["max_tokens_to_sample"], + } + return query_args def _call_model(self, body: str) -> str: @@ -77,7 +87,22 @@ def _call_model(self, body: str) -> str: contentType="application/json", ) response_body = json.loads(response["body"].read()) - completion = response_body["completion"] + + if self.use_messages: # Claude-3 model + try: + completion = response_body['content'][0]['text'] + except (KeyError, IndexError): + raise ValueError("Unexpected response format from the Claude-3 model.") + else: # Other models + expected_keys = ["completion", "text"] + found_key = next((key for key in expected_keys if key in response_body), None) + + if found_key: + completion = response_body[found_key] + else: + raise ValueError( + f"Unexpected response format from the model. Expected one of {', '.join(expected_keys)} keys.") + return completion def _extract_input_parameters( From 12e3b130167ea0e9158bbf278fa27966804965b6 Mon Sep 17 00:00:00 2001 From: cta2106 <32028497+cta2106@users.noreply.github.com> Date: Mon, 25 Mar 2024 22:24:13 -0400 Subject: [PATCH 3/4] feat(bedrock/claude-3): Add support for chat format --- dsp/modules/bedrock.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/dsp/modules/bedrock.py b/dsp/modules/bedrock.py index dff4afcdac..8afe51d3a5 100644 --- a/dsp/modules/bedrock.py +++ b/dsp/modules/bedrock.py @@ -51,11 +51,12 @@ def _validate_model(self, model: str) -> None: def _create_body(self, prompt: str, system_prompt: Optional[str] = None, **kwargs) -> dict[str, Any]: base_args: dict[str, Any] = { "anthropic_version": "bedrock-2023-05-31", - "max_tokens_to_sample": self._max_new_tokens, } for k, v in kwargs.items(): base_args[k] = v + query_args: dict[str, Any] = self._sanitize_kwargs(base_args) + if self.use_messages: messages = [ChatMessage(role="user", content=prompt)] if system_prompt: @@ -64,18 +65,12 @@ def _create_body(self, prompt: str, system_prompt: Optional[str] = None, **kwarg messages.insert(0, ChatMessage(role="system", content="You are a helpful AI assistant.")) serialized_messages = [vars(m) for m in messages if m.role != "system"] system_message = next(m["content"] for m in [vars(m) for m in messages if m.role == "system"]) - query_args = { - "messages": serialized_messages, - "system": system_message, - "anthropic_version": base_args["anthropic_version"], - "max_tokens": base_args["max_tokens_to_sample"], - } + query_args["messages"] = serialized_messages + query_args["system"] = system_message + query_args["max_tokens"] = self._max_new_tokens else: - query_args = { - "prompt": self._format_prompt(prompt), - "anthropic_version": base_args["anthropic_version"], - "max_tokens_to_sample": base_args["max_tokens_to_sample"], - } + query_args["prompt"] = self._format_prompt(prompt) + query_args["max_tokens_to_sample"] = self._max_new_tokens return query_args @@ -111,4 +106,4 @@ def _extract_input_parameters( return body def _format_prompt(self, raw_prompt: str) -> str: - return "\n\nHuman: " + raw_prompt + "\n\nAssistant:" + return "\n\nHuman: " + raw_prompt + "\n\nAssistant:" \ No newline at end of file From 99f3cc8cfd8daf0c3b0781d4c32793e29a74349f Mon Sep 17 00:00:00 2001 From: cta2106 <32028497+cta2106@users.noreply.github.com> Date: Thu, 4 Apr 2024 19:49:07 -0400 Subject: [PATCH 4/4] linted code with ruff --- dsp/modules/bedrock.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dsp/modules/bedrock.py b/dsp/modules/bedrock.py index 8afe51d3a5..737dd61c0a 100644 --- a/dsp/modules/bedrock.py +++ b/dsp/modules/bedrock.py @@ -1,8 +1,10 @@ from __future__ import annotations + import json -from typing import Any, Optional -from dsp.modules.aws_lm import AWSLM from dataclasses import dataclass +from typing import Any + +from dsp.modules.aws_lm import AWSLM @dataclass @@ -16,7 +18,7 @@ def __init__( self, region_name: str, model: str, - profile_name: Optional[str] = None, + profile_name: str | None = None, input_output_ratio: int = 3, max_new_tokens: int = 1500, ) -> None: @@ -48,7 +50,7 @@ def _validate_model(self, model: str) -> None: if "claude" not in model.lower(): raise NotImplementedError("Only claude models are supported as of now") - def _create_body(self, prompt: str, system_prompt: Optional[str] = None, **kwargs) -> dict[str, Any]: + def _create_body(self, prompt: str, system_prompt: str | None = None, **kwargs) -> dict[str, Any]: base_args: dict[str, Any] = { "anthropic_version": "bedrock-2023-05-31", }