## Deploy LLM using DJL

- Reference doc : https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-tutorials-fastertransformer.html
- Reference blog : https://aws.amazon.com/ko/blogs/machine-learning/deploy-large-models-at-high-performance-using-fastertransformer-on-amazon-sagemaker/
- DJL container list : https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-dlc.html
- Code example : https://github.com/aws/amazon-sagemaker-examples/blob/main/inference/generativeai/llm-workshop/lab4-openchatkit/deploy_openchatkit_on_sagemaker_deepspeed.ipynb


### Why use DJL

- DJL is effective for deploying LLM. Normal inference object such as `PyTorchModel` or `HuggingFaceModel` aren't good to deploy LLM.

In [None]:
%store -r

In [None]:
# model_download_path

In [None]:
stablelm_model_artifact = "s3://sagemaker-us-west-2-723597067299/llm/stablelm/model/base-7b"

In [None]:
import boto3
import sagemaker
from sagemaker.utils import name_from_base
from sagemaker import image_uris

In [None]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
sm_client = sagemaker_session.sagemaker_client
sm_runtime_client = sagemaker_session.sagemaker_runtime_client

In [None]:
llm_engine = "deepspeed"
# llm_engine = "fastertransformer"

In [None]:
framework_name = f"djl-{llm_engine}"
inference_image_uri = image_uris.retrieve(
    framework=framework_name, region=sagemaker_session.boto_session.region_name, version="0.21.0"
)

print(f"Inference container uri: {inference_image_uri}")

In [None]:
src_dir_name = f"djl-{llm_engine}-src"

In [None]:
s3_target = f"s3://{sagemaker_session.default_bucket()}/llm/stable-lm/code/"

In [None]:
!rm -rf {src_dir_name}.tar.gz
!tar zcvf {src_dir_name}.tar.gz {src_dir_name} --exclude ".ipynb_checkpoints" --exclude "__pycache__"
!aws s3 cp {src_dir_name}.tar.gz {s3_target}

In [None]:
model_uri = f"{s3_target}{src_dir_name}.tar.gz"
print(model_uri)

In [None]:
model_name = name_from_base(f"stable-lm-7b-djl")
print(model_name)

create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": inference_image_uri, "ModelDataUrl": model_uri},
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

In [None]:
# instance_type = "ml.g4dn.xlarge"
instance_type = "ml.g5.4xlarge"

endpoint_config_name = f"{model_name}-config"
endpoint_name = f"{model_name}-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": 600,
        },
    ],
)
print(endpoint_config_response)

In [None]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

In [None]:
import json

In [None]:
# system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
# - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
# - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
# - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
# - StableLM will refuse to participate in anything that could harm a human.
# """

# prompt = f"{system_prompt}<|USER|>What kind of animals can I adopt?\n<|ASSISTANT|>"

In [None]:
prompt = "How can I buy some great phone in vietnam?"

In [None]:
%%time
prompts = [prompt]

response_model = sm_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps(
        {
            "text": prompts,
            "parameters": {
                "max_new_tokens": 512,
                "temperature": 0.7,
            },
        }
    ),
    ContentType="application/json",
)

In [None]:
output = str(response_model["Body"].read(), "utf-8")

In [None]:
print(output)