Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inf2 example #2399

Merged
merged 11 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions examples/large_models/inferentia2/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Large model inference on Inferentia2

This document briefs on serving large HuggingFace (HF) models on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) instances.

Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is build on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference.

Let's take a look at the steps to prepare our model for inference on Inf2 instances.

**Note** To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass the same batch size that was used during compilation. This example uses batch size of 2 but make sure to change it and register the model according to your batch size.

### Step 1: Inf2 instance

Get an Inf2 instance, ssh to it, make sure to use the following DLAMI as it comes with PyTorch and necessary packages for AWS Neuron SDK pre-installed.
DLAMI Name: ` Deep Learning AMI Neuron PyTorch 1.13.0 (Ubuntu 20.04) 20230226 Amazon Machine Image (AMI)`

### Step 1: Package Installations

Follow the steps below to complete package installations

```bash

# Update Neuron Runtime
sudo apt-get install aws-neuronx-collectives=2.* -y
sudo apt-get install aws-neuronx-runtime-lib=2.* -y

# Activate Python venv
source /opt/aws_neuron_venv_pytorch/bin/activate

# Set pip repository pointing to the Neuron repository
python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com

# Update Neuron Compiler and Framework
python -m pip install --upgrade neuronx-cc==2.* torch-neuronx torchvision

pip install git+https://github.com/aws-neuron/transformers-neuronx.git transformers -U

```



### Step 2: Save the model split checkpoints compatible with `transformers-neuronx`

```bash
python save_split_checkpoints.py --model_name facebook/opt-6.7b --save_path './opt-6.7b-split'

```


### Step 3: Generate Tar/ MAR file

Navigate up to `large_model/inferentia2` directory.

```bash
torch-model-archiver --model-name opt --version 1.0 --handler inf2_handler.py --extra-files ./opt-6.7b-split -r requirements.txt --config-file model-config.yaml --archive-format tgz

```

### Step 4: Add the mar file to model store

```bash
mkdir model_store
mv opt.tar.gz model_store
```

### Step 5: Start torchserve

Update config.properties and start torchserve

```bash
torchserve --ncs --start --model-store model_store --models opt.tar.gz
```

### Step 6: Run inference

```bash
curl -v "http://localhost:8080/predictions/opt" -T sample_text.txt
```
148 changes: 148 additions & 0 deletions examples/large_models/inferentia2/inf2_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import logging
import os
from abc import ABC

import torch
import transformers
from transformers import AutoTokenizer
from transformers_neuronx.opt.model import OPTForSampling

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)
logger.info("Transformers version %s", transformers.__version__)


class LLMHandler(BaseHandler, ABC):
"""
Transformers handler class for sequence, token classification and question answering.
"""

def __init__(self):
super(LLMHandler, self).__init__()
self.initialized = False

def initialize(self, ctx):
"""In this initialize function, the HF large model is loaded and
partitioned into multiple stages each on one device using PiPPy.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artefacts parameters.
"""

self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")

# settings for model compiliation and loading
seed = ctx.model_yaml_config["handler"]["manual_seed"]
tp_degree = ctx.model_yaml_config["handler"]["tp_degree"]
amp = ctx.model_yaml_config["handler"]["amp"]
model_name = ctx.model_yaml_config["handler"]["model_name"]

# allocate "tp_degree" number of neuron cores to the worker process
os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you make sure neuron has enough number of cores to support tp_degree?

Copy link
Collaborator Author

@namannandan namannandan Jun 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe torch-neuronx currently does not have an API that provides the number of available(unallocated) neuron cores. Here, if the required number of neuron cores, i.e tp_degree are not available then the model loading will fail with error of the form:

ERROR  TDRV:db_vtpb_get_mla_and_tpb                 Could not find VNC id 1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out that torch-neuronx does have a method to query the number of available unallocated cores using torch_neuronx.xla_impl.data_parallel.device_count(). Updated the handler to verify that the necessary number of cores are available before proceeding with model loading


torch.manual_seed(seed)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, return_tensors="pt")
self.tokenizer.pad_token = self.tokenizer.eos_token

logger.info("Starting to compile the model")

self.batch_size = ctx.model_yaml_config["handler"]["batch_size"]
self.model = OPTForSampling.from_pretrained(
model_dir, batch_size=self.batch_size, tp_degree=tp_degree, amp=amp
)
self.model.to_neuron()
logger.info("Model has been successfully compiled")

self.max_length = ctx.model_yaml_config["handler"]["max_length"]

self.initialized = True

