<a target="_blank" href="https://colab.research.google.com/github/cohere-ai/notebooks/blob/main/notebooks/llmu/co_aws_ch5_rerank_sm.ipynb"> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>

# Cohere Rerank on Amazon SageMaker



In [None]:
! pip install cohere cohere-aws boto3

In [None]:
import os
import boto3
import cohere
import cohere_aws
from cohere_aws import Client

In [None]:
import cohere

# Create SageMaker client via the native Cohere SDK
# Contact your AWS administrator for the credentials
co = cohere.SagemakerClient(
    aws_region="us-east-1",
    aws_access_key="",
    aws_secret_key="",
    aws_session_token="",
)

# For creating an endpoint, you need to use the cohere_aws client: Set environment variables with the AWS credentials
os.environ['AWS_ACCESS_KEY_ID'] = ""
os.environ['AWS_SECRET_ACCESS_KEY'] = ""
os.environ['AWS_SESSION_TOKEN'] = ""

In [None]:
# Create SageMaker endpoint via the cohere_aws SDK
cohere_package = "cohere-rerank-english-v3-01-d3687e0d2e3a366bb904275616424807"
model_package_map = {
    "us-east-1": f"arn:aws:sagemaker:us-east-1:865070037744:model-package/{cohere_package}",
    "us-east-2": f"arn:aws:sagemaker:us-east-2:057799348421:model-package/{cohere_package}",
    "us-west-1": f"arn:aws:sagemaker:us-west-1:382657785993:model-package/{cohere_package}",
    "us-west-2": f"arn:aws:sagemaker:us-west-2:594846645681:model-package/{cohere_package}",
    "ca-central-1": f"arn:aws:sagemaker:ca-central-1:470592106596:model-package/{cohere_package}",
    "eu-central-1": f"arn:aws:sagemaker:eu-central-1:446921602837:model-package/{cohere_package}",
    "eu-west-1": f"arn:aws:sagemaker:eu-west-1:985815980388:model-package/{cohere_package}",
    "eu-west-2": f"arn:aws:sagemaker:eu-west-2:856760150666:model-package/{cohere_package}",
    "eu-west-3": f"arn:aws:sagemaker:eu-west-3:843114510376:model-package/{cohere_package}",
    "eu-north-1": f"arn:aws:sagemaker:eu-north-1:136758871317:model-package/{cohere_package}",
    "ap-southeast-1": f"arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/{cohere_package}",
    "ap-southeast-2": f"arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/{cohere_package}",
    "ap-northeast-2": f"arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/{cohere_package}",
    "ap-northeast-1": f"arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/{cohere_package}",
    "ap-south-1": f"arn:aws:sagemaker:ap-south-1:077584701553:model-package/{cohere_package}",
    "sa-east-1": f"arn:aws:sagemaker:sa-east-1:270155090741:model-package/{cohere_package}",
}

region = boto3.Session().region_name

if region not in model_package_map.keys():
    raise Exception("UNSUPPORTED REGION")

model_package_arn = model_package_map[region]

co_aws = Client(region_name=region)

co_aws.create_endpoint(arn=model_package_arn, endpoint_name="my-rerank-v3", instance_type="ml.g5.xlarge", n_instances=1)

In [None]:
documents = [
    {"Title":"Incorrect Password","Content":"Hello, I have been trying to access my account for the past hour and it keeps saying my password is incorrect. Can you please help me?"},
    {"Title":"Confirmation Email Missed","Content":"Hi, I recently purchased a product from your website but I never received a confirmation email. Can you please look into this for me?"},
    {"Title":"Questions about Return Policy","Content":"Hello, I have a question about the return policy for this product. I purchased it a few weeks ago and it is defective."},
    {"Title":"Customer Support is Busy","Content":"Good morning, I have been trying to reach your customer support team for the past week but I keep getting a busy signal. Can you please help me?"},
    {"Title":"Received Wrong Item","Content":"Hi, I have a question about my recent order. I received the wrong item and I need to return it."},
    {"Title":"Customer Service is Unavailable","Content":"Hello, I have been trying to reach your customer support team for the past hour but I keep getting a busy signal. Can you please help me?"},
    {"Title":"Return Policy for Defective Product","Content":"Hi, I have a question about the return policy for this product. I purchased it a few weeks ago and it is defective."},
    {"Title":"Wrong Item Received","Content":"Good morning, I have a question about my recent order. I received the wrong item and I need to return it."},
    {"Title":"Return Defective Product","Content":"Hello, I have a question about the return policy for this product. I purchased it a few weeks ago and it is defective."}
]

In [None]:
query = 'What emails have been about refunds?'

response = co.rerank(documents=documents,
                     query=query,
                     rank_fields=["Title","Content"],
                     top_n=3,
                     model="my-rerank-v3")

In [None]:
print("Documents","\n")

for idx,doc in enumerate(response.results):
    print(f"#{idx+1}:\n{documents[doc.index]}\n")

Documents 

#1:
{'Title': 'Questions about Return Policy', 'Content': 'Hello, I have a question about the return policy for this product. I purchased it a few weeks ago and it is defective.'}

#2:
{'Title': 'Return Policy for Defective Product', 'Content': 'Hi, I have a question about the return policy for this product. I purchased it a few weeks ago and it is defective.'}

#3:
{'Title': 'Return Defective Product', 'Content': 'Hello, I have a question about the return policy for this product. I purchased it a few weeks ago and it is defective.'}



In [None]:
co_aws.delete_endpoint()
co_aws.close()