### SageMaker streamings LLM Inference 部署(chatglm)


#### 前提条件
1. 亚马逊云账号
2. 建议使用ml.g5.xlarge

### Notebook部署步骤
1. 升级boto3, sagemaker python sdk
2. 编译docker image
3. 部署AIGC推理服务
    * 配置模型参数
    * 部署SageMaker Endpoint 
4. 测试ControlNet模型
5. 清除资源


### 1. 升级boto3, sagemaker python sdk

In [None]:
!pip install --upgrade boto3 sagemaker

In [None]:
#导入对应的库

import re
import os
import json
import uuid

import numpy as np
import pandas as pd
from time import gmtime, strftime


import boto3
import sagemaker

from sagemaker import get_execution_role,session

role = get_execution_role()


sage_session = session.Session()
bucket = sage_session.default_bucket()
aws_region = boto3.Session().region_name


print(f'sagemaker sdk version: {sagemaker.__version__}\nrole:  {role}  \nbucket:  {bucket}')




### 2. 编译docker image

In [None]:
!./build_push.sh

### 3. 部署AIGC推理服务

#### 3.1 创建dummy model_data 文件(真正的模型使用code/infernece.py进行加载)

In [None]:
!touch dummy
!tar czvf model.tar.gz dummy sagemaker-logo-small.png
assets_dir = 's3://{0}/{1}/assets/'.format(bucket, 'chatglm')
model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(bucket, 'chatglm')
!aws s3 cp model.tar.gz $assets_dir
!rm -f dummy model.tar.gz

#### 3.2 创建 model 配置

In [None]:

boto3_session = boto3.session.Session()
current_region=boto3_session.region_name

client = boto3.client("sts")
account_id=client.get_caller_identity()["Account"]

client = boto3.client('sagemaker')

#使用步骤2编译好的docker images
container = f'{account_id}.dkr.ecr.{current_region}.amazonaws.com/streaming-chat'
model_data = f's3://{bucket}/chatglm/assets/model.tar.gz'
model_name = 'AIGC-Quick-Kit-' +  strftime("%Y-%m-%d-%H-%M-%S", gmtime())
role = get_execution_role()



primary_container = {
    'Image': container,
    'ModelDataUrl': model_data,
    'Environment':{
        'SAGEMAKER_MODEL_SERVER_TIMEOUT':'600', 
         'SAGEMAKER_MODEL_SERVER_WORKERS': '1', 
         #'MODEL_TYPE'                    : 'ptuning',
         #'MODEL_TYPE'                    : 'full turning',
         'MODEL_TYPE'                    : 'normal',
         #'MODEL_S3_PATH'                 : 's3://sagemaker-us-west-2-687912291502/llm/models/chatglm/simple/adgen-chatglm-6b-ft/'
         #'MODEL_S3_PATH'                 : 's3://sagemaker-us-west-2-687912291502/llm/models/chatglm/simple/adgen-chatglm-6b-ft/'
    }
}

create_model_response = client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = primary_container,
)

In [None]:
_time_tag = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
_variant_name =  'AIGC-Quick-Kit-'+ _time_tag
endpoint_config_name = 'AIGC-Quick-Kit-' +  _time_tag

response = client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            'VariantName': _variant_name,
            'ModelName': model_name,
            'InitialInstanceCount': 1,
            'InstanceType': 'ml.g5.xlarge',
            'InitialVariantWeight': 1
        },
    ]
)

#### 3.3 部署SageMaker endpoint

In [None]:
endpoint_name = f'AIGC-Quick-Kit-{str(uuid.uuid4())}'


response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name,
    
)



print(f'终端节点:{endpoint_name} 正在创建中，首次启动中会加载模型，请耐心等待, 请在控制台上查看状态')


In [None]:
import time
import uuid
import io
import traceback
from PIL import Image


s3_resource = boto3.resource('s3')

def get_bucket_and_key(s3uri):
    pos = s3uri.find('/', 5)
    bucket = s3uri[5 : pos]
    key = s3uri[pos + 1 : ]
    return bucket, key


def predict_async(endpoint_name,payload):
    runtime_client = boto3.client('runtime.sagemaker')
    input_file=str(uuid.uuid4())+".json"
    s3_resource = boto3.resource('s3')
    s3_object = s3_resource.Object(bucket, f'stablediffusion/asyncinvoke/input/{input_file}')
    payload_data = json.dumps(payload).encode('utf-8')
    s3_object.put( Body=bytes(payload_data))
    input_location=f's3://{bucket}/stablediffusion/asyncinvoke/input/{input_file}'
    print(f'input_location: {input_location}')
    response = runtime_client.invoke_endpoint_async(
        EndpointName=endpoint_name,
        InputLocation=input_location
    )
    result =response.get("OutputLocation",'')
    wait_async_result(result)


