Skip to content

Conversation

@drawal1
Copy link
Contributor

@drawal1 drawal1 commented Apr 4, 2024

This check-in replaces the old Bedrock module which was incomplete and obsolete. The new code refactors the module into aws_providers and aws_models and allows extending to new models.
Supported aws providers - Bedrock and Sagemaker
Supported models - Mistral, Anthropic, Llama (extensible design to add custom models)

"""Test harness: Testing the AWS modules"""
import dspy
from dspy import import AWSAnthropic, AWSMeta, AWSMistral, Bedrock, Sagemaker

class QASignature(dspy.Signature):
    """answer the question"""

    question = dspy.InputField()
    answer = dspy.OutputField()


def test_aws_models(lm):
    """Test the models on the given AWS provider"""

    predict_func = dspy.ChainOfThought(QASignature)
    with dspy.context(lm=lm):
        answer = predict_func(question="What is the capital of France")
        assert "paris" in str(answer).lower()
        print(answer)


if __name__ == "__main__":
    # NOTE: Configure your AWS credentials with the AWS CLI before running this test!

    bedrock = Bedrock(region_name="us-west-2")
    test_aws_models(AWSMistral(bedrock, model="mistral.mixtral-8x7b-instruct-v0:1"))
    test_aws_models(AWSMistral(bedrock, "mistral.mistral-7b-instruct-v0:2"))
    test_aws_models(AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0"))
    test_aws_models(AWSAnthropic(bedrock, "anthropic.claude-3-sonnet-20240229-v1:0"))
    # this is slower than molasses and generates irrelevant content after the answer!!
    # You may have to wait for 10-15min before it returns
    test_aws_models(AWSMeta(bedrock, "meta.llama2-70b-chat-v1"))

    # NOTE: Configure your Sagemaker endpoints before running this test!
    sagemaker = Sagemaker(region_name="us-west-2")
    # NOTE: Replace model value below with your own endpoint name
    test_aws_models(AWSMistral(sagemaker, model="g5-48xlarge-mixtral-8x7b"))

Copy link

@xiaochuan-du xiaochuan-du left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think the lm.py need to be updated as well? otherwise, the llm_model.inspect_history(n=1) may throw errors

@drawal1
Copy link
Contributor Author

drawal1 commented Apr 5, 2024

Good catch!

from .cache_utils import *
from .clarifai import *
from .cohere import *
from .cache_utils import * # noqa: F403
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drawal1 can you remove all extra comments here?

from typing import Any

from dsp.modules.lm import LM
from shared.src.models.utils.aws_providers import AWSProvider, Bedrock, Sagemaker
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this import coming from? can you wrap any external imports in a try/except block as done here?

right now.
"""
if not only_completed:
raise NotImplementedError("Error, only_completed not yet supported!")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor styling but can you follow this here?



class AWSModel(LM):
"""This class adds support for an AWS model."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add more clear documentation on how this class works with the other classes?

profile_name (str, optional): boto3 credentials profile.
batch_n_enabled (bool): If False, call the LM N times rather than batching.
"""
import boto3 # pylint: disable=import-outside-toplevel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above - wrap external imports in a try/except block as done here?

dspy/__init__.py Outdated
from .primitives import *
from .retrieve import *
from .signatures import *
from .predict import * # noqa: F403
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above - remove any extra comments

Google = dsp.Google

HFClientTGI = dsp.HFClientTGI
HFClientVLLM = HFClientVLLM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this deleted?

pgvector = {version = "^0.2.5", optional = true}
structlog = "^24.1.0"


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove change

from .aws_models import AWSAnthropic, AWSLlama2, AWSMistral, AWSModel
from .aws_providers import Bedrock, Sagemaker
from .azure_openai import AzureOpenAI
from .bedrock import *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is importing bedrock removed from here?

@arnavsinghvi11
Copy link
Collaborator

Hi @drawal1, left some comments on the PR. It's also failing some of the CI tests. Please run ruff check . --fix-only and address the other errors found in the build tests.

It would be great if you could add both documentation and integration tests for the AWS model support as well. Feel free to follow these LM docs for reference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants