In [16]:
!rm -rf code && mkdir code

In [17]:
%%writefile code/requirements.txt
-i https://pypi.tuna.tsinghua.edu.cn/simple
diffusers
ftfy
spacy
boto3
sagemaker
nvgpu
sentencepiece
protobuf>=3.19.5,<3.20.1
transformers==4.32.0
icetk
cpm_kernels
accelerate
colorama
bitsandbytes
transformers_stream_generator
xformers

Writing code/requirements.txt


In [18]:
%%writefile code/inference.py
import os
import json
import uuid
import io
import sys
import traceback

from PIL import Image

import requests
import boto3
import sagemaker
import torch
import shutil

from torch import autocast
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.utils import GenerationConfig


LLM_NAME = "/opt/amazon/var/run/baichuan2"

if os.path.exists(LLM_NAME):
    shutil.rmtree(LLM_NAME) 

os.mkdir(LLM_NAME)

#替换成自己的S3模型路径
s3_location = "s3://sagemaker-cn-north-1-086238767671/baichuan/Baichuan2-13B-Chat-4bits/"
os.system(f"aws s3 cp {s3_location} {LLM_NAME} --recursive")

tokenizer = AutoTokenizer.from_pretrained(LLM_NAME, trust_remote_code=True)


def preprocess(text):
    text = text.replace("\n", "\\n").replace("\t", "\\t")
    return text

def postprocess(text):
    return text.replace("\\n", "\n").replace("\\t", "\t")

def answer(text, sample=True, top_p=0.45, temperature=0.01, model=None):
    text = preprocess(text)
    messages = []
    messages.append({"role": "user", "content": text})
    response = model.chat(tokenizer, messages)
        
    return postprocess(response)


def model_fn(model_dir):
    """
    Load the model for inference
    
    """
    print("=================model_fn_Start=================")
    model = AutoModelForCausalLM.from_pretrained(LLM_NAME, device_map="auto", trust_remote_code=True)
    model.generation_config = GenerationConfig.from_pretrained(LLM_NAME)
    print("=================model_fn_End=================")
    return model


def input_fn(request_body, request_content_type):
    """
    Deserialize and prepare the prediction input
    """
    # {
    # "ask": "写一个文章，题目是未来城市"
    # }
    print(f"=================input_fn=================\n{request_content_type}\n{request_body}")
    input_data = json.loads(request_body)
    if 'ask' not in input_data:
        input_data['ask']="写一个文章，题目是未来城市"
    return input_data


def predict_fn(input_data, model):
    """
    Apply model to the incoming request
    """
    print("=================predict_fn=================")
   
    print('input_data: ', input_data)
    

    try:
        #if 'history' not in input_data:
        #    history = []
        #else:
        #    history = input_data['history']
        if 'temperature' not in input_data:
            temperature = 0.01
        else:
            temperature = input_data['temperature']
        #result, history = answer(input_data['ask'], history=history, model=model)
        result = answer(input_data['ask'], model=model)
        print(f'====result {result}====')
        return result
        
    except Exception as ex:
        traceback.print_exc(file=sys.stdout)
        print(f"=================Exception================={ex}")

    return 'Not found answer'


def output_fn(prediction, content_type):
    """
    Serialize and prepare the prediction output
    """
    print(content_type)
    return json.dumps(
        {
            'answer': prediction
        }
    )

Writing code/inference.py


In [19]:
import boto3
import sagemaker

account_id = boto3.client('sts').get_caller_identity().get('Account')
region_name = boto3.session.Session().region_name

sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
role = sagemaker.get_execution_role()

print(role)
print(bucket)
print(region_name)

!touch dummy
!tar czvf model.tar.gz dummy
assets_dir = 's3://{0}/{1}/assets/'.format(bucket, 'llm_baichuan2_13b-chat-4bits')
model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(bucket, 'llm_baichuan2_13b-chat-4bits')
!aws s3 cp model.tar.gz $assets_dir
!rm -f dummy model.tar.gz

model_name = None
entry_point = 'inference.py'
framework_version = '2.0.1'
py_version = 'py310'
model_environment = {
    'SAGEMAKER_MODEL_SERVER_TIMEOUT':'600', 
    'SAGEMAKER_MODEL_SERVER_WORKERS':'1', 
}

from sagemaker.pytorch.model import PyTorchModel

model = PyTorchModel(
    name = model_name,
    model_data = model_data,
    entry_point = entry_point,
    source_dir = './code',
    role = role,
    framework_version = framework_version, 
    py_version = py_version,
    env = model_environment
)

endpoint_name = 'pytorch-inference-baichuan2'
# instance_type = 'ml.p3.2xlarge'
instance_type='ml.g4dn.2xlarge' 

instance_count = 1

from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
predictor = model.deploy(
    endpoint_name = endpoint_name,
    instance_type = instance_type, 
    initial_instance_count = instance_count,
    serializer = JSONSerializer(),
    deserializer = JSONDeserializer()
)

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
arn:aws-cn:iam::086238767671:role/NotebookStack-SmartSearchNotebookRole6F6BB12B-1N9S5RYMN3KUG
sagemaker-cn-north-1-086238767671
cn-north-1
dummy
upload: ./model.tar.gz to s3://sagemaker-cn-north-1-086238767671/llm_baichuan2_13b-chat-4bits/assets/model.tar.gz
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not

In [None]:
### 创建阶段，等待10分钟，期间可以查看CloudWatch日志有无异常

In [22]:
inputs= {
    "ask":'你好'

}

response = predictor.predict(inputs)
print(response["answer"])


你好！今天我能为您提供什么帮助？


### 删除SageMaker  Endpoint
删除推理服务

In [None]:
#predictor.delete_endpoint()