### 1. 下载模型到本地

In [None]:
# For notebook instances (Amazon Linux)
!sudo yum update -y
!sudo yum install amazon-linux-extras
!sudo amazon-linux-extras install epel -y
!sudo yum update -y
!sudo yum install git-lfs git -y

In [None]:
#下载模型snapshot到本地，需要25G空间
#需大约15-30分钟时间，请耐心等待, 如果左侧大括号内还是[*]，就还在下载中，*变成任意数例如[3]就证明已完成

from pathlib import Path
local_model_path = Path("./Baichuan2-13B-Chat-4bits")
local_model_path.mkdir(exist_ok=True)
model_name = "Baichuan-inc/Baichuan2-13B-Chat-4bits"
clone_path = f"https://www.wisemodel.cn/{model_name}.git"
print(clone_path)

!git lfs install
!git clone $clone_path
!cd ./Baichuan2-13B-Chat-4bits && rm -rf .git

### 2. 把模型拷贝到S3为后续部署做准备

In [None]:
import sagemaker
import boto3

s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

role = sagemaker.get_execution_role()
sagemaker_session = sagemaker.Session()
sagemaker_session_bucket = sagemaker_session.default_bucket()

region = sagemaker_session._region_name
account_id = sagemaker_session.account_id()
bucket = sagemaker_session.default_bucket()

s3_code_prefix = f"lmi_inference_code/{model_name.split('/')[-1]}"

s3_location = f"s3://{sagemaker_session_bucket}/llm_model/{model_name.split('/')[-1]}/"

#你也可以把local_model_path直接替换成你的模型路径，例"model_snapshot_path=./chatglm3-6b", 这个文件夹里需要包含config.json
model_snapshot_path = local_model_path

print(f"model_snapshot_path: {model_snapshot_path}")
print("s3_location:",s3_location)
print("s3_code_prefix:",s3_code_prefix)

In [None]:
#上传模型
!aws s3 sync $model_snapshot_path $s3_location

### 3. 模型部署准备（entrypoint脚本，容器镜像，服务配置）

In [None]:
inference_image_uri = (
    f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.24.0-deepspeed0.10.0-cu118"
)
if "cn-" in region:
    inference_image_uri = (
        f"727897471807.dkr.ecr.{region}.amazonaws.com.cn/djl-inference:0.24.0-deepspeed0.10.0-cu118"
    )


print(f"Image going to be used is ---- > {inference_image_uri}")

In [None]:
!mkdir -p LLM_baichuan_deploy_code

In [None]:
%%writefile LLM_baichuan_deploy_code/model.py
from djl_python import Input, Output
import torch
import logging
import math
import os
from transformers import pipeline, AutoModel, AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
import deepspeed

from transformers.generation.utils import GenerationConfig


def load_model(properties):
    tensor_parallel_degree = properties["tensor_parallel_degree"]
    model_location = properties['model_dir']
    if "model_id" in properties:
        model_location = properties['model_id']
    logging.info(f"Loading model in {model_location}")
    
    
    print('============================tokenizer====================')
    
    tokenizer = AutoTokenizer.from_pretrained(model_location, trust_remote_code=True)
    
    print('============================model====================')
    
    model = AutoModelForCausalLM.from_pretrained(model_location, device_map="auto",
                                                 trust_remote_code=True)
    
    model.generation_config = GenerationConfig.from_pretrained(model_location)
    
    return model, tokenizer


model = None
tokenizer = None
generator = None

def handle(inputs: Input):
    global model, tokenizer
    if not model:
        model, tokenizer = load_model(inputs.get_properties())

    if inputs.is_empty():
        return None
    data = inputs.get_as_json()
    
    input_sentences = data["ask"]
    # params = data["parameters"]

    messages = []
    messages.append({"role": "user", "content": input_sentences})
    response = model.chat(tokenizer, messages)
    
    result = {"answer": response}
    return Output().add_as_json(result)

In [None]:
%%writefile LLM_baichuan_deploy_code/serving.properties
engine=Python
option.tensor_parallel_degree=1
option.model_id=s3://sagemaker-us-west-2-687912291502/llm/models/LLM_baichuan_model/

In [None]:
#将模型的s3路径更新到inference.py中
!sed -i 's|option.model_id=.*|option.model_id={s3_location}|' LLM_baichuan_deploy_code/serving.properties

In [None]:
%%writefile LLM_baichuan_deploy_code/requirements.txt
-i https://pypi.tuna.tsinghua.edu.cn/simple
transformers==4.33.1
accelerate>=0.17.1
einops

In [None]:
!rm model.tar.gz
!cd LLM_baichuan_deploy_code && rm -rf ".ipynb_checkpoints"
!tar czvf model.tar.gz LLM_baichuan_deploy_code

In [None]:
s3_code_artifact = sagemaker_session.upload_data("model.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {s3_code_artifact}")

### 4. 创建模型 & 创建endpoint

In [None]:
from sagemaker.utils import name_from_base
import boto3

model_name = 'pytorch-inference-llm-v1'
print(model_name)
print(f"Image going to be used is ---- > {inference_image_uri}")

create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={
        "Image": inference_image_uri,
        "ModelDataUrl": s3_code_artifact
    },
    
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

