## Testing Redpajama 7B model

- Redpajama 7B chat model : https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat
- Download models from HF model hub (RedPajama 7B Chat)
- Local testing
- DJL deploy and testing



In [None]:
!pip install -q transformers accelerate sentencepiece bitsandbytes

In [None]:
import sagemaker
import transformers
print(sagemaker.__version__)
print(transformers.__version__)

In [None]:
!pip list | grep scipy

In [None]:
from huggingface_hub import snapshot_download
from pathlib import Path
import os

local_model_path = Path("./pretrained-models")
local_model_path.mkdir(exist_ok=True)
model_name = "togethercomputer/RedPajama-INCITE-7B-Chat"
allow_patterns = ["*.json", "*.pt", "*.bin", "*.txt", "*.model", "*.py"]

model_download_path = snapshot_download(
    repo_id=model_name,
    cache_dir=local_model_path,
    allow_patterns=allow_patterns,
)

In [None]:
print(f"Local model download path: {model_download_path}")

### Local mode testing

- Testing model on local mode

In [None]:
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
# init
tokenizer = AutoTokenizer.from_pretrained(model_download_path)
model = AutoModelForCausalLM.from_pretrained(
    model_download_path,
    device_map='auto',
    torch_dtype=torch.float16,
    load_in_8bit=True)

In [None]:
# query = "could you recommend the places in korea to travel with my baby and wife?"
# query = "How to convert standard s3 class to glacier with code in Java?"
query = "Could you show me the code sample to upload large file on s3 in typescript?"

In [None]:
prompt = f"<human>: {query}\n<bot>:"

In [None]:
print(prompt)

In [None]:
# Stopping condition from: https://discuss.huggingface.co/t/implimentation-of-stopping-criteria-list/20040/7

from transformers import StoppingCriteria, StoppingCriteriaList

stop_words = ["<human>:", "<bot>:"]

class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops = [], encounters=1):
        super().__init__()
        self.stops = [stop.to("cuda") for stop in stops]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True

        return False
    
stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
print(f"Stop word ids: {stop_words_ids}")
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

In [None]:
%%time
inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
input_length = inputs.input_ids.shape[1]
outputs = model.generate(
    **inputs,
    max_new_tokens=512,
    do_sample=True,
    temperature=0.5,
    top_p=0.5,
    top_k=50,
    return_dict_in_generate=True,
    early_stopping=True,
    stopping_criteria=stopping_criteria
)


In [None]:
token = outputs.sequences[0, input_length:]
output_str = tokenizer.decode(token)

In [None]:
# print(output_str)

In [None]:
def remove_stopword(output, stop_words):
    for stop_word in stop_words:
        if output[-len(stop_word):] == stop_word:
            return output[:-len(stop_word)]
    return output

result = remove_stopword(output_str, stop_words)
print(result)

### SageMaker Deployment testing

- Deploy model to SageMaker endpoint using DJL


### TODO
- DeepSpeed wrapping (ing)
- int8 quantization g4dn.2xlarge deployment
- Async inference


### Test
- g5.4xlarge int8 : 15~20s
- g5.4xlarge fp16 deepspeed : 15~20s (The result is strange)
- g4dn.2xlarge int8 : very slow ...!


In [None]:
s3_model_prefix = "llm/redpajama/model"  # folder where model checkpoint will go

In [None]:
base_model_s3 = f"{s3_model_prefix}/chat-7b"

In [None]:
sagemaker_session = sagemaker.Session()
s3_model_artifact = sagemaker_session.upload_data(path=model_download_path, key_prefix=base_model_s3)

In [None]:
print(f"Model s3 uri : {s3_model_artifact}")

In [None]:
import boto3
import sagemaker
from sagemaker.utils import name_from_base
from sagemaker import image_uris

In [None]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
sm_client = sagemaker_session.sagemaker_client
sm_runtime_client = sagemaker_session.sagemaker_runtime_client

In [None]:
print(f"sagemaker role: {role}")

In [None]:
# llm_engine = "deepspeed"
llm_engine = "fastertransformer"

