## LLAVA on SageMaker

- LLAVA 모델을 sagemaker에서 async endpoint로 테스트 진행
- DeepSpeed나 FasterTransformer 를 사용하고 있지는 않지만, DJL을 활용하도록 함.
- S3에 모델 업로드 -> 코드 수정 -> 코드 s3에 업로드 -> SageMaker endpoint 생성 -> 확인


In [9]:
%store -r

In [10]:
model_artifact

's3://sagemaker-us-west-2-723597067299/llm/llava/llava-v15/model'

In [11]:
import boto3
import sagemaker
from sagemaker.utils import name_from_base
from sagemaker import image_uris

In [12]:
llm_engine = "deepspeed"
# llm_engine = "fastertransformer"

In [29]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
sm_client = sagemaker_session.sagemaker_client
sm_runtime_client = sagemaker_session.sagemaker_runtime_client
s3_client = boto3.client('s3')
default_bucket = sagemaker_session.default_bucket()

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


In [14]:
framework_name = f"djl-{llm_engine}"
inference_image_uri = image_uris.retrieve(
    framework=framework_name, region=sagemaker_session.boto_session.region_name, version="0.23.0"
)

print(f"Inference container uri: {inference_image_uri}")

Inference container uri: 763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118


In [15]:
s3_target = f"s3://{sagemaker_session.default_bucket()}/llm/llava/llava-v15/code/"
print(s3_target)

s3://sagemaker-us-west-2-723597067299/llm/llava/llava-v15/code/


In [16]:
!rm llava-src.tar.gz
!tar zcvf llava-src.tar.gz llava-src --exclude ".ipynb_checkpoints" --exclude "__pycache__"
!aws s3 cp llava-src.tar.gz {s3_target}

rm: cannot remove ‘llava-src.tar.gz’: No such file or directory
llava-src/
llava-src/model.py
llava-src/requirements.txt
llava-src/serving.properties
llava-src/run_llava_local.py
upload: ./llava-src.tar.gz to s3://sagemaker-us-west-2-723597067299/llm/llava/llava-v15/code/llava-src.tar.gz


In [17]:
model_uri = f"{s3_target}llava-src.tar.gz"
print(model_uri)

s3://sagemaker-us-west-2-723597067299/llm/llava/llava-v15/code/llava-src.tar.gz


### SageMaker endpoint 활용한 배포

- Async Endpoint로 배포를 진행함.

In [18]:
model_name = name_from_base(f"llava-djl")
print(model_name)

create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": inference_image_uri, "ModelDataUrl": model_uri},
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

llava-djl-2023-10-27-09-18-01-391
Created Model: arn:aws:sagemaker:us-west-2:723597067299:model/llava-djl-2023-10-27-09-18-01-391


In [19]:
default_bucket = sagemaker_session.default_bucket()
async_output_uri = f"s3://{default_bucket}/llm/outputs/{model_name}/"
print(async_output_uri)

s3://sagemaker-us-west-2-723597067299/llm/outputs/llava-djl-2023-10-27-09-18-01-391/


In [20]:
instance_type = "ml.g5.2xlarge"
# instance_type = "ml.g5.xlarge"
# instance_type = "ml.g4dn.xlarge"

endpoint_config_name = f"{model_name}-async-config"
endpoint_name = f"{model_name}-async-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": 600,
        },
    ],
    AsyncInferenceConfig={
        "OutputConfig": {
            "S3OutputPath": async_output_uri,
        },
        "ClientConfig": {
            "MaxConcurrentInvocationsPerInstance": 1
        }
    }
)
print(endpoint_config_response)

{'EndpointConfigArn': 'arn:aws:sagemaker:us-west-2:723597067299:endpoint-config/llava-djl-2023-10-27-09-18-01-391-async-config', 'ResponseMetadata': {'RequestId': 'cfe2e969-71de-4617-91fc-03444fc3d180', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'cfe2e969-71de-4617-91fc-03444fc3d180', 'content-type': 'application/x-amz-json-1.1', 'content-length': '127', 'date': 'Fri, 27 Oct 2023 09:18:11 GMT'}, 'RetryAttempts': 0}}


In [21]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

Created Endpoint: arn:aws:sagemaker:us-west-2:723597067299:endpoint/llava-djl-2023-10-27-09-18-01-391-async-endpoint


In [24]:
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

Status: InService
Arn: arn:aws:sagemaker:us-west-2:723597067299:endpoint/llava-djl-2023-10-27-09-18-01-391-async-endpoint
Status: InService


In [30]:
import json
import uuid

In [37]:
prompt = "where is the 25 exist in the matrix table?"

In [38]:
payload = {
    "text": [prompt],
    "input_image_s3": "s3://sagemaker-us-west-2-723597067299/llm/llava/input-samples/test_01.jpg",
}

In [39]:
# Upload input data onto the S3
s3_uri = f"llm/inputs/{model_name}/{uuid.uuid4()}.json"
s3_client.put_object(
    Bucket=default_bucket,
    Key=s3_uri,
    Body=json.dumps(payload))

input_data_uri = f"s3://{default_bucket}/{s3_uri}"
input_location = input_data_uri

In [40]:
response = sm_runtime_client.invoke_endpoint_async(
    EndpointName=endpoint_name, 
    InputLocation=input_location,
    ContentType="application/json"
)
output_location = response["OutputLocation"]
print(output_location)
output_key_uri = "/".join(output_location.split("/")[3:])

s3://sagemaker-us-west-2-723597067299/llm/outputs/llava-djl-2023-10-27-09-18-01-391/4eb045e7-5999-4aff-93d4-5d6d390ee843.out


In [43]:
try:
    exists = s3_client.head_object(Bucket=default_bucket, Key=output_key_uri)['ResponseMetadata']['HTTPStatusCode'] == 200
    if exists:
        text_obj = s3_client.get_object(Bucket=default_bucket, Key=output_key_uri)['Body'].read()
        text = text_obj.decode('utf-8')
        print(text)
except:
    print("Data is not exist yet. Wait until inference finished or check the CW log")

The 25 exists in the matrix table in the middle row, specifically in the third column.</s>


In [44]:
output = text.split("</s>")[0]
print(output)

The 25 exists in the matrix table in the middle row, specifically in the third column.