endpoint_config_name = model_name
endpoint_name = model_name
endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": "ml.g4dn.4xlarge",
            "InitialInstanceCount": 1,
            # "VolumeSizeInGB" : 400,
            # "ModelDataDownloadTimeoutInSeconds": 2400,
            "ContainerStartupHealthCheckTimeoutInSeconds": 15*60,
        },
    ],
)

print(endpoint_config_response)

create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

#### 持续检测模型部署进度

In [None]:
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

In [None]:
import json

endpoint_name = "pytorch-inference-llm-v1"
prompts1 = """
你是MySQL的专家。给定一个输入问题，创建一个语法正确的MySQL查询语句。
除非用户在问题中指定了要获得的特定数量的示例，否则使用LIMIT子句查询最多3个结果。您可以对结果进行排序，以返回数据库中信息量最大的数据。您必须仅查询回答问题所需的列。将每个列名用反引号（`）括起来，表示为分隔的标识符。
请注意，仅可以使用在下面这些表中看到的列名，不要查询不存在的列。此外，还要注意哪个列在哪个表中。如果问题涉及”今天”，请注意使用CURDATE()函数获取当前日期.

使用如下格式:
Question: 具体的问题
SQLQuery: 运行的sql语句
SQLResult: SQLQuery运行的结果
Answer: 最终的回答


使用如下的表:
CREATE TABLE customer (
	c_customer_sk INTEGER NOT NULL, 
	c_customer_id CHAR(16) NOT NULL, 
	c_current_cdemo_sk INTEGER, 
	c_current_hdemo_sk INTEGER, 
	c_current_addr_sk INTEGER, 
	c_first_shipto_date_sk INTEGER, 
	c_first_sales_date_sk INTEGER, 
	c_salutation CHAR(10), 
	c_first_name CHAR(20), 
	c_last_name CHAR(30), 
	c_preferred_cust_flag CHAR(1), 
	c_birth_day INTEGER, 
	c_birth_month INTEGER, 
	c_birth_year INTEGER, 
	c_birth_country VARCHAR(20), 
	c_login CHAR(13), 
	c_email_address CHAR(50), 
	c_last_review_date CHAR(10), 
	PRIMARY KEY (c_customer_sk)
)ENGINE=InnoDB DEFAULT CHARSET=utf8


CREATE TABLE web_sales (
	ws_sold_date_sk INTEGER, 
	ws_sold_time_sk INTEGER, 
	ws_ship_date_sk INTEGER, 
	ws_item_sk INTEGER NOT NULL, 
	ws_bill_customer_sk INTEGER, 
	ws_bill_cdemo_sk INTEGER, 
	ws_bill_hdemo_sk INTEGER, 
	ws_bill_addr_sk INTEGER, 
	ws_ship_customer_sk INTEGER, 
	ws_ship_cdemo_sk INTEGER, 
	ws_ship_hdemo_sk INTEGER, 
	ws_ship_addr_sk INTEGER, 
	ws_web_page_sk INTEGER, 
	ws_web_site_sk INTEGER, 
	ws_ship_mode_sk INTEGER, 
	ws_warehouse_sk INTEGER, 
	ws_promo_sk INTEGER, 
	ws_order_number INTEGER NOT NULL, 
	ws_quantity INTEGER, 
	ws_wholesale_cost DECIMAL(7, 2), 
	ws_list_price DECIMAL(7, 2), 
	ws_sales_price DECIMAL(7, 2), 
	ws_ext_discount_amt DECIMAL(7, 2), 
	ws_ext_sales_price DECIMAL(7, 2), 
	ws_ext_wholesale_cost DECIMAL(7, 2), 
	ws_ext_list_price DECIMAL(7, 2), 
	ws_ext_tax DECIMAL(7, 2), 
	ws_coupon_amt DECIMAL(7, 2), 
	ws_ext_ship_cost DECIMAL(7, 2), 
	ws_net_paid DECIMAL(7, 2), 
	ws_net_paid_inc_tax DECIMAL(7, 2), 
	ws_net_paid_inc_ship DECIMAL(7, 2), 
	ws_net_paid_inc_ship_tax DECIMAL(7, 2), 
	ws_net_profit DECIMAL(7, 2), 
	PRIMARY KEY (ws_item_sk, ws_order_number)
)ENGINE=InnoDB DEFAULT CHARSET=utf8

Question: 我需要知道销售报表中，下单金额最大的客户email地址
"""

prompts2="给我一个青海和甘肃旅游的路线，8天7晚"
prompts3="好累啊"
parameters={
    "do_sample": False,
    "top_p": 0.9,
    "temperature": 1,
    "max_new_tokens": 300,
    "repetition_penalty": 1.03
}
response_model = smr_client.invoke_endpoint(
            EndpointName=endpoint_name,
            Body=json.dumps(
            {
                "ask": prompts3,
                "parameters": parameters,
            }
            ),
            ContentType="application/json",
        )

response_model['Body'].read().decode("utf-8")

#### 清除模型Endpoint和config

In [None]:
!aws sagemaker delete-endpoint --endpoint-name pytorch-inference-llm-v1

In [None]:
!aws sagemaker delete-endpoint-config --endpoint-config-name pytorch-inference-llm-v1

In [None]:
!aws sagemaker delete-model --model-name pytorch-inference-llm-v1