# Deploy HF LLMs to Inferentia2 and SageMaker

**SageMaker Studio Kernel**: Python 3 (PyTorch 1.13 Python 3.9 CPU Optimized)  
**Instance**: ml.t3.medium

## 1) Update SageMaker SDK

In [None]:
%pip install -U sagemaker sagemaker-studio-image-build

## 2) Initialize session

In [None]:
import os
import boto3
import sagemaker

print(sagemaker.__version__)
if not sagemaker.__version__ >= "2.146.0": print("You need to upgrade or restart the kernel if you already upgraded")

sess = sagemaker.Session()
role = sagemaker.get_execution_role()
bucket = sess.default_bucket()
region = sess.boto_region_name
account_id = boto3.client("sts").get_caller_identity()["Account"]

## ATTENTION: Copy your HF Access token to the following variable
HF_TOKEN=None

assert not HF_TOKEN is None, "Go to your HF account and get an access token. Set HF_TOKEN to your token"
os.makedirs("src", exist_ok=True)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {bucket}")
print(f"sagemaker session region: {region}")

## 2.1) Build NeuronSDK 2.16 images

In [None]:
import os
os.makedirs("container/training", exist_ok=True)
os.makedirs("container/inference", exist_ok=True)

In [None]:
%%writefile container/training/Dockerfile
ARG REGION
FROM 763104351884.dkr.ecr.${REGION}.amazonaws.com/pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.15.0-ubuntu20.04
RUN apt update -y && \
    apt install -y aws-neuronx-collectives aws-neuronx-runtime-lib aws-neuronx-tools && \
    rm -rf /var/lib/apt/lists/*
RUN pip3 install --extra-index-url https://pip.repos.neuron.amazonaws.com \
    transformers==4.36.2 \
    libneuronxla==0.5.669 \
    torch-neuronx==1.13.1.1.13.0 \
    transformers-neuronx==0.9.474 \
    torch-xla==1.13.1+torchneurond \
    neuronx-cc==2.12.54.0+f631c2365 \
    neuronx-hwm==2.12.0.0+422c9037c

In [None]:
!sm-docker build container/training \
    --repository pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.16.0-ubuntu20.04 \
    --build-arg REGION=$region

In [None]:
%%writefile container/inference/Dockerfile
ARG REGION
FROM 763104351884.dkr.ecr.${REGION}.amazonaws.com/pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.15.0-ubuntu20.04
RUN apt update -y && \
    apt install -y aws-neuronx-collectives aws-neuronx-runtime-lib aws-neuronx-tools && \
    rm -rf /var/lib/apt/lists/*
RUN pip3 install --extra-index-url https://pip.repos.neuron.amazonaws.com \
    transformers==4.36.2 \
    libneuronxla==0.5.669 \
    torch-neuronx==1.13.1.1.13.0 \
    transformers-neuronx==0.9.474 \
    torch-xla==1.13.1+torchneurond \
    neuronx-cc==2.12.54.0+f631c2365 \
    neuronx-hwm==2.12.0.0+422c9037c

In [None]:
!sm-docker build container/inference \
    --repository pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.16.0-ubuntu20.04 \
    --build-arg REGION=$region

## 3) Install additional packages before compiling the model

In [None]:
%%writefile src/requirements.txt
--extra-index-url https://pip.repos.neuron.amazonaws.com
transformers==4.36.2
transformers-neuronx==0.9.474

## 4) Create now Python scripts for compiling and deploying the model

### 4.1) This script will download model weights from HF, split into multiple files and compile the model for a given number of cores

In [None]:
%%writefile src/compile.py
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
import os
os.environ['NEURON_RT_NUM_CORES']=os.environ.get('TP_DEGREE', os.environ.get('SM_NUM_NEURONS', '2'))
os.environ["NEURON_CC_FLAGS"] = "--logfile=/dev/null --model-type=transformer"
import sys
import json
import time
import torch
import shutil
import argparse
import importlib
import traceback
import transformers_neuronx

from importlib import reload
from threading import Thread
from huggingface_hub import login
from filelock import Timeout, FileLock
from transformers_neuronx import constants
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_neuronx.module import save_pretrained_split
from transformers.generation.streamers import TextIteratorStreamer
from transformers_neuronx.config import NeuronConfig, QuantizationConfig

try:
    from ts.protocol.otf_message_handler import send_intermediate_predict_response
except ModuleNotFoundError as e:
    # this is required only for inference not for training
    print("Package TS not found. Streaming disabled.")

lock_path='/tmp/new_packages.lock'
lock = FileLock(lock_path)

def compile_or_load_model(model_dir, model_arch, **kwargs):
    '''
    If the model artifacts are in the model_dir just load the model,
    otherwise, compile it and generate the artifacts.
    '''
    os.environ['NEURONX_DUMP_TO'] = os.path.join(model_dir, "neuron_cache")
    print(kwargs)
    t=time.time()
    print(f"Loading... Model arch: {model_arch}")
    model_class = f"transformers_neuronx.{model_arch.lower()}.model.{model_arch.title()}ForSampling"
    importlib.import_module(f"transformers_neuronx.{model_arch.lower()}.model")
    AutoModelForSampling = eval(model_class)
    
    model = AutoModelForSampling.from_pretrained(os.path.join(model_dir, "model-split"), **kwargs)
    model.to_neuron()
    print(f"Elapsed: {time.time()-t}s")
    return model


def model_fn(model_dir, context=None):
    # this lock is necessary to serialize model loading
    # when you have multiple workers trying to load different
    # copies using the same hardware
    print("Waiting for the lock acquire...")
    lock.acquire()
    conf = json.load(open(os.path.join(model_dir, "params.json"), 'r'))
    fields = ["batch_size", "tp_degree", "amp", "n_positions", "gqa"]
    kwargs = {f:conf[f] for f in fields}
    
    neuron_config = NeuronConfig()
    if kwargs['amp'] == 's8':
        neuron_config.quant=QuantizationConfig(quant_dtype='s8', dequant_dtype='bf16')
        kwargs['amp'] = 'bf16'
        kwargs['neuron_config'] = neuron_config
    
    if not kwargs.get('gqa') is None:
        neuron_config.group_query_attention = eval(f"constants.GQA.{kwargs['gqa']}")
        kwargs['neuron_config'] = neuron_config

    model = compile_or_load_model(model_dir, conf['model_arch'], **kwargs)
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    streamer = TextIteratorStreamer(tokenizer)
    lock.release()
    print("Lock released")
    return model,tokenizer,streamer

def input_fn(input_data, content_type, context=None):
    if content_type == 'application/json':
        req = json.loads(input_data)
        prompt = req.get('prompt')
        seq_len = req.get('sequence_length', 2048)
        top_k = req.get('top_k', 50)
        top_p = req.get('top_p', 1.0)
        temperature = req.get('temperature', 1.0)
        stream = req.get('stream', False) # enables streaming
        if prompt is None or len(prompt) < 3:
            raise("Invalid prompt. Provide an input like: {'prompt': 'text text text'}")
        return prompt,seq_len,top_k,top_p,temperature,stream
    else:
        raise Exception(f"Unsupported mime type: {content_type}. Supported: application/json. Expected keys: prompt,optional[sequence_length,top_k,top_p,temperature,stream]")

def predict_fn(input_object, model_tokenizer_streamer, context=None):
    model,tokenizer,streamer = model_tokenizer_streamer
    prompt,seq_len,top_k,top_p,temperature,stream = input_object
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    if stream:
        # stream the tokens/words to the client as soon as they are decoded
        def predict(model, input_ids, sequence_length, top_k, top_p, temperature, streamer):
            with torch.inference_mode():
                generated_sequences = model.sample(input_ids=input_ids, sequence_length=seq_len, top_k=top_k, top_p=top_p, temperature=temperature, streamer=streamer)
        generation_kwargs = dict(model=model, input_ids=input_ids, sequence_length=seq_len, top_k=top_k, top_p=top_p, temperature=temperature, streamer=streamer)
        thread = Thread(target=predict, kwargs=generation_kwargs)
        thread.start()
        for part in streamer:
            if len(part) == 0: continue
            send_intermediate_predict_response([part], context.request_ids, "Intermediate Prediction success", 200, context)
        thread.join()
        # Do not return anything when streaming, otherwise it will kill the worker
        # this is a workaround that needs to be handled by the client
        raise Warning("__END_OF_PREDICTION__")
    else:
        # collect all the words/tokens before sending it to the customer
        with torch.inference_mode():
            generated_sequences = model.sample(input_ids=input_ids, sequence_length=seq_len, top_k=top_k, top_p=top_p, temperature=temperature)
            return [tokenizer.decode(s) for s in generated_sequences]

if __name__=='__main__':
    parser = argparse.ArgumentParser()

    # hyperparameters sent by the client are passed as command-line arguments to the script.
    parser.add_argument("--model_id", type=str, default="meta-llama/Llama-2-7b-chat-hf")
    parser.add_argument("--model_arch", type=str, required=True)
    parser.add_argument("--hf_access_token", type=str, default=None)
    parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"])

    parser.add_argument("--tp_degree", type=int, default=2)
    parser.add_argument("--n_positions", type=int, default=2048)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--dtype", type=str, default='bf16', choices=['s8', 'bf16', 'fp16', 'fp32'])
    parser.add_argument("--gqa", type=str, default=None, choices=[i.name for i in constants.GQA])

    args, _ = parser.parse_known_args()

    model_class = f"transformers_neuronx.{args.model_arch.lower()}.model.{args.model_arch.title()}ForSampling"
    importlib.import_module(f"transformers_neuronx.{args.model_arch.lower()}.model")
    AutoModelForSampling = eval(model_class)

    if args.hf_access_token:
        login(args.hf_access_token)
    print("Loading model...")
    t=time.time()
    model = AutoModelForCausalLM.from_pretrained(args.model_id)
    print(f"Elapsed: {time.time()-t}s, Spliting and saving...")
    t=time.time()
    save_pretrained_split(model, os.path.join(args.model_dir, "model-split"))
    print(f"Elapsed: {time.time()-t}s, Done")
    print("Saving tokenizer...")
    t=time.time()
    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
    tokenizer.save_pretrained(args.model_dir)
    print(f"Elapsed: {time.time()-t}s, Done")
    print("Copying inference.py")
    code_path = os.path.join(args.model_dir, "code")
    os.makedirs(code_path, exist_ok=True)
    shutil.copy(__file__, os.path.join(code_path, "inference.py"))
    shutil.copy("requirements.txt", os.path.join(code_path, "requirements.txt"))
    kwargs = {
        "batch_size": args.batch_size,
        "amp": args.dtype,
        "tp_degree": args.tp_degree,
        "n_positions": args.n_positions,
        "gqa": args.gqa
    }
    with open(os.path.join(args.model_dir, "params.json"), "w") as c:
        conf = dict(kwargs)
        conf['model_arch'] = args.model_arch
        c.write(json.dumps(conf))
    
    neuron_config = NeuronConfig()
    if kwargs['amp'] == 's8':
        neuron_config.quant=QuantizationConfig(quant_dtype='s8', dequant_dtype='bf16')
        kwargs['amp'] = 'bf16'
        kwargs['neuron_config'] = neuron_config
    
    if not kwargs.get('gqa') is None:
        neuron_config.group_query_attention = eval(f"constants.GQA.{kwargs['gqa']}")
        kwargs['neuron_config'] = neuron_config
    
    compile_or_load_model(args.model_dir, args.model_arch, **kwargs)

## 5) SageMaker (training) Job that will download, split and compile the model

In [None]:
tp_degree=2
dtype='bf16' # s8, bf16, fp16, fp32
batch_size=1
sentence_len=1024
model_id="mistralai/Mistral-7B-Instruct-v0.1" # "meta-llama/Llama-2-7b-chat-hf"
model_arch="mistral" # llama
gqa="SHARD_OVER_HEADS" # None
assert tp_degree==2 or tp_degree==8, "2 = cheapest option with higher latency; 8 = more efficient with lower latency;"

In [None]:
import json
import logging
from sagemaker.pytorch import PyTorch

instance_type='ml.trn1.32xlarge' if tp_degree > 1 else 'ml.trn1.2xlarge'
print(f"Instance type: {instance_type}")
estimator = PyTorch(
    entry_point="compile.py", # Specify your train script
    source_dir="src",
    role=role,
    sagemaker_session=sess,    
    instance_count=1,
    instance_type=instance_type,
    output_path=f"s3://{bucket}/output",
    disable_profiler=True,
    disable_output_compression=True,
    
    image_uri=f"{account_id}.dkr.ecr.{region}.amazonaws.com/pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.16.0-ubuntu20.04",
    
    volume_size = 128,
    hyperparameters={
        "hf_access_token": HF_TOKEN,
        "model_id": model_id,
        "model_arch": model_arch,
        "gqa": gqa,
        "tp_degree": tp_degree,
        "n_positions": sentence_len,
        "dtype": dtype
    }
)
estimator.framework_version = '1.13.1' # workround when using image_uri

In [None]:
# this takes ~21mins on a trn1.32xlarge and ~40mins on a trn1.2xlarge
estimator.fit()

## 6) Deploy the compiled model to a SageMaker endpoint on inf2
Depending on the size of the deployed instance and the number of cores used by the model (**tp_degree**), SageMaker can launch multiple workers. A worker is a standalone Python process that manages one copy of the model. SageMaker puts a load balancer on top of all these processes and distributes the load automatically for your clients. It means that you can increase throughput by launching multiple workers which serve different clients in parallel.

For instance. If you set **tp_degree** to 8 and deploy your model to a **ml.inf2.48xlarge**, SageMaker can launch 3 workers with 3 copies of the model. This instance has 24 cores and each model utilizes in this scenario 8 cores. Then, you can have 3 simultaneous clients invoking the endpoint and being served at the same time.

In [None]:
import logging
from sagemaker.utils import name_from_base
from sagemaker.pytorch.model import PyTorchModel

# depending on the inf2 instance you deploy the model you'll have more or less accelerators
# we'll ask SageMaker to launch 1 worker per core

instance_type_idx=0
## Attention: ml.inf2.xlarge doesnt have enough memory to work with llama7b
instance_types=['ml.inf2.8xlarge', 'ml.inf2.24xlarge','ml.inf2.48xlarge']
num_cores=[2,12,24]
num_workers=num_cores[instance_type_idx]//tp_degree
assert num_workers > 0, f"Instance {instance_types[instance_type_idx]} doesn't support tp_degree={tp_degree}"

print(f"Instance type: {instance_types[instance_type_idx]}. Num SM workers: {num_workers}")
pytorch_model = PyTorchModel(
    image_uri=f"{account_id}.dkr.ecr.{region}.amazonaws.com/pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.16.0-ubuntu20.04",
    model_data=estimator.model_data,
    role=role,    
    name=name_from_base(model_arch),
    sagemaker_session=sess,
    container_log_level=logging.DEBUG,
    model_server_workers=num_workers,
    framework_version="1.13.1",
    env = {
        'SAGEMAKER_MODEL_SERVER_TIMEOUT' : '3600',
    },
    # for production it is important to define vpc_config and use a vpc_endpoint
    #vpc_config={
    #    'Subnets': ['<SUBNET1>', '<SUBNET2>'],
    #    'SecurityGroupIds': ['<SECURITYGROUP1>', '<DEFAULTSECURITYGROUP>']
    #}
)
pytorch_model._is_compiled_model = True

In [None]:
predictor = pytorch_model.deploy(
    initial_instance_count=1,
    volume_size=128,
    instance_type=instance_types[instance_type_idx],
    model_data_download_timeout=600, # it takes some time to download all the artifacts and load the model
    container_startup_health_check_timeout=600
)

## 7) Run a simple test to check the endpoint

In [None]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()

In [None]:
import re
import time

def predict(text):
    global predictor
    t=time.time()
    pred = predictor.predict({"prompt": text })[0]
    print(pred)
    elapsed = time.time()-t
    answer = re.match(r'^.*\[\/INST\] ?(.*)</s>', pred)[1]
    num_words = len(answer.split(' '))
    return answer,num_words,elapsed

text="[INST]Hi, my name is Robot. How are you?[/INST]"
answer,num_words,elapsed=predict(text)
print(f"Num Words: {num_words}, Words/sec: {num_words/elapsed:.04f}, Elapsed time: {elapsed:.04f}s\nAnswer: {answer}")

### 7.1) Stream the prediction word by word

In [None]:
import json
import boto3

sm_client = boto3.client('sagemaker-runtime')

prompt="""[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

What is a whole food plant diet? [/INST]
"""
body = json.dumps({'prompt': prompt, 'sequence_length': 512, 'temperature': 1.0, 'stream': True}).encode('utf-8')
resp = sm_client.invoke_endpoint_with_response_stream(
    EndpointName=predictor.endpoint_name,
    Body=body,
    ContentType='application/json',
    Accept='application/json',
)
eop=False
for e in resp['Body']:
    tok = e['PayloadPart']['Bytes'].decode('utf-8')
    if tok.startswith("__END_OF_PREDICTION__"): eop = True
    if not eop: print(tok, end='')

### 7.2) Now, launch multiple threads in parallel to simulate concurrent clients
Only valid when **num_workers > 1**

In [None]:
import time
from multiprocessing.pool import ThreadPool
with ThreadPool(num_workers) as p:
    t=time.time()
    resp = p.map(predict, [text] * num_workers)
    elapsed=time.time()-t
    print(f"Total elapsed time for {num_workers} workers: {elapsed}")
    
    for answer,num_words,elapsed in resp:
        print(f" :: Num Words: {num_words}, Words/sec: {num_words/elapsed:.04f}, Elapsed time: {elapsed:.04f}s\nAnswer: {answer}")

## 8) Cleanup
Delete the endpoint to stop paying for the provisioned resources

In [None]:
predictor.delete_model()
predictor.delete_endpoint()