def s3_object_exists(s3_path):
    """
    s3_object_exists
    """
    try:
        s3 = boto3.client('s3')
        base_name=os.path.basename(s3_path)
        _,ext_name=os.path.splitext(base_name)
        bucket,key=get_bucket_and_key(s3_path)
        
        s3.head_object(Bucket=bucket, Key=key)
        return True
    except Exception as ex:
        print("job is not completed, waiting...")   
        return False
    
def draw_image(output_location):
    try:
        bucket, key = get_bucket_and_key(output_location)
        obj = s3_resource.Object(bucket, key)
        body = obj.get()['Body'].read().decode('utf-8') 
        predictions = json.loads(body)
        print(predictions['result'])
        for image in predictions['result']:
            bucket, key = get_bucket_and_key(image)
            obj = s3_resource.Object(bucket, key)
            bytes = obj.get()['Body'].read()
            image = Image.open(io.BytesIO(bytes))
            #resize image to 50% size
            half = 0.5
            out_image = image.resize( [int(half * s) for s in image.size] )
            out_image.show()
    except Exception as e:
        print("result is not completed, waiting...")   
    

    
def wait_async_result(output_location,timeout=60):
    current_time=0
    while current_time<timeout:
        if s3_object_exists(output_location):
            print("have async result")
            draw_image(output_location)
            break
        else:
            time.sleep(5)

            
def check_sendpoint_status(endpoint_name):
    client = boto3.client('sagemaker')
    response = client.describe_endpoint(
        EndpointName=endpoint_name
    )
    if response['EndpointStatus'] !='InService':
        raise Exception (f'{endpoint_name} not ready , please wait....')
    else:
        status = response['EndpointStatus']
        print(f'{endpoint_name} is ready, status: {status}')

#### 检查endpoint 状态

In [None]:
check_sendpoint_status(endpoint_name)

4.0 Basic txt2image 测试

In [None]:
#AIGC Quick Kit txt2img
inputs_txt2img = {
    "prompt": "callie, inkling, splatoon, cute, forest background",
    "negative_prompt":"(worst quality, low quality:1.4), simple background",
    "steps":20,
    "sampler":"euler_a",
    "seed": 52362,
    "height": 512, 
    "width": 512,
    "count":2

}
start=time.time()
predict_async(endpoint_name,inputs_txt2img)
print(f"Time taken: {time.time() - start}s")

In [None]:

payload={
                "prompt": "taylor swift, best quality, extremely detailed",
                "negative_prompt":"monochrome, lowres, bad anatomy, worst quality, low quality",
                "steps":20,
                "sampler":"euler_a",
                "seed":43768,
                "height": 512, 
                "width": 512,
                "count":2,
                "control_net_model":"canny",
                "input_image":"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
}

predict_async(endpoint_name,payload)


In [None]:
payload={
                    "prompt": "room",
                    "negative_prompt":"monochrome, lowres, bad anatomy, worst quality, low quality",
                    "steps":20,
                    "sampler":"euler_a",
                    "seed":43768,
                    "height": 512,
                    "width": 512,
                    "count":2,
                    "control_net_model":"mlsd",
                    "input_image":"https://huggingface.co/lllyasviel/sd-controlnet-mlsd/resolve/main/images/room.png"
}
predict_async(endpoint_name,payload)

### 4. 测试

In [None]:
import json
import boto3

client = boto3.client('runtime.sagemaker')

def query_endpoint_with_json_payload(encoded_json):
    response = client.invoke_endpoint(EndpointName=endpoint_name, ContentType='application/json', Body=encoded_json)
    return response

def parse_response_texts(query_response):
    model_predictions = json.loads(query_response['Body'].read())
    generated_text = model_predictions["answer"]
    return generated_text

payload={"ask":"""新加坡有什么好玩的么？"""}

query_response = query_endpoint_with_json_payload(json.dumps(payload2).encode('utf-8'))
generated_texts = parse_response_texts(query_response)
print(generated_texts)

In [None]:
response = client.delete_endpoint(
    EndpointName=endpoint_name
    
)

response = client.delete_endpoint_config(
    EndpointConfigName=endpoint_config_name
)


print(f'终端节点:{endpoint_name} 已经被清除，请在控制台上查看状态')
