# Rinna NeoX SageMaker Inference

This is a sample code to deploy [Rinna NeoX](https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft) on SageMaker.

In [2]:
%pip install sagemaker pip boto3 botocore --upgrade  --quiet

[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [3]:
from sagemaker.huggingface import get_huggingface_llm_image_uri
from sagemaker.huggingface import HuggingFaceModel
from sagemaker.session import Session
import sagemaker

sagemaker_session = Session()
role = sagemaker_session.get_caller_identity_arn()

## Deploy Model

In [7]:
hf_model_id = "rinna/japanese-gpt-neox-3.6b-instruction-sft-v2" # model id from huggingface.co/models
number_of_gpus = 1 # number of gpus to use for inference and tensor parallelism
health_check_timeout = 300 # Increase the timeout for the health check to 5 minutes for downloading the model
instance_type = "ml.g5.2xlarge" # instance type to use for deployment

In [8]:
llm_image = get_huggingface_llm_image_uri(
    "huggingface",
    version="0.9.3"
)
endpoint_name = sagemaker.utils.name_from_base("tgi-model-rinna-3-6b")
llm_model = HuggingFaceModel(
    role=role,
    image_uri=llm_image,
    env={
        'HF_MODEL_ID': hf_model_id,
        # 'REVISION': '2140541486bfb31269acd035edd51208da40185b',
        # 'HF_MODEL_QUANTIZE': "bitsandbytes", # comment in to quantize
        'SM_NUM_GPUS': str(number_of_gpus),
        'MAX_INPUT_LENGTH': "1024",  # Max length of input text
        'MAX_TOTAL_TOKENS': "2048",  # Max length of the generation (including input text)
        "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
    }
)
llm = llm_model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    container_startup_health_check_timeout=health_check_timeout,
    endpoint_name=endpoint_name,
)

----------!

## Run Inference

In [18]:
import json
import boto3
import logging
import io

boto3.set_stream_logger("",logging.INFO)
smr = boto3.client('sagemaker-runtime')

endpoint_name = llm.endpoint_name

stop_token = '</s>'
prompt = [
    {
        "speaker": "ユーザー",
        "text": "日本で旅行におすすめのスポットは？"
    }
]
prompt = [
    f"{uttr['speaker']}: {uttr['text']}"
    for uttr in prompt
]
prompt = "<NL>".join(prompt)
prompt = (
    prompt
    + "<NL>"
    + "システム: "
)
print(prompt)
body = {
    "inputs":prompt,
    "parameters":{
        "max_new_tokens": 256,
        # "return_full_text": False,
        "do_sample": True,
        "temperature": 0.1,
        "stop": [stop_token]
    },
    # "stream": True
}

ユーザー: 日本で旅行におすすめのスポットは？<NL>システム: 


In [19]:
response = smr.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType='application/json',
    Accept='application/json',
    Body=json.dumps(body)
)
print(json.loads(response['Body'].read()))

[{'generated_text': 'ユーザー: 日本で旅行におすすめのスポットは？<NL>システム: 「お前は、私がお前を好きだと言ったら、私を好きになるのか?」'}]


### Run Inference (Streaming Output)

In [11]:
body["stream"] = True

In [12]:
class LineIterator:
    """
    A helper class for parsing the byte stream input from TGI container. 
    
    The output of the model will be in the following format:
    ```
    b'data:{"token": {"text": " a"}}\n\n'
    b'data:{"token": {"text": " challenging"}}\n\n'
    b'data:{"token": {"text": " problem"
    b'}}'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. It will also save any pending 
    lines that doe not end with a '\n' to make sure truncations are concatinated
    """
    
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])

In [13]:
resp = smr.invoke_endpoint_with_response_stream(EndpointName=endpoint_name, Body=json.dumps(body), ContentType='application/json')
print(resp)
event_stream = resp['Body']
start_json = b'{'
for line in LineIterator(event_stream):
    if line != b'' and start_json in line:
        data = json.loads(line[line.find(start_json):].decode('utf-8'))
        if data['token']['text'] != stop_token:
            print(data['token']['text'],end='')

{'ResponseMetadata': {'RequestId': '783d7425-c38f-47ca-937a-e886a75795eb', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '783d7425-c38f-47ca-937a-e886a75795eb', 'x-amzn-invoked-production-variant': 'AllTraffic', 'x-amzn-sagemaker-content-type': 'text/event-stream', 'date': 'Mon, 11 Sep 2023 01:49:01 GMT', 'content-type': 'application/vnd.amazon.eventstream', 'transfer-encoding': 'chunked', 'connection': 'keep-alive'}, 'RetryAttempts': 0}, 'ContentType': 'text/event-stream', 'InvokedProductionVariant': 'AllTraffic', 'Body': <botocore.eventstream.EventStream object at 0x7fd88a963a90>}
「はい、もちろんです。お手伝いできることがあれば、お知らせください。」

## Delete Endpoint

In [6]:
llm.delete_model()
llm.delete_endpoint()

NameError: name 'llm' is not defined