## Kullm 모델을 DJL로 배포하기

In [None]:
!pip install -q transformers accelerate sentencepiece bitsandbytes

In [None]:
import boto3
import sagemaker
from sagemaker.utils import name_from_base
from sagemaker import image_uris

In [None]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
sm_client = sagemaker_session.sagemaker_client
sm_runtime_client = sagemaker_session.sagemaker_runtime_client
default_bucket = sagemaker_session.default_bucket()

In [None]:
from huggingface_hub import snapshot_download
from pathlib import Path
import os

local_model_path = Path("./pretrained-models")
local_model_path.mkdir(exist_ok=True)
model_name = "nlpai-lab/kullm-polyglot-12.8b-v2"
# model_name = "nlpai-lab/kullm-polyglot-5.8b-v2"
allow_patterns = ["*.json", "*.pt", "*.bin", "*.txt", "*.model", "*.py"]

model_download_path = snapshot_download(
    repo_id=model_name,
    cache_dir=local_model_path,
    allow_patterns=allow_patterns,
)

In [None]:
s3_model_prefix = "llm/kullm/model"  # folder where model checkpoint will go

In [None]:
base_model_s3 = f"{s3_model_prefix}/kullm-13b"

In [None]:
sagemaker_session = sagemaker.Session()
s3_model_artifact = sagemaker_session.upload_data(path=model_download_path, key_prefix=base_model_s3)

In [None]:
print(f"Model s3 uri : {s3_model_artifact}")

In [None]:
# llm_engine = "deepspeed"
llm_engine = "fastertransformer"

In [None]:
framework_name = f"djl-{llm_engine}"
# inference_image_uri = image_uris.retrieve(
#     framework=framework_name, region=sagemaker_session.boto_session.region_name, version="0.21.0"
# )

inference_image_uri = image_uris.retrieve(
    framework=framework_name, region=sagemaker_session.boto_session.region_name, version="0.22.1"
)

print(f"Inference container uri: {inference_image_uri}")

In [None]:
src_dir_name = f"kullm-13b-src"
s3_target = f"s3://{sagemaker_session.default_bucket()}/llm/kullm/code/"

In [None]:
!rm -rf {src_dir_name}.tar.gz
!tar zcvf {src_dir_name}.tar.gz {src_dir_name} --exclude ".ipynb_checkpoints" --exclude "__pycache__"
!aws s3 cp {src_dir_name}.tar.gz {s3_target}

In [None]:
model_uri = f"{s3_target}{src_dir_name}.tar.gz"
print(model_uri)

In [None]:
model_name = name_from_base(f"kullm-13b-djl")
print(model_name)

create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": inference_image_uri, "ModelDataUrl": model_uri},
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

In [None]:
instance_type = "ml.g4dn.xlarge"
# instance_type = "ml.g4dn.2xlarge"
# instance_type = "ml.g5.4xlarge"

endpoint_config_name = f"{model_name}-config"
endpoint_name = f"{model_name}-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": 600,
        },
    ],
)
print(endpoint_config_response)

In [None]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

In [None]:
import json

In [None]:
# prompt = "카자흐스탄과 베트남 중에서 어디가 더 여행하기 좋아?"
# prompt = "어떻게 하면 부자가 될 수 있을까?"
prompt = "What is the easiest way to become a rich?"
print(prompt)

In [None]:
%%time
prompts = [prompt]

response_model = sm_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps(
        {
            "input_text": prompts,
            "instruction": "입력된 질문에 대해서 정확하고 자세한 답변을 해 주세요.",
            "parameters": {
                "max_new_tokens": 512,
                "temperature": 0.7,
                "top_p": 0.7,
            },
        }
    ),
    ContentType="application/json",
)

In [None]:
output = str(response_model["Body"].read(), "utf-8")
print(output)

### 테스트 결과 속도면에서 많은 차이는 없지만 예상대로 g5가 좀 더 속도가 빠르다.

- g5.4xlarge 사용 시 : 30 sec ~ 50 sec
- g4dn.xlarge 사용 시 : 30 sec ~ timeout (시간이 좀 더 오래걸림)