def preprocess(self, requests):
"""
Basic text preprocessing, based on the user's choice of application mode.
Args:
requests (list): A list of dictionaries with a "data" or "body" field, each
containing the input text to be processed.
Returns:
tuple: A tuple with two tensors: the batch of input ids and the batch of
attention masks.
"""
input_texts = [data.get("data") or data.get("body") for data in requests]
input_ids_batch, attention_mask_batch = [], []
for input_text in input_texts:
input_ids, attention_mask = self.encode_input_text(input_text)
input_ids_batch.append(input_ids)
attention_mask_batch.append(attention_mask)
input_ids_batch = torch.cat(input_ids_batch, dim=0)
attention_mask_batch = torch.cat(attention_mask_batch, dim=0)
return input_ids_batch, attention_mask_batch

def encode_input_text(self, input_text):
"""
Encodes a single input text using the tokenizer.
Args:
input_text (str): The input text to be encoded.
Returns:
tuple: A tuple with two tensors: the encoded input ids and the attention mask.
"""
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
logger.info("Received text: '%s'", input_text)
inputs = self.tokenizer.encode_plus(
input_text,
max_length=self.max_length,
padding=True,
add_special_tokens=True,
return_tensors="pt",
truncation=True,
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
return input_ids, attention_mask

def inference(self, input_batch):
"""
Predicts the class (or classes) of the received text using the serialized transformers
checkpoint.
Args:
input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch
of attention masks, as returned by the preprocess function.
Returns:
list: A list of strings with the predicted values for each input text in the batch.
"""
input_ids_batch = input_batch[0]

# insert padding if a partial batch was received
num_inferences = len(input_ids_batch)
logger.info("Number of inference requests in batch: %s", num_inferences)
logger.info("Model batch size: %s", self.batch_size)
padding = self.batch_size - num_inferences
if padding > 0:
logger.info("Padding input batch with %s padding inputs", padding)
pad = torch.nn.ConstantPad1d((0, 0, 0, padding), value=0)
input_ids_batch = pad(input_ids_batch)

outputs = self.model.sample(
input_ids_batch,
self.max_length,
)

inferences = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
inferences = inferences[:num_inferences]

logger.info("Generated text: %s", inferences)
return inferences

def postprocess(self, inference_output):
"""Post Process Function converts the predicted response into Torchserve readable format.
Args:
inference_output (list): It contains the predicted response of the input text.
Returns:
(list): Returns a list of the Predictions and Explanations.
"""
return inference_output
12 changes: 12 additions & 0 deletions examples/large_models/inferentia2/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 100
responseTimeout: 900

handler:
max_length: 50
manual_seed: 40
batch_size: 2
tp_degree: 2
amp: f16
model_name: facebook/opt-6.7b
3 changes: 3 additions & 0 deletions examples/large_models/inferentia2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
transformers
transformers-neuronx
torch==1.13.1
1 change: 1 addition & 0 deletions examples/large_models/inferentia2/sample_text.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Today the weather is really nice and I am planning on
55 changes: 55 additions & 0 deletions examples/large_models/inferentia2/save_split_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import argparse
import os

import torch
from transformers_neuronx.module import save_pretrained_split
from transformers.models.opt import OPTForCausalLM


def create_directory_if_not_exists(path_str: str) -> str:
"""Creates a directory if it doesn't exist, and returns the directory path."""
if os.path.isdir(path_str):
return path_str
elif input(f"{path_str} does not exist, create directory? [y/n]").lower() == "y":
os.makedirs(path_str)
return path_str
else:
raise NotADirectoryError(path_str)


def amp_callback(model: OPTForCausalLM, dtype: torch.dtype) -> None:
"""Casts attention and MLP to low precision only; layernorms stay as f32."""
for block in model.model.decoder.layers:
block.self_attn.to(dtype)
block.fc1.to(dtype)
block.fc2.to(dtype)
model.lm_head.to(dtype)


# Define and parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name", "-m", type=str, required=True, help="HuggingFace model name"
)
parser.add_argument(
"--save_path",
type=str,
default="./model-splits",
help="Output directory for downloaded model files",
)
args = parser.parse_args()

save_path = create_directory_if_not_exists(args.save_path)

# Load HuggingFace model
hf_model = OPTForCausalLM.from_pretrained(args.model_name, low_cpu_mem_usage=True)

# Apply Automatic Mixed Precision (AMP)
amp_callback(hf_model, torch.float16)

# Save the model
save_pretrained_split(hf_model, args.save_path)

print(f"Files for '{args.model_name}' have been downloaded to '{args.save_path}'.")


6 changes: 5 additions & 1 deletion ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1055,4 +1055,8 @@ largemodels
torchpippy
InferenceSession
maxRetryTimeoutInSec
neuronx
neuronx
AMI
DLAMI
XLA
inferentia
Loading