# Invoke SGT-Llama Endpoint

This notebook demonstrates how to invoke an SGT-Llama-Fused AWS SageMaker endpoint via an assumed a role.

## AWS Helpers

In [1]:
# Install dependencies
%pip install boto3 'types-boto3[sagemaker,sagemaker-runtime]'

Note: you may need to restart the kernel to use updated packages.


In [2]:
import json
from typing import Any, Final

import boto3
from types_boto3_sagemaker_runtime.client import SageMakerRuntimeClient

AWS_REGION: Final[str] = "us-east-1"


def assume_role(
    role_arn: str,
    boto_session: boto3.Session | None = None,
    session_name: str = "AssumeRoleSession",
    profile_name: str | None = None,
) -> boto3.Session:
    """Create a session using an assumed role.
    The session uses temporary credentials and is not advised for use
    in long running jobs.

    Args:
        boto_session: The boto3 session to use.
        role_arn: The ARN of the role to assume.
        session_name: The name of the session.
        profile_name: The name of the profile to use if boto_session is not provided.

    Returns:
        The session with the assumed role.
    """
    if boto_session is None:
        boto_session = boto3.Session(profile_name=profile_name)

    sts_client = boto_session.client("sts")

    response = sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)
    credentials = response["Credentials"]

    # Create a new session with the assumed role's credentials
    assumed_session = boto3.Session(
        aws_access_key_id=credentials["AccessKeyId"],
        aws_secret_access_key=credentials["SecretAccessKey"],
        aws_session_token=credentials["SessionToken"],
        region_name=AWS_REGION,
    )

    return assumed_session


def invoke_sagemaker_endpoint(endpoint_name: str, payload: dict[str, str], b3_session: boto3.Session) -> dict:
    """Invoke a SageMaker endpoint with the given payload.

    Args:
        endpoint_name: The name of the SageMaker endpoint.
        payload: The input data to send to the endpoint.
        b3_session: A boto3 session with authorization to invoke endpoints.

    Returns:
        dict: The response from the SageMaker endpoint.
    """
    runtime_client: SageMakerRuntimeClient = b3_session.client("sagemaker-runtime")
    response = runtime_client.invoke_endpoint(EndpointName=endpoint_name, ContentType="application/json", Body=json.dumps(payload))
    return json.loads(response["Body"].read().decode("utf-8"))

## Set Up Boto3 Sessions

In [None]:
AWS_PROFILE: Final[str] = "<< YOUR PROFILE NAME HERE"
# SESSION = boto3.Session() # Use this if you'd like to create your own session rather than use an AWS profile.
SAGEMAKER_INVOKE_ROLE: Final[str] = "arn:aws:iam::423667443377:role/InvokeSageMakerEndpoint"
ENDPOINT_NAME: Final[str] = "fused-sgt-llama-8B-g54x"

sagemaker_b3_session: boto3.Session = assume_role(
    role_arn=SAGEMAKER_INVOKE_ROLE,
    profile_name=AWS_PROFILE,
    # boto_session=SESSION,
)

## Invoke Endpoint

The endpoint adheres to the following API contract based on the [Bedrock Antrhopic Claude Messages API](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html)

### Request

```json
{
    "messages": [ 
        { "role": str, "content": str} \\ Message
    ],
    "max_tokens": int,
    "temperature": float,
    "top_p": float,
}
```

* `messages` [`list[dict]`]: Each element is text "content" sent by a "role".
    *   `role`: Either "user" or "assistant"
    *   `content`: The text sent by the role.

* `max_tokens` [int]: The maximum number of tokens allowed to generate in response. Default is 2048.
* `temperature` [float]: Controls the randomness of the model. Default is 0.01.
* `top_p` [float]: The top probability to cut off nucleus sampling. Default is 0.999. You should alter either `temperature` or `top_p`, but not both.

### Response

```json
{
    "id": str,
    "content": list[str],
    "model": str,
}
```

* `id` [str]: A uuid generated for the response.
* `content` [list[str]]: A list with a single string that is the response of the model.
* `model` [str]: The name of the model that generated the response.

## A Chat Example

In [4]:
REQUEST: dict[str, Any] = {
    "messages": [
        {"role": "user", "content": "What is the capital of France?"},
    ],
    "max_tokens": 120,
}

response = invoke_sagemaker_endpoint(ENDPOINT_NAME, REQUEST, b3_session=sagemaker_b3_session)
response

{'id': '3f881cd3-7467-4c68-9f2d-23c5faadc54b',
 'content': ['The capital of France is Paris.'],
 'model': 'SGT-Llama'}

In [5]:
chat_request: dict[str, Any] = {
    "messages": [
        {"role": "user", "content": "What is the capital of France?"},
        {"role": "assistant", "content": response["content"][0]},
        {"role": "user", "content": "What landmark is there?"},
    ],
    "max_tokens": 120,
}

In [6]:
response = invoke_sagemaker_endpoint(ENDPOINT_NAME, chat_request, b3_session=sagemaker_b3_session)
response

{'id': 'a626fc37-8a51-49d1-ac9c-152fb64245be',
 'content': ["One of the most famous landmarks in Paris is the Eiffel Tower (La Tour Eiffel in French). It's a iconic iron lattice tower built for the 1889 World's Fair and has become a symbol of Paris and France."],
 'model': 'SGT-Llama'}