# AOT 方式使用 Tensorrt-LMI 部署 Chatglm3-6b-32

In [None]:
%pip install sagemaker --upgrade  --quiet

In [None]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

## Pull 转换模型时需要的镜像

In [None]:
!aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/s0w3f1p2

In [None]:
!docker pull public.ecr.aws/s0w3f1p2/tensorrt-lmi-xq:v1

## 转换模型

In [None]:
%%writefile serving.properties
option.model_id=option.model_id=THUDM/chatglm3-6b-32k
option.tensor_parallel_degree=2
option.max_rolling_batch_size=64
option.dtype=fp16
option.chatglm_model_version=chatglm3_6b_32k
option.trust_remote_code=True

In [None]:
!cat serving.properties

In [None]:
file_name = "chatglm3_6b_32k_aot"
MODEL_REPO_DIR=f"output/{file_name}"
!mkdir -p $MODEL_REPO_DIR
!mv serving.properties $MODEL_REPO_DIR

In [None]:
bucket="llm-trt"
s3_model_prefix = f"lmi/{file_name}"
s3url=f"s3://{bucket}/{s3_model_prefix}"
s3url

In [None]:
! readlink -f $MODEL_REPO_DIR

In [None]:
from pathlib import Path
current_path = Path.cwd()
trt_model_path = Path(current_path, MODEL_REPO_DIR)
print(trt_model_path)

In [None]:
!docker run --runtime=nvidia --gpus all --shm-size 12gb \
-v $trt_model_path:/tmp/trtllm \
-p 8080:8080 \
public.ecr.aws/s0w3f1p2/tensorrt-lmi-xq:v1 /opt/djl/partition/trt_llm_partition.py \
--properties_dir /tmp/trtllm \
--trt_llm_model_repo /tmp/trtllm \
--tensor_parallel_degree 2

In [None]:
!mv ./model_tokenizer/chatglm3_6B_32K_tokenizer_config.json $trt_model_path/tokenizer_config.json

## 上传转换后的模型到 S3

In [None]:
!aws s3 sync $trt_model_path $s3url

## 创建，上传，部署所需配置文件到S3
 - 修改 serving.properties 中 model_id 为上传的 S3 模型地址
 - 根据自己的输入输出，写 model.py 文件

In [None]:
with open('serving.properties', 'w') as f:
    f.write(f"option.model_id={s3url}\n")
    f.write("option.tensor_parallel_degree=2\n")
    f.write("option.max_rolling_batch_size=64\n")
    f.write("option.dtype=fp16\n")
    f.write("option.chatglm_model_version=chatglm3_6b_32k\n")
    f.write("option.trust_remote_code=True\n")
    f.write("option.output_formatter=jsonlines\n")
    f.write("option.rolling_batch=trtllm\n")

In [None]:
!cat serving.properties

In [None]:
%%sh
mkdir mymodel
mv serving.properties mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel

In [None]:
s3_code_prefix = f"large-model-lmi/{file_name}"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

## 设置部署使用的镜像初始化model

In [None]:
image_uri = image_uris.retrieve(
        framework="djl-tensorrtllm",
        region=sess.boto_session.region_name,
        version="0.26.0"
    )

model = Model(image_uri=image_uri, model_data=code_artifact, role=role)

## 开始部署

In [None]:
instance_type = "ml.g5.12xlarge"
endpoint_name = sagemaker.utils.name_from_base("chatglm3-6b-32k-aot")

model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name,
             # container_startup_health_check_timeout=3600
            )

# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer()
)

## 测试

In [None]:
role = "user"
input_text = "世界上第二高的山峰是哪座"

response = predictor.predict(
    {"inputs": f"<|{role}|>{input_text}<|assistant|>\n", 
     "parameters": {
             "max_new_tokens":128,
             "temperature": 0.8,
             "do_sample": True,
             "top_p": 0.8,
             "return_log_probs": False
         }
    }
)
print(type(response))
text = str(response, 'utf-8')
print(text)
# print(json.loads(text)["generated_text"].strip())

## 删除部署的endpoint以及对应的cogfug

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()