In [None]:
%%sh
pip install boto3 python-dotenv tqdm

In [None]:
import os
import time
from tqdm import tqdm
import boto3
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# Retrieve AWS info from environment variables
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
region_name = os.getenv('AWS_REGION')
aws_account_id = os.getenv('AWS_ACCOUNT_NUMBER')

# Choose a role in your AWS account that trusts sagemaker
# This one has the AmazonSageMakerFullAccess policy attached
role = "ServiceRoleSagemaker"
execution_role_arn = f"arn:aws:iam::{aws_account_id}:role/{role}"

In [None]:
# Got this from console.aws.amazon.com/marketplace/home#/subscriptions
# and selecting this model and choosing "Configure"
model_package = "arn:aws:sagemaker:us-east-1:865070037744:model-package/cohere-rerank-multilingual-v3--13dba038aab73b11b3f0b17fbdb48ea0"

# Only this and ml.g5.2xlarge are supported
instance_type = "ml.g5.xlarge"

name = "rerank-3-demo"

sm = boto3.client('sagemaker', region_name=region_name)

# Create the model
sm.create_model(
    ModelName=name,
    ExecutionRoleArn=execution_role_arn,
    PrimaryContainer={"ModelPackageName": model_package},
    EnableNetworkIsolation=True
)

# Create the endpoint config
sm.create_endpoint_config(
    EndpointConfigName=name,
    ProductionVariants=[
        {
            "VariantName": "variant-1",
            "ModelName": name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
        }
    ],
)

sm.create_endpoint(
    EndpointName=name,
    EndpointConfigName=name,
)

status = "Creating"
with tqdm(total=None, desc="Creating Endpoint", unit="check") as pbar:
    while status != "InService":
        response = sm.describe_endpoint(EndpointName=name)
        status = response["EndpointStatus"]
        pbar.set_postfix({"Status": status})
        pbar.update(1)
        time.sleep(10)

print(f"Endpoint creation completed. Status: {status}")

In [None]:
try:
    sm_runtime = boto3.client(
        'sagemaker-runtime',
        aws_access_key_id=aws_access_key_id,
        aws_secret_access_key=aws_secret_access_key,
        region_name=region_name
    )
except Exception as e:
    print(f"Error creating SageMaker runtime client: {e}")

In [None]:
import json

# Input data
input_data = {
    "query": "What is the capital of the United States?",
    "rank_fields": ["Title", "Content"],
    "documents": [
        {
            "Title": "Facts about Carson City", 
            "Content": "Carson City is the capital city of \
                the American state of Nevada."
        },
        {
            "Title": "उत्तरी मारियाना द्वीप समूह के राष्ट्रमंडल का इतिहास", 
            "Content": "उत्तरी मारियाना द्वीप समूह का राष्ट्रमंडल प्रशांत महासागर में \
                द्वीपों का एक समूह है।"
        },
        {
            "Title": "Los Estados Unidos",
            "Content": "Washington, DC es la capital de los \
                Estados Unidos."
        }
    ],
    "top_n": 3
}

# Invoke the endpoint
response = sm_runtime.invoke_endpoint(
    EndpointName=name,
    ContentType='application/json',
    Body=json.dumps(input_data)
)

# Get the results
result = json.loads(response['Body'].read().decode())
print(result)

In [None]:
%%sh
pip install tabulate

In [8]:
from tabulate import tabulate

# Format the output
print("\nQuery:", input_data['query'])
print("\nRanked Results:")

# Prepare data for tabulate
table_data = []
for rank, item in enumerate(result['results'], 1):
    index = item['index']
    score = item['relevance_score']
    title = input_data['documents'][index]['Title']
    content = input_data['documents'][index]['Content']
    table_data.append([rank, f"{score:.4f}", title, content])

# Print table
print(
    tabulate(
        table_data,
        headers=[
            'Rank',
            'Score',
            'Title',
            'Content'
        ],
        tablefmt='grid')
    )

# Print additional information
print(f"\nRequest ID: {result['id']}")
print(f"API Version: {result['meta']['api_version']['version']}")
print(f"Billed Units: {result['meta']['billed_units']['search_units']}")


Query: What is the capital of the United States?

Ranked Results:
+--------+---------+------------------------------------+----------------------------------------------------------------------------------+
|   Rank |   Score | Title                              | Content                                                                          |
|      1 |  0.9994 | Los Estados Unidos                 | Washington, DC es la capital de los                 Estados Unidos.              |
+--------+---------+------------------------------------+----------------------------------------------------------------------------------+
|      2 |  0.1247 | Facts about Carson City            | Carson City is the capital city of                 the American state of Nevada. |
+--------+---------+------------------------------------+----------------------------------------------------------------------------------+
|      3 |  0.0001 | उत्तरी मारियाना द्वीप समूह के राष्ट्रमंडल का इतिहास | उत्तरी मारिय