In [None]:
framework_name = f"djl-{llm_engine}"
inference_image_uri = image_uris.retrieve(
    framework=framework_name, region=sagemaker_session.boto_session.region_name, version="0.21.0"
)

print(f"Inference container uri: {inference_image_uri}")

In [None]:
src_dir_name = f"redpajama-7b-src"
s3_target = f"s3://{sagemaker_session.default_bucket()}/llm/redpajama/code/"

In [None]:
!rm -rf {src_dir_name}.tar.gz
!tar zcvf {src_dir_name}.tar.gz {src_dir_name} --exclude ".ipynb_checkpoints" --exclude "__pycache__"
!aws s3 cp {src_dir_name}.tar.gz {s3_target}

In [None]:
model_uri = f"{s3_target}{src_dir_name}.tar.gz"
print(model_uri)

In [None]:
model_name = name_from_base(f"redpajama-7b-djl")
print(model_name)

create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": inference_image_uri, "ModelDataUrl": model_uri},
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

In [None]:
instance_type = "ml.g4dn.2xlarge"
# instance_type = "ml.g5.4xlarge"

endpoint_config_name = f"{model_name}-config"
endpoint_name = f"{model_name}-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": 600,
        },
    ],
)
print(endpoint_config_response)

In [None]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

In [None]:
import json

In [None]:
# query = "Do you know why the italy and spain had a economic crisis before?"
query = "Can you recommend my newborn baby's name?"

prompt = f"<human>: {query}\n<bot>:"

print(prompt)

In [None]:
%%time
prompts = [prompt]

response_model = sm_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps(
        {
            "text": prompts,
            "parameters": {
                # "max_new_tokens": 512,
                "max_new_tokens": 128,
                "temperature": 0.5,
                "do_sample": True,
                "top_p": 0.5,
                "top_k": 50,
                "early_stopping": True
            },
        }
    ),
    ContentType="application/json",
)

In [None]:
output = str(response_model["Body"].read(), "utf-8")
print(output)

### Deploy to async endpoint

- LLM takes long time so real time inference is not a good way to use it.

In [None]:
default_bucket = sagemaker_session.default_bucket()
async_output_uri = f"s3://{default_bucket}/llm/outputs/{model_name}/"
print(async_output_uri)

In [None]:
instance_type = "ml.g4dn.xlarge"
# instance_type = "ml.g5.2xlarge"

endpoint_config_name = f"{model_name}-async-config"
endpoint_name = f"{model_name}-async-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": 600,
        },
    ],
    AsyncInferenceConfig={
        "OutputConfig": {
            "S3OutputPath": async_output_uri,
        },
        "ClientConfig": {
            "MaxConcurrentInvocationsPerInstance": 1
        }
    }
)
print(endpoint_config_response)

In [None]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

In [None]:
import uuid
import boto3
s3_client = boto3.client('s3')

In [None]:
prompts = [prompt]
input_data = {
    "text": prompts,
    "parameters": {
        # "max_new_tokens": 512,
        "max_new_tokens": 128,
        "temperature": 0.5,
        "do_sample": True,
        "top_p": 0.5,
        "top_k": 50,
        "early_stopping": True
    },
}
print(input_data)

In [None]:
# Upload input data onto the S3
s3_uri = f"llm/inputs/{model_name}/{uuid.uuid4()}.json"
s3_client.put_object(
    Bucket=default_bucket,
    Key=s3_uri,
    Body=json.dumps(input_data))

input_data_uri = f"s3://{default_bucket}/{s3_uri}"
input_location = input_data_uri

In [None]:
response = sm_runtime_client.invoke_endpoint_async(
    EndpointName=endpoint_name, 
    InputLocation=input_location
)
output_location = response["OutputLocation"]
print(output_location)
output_key_uri = "/".join(output_location.split("/")[3:])

In [None]:
try:
    exists = s3_client.head_object(Bucket=default_bucket, Key=output_key_uri)['ResponseMetadata']['HTTPStatusCode'] == 200
    if exists:
        text_obj = s3_client.get_object(Bucket=default_bucket, Key=output_key_uri)['Body'].read()
        text = text_obj.decode('utf-8')
        print(text)
except:
    print("Data is not exist yet. Wait until inference finished or check the CW log")