From 8a8fb3203fd4a92ec067172876a00b9d222d3c02 Mon Sep 17 00:00:00 2001 From: Naman Nandan Date: Wed, 31 May 2023 18:16:54 -0700 Subject: [PATCH 1/5] fix INF2 example handler --- .../large_models/inferentia2/inf2_handler.py | 53 +++++++++++-------- .../inferentia2/model-config.yaml | 6 +-- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/examples/large_models/inferentia2/inf2_handler.py b/examples/large_models/inferentia2/inf2_handler.py index 14a1bbbb2b..0d5a3fe8cd 100644 --- a/examples/large_models/inferentia2/inf2_handler.py +++ b/examples/large_models/inferentia2/inf2_handler.py @@ -1,15 +1,14 @@ import logging -import time from abc import ABC import requests import torch import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer -from ts.torch_handler.base_handler import BaseHandler - +from transformers import AutoTokenizer from transformers_neuronx.opt.model import OPTForSampling +from ts.torch_handler.base_handler import BaseHandler + logger = logging.getLogger(__name__) logger.info("Transformers version %s", transformers.__version__) @@ -30,33 +29,32 @@ def initialize(self, ctx): ctx (context): It is a JSON Object containing information pertaining to the model artefacts parameters. """ - + self.manifest = ctx.manifest properties = ctx.system_properties model_dir = properties.get("model_dir") - - + # settings for model compiliation and loading seed = ctx.model_yaml_config["handler"]["manual_seed"] - batch_size = ctx.model_yaml_config["handler"]["batch_size"] + self.batch_size = ctx.model_yaml_config["handler"]["batch_size"] tp_degree = ctx.model_yaml_config["handler"]["tp_degree"] amp = ctx.model_yaml_config["handler"]["amp"] model_name = ctx.model_yaml_config["handler"]["model_name"] - + torch.manual_seed(seed) self.tokenizer = AutoTokenizer.from_pretrained(model_name, return_tensors="pt") - + self.tokenizer.pad_token = self.tokenizer.eos_token + logger.info("Starting to compile the model") - self.model = OPTForSampling.from_pretrained(model_dir, batch_size=batch_size, tp_degree=tp_degree, amp=amp) + self.model = OPTForSampling.from_pretrained( + model_dir, batch_size=self.batch_size, tp_degree=tp_degree, amp=amp + ) self.model.to_neuron() logger.info("Model has been successfully compiled") - self.max_length = ctx.model_yaml_config["handler"]["max_length"] - - self.initialized = True def preprocess(self, requests): @@ -93,9 +91,10 @@ def encode_input_text(self, input_text): inputs = self.tokenizer.encode_plus( input_text, max_length=self.max_length, - pad_to_max_length=True, + padding=True, add_special_tokens=True, return_tensors="pt", + truncation=True, ) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] @@ -112,17 +111,27 @@ def inference(self, input_batch): list: A list of strings with the predicted values for each input text in the batch. """ input_ids_batch, attention_mask_batch = input_batch - input_ids_batch = input_ids_batch + + # insert padding if a partial batch was received + num_inferences = len(input_ids_batch) + logger.info("Input ids batch: %s", input_ids_batch) + logger.info("Num inferences: %s", num_inferences) + padding = self.batch_size - num_inferences + if padding > 0: + pad = torch.nn.ConstantPad1d((0, 0, 0, padding), value=0) + input_ids_batch = pad(input_ids_batch) + attention_mask_batch = pad(attention_mask_batch) + outputs = self.model.sample( input_ids_batch, - max_length=30, + self.max_length, + ) + + inferences = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False ) + inferences = inferences[:num_inferences] - inferences = [ - self.tokenizer.batch_decode( - outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - ] logger.info("Generated text: %s", inferences) return inferences diff --git a/examples/large_models/inferentia2/model-config.yaml b/examples/large_models/inferentia2/model-config.yaml index c15b086cde..d53545964a 100644 --- a/examples/large_models/inferentia2/model-config.yaml +++ b/examples/large_models/inferentia2/model-config.yaml @@ -1,12 +1,12 @@ minWorkers: 1 maxWorkers: 1 maxBatchDelay: 100 -responseTimeout: 120 +responseTimeout: 1800 handler: max_length: 50 manual_seed: 40 batch_size: 2 - tp_degree: 4 + tp_degree: 2 amp: f16 - model_name: facebook/opt-6.7b + model_name: facebook/opt-13b From 8f423b0b2d1d82ee2c0ec3b263f1834ece452df5 Mon Sep 17 00:00:00 2001 From: Naman Nandan Date: Mon, 5 Jun 2023 14:15:52 -0700 Subject: [PATCH 2/5] Add logging for padding in inf2 handler --- examples/large_models/inferentia2/inf2_handler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/large_models/inferentia2/inf2_handler.py b/examples/large_models/inferentia2/inf2_handler.py index 0d5a3fe8cd..6d8960258b 100644 --- a/examples/large_models/inferentia2/inf2_handler.py +++ b/examples/large_models/inferentia2/inf2_handler.py @@ -36,7 +36,6 @@ def initialize(self, ctx): # settings for model compiliation and loading seed = ctx.model_yaml_config["handler"]["manual_seed"] - self.batch_size = ctx.model_yaml_config["handler"]["batch_size"] tp_degree = ctx.model_yaml_config["handler"]["tp_degree"] amp = ctx.model_yaml_config["handler"]["amp"] model_name = ctx.model_yaml_config["handler"]["model_name"] @@ -47,6 +46,7 @@ def initialize(self, ctx): logger.info("Starting to compile the model") + self.batch_size = ctx.model_yaml_config["handler"]["batch_size"] self.model = OPTForSampling.from_pretrained( model_dir, batch_size=self.batch_size, tp_degree=tp_degree, amp=amp ) @@ -110,17 +110,17 @@ def inference(self, input_batch): Returns: list: A list of strings with the predicted values for each input text in the batch. """ - input_ids_batch, attention_mask_batch = input_batch + input_ids_batch = input_batch[0] # insert padding if a partial batch was received num_inferences = len(input_ids_batch) - logger.info("Input ids batch: %s", input_ids_batch) - logger.info("Num inferences: %s", num_inferences) + logger.info("Number of inference requests in batch: %s", num_inferences) + logger.info("Batch size: %s", self.batch_size) padding = self.batch_size - num_inferences if padding > 0: + logger.info("Padding input batch with %s padding inputs", padding) pad = torch.nn.ConstantPad1d((0, 0, 0, padding), value=0) input_ids_batch = pad(input_ids_batch) - attention_mask_batch = pad(attention_mask_batch) outputs = self.model.sample( input_ids_batch, From 4100979caaca3c058f4a190356a2ab12cb74d718 Mon Sep 17 00:00:00 2001 From: Naman Nandan Date: Mon, 5 Jun 2023 17:19:08 -0700 Subject: [PATCH 3/5] update response timeout and model --- examples/large_models/inferentia2/model-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/large_models/inferentia2/model-config.yaml b/examples/large_models/inferentia2/model-config.yaml index d53545964a..1258f86546 100644 --- a/examples/large_models/inferentia2/model-config.yaml +++ b/examples/large_models/inferentia2/model-config.yaml @@ -1,7 +1,7 @@ minWorkers: 1 maxWorkers: 1 maxBatchDelay: 100 -responseTimeout: 1800 +responseTimeout: 900 handler: max_length: 50 @@ -9,4 +9,4 @@ handler: batch_size: 2 tp_degree: 2 amp: f16 - model_name: facebook/opt-13b + model_name: facebook/opt-6.7b From 0aa3be69e3b800d3bda9fa3d5b585f3dc1449aa4 Mon Sep 17 00:00:00 2001 From: Naman Nandan Date: Mon, 5 Jun 2023 17:29:05 -0700 Subject: [PATCH 4/5] Update documentation to show opt-6.7b as the example model --- examples/large_models/inferentia2/Readme.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/large_models/inferentia2/Readme.md b/examples/large_models/inferentia2/Readme.md index 48064193d9..d47d5be96e 100644 --- a/examples/large_models/inferentia2/Readme.md +++ b/examples/large_models/inferentia2/Readme.md @@ -1,8 +1,8 @@ -# Large model inference on Inferentia2 +# Large model inference on Inferentia2 This document briefs on serving large HuggingFace (HF) models on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/), Inf2 instances. We use -Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is build on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference. +Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is build on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference. Lets have a look at the steps to prepare our model for inference on Inf2 instances. @@ -23,10 +23,10 @@ Follow the steps below to complete package installations sudo apt-get install aws-neuronx-collectives=2.* -y sudo apt-get install aws-neuronx-runtime-lib=2.* -y -# Activate Python venv -source /opt/aws_neuron_venv_pytorch/bin/activate +# Activate Python venv +source /opt/aws_neuron_venv_pytorch/bin/activate -# Set pip repository pointing to the Neuron repository +# Set pip repository pointing to the Neuron repository python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com # Update Neuron Compiler and Framework @@ -41,7 +41,7 @@ pip install git+https://github.com/aws-neuron/transformers-neuronx.git transform ### Step 2: Save the model splitted checkpoints compatibale with `transformers-neuronx` ```bash - python save_split_checkpoints.py --model_name facebook/opt-13b --save_path './opt-13b-split' + python save_split_checkpoints.py --model_name facebook/opt-6.7b --save_path './opt-6.7b-split' ``` @@ -51,7 +51,7 @@ pip install git+https://github.com/aws-neuron/transformers-neuronx.git transform Navigate up to `large_model/inferentia2` directory. ```bash -torch-model-archiver --model-name opt --version 1.0 --handler inf2_handler.py --extra-files ./opt-13b-split -r requirements.txt --config-file model-config.yaml --archive-format tgz +torch-model-archiver --model-name opt --version 1.0 --handler inf2_handler.py --extra-files ./opt-6.7b-split -r requirements.txt --config-file model-config.yaml --archive-format tgz ``` From a5ff03e82cb36abb2a8e56de5eb003abd6ee5d74 Mon Sep 17 00:00:00 2001 From: Naman Nandan Date: Mon, 5 Jun 2023 17:32:27 -0700 Subject: [PATCH 5/5] Update model batch log --- examples/large_models/inferentia2/inf2_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/large_models/inferentia2/inf2_handler.py b/examples/large_models/inferentia2/inf2_handler.py index 6d8960258b..f3ab59fa83 100644 --- a/examples/large_models/inferentia2/inf2_handler.py +++ b/examples/large_models/inferentia2/inf2_handler.py @@ -115,7 +115,7 @@ def inference(self, input_batch): # insert padding if a partial batch was received num_inferences = len(input_ids_batch) logger.info("Number of inference requests in batch: %s", num_inferences) - logger.info("Batch size: %s", self.batch_size) + logger.info("Model batch size: %s", self.batch_size) padding = self.batch_size - num_inferences if padding > 0: logger.info("Padding input batch with %s padding inputs", padding)