Skip to content

Commit

Permalink
Inf2 example (#2399)
Browse files Browse the repository at this point in the history
* adding inf2 example

* fix the inference func

* Add batch size note

* Fix INF2 example handler (#2378)

* fix INF2 example handler

* Add logging for padding in inf2 handler

* update response timeout and model

* Update documentation to show opt-6.7b as the example model

* Update model batch log

---------

Co-authored-by: Naman Nandan <namannan@amazon.com>

* Update requirements and sample text file

* fix neuron core allocation to worker process

* Fix linter errors and update documentation

* enable core allocation verification in handler

* fix lint error

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-45-0.ec2.internal>
Co-authored-by: Hamid Shojanazeri <hamid.nazeri2010@gmail.com>
Co-authored-by: Naman Nandan <namannan@amazon.com>
  • Loading branch information
4 people committed Jun 16, 2023
1 parent 679b33d commit 4e21262
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 1 deletion.
77 changes: 77 additions & 0 deletions examples/large_models/inferentia2/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Large model inference on Inferentia2

This document briefs on serving large HuggingFace (HF) models on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) instances.

Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is build on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference.

Let's take a look at the steps to prepare our model for inference on Inf2 instances.

**Note** To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass the same batch size that was used during compilation. This example uses batch size of 2 but make sure to change it and register the model according to your batch size.

### Step 1: Inf2 instance

Get an Inf2 instance, ssh to it, make sure to use the following DLAMI as it comes with PyTorch and necessary packages for AWS Neuron SDK pre-installed.
DLAMI Name: ` Deep Learning AMI Neuron PyTorch 1.13.0 (Ubuntu 20.04) 20230226 Amazon Machine Image (AMI)`

### Step 1: Package Installations

Follow the steps below to complete package installations

```bash

# Update Neuron Runtime
sudo apt-get install aws-neuronx-collectives=2.* -y
sudo apt-get install aws-neuronx-runtime-lib=2.* -y

# Activate Python venv
source /opt/aws_neuron_venv_pytorch/bin/activate

# Set pip repository pointing to the Neuron repository
python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com

# Update Neuron Compiler and Framework
python -m pip install --upgrade neuronx-cc==2.* torch-neuronx torchvision

pip install git+https://github.com/aws-neuron/transformers-neuronx.git transformers -U

```



### Step 2: Save the model split checkpoints compatible with `transformers-neuronx`

```bash
python save_split_checkpoints.py --model_name facebook/opt-6.7b --save_path './opt-6.7b-split'

```


### Step 3: Generate Tar/ MAR file

Navigate up to `large_model/inferentia2` directory.

```bash
torch-model-archiver --model-name opt --version 1.0 --handler inf2_handler.py --extra-files ./opt-6.7b-split -r requirements.txt --config-file model-config.yaml --archive-format tgz

```

### Step 4: Add the mar file to model store

```bash
mkdir model_store
mv opt.tar.gz model_store
```

### Step 5: Start torchserve

Update config.properties and start torchserve

```bash
torchserve --ncs --start --model-store model_store --models opt.tar.gz
```

### Step 6: Run inference

```bash
curl -v "http://localhost:8080/predictions/opt" -T sample_text.txt
```
161 changes: 161 additions & 0 deletions examples/large_models/inferentia2/inf2_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import logging
import os
from abc import ABC

import torch
import torch_neuronx
import transformers
from transformers import AutoTokenizer
from transformers_neuronx.opt.model import OPTForSampling

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)
logger.info("Transformers version %s", transformers.__version__)


class LLMHandler(BaseHandler, ABC):
"""
Transformers handler class for sequence, token classification and question answering.
"""

def __init__(self):
super(LLMHandler, self).__init__()
self.initialized = False

def initialize(self, ctx):
"""In this initialize function, the HF large model is loaded and
partitioned into multiple stages each on one device using PiPPy.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artefacts parameters.
"""

self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")

# settings for model compiliation and loading
seed = ctx.model_yaml_config["handler"]["manual_seed"]
tp_degree = ctx.model_yaml_config["handler"]["tp_degree"]
amp = ctx.model_yaml_config["handler"]["amp"]
model_name = ctx.model_yaml_config["handler"]["model_name"]

# allocate "tp_degree" number of neuron cores to the worker process
os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree)
try:
num_neuron_cores_available = (
torch_neuronx.xla_impl.data_parallel.device_count()
)
assert num_neuron_cores_available >= int(tp_degree)
except (RuntimeError, AssertionError) as error:
raise RuntimeError(
"Required number of neuron cores for tp_degree "
+ str(tp_degree)
+ " are not available: "
+ str(error)
)

torch.manual_seed(seed)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, return_tensors="pt")
self.tokenizer.pad_token = self.tokenizer.eos_token

logger.info("Starting to compile the model")

self.batch_size = ctx.model_yaml_config["handler"]["batch_size"]
self.model = OPTForSampling.from_pretrained(
model_dir, batch_size=self.batch_size, tp_degree=tp_degree, amp=amp
)
self.model.to_neuron()
logger.info("Model has been successfully compiled")

