# AOT 方式使用 Tensorrt-LMI 部署Baichuan2-13B

In [None]:
%pip install sagemaker --upgrade  --quiet

In [None]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

## Pull 转换模型时需要的镜像

In [None]:
!aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/s0w3f1p2

In [None]:
!docker pull public.ecr.aws/s0w3f1p2/tensorrt-lmi-xq:v1

## 转换模型

In [None]:
%%writefile serving.properties
option.model_id=baichuan-inc/Baichuan2-13B-Chat
option.tensor_parallel_degree=2
option.max_rolling_batch_size=64
option.dtype=fp16
option.baichuan_model_version=v2_13b
option.trust_remote_code=True

In [None]:
!cat serving.properties

In [None]:
MODEL_REPO_DIR="./output/baichuan_v2_13B_aot2"
!mkdir -p $MODEL_REPO_DIR
!mv serving.properties $MODEL_REPO_DIR

In [None]:
bucket="llm-trt"
s3_model_prefix = "lmi/baichuan_v2_13B_aot_2p_2_64"
s3url=f"s3://{bucket}/{s3_model_prefix}"
s3url

In [None]:
! readlink -f $MODEL_REPO_DIR

In [None]:
!docker run --runtime=nvidia --gpus all --shm-size 12gb \
-v /home/ec2-user/SageMaker/output/baichuan_v2_13B_aot2:/tmp/trtllm \
-p 8080:8080 \
public.ecr.aws/s0w3f1p2/tensorrt-lmi-xq:v1 /opt/djl/partition/trt_llm_partition.py \
--properties_dir /tmp/trtllm \
--trt_llm_model_repo /tmp/trtllm \
--tensor_parallel_degree 2

## 上传转换后的模型到 S3

In [None]:
!aws s3 sync /home/ec2-user/SageMaker/output/baichuan_v2_13B_aot2 $s3url

## 创建，上传，部署所需配置文件到S3
 - 修改 serving.properties 中 model_id 为上传的 S3 模型地址
 - 根据自己的输入输出，写 model.py 文件

In [None]:
%%writefile serving.properties
option.model_id=s3://llm-trt/lmi/baichuan_v2_13B_aot_2p_2_64
option.tensor_parallel_degree=2
option.max_rolling_batch_size=64
option.dtype=fp16
option.rolling_batch=trtllm
option.baichuan_model_version=v2_13b
option.trust_remote_code=True

In [None]:
%%writefile model.py
from djl_python.tensorrt_llm import TRTLLMService
from djl_python.inputs import Input
from djl_python.outputs import Output
from djl_python.encode_decode import encode, decode
import logging
import json
import types
import re

_service = TRTLLMService()

def custom_output_formatter(token, first_token, last_token, details, generated_tokens):
    """
    Replace this function with your custom output formatter.
    
    Args:
        token (Token): Token object 
        first (bool): If first token 
        last (bool): If last token
        aux (dict): Miscellaneous information
        prev_response (str): Previously generated tokens

    Returns:
        (str): Response string
        
    """
    result = {"token_id": token.id, "token_text": token.text, "token_log_prob": token.log_prob, "token_special_token": token.special_token}
    final_response = "" 
    if last_token:
        # result["generated_text"] = generated_tokens
        final_response = re.sub("。</s>$", "", generated_tokens) 
        if details:
            result["details"] = {
                "finish_reason": details.get("finish_reason", None)
            }
    
    #return json.dumps(result, ensure_ascii=False) + "\n"
    return final_response

def custom_input_formatter(self, inputs):
    """
    Replace this function with your custom input formatter.
        
    Args:
        data (obj): The request data, dict or string  

    Returns:
        (tuple): input_data (list), input_size (list), parameters (dict), errors (dict), batch (list)
    """
    input_data = []
    input_size = []
    parameters = []
    errors = {}
    batch = inputs.get_batches()
    for i, item in enumerate(batch):
        try:
            content_type = item.get_property("Content-Type")
            input_map = decode(item, content_type)
        except Exception as e:  # pylint: disable=broad-except
            logging.warning(f"Parse input failed: {i}")
            input_size.append(0)
            errors[i] = str(e)
            continue

        _inputs = input_map.pop("inputs", input_map)
        print(f"Dongxq self model.py _inputs: {_inputs}")
        if not isinstance(_inputs, list):
            _inputs = [_inputs]
        _inputs=list(map(lambda x: f"<reserved_106>{x}<reserved_107>", _inputs))
        input_data.extend(_inputs)
        input_size.append(len(_inputs))

        _param = input_map.pop("parameters", {})
        if "cached_prompt" in input_map:
            _param["cached_prompt"] = input_map.pop("cached_prompt")
        if not "seed" in _param:
            # set server provided seed if seed is not part of request
            if item.contains_key("seed"):
                _param["seed"] = item.get_as_string(key="seed")
        for _ in range(input_size[i]):
            parameters.append(_param)

    return input_data, input_size, parameters, errors, batch

def handle(inputs: Input):
    """
    Default handler function
    """
    if not _service.initialized:
        # stateful model
        props = inputs.get_properties()
        print(f"props: {props}")
        props['output_formatter'] = custom_output_formatter
        _service.initialize(props)
        _service.parse_input = types.MethodType(custom_input_formatter, _service)

    if inputs.is_empty():
        # initialization request
        return None

    return _service.inference(inputs)


In [None]:
%%sh
mkdir mymodel
mv serving.properties model.py mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel

In [None]:
s3_code_prefix = "large-model-lmi/baichuan-v2-13B-aot-code"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

## 设置部署使用的镜像初始化model

In [None]:
image_uri = image_uris.retrieve(
        framework="djl-tensorrtllm",
        region=sess.boto_session.region_name,
        version="0.26.0"
    )

model = Model(image_uri=image_uri, model_data=code_artifact, role=role)

## 开始部署

In [None]:
instance_type = "ml.g5.12xlarge"
endpoint_name = sagemaker.utils.name_from_base("baichuan-v2-13B-aot-2p-64")

model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name,
             # container_startup_health_check_timeout=3600
            )

# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer()
)

## 测试

In [None]:
response = predictor.predict(
    {"inputs": "世界上第二高的山峰是哪座", "parameters": {"max_new_tokens":128, "top_k":5, "repetition_penalty": 1.05, "top_p": 0.85, "pad_id":0,"temperature":0.3}}
)

text = str(response, 'utf-8')
text

## 删除部署的endpoint以及对应的cogfug

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()