# Response streaming using LMI container on SageMaker
In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it. 
We will use the SageMaker Runtime `InvokeEndpointWithResponseStream` API which returns model responses in a stream-like manner.

Note that this API is currently in beta. DO NOT onboard any production use-cases to this API during this beta program.

Please make sure the following permission granted before running the notebook:

- S3 bucket push access
- SageMaker access

## Step 1: Let's bump up SageMaker and import stuff

The wheel installed here is a private preview wheel, you need to add into allowlist to run this function

In [None]:
%pip install sagemaker pip --upgrade  --quiet

In [None]:
# Note the following may error depending on which awscli is installed in your jupyter kernel, 
# but that is ok 

%pip install botocore-*-py3-none-any.whl boto3-*-py3-none-any.whl --force

In [None]:
!aws configure add-model --service-model file://runtime.sagemaker-2017-05-13.normal.json --service-name sagemaker-runtime-demo

In [83]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

boto3_session=boto3.session.Session(region_name="us-west-2")
smr = boto3.client('sagemaker-runtime-demo')
sm = boto3.client('sagemaker')
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session(boto3_session, sagemaker_client=sm, sagemaker_runtime_client=smr)  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

In [84]:
print(f"Role: {role}")

Role: arn:aws:iam::687912291502:role/webui-notebook-stack-ExecutionRole-62U5FV4LJQS


## Step 2: Start preparing model artifacts
In LMI contianer, we expect some artifacts to help setting up the model
- serving.properties (required): Defines the model server settings
- model.py (optional): A python file to define the core inference logic
- requirements.txt (optional): Any additional pip wheel need to install

In [85]:
%%writefile serving.properties
engine=Python
option.model_id=THUDM/chatglm-6b
option.tensor_parallel_degree=1
option.dtype=fp16
option.enable_streaming=True
gpu.maxWorkers=1
option.predict_timeout=240

Writing serving.properties


In [86]:
%%writefile requirements.txt
transformers==4.27.1
cpm_kernels

Writing requirements.txt


In [87]:
%%writefile model.py
from djl_python import Input, Output
import os
import deepspeed
import torch
from djl_python.streaming_utils import StreamingUtils
import transformers
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    set_seed,
)

model = None
tokenizer =None

def model_fn(model_dir):
    tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
    print("=================model_fn_Start=================")
    model_s3_path = os.environ['MODEL_S3_PATH']
    print("=================model s3 path=================="+model_s3_path)
    os.system("sudo find / -name s5cmd")
    os.system("s5cmd sync {0} {1}".format(model_s3_path+"*","/tmp/model/"))
    if os.environ["MODEL_TYPE"] == "ptuning":
        config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
        model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
        prefix_state_dict = torch.load(os.path.join("/tmp/model/", "pytorch_model.bin"))
        new_prefix_state_dict = {}
        for k, v in prefix_state_dict.items():
            if k.startswith("transformer.prefix_encoder."):
                new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
        model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

    elif os.environ["MODEL_TYPE"] == "full turning":
        print("====================load full turning=================")
        config = AutoConfig.from_pretrained("/tmp/model/", trust_remote_code=True, pre_seq_len=128)
        model = AutoModel.from_pretrained("/tmp/model/", trust_remote_code=True)
    else:
        print("====================load normal ======================")
        config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
        model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)

    #model = model.to("cuda")
    model = model.quantize(4)
    model = model.half().cuda()
    model = model.eval()
    print("=================model_fn_End=================")
    return model,tokenizer



def get_model(properties):
    return model_fn(None)

def stream(query, history):
    if query is None or history is None:
        yield {"query": "", "response": "", "history": [], "finished": True}
    size = 0
    response = ""
    for response, history in model.stream_chat(tokenizer, query, history):
        this_response = response[size:]
        history = [list(h) for h in history]
        size = len(response)
        #yield {"delta": this_response, "response": response, "finished": False}
        yield {"response": response}
    #yield {"query": query, "delta": "[EOS]", "response": response, "history": history, "finished": True}
    yield {"response": response}




def handle(inputs: Input) -> None:
    global model, tokenizer
    
    if not model:
        model,tokenizer = get_model(None)
    
    if inputs.is_empty():
        # Model server makes an empty call to warmup the model on startup
        return None

    data = inputs.get_as_json()
    outputs = Output()
    outputs.add_property("content-type", "application/jsonlines")
    outputs.add_stream_content(stream(data['query'],data['history']))
    return outputs

Writing model.py


In [88]:
%%sh
mkdir mymodel
mv serving.properties mymodel/
mv model.py mymodel/
mv requirements.txt mymodel/
# remove the following lines if not needed
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel

mymodel/
mymodel/serving.properties
mymodel/requirements.txt
mymodel/model.py


