## Deploy StableVicuna 13B

- In this example, we deploy 13B model on g4dn.2xlage using 8bit quantization with DJL.
- Also we uses async inference this time, which is good choice for LLM since inference time is not short.


### Container that used for deployment
- `763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-fastertransformer5.3.0-cu117`

In [None]:
import boto3
import sagemaker
from sagemaker.utils import name_from_base
from sagemaker import image_uris

In [None]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
sm_client = sagemaker_session.sagemaker_client
sm_runtime_client = sagemaker_session.sagemaker_runtime_client
s3_client = boto3.client('s3')

In [None]:
llm_engine = "fastertransformer"

In [None]:
framework_name = f"djl-{llm_engine}"
inference_image_uri = image_uris.retrieve(
    framework=framework_name, region=sagemaker_session.boto_session.region_name, version="0.21.0"
)

print(f"Inference container uri: {inference_image_uri}")

In [None]:
src_dir_name = "stable-vicuna-src"

In [None]:
s3_target = f"s3://{sagemaker_session.default_bucket()}/llm/stable-vicuna/code/"

In [None]:
!rm -rf {src_dir_name}.tar.gz
!tar zcvf {src_dir_name}.tar.gz {src_dir_name} --exclude ".ipynb_checkpoints" --exclude "__pycache__"
!aws s3 cp {src_dir_name}.tar.gz {s3_target}

In [None]:
model_uri = f"{s3_target}{src_dir_name}.tar.gz"
print(model_uri)

In [None]:
model_name = name_from_base(f"stable-vicuna-13b-djl")
print(model_name)

create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": inference_image_uri, "ModelDataUrl": model_uri},
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

In [None]:
async_output_uri = f"s3://{sagemaker_session.default_bucket()}/llm/outputs/{model_name}/"
print(async_output_uri)

In [None]:
instance_type = "ml.g4dn.2xlarge"

endpoint_config_name = f"{model_name}-config"
endpoint_name = f"{model_name}-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": 600,
        },
    ],
    AsyncInferenceConfig={
        "OutputConfig": {
            "S3OutputPath": async_output_uri,
        },
        "ClientConfig": {
            "MaxConcurrentInvocationsPerInstance": 1
        }
    }
)
print(endpoint_config_response)

In [None]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

In [None]:
import json
import uuid

In [None]:
# prompt = "How to upload json text to S3 without saving to local, in python?"
# prompt = "Can you draw a picture which contains pigs flying in the sky?"
prompt = """\
### Human: How to upload json text to S3 without saving to local file system in python?
### Assistant:\
"""

In [None]:
prompts = [prompt]

input_data = {
    "text": prompts,
    "parameters": {
        "max_new_tokens": 256,
        "temperature": 0.5,
        "top_p": 0.5
    },
}

print(input_data)

In [None]:
default_bucket = sagemaker_session.default_bucket()

In [None]:
# Upload input data onto the S3
s3_uri = f"llm/inputs/{model_name}/{uuid.uuid4()}.json"
s3_client.put_object(
    Bucket=default_bucket,
    Key=s3_uri,
    Body=json.dumps(input_data))

input_data_uri = f"s3://{default_bucket}/{s3_uri}"

In [None]:
input_location = input_data_uri

In [None]:
%%time
response = sm_runtime_client.invoke_endpoint_async(
    EndpointName=endpoint_name, 
    InputLocation=input_location
)
output_location = response["OutputLocation"]

In [None]:
print(output_location)
output_key_uri = "/".join(output_location.split("/")[3:])

### Check the result

- This is async inference, therefore you need to check s3 output is exist.
- In real service architecture, output s3 will trigger Lambda or other event using SNS, SQS, EventBridge, ...

In [None]:

try:
    exists = s3_client.head_object(Bucket=default_bucket, Key=output_key_uri)['ResponseMetadata']['HTTPStatusCode'] == 200
    if exists:
        text_obj = s3_client.get_object(Bucket=default_bucket, Key=output_key_uri)['Body'].read()
        text = text_obj.decode('utf-8')
        print(text)
except:
    print("Data is not exist yet. Wait until inference finished or check the CW log")