self.max_length = ctx.model_yaml_config["handler"]["max_length"]

self.initialized = True

def preprocess(self, requests):
"""
Basic text preprocessing, based on the user's choice of application mode.
Args:
requests (list): A list of dictionaries with a "data" or "body" field, each
containing the input text to be processed.
Returns:
tuple: A tuple with two tensors: the batch of input ids and the batch of
attention masks.
"""
input_texts = [data.get("data") or data.get("body") for data in requests]
input_ids_batch, attention_mask_batch = [], []
for input_text in input_texts:
input_ids, attention_mask = self.encode_input_text(input_text)
input_ids_batch.append(input_ids)
attention_mask_batch.append(attention_mask)
input_ids_batch = torch.cat(input_ids_batch, dim=0)
attention_mask_batch = torch.cat(attention_mask_batch, dim=0)
return input_ids_batch, attention_mask_batch

def encode_input_text(self, input_text):
"""
Encodes a single input text using the tokenizer.
Args:
input_text (str): The input text to be encoded.
Returns:
tuple: A tuple with two tensors: the encoded input ids and the attention mask.
"""
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
logger.info("Received text: '%s'", input_text)
inputs = self.tokenizer.encode_plus(
input_text,
max_length=self.max_length,
padding=True,
add_special_tokens=True,
return_tensors="pt",
truncation=True,
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
return input_ids, attention_mask

def inference(self, input_batch):
"""
Predicts the class (or classes) of the received text using the serialized transformers
checkpoint.
Args:
input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch
of attention masks, as returned by the preprocess function.
Returns:
list: A list of strings with the predicted values for each input text in the batch.
"""
input_ids_batch = input_batch[0]

# insert padding if a partial batch was received
num_inferences = len(input_ids_batch)
logger.info("Number of inference requests in batch: %s", num_inferences)
logger.info("Model batch size: %s", self.batch_size)
padding = self.batch_size - num_inferences
if padding > 0:
logger.info("Padding input batch with %s padding inputs", padding)
pad = torch.nn.ConstantPad1d((0, 0, 0, padding), value=0)
input_ids_batch = pad(input_ids_batch)

outputs = self.model.sample(
input_ids_batch,
self.max_length,
)

inferences = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
inferences = inferences[:num_inferences]

logger.info("Generated text: %s", inferences)
return inferences

def postprocess(self, inference_output):
"""Post Process Function converts the predicted response into Torchserve readable format.
Args:
inference_output (list): It contains the predicted response of the input text.
Returns:
(list): Returns a list of the Predictions and Explanations.
"""
return inference_output
12 changes: 12 additions & 0 deletions examples/large_models/inferentia2/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 100
responseTimeout: 900

handler:
max_length: 50
manual_seed: 40
batch_size: 2
tp_degree: 2
amp: f16
model_name: facebook/opt-6.7b
3 changes: 3 additions & 0 deletions examples/large_models/inferentia2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torch-neuronx
transformers-neuronx
transformers
1 change: 1 addition & 0 deletions examples/large_models/inferentia2/sample_text.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Today the weather is really nice and I am planning on
53 changes: 53 additions & 0 deletions examples/large_models/inferentia2/save_split_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import argparse
import os

import torch
from transformers.models.opt import OPTForCausalLM
from transformers_neuronx.module import save_pretrained_split


def create_directory_if_not_exists(path_str: str) -> str:
"""Creates a directory if it doesn't exist, and returns the directory path."""
if os.path.isdir(path_str):
return path_str
elif input(f"{path_str} does not exist, create directory? [y/n]").lower() == "y":
os.makedirs(path_str)
return path_str
else:
raise NotADirectoryError(path_str)


def amp_callback(model: OPTForCausalLM, dtype: torch.dtype) -> None:
"""Casts attention and MLP to low precision only; layernorms stay as f32."""
for block in model.model.decoder.layers:
block.self_attn.to(dtype)
block.fc1.to(dtype)
block.fc2.to(dtype)
model.lm_head.to(dtype)


# Define and parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name", "-m", type=str, required=True, help="HuggingFace model name"
)
parser.add_argument(
"--save_path",
type=str,
default="./model-splits",
help="Output directory for downloaded model files",
)
args = parser.parse_args()

save_path = create_directory_if_not_exists(args.save_path)

# Load HuggingFace model
hf_model = OPTForCausalLM.from_pretrained(args.model_name, low_cpu_mem_usage=True)

# Apply Automatic Mixed Precision (AMP)
amp_callback(hf_model, torch.float16)

# Save the model
save_pretrained_split(hf_model, args.save_path)

print(f"Files for '{args.model_name}' have been downloaded to '{args.save_path}'.")
6 changes: 5 additions & 1 deletion ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1055,4 +1055,8 @@ largemodels
torchpippy
InferenceSession
maxRetryTimeoutInSec
neuronx
neuronx
AMI
DLAMI
XLA
inferentia

0 comments on commit 4e21262

Please sign in to comment.