## Step 3: Start building SageMaker endpoint
In this step, we will build SageMaker endpoint from scratch

### Getting the container image URI

Available framework are:
- djl-deepspeed (0.20.0, 0.21.0)
- djl-fastertransformer (0.21.0)

In [89]:
image_uri = image_uris.retrieve(
    framework="djl-deepspeed",
    region=sess.boto_session.region_name,
    version="0.22.1"
)

### Upload artifact on S3 and create SageMaker model

In [90]:
s3_code_prefix = "large-model-lmi/code"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")
env = {"HUGGINGFACE_HUB_CACHE": "/tmp", "TRANSFORMERS_CACHE": "/tmp",
    #'MODEL_TYPE'                    : 'ptuning',
    #'MODEL_TYPE'                    : 'full turning',
    'MODEL_TYPE'                    : 'normal',
    #'MODEL_S3_PATH'                 : 's3://sagemaker-us-west-2-687912291502/llm/models/chatglm/simple/adgen-chatglm-6b-ft/'
    'MODEL_S3_PATH'                 : 's3://sagemaker-us-west-2-687912291502/llm/models/chatglm/simple/adgen-chatglm-6b-ft/'
}

model = Model(sagemaker_session=sess, image_uri=image_uri, model_data=code_artifact, env=env, role=role)

S3 Code or Model tar ball uploaded to --- > s3://sagemaker-us-west-2-687912291502/large-model-lmi/code/mymodel.tar.gz


### 4.2 Create SageMaker endpoint

You need to specify the instance to use and endpoint names

In [91]:
instance_type = "ml.g5.4xlarge"
endpoint_name = sagemaker.utils.name_from_base("streaming-test")
model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name
)

-----------------!

## Step 5: Test and benchmark the inference

In [92]:
import io


class StreamScanner:
    """
    A helper class for parsing the InvokeEndpointWithResponseStream event stream. 
    
    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'readlines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """
    
    def __init__(self):
        self.buff = io.BytesIO()
        self.read_pos = 0
        
    def write(self, content):
        self.buff.seek(0, io.SEEK_END)
        self.buff.write(content)
        
    def readlines(self):
        self.buff.seek(self.read_pos)
        for line in self.buff.readlines():
            if line[-1] != b'\n':
                self.read_pos += len(line)
                yield line[:-1]
                
    def reset(self):
        self.read_pos = 0

In [112]:
import json
import re
def extract_unicode_chars(text):
    pattern = r'\\u([\dA-Fa-f]{4})'
    result = re.sub(pattern, lambda m: chr(int(m.group(1), 16)), text)
    return result


                
endpoint_name="streaming-test-2023-06-29-14-40-53-355"
body = {"query": "有什么缓解颈椎劳损的方法？", "history":[],"parameters": {"max_new_tokens":200, "enable_sampling": "true"}}
resp = smr.invoke_endpoint_with_response_stream(EndpointName=endpoint_name, Body=json.dumps(body), ContentType="application/json")
event_stream = resp['Body']
scanner = StreamScanner()
for event in event_stream:
    eventJson=event['PayloadPart']['Bytes'].decode('utf-8')
    output=extract_unicode_chars(eventJson)
    print(output)

    #m = p.search(eventJson)
    #if m!= None:
    #   print(repr(m.group()))
    #print(json.loads(eventJson)['outputs']['response'])
    #scanner.write(event['PayloadPart']['Bytes'])
    #for line in scanner.readlines():
    #    resp = json.loads(line)
    #    #print(resp.get("outputs")[0], end='')
    #    print(resp.get("outputs"), end='')
        

{"outputs": {"response": ""}}

{"outputs": {"response": "颈椎"}}

{"outputs": {"response": "颈椎劳损"}}

{"outputs": {"response": "颈椎劳损是一种常见的"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，可能会导致"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，可能会导致疼痛"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，可能会导致疼痛、"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，可能会导致疼痛、僵硬"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，可能会导致疼痛、僵硬和"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，可能会导致疼痛、僵硬和肌肉"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，可能会导致疼痛、僵硬和肌肉疲劳"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，可能会导致疼痛、僵硬和肌肉疲劳等问题"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，可能会导致疼痛、僵硬和肌肉疲劳等问题。"}
}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，可能会导致疼痛、僵硬和肌肉疲劳等问题。以下"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，可能会导致疼痛、僵硬和肌肉疲劳等问题。以下是"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，可能会导致疼痛、僵硬和肌肉疲劳等问题。以下是一些"}}

{"outputs": {"response": "颈椎劳损是一种常见的疾病，可能会导致疼痛、僵硬和肌肉疲劳等问题。以下是一些缓解"}}

{

## Clean up the environment

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()