From ffb255a235a532bc844d55df32735af8d4c0a438 Mon Sep 17 00:00:00 2001 From: agunapal Date: Thu, 9 Nov 2023 19:32:52 +0000 Subject: [PATCH 1/5] vLLM example with mistral7B model --- examples/large_models/vllm/mistral/Readme.md | 61 +++++++++++++ .../vllm/mistral/config.properties | 5 ++ .../vllm/mistral/custom_handler.py | 88 +++++++++++++++++++ .../vllm/mistral/model-config.yaml | 13 +++ .../vllm/mistral/requirements.txt | 1 + .../large_models/vllm/mistral/sample_text.txt | 1 + 6 files changed, 169 insertions(+) create mode 100644 examples/large_models/vllm/mistral/Readme.md create mode 100644 examples/large_models/vllm/mistral/config.properties create mode 100644 examples/large_models/vllm/mistral/custom_handler.py create mode 100644 examples/large_models/vllm/mistral/model-config.yaml create mode 100644 examples/large_models/vllm/mistral/requirements.txt create mode 100644 examples/large_models/vllm/mistral/sample_text.txt diff --git a/examples/large_models/vllm/mistral/Readme.md b/examples/large_models/vllm/mistral/Readme.md new file mode 100644 index 0000000000..ba15867758 --- /dev/null +++ b/examples/large_models/vllm/mistral/Readme.md @@ -0,0 +1,61 @@ +# Example showing inference with vLLM with mistralai/Mistral-7B-v0.1 model + +This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on `mistralai/Mistral-7B-v0.1` model. +vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/) + +Install vLLM with the following + +``` +pip install -r requirements.txt +``` +### Step 1: Login to HuggingFace + +Login with a HuggingFace account +``` +huggingface-cli login +# or using an environment variable +huggingface-cli login --token $HUGGINGFACE_TOKEN +``` + +```bash +python ../Download_model.py --model_path model --model_name mistralai/Mistral-7B-v0.1 +``` +Model will be saved in the following path, `mistralai/Mistral-7B-v0.1`. + +### Step 2: Generate MAR file + +Add the downloaded path to " model_path:" in `model-config.yaml` and run the following. + +```bash +torch-model-archiver --model-name mistral7b --version 1.0 --handler custom_handler.py --config-file model-config.yaml -r requirements.txt --archive-format tgz +``` + +### Step 3: Add the mar file to model store + +```bash +mkdir model_store +mv mistral7b.tar.gz model_store +``` + +### Step 3: Start torchserve + +Update config.properties and start torchserve + +```bash +torchserve --start --ncs --ts-config config.properties --model-store model_store --models mistral7b.tar.gz +``` + +### Step 4: Run inference + +```bash +curl -v "http://localhost:8080/predictions/mistral7b" -T sample_text.txt +``` + +results in the following output +``` +Mayonnaise is made of eggs, oil, vinegar, salt and pepper. Using an electric blender, combine all the ingredients and beat at high speed for 4 to 5 minutes. + +Try it with some mustard and paprika mixed in, and a bit of sweetener if you like. But use real mayonnaise or it isn’t the same. Marlou + +What in the world is mayonnaise? +``` diff --git a/examples/large_models/vllm/mistral/config.properties b/examples/large_models/vllm/mistral/config.properties new file mode 100644 index 0000000000..67f62d182f --- /dev/null +++ b/examples/large_models/vllm/mistral/config.properties @@ -0,0 +1,5 @@ +inference_address=http://127.0.0.1:8080 +management_address=http://127.0.0.1:8081 +metrics_address=http://127.0.0.1:8082 +enable_envvars_config=true +install_py_dep_per_model=true diff --git a/examples/large_models/vllm/mistral/custom_handler.py b/examples/large_models/vllm/mistral/custom_handler.py new file mode 100644 index 0000000000..5e8b2f8d29 --- /dev/null +++ b/examples/large_models/vllm/mistral/custom_handler.py @@ -0,0 +1,88 @@ +import logging +from abc import ABC + +import torch +import vllm +from vllm import LLM, SamplingParams + +from ts.context import Context +from ts.torch_handler.base_handler import BaseHandler + +logger = logging.getLogger(__name__) +logger.info("vLLM version %s", vllm.__version__) + + +class LlamaHandler(BaseHandler, ABC): + """ + Transformers handler class for sequence, token classification and question answering. + """ + + def __init__(self): + super(LlamaHandler, self).__init__() + self.max_new_tokens = None + self.tokenizer = None + self.initialized = False + + def initialize(self, ctx: Context): + """In this initialize function, the HF large model is loaded and + partitioned using DeepSpeed. + Args: + ctx (context): It is a JSON Object containing information + pertaining to the model artifacts parameters. + """ + model_dir = ctx.system_properties.get("model_dir") + self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"]) + model_name = ctx.model_yaml_config["handler"]["model_name"] + model_path = ctx.model_yaml_config["handler"]["model_path"] + seed = int(ctx.model_yaml_config["handler"]["manual_seed"]) + torch.manual_seed(seed) + + self.model = LLM(model=model_path) + + logger.info("Model %s loaded successfully", ctx.model_name) + self.initialized = True + + def preprocess(self, requests): + """ + Basic text preprocessing, based on the user's choice of application mode. + Args: + requests (list): A list of dictionaries with a "data" or "body" field, each + containing the input text to be processed. + Returns: + tuple: A tuple with two tensors: the batch of input ids and the batch of + attention masks. + """ + input_texts = [data.get("data") or data.get("body") for data in requests] + #return torch.as_tensor(input_texts, device=self.device) + input_texts = [ input_text.decode("utf-8") for input_text in input_texts if isinstance(input_text, (bytes, bytearray))] + return input_texts + + + def inference(self, input_batch): + """ + Predicts the class (or classes) of the received text using the serialized transformers + checkpoint. + Args: + input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch + of attention masks, as returned by the preprocess function. + Returns: + list: A list of strings with the predicted values for each input text in the batch. + """ + logger.info(f"Input text is {input_batch}") + sampling_params = SamplingParams(max_tokens=self.max_new_tokens) + outputs = self.model.generate( + input_batch, sampling_params=sampling_params + ) + + + logger.info("Generated text: %s", outputs) + return outputs + + def postprocess(self, inference_output): + """Post Process Function converts the predicted response into Torchserve readable format. + Args: + inference_output (list): It contains the predicted response of the input text. + Returns: + (list): Returns a list of the Predictions and Explanations. + """ + return [inference_output[0].outputs[0].text] diff --git a/examples/large_models/vllm/mistral/model-config.yaml b/examples/large_models/vllm/mistral/model-config.yaml new file mode 100644 index 0000000000..c4282cac71 --- /dev/null +++ b/examples/large_models/vllm/mistral/model-config.yaml @@ -0,0 +1,13 @@ +# TorchServe frontend parameters +minWorkers: 1 +maxWorkers: 1 +maxBatchDelay: 100 +responseTimeout: 1200 +deviceType: "gpu" + +handler: + model_name: "mistralai/Mistral-7B-v0.1" + model_path: "/home/ubuntu/serve/examples/large_models/vllm/mistral/model/models--mistralai--Mistral-7B-v0.1/snapshots/5e9c98b96d071dce59368012254c55b0ec6f8658" + max_new_tokens: 100 + manual_seed: 40 + fast_kernels: True diff --git a/examples/large_models/vllm/mistral/requirements.txt b/examples/large_models/vllm/mistral/requirements.txt new file mode 100644 index 0000000000..76f11f1540 --- /dev/null +++ b/examples/large_models/vllm/mistral/requirements.txt @@ -0,0 +1 @@ +vllm \ No newline at end of file diff --git a/examples/large_models/vllm/mistral/sample_text.txt b/examples/large_models/vllm/mistral/sample_text.txt new file mode 100644 index 0000000000..edfe9f4c10 --- /dev/null +++ b/examples/large_models/vllm/mistral/sample_text.txt @@ -0,0 +1 @@ +what is the recipe of mayonnaise? \ No newline at end of file From 4dd890f18e09feea8933d6b01d5cfe2757bba52d Mon Sep 17 00:00:00 2001 From: agunapal Date: Thu, 9 Nov 2023 19:36:51 +0000 Subject: [PATCH 2/5] spellcheck --- ts_scripts/spellcheck_conf/wordlist.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index b2b05a22b7..0ec2048ccd 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1137,3 +1137,6 @@ Naver FlashAttention GenAI prem +vLLM +mistralai +PagedAttention \ No newline at end of file From 5fedf7a96fd4dd9afd064264b506b0a3f19991ad Mon Sep 17 00:00:00 2001 From: agunapal Date: Thu, 9 Nov 2023 19:38:50 +0000 Subject: [PATCH 3/5] lint --- .../large_models/vllm/mistral/custom_handler.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/large_models/vllm/mistral/custom_handler.py b/examples/large_models/vllm/mistral/custom_handler.py index 5e8b2f8d29..ae4c29ab57 100644 --- a/examples/large_models/vllm/mistral/custom_handler.py +++ b/examples/large_models/vllm/mistral/custom_handler.py @@ -38,7 +38,7 @@ def initialize(self, ctx: Context): torch.manual_seed(seed) self.model = LLM(model=model_path) - + logger.info("Model %s loaded successfully", ctx.model_name) self.initialized = True @@ -53,11 +53,14 @@ def preprocess(self, requests): attention masks. """ input_texts = [data.get("data") or data.get("body") for data in requests] - #return torch.as_tensor(input_texts, device=self.device) - input_texts = [ input_text.decode("utf-8") for input_text in input_texts if isinstance(input_text, (bytes, bytearray))] + # return torch.as_tensor(input_texts, device=self.device) + input_texts = [ + input_text.decode("utf-8") + for input_text in input_texts + if isinstance(input_text, (bytes, bytearray)) + ] return input_texts - def inference(self, input_batch): """ Predicts the class (or classes) of the received text using the serialized transformers @@ -70,10 +73,7 @@ def inference(self, input_batch): """ logger.info(f"Input text is {input_batch}") sampling_params = SamplingParams(max_tokens=self.max_new_tokens) - outputs = self.model.generate( - input_batch, sampling_params=sampling_params - ) - + outputs = self.model.generate(input_batch, sampling_params=sampling_params) logger.info("Generated text: %s", outputs) return outputs From a34a3059aa86fa302d4bf256fe3f627193716146 Mon Sep 17 00:00:00 2001 From: agunapal Date: Mon, 13 Nov 2023 20:43:13 +0000 Subject: [PATCH 4/5] Added support for multiple GPUs --- examples/large_models/vllm/mistral/Readme.md | 2 +- examples/large_models/vllm/mistral/custom_handler.py | 3 ++- examples/large_models/vllm/mistral/model-config.yaml | 7 ++++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/large_models/vllm/mistral/Readme.md b/examples/large_models/vllm/mistral/Readme.md index ba15867758..b88c91efe9 100644 --- a/examples/large_models/vllm/mistral/Readme.md +++ b/examples/large_models/vllm/mistral/Readme.md @@ -18,7 +18,7 @@ huggingface-cli login --token $HUGGINGFACE_TOKEN ``` ```bash -python ../Download_model.py --model_path model --model_name mistralai/Mistral-7B-v0.1 +python ../../Huggingface_accelerate/Download_model.py --model_path model --model_name mistralai/Mistral-7B-v0.1 ``` Model will be saved in the following path, `mistralai/Mistral-7B-v0.1`. diff --git a/examples/large_models/vllm/mistral/custom_handler.py b/examples/large_models/vllm/mistral/custom_handler.py index ae4c29ab57..cedca0c5bb 100644 --- a/examples/large_models/vllm/mistral/custom_handler.py +++ b/examples/large_models/vllm/mistral/custom_handler.py @@ -34,10 +34,11 @@ def initialize(self, ctx: Context): self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"]) model_name = ctx.model_yaml_config["handler"]["model_name"] model_path = ctx.model_yaml_config["handler"]["model_path"] + tp_size = ctx.model_yaml_config["handler"]["tensor_parallel_size"] seed = int(ctx.model_yaml_config["handler"]["manual_seed"]) torch.manual_seed(seed) - self.model = LLM(model=model_path) + self.model = LLM(model=model_path, tensor_parallel_size=tp_size) logger.info("Model %s loaded successfully", ctx.model_name) self.initialized = True diff --git a/examples/large_models/vllm/mistral/model-config.yaml b/examples/large_models/vllm/mistral/model-config.yaml index c4282cac71..dbd73251b5 100644 --- a/examples/large_models/vllm/mistral/model-config.yaml +++ b/examples/large_models/vllm/mistral/model-config.yaml @@ -4,10 +4,15 @@ maxWorkers: 1 maxBatchDelay: 100 responseTimeout: 1200 deviceType: "gpu" +# example of user specified GPU deviceIds +deviceIds: [0,1,2,3] # seting CUDA_VISIBLE_DEVICES + +torchrun: + nproc-per-node: 4 handler: model_name: "mistralai/Mistral-7B-v0.1" model_path: "/home/ubuntu/serve/examples/large_models/vllm/mistral/model/models--mistralai--Mistral-7B-v0.1/snapshots/5e9c98b96d071dce59368012254c55b0ec6f8658" max_new_tokens: 100 manual_seed: 40 - fast_kernels: True + tensor_parallel_size : 4 From 639f080f54811db212d1fbbf214c61bbaddaf74c Mon Sep 17 00:00:00 2001 From: agunapal Date: Fri, 8 Dec 2023 23:55:34 +0000 Subject: [PATCH 5/5] addressed review comments --- examples/large_models/vllm/mistral/Readme.md | 6 ---- .../vllm/mistral/custom_handler.py | 31 +++++++++---------- .../vllm/mistral/model-config.yaml | 2 +- 3 files changed, 15 insertions(+), 24 deletions(-) diff --git a/examples/large_models/vllm/mistral/Readme.md b/examples/large_models/vllm/mistral/Readme.md index b88c91efe9..267d57e11d 100644 --- a/examples/large_models/vllm/mistral/Readme.md +++ b/examples/large_models/vllm/mistral/Readme.md @@ -3,11 +3,6 @@ This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on `mistralai/Mistral-7B-v0.1` model. vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/) -Install vLLM with the following - -``` -pip install -r requirements.txt -``` ### Step 1: Login to HuggingFace Login with a HuggingFace account @@ -39,7 +34,6 @@ mv mistral7b.tar.gz model_store ### Step 3: Start torchserve -Update config.properties and start torchserve ```bash torchserve --start --ncs --ts-config config.properties --model-store model_store --models mistral7b.tar.gz diff --git a/examples/large_models/vllm/mistral/custom_handler.py b/examples/large_models/vllm/mistral/custom_handler.py index cedca0c5bb..626b62333a 100644 --- a/examples/large_models/vllm/mistral/custom_handler.py +++ b/examples/large_models/vllm/mistral/custom_handler.py @@ -1,5 +1,4 @@ import logging -from abc import ABC import torch import vllm @@ -12,20 +11,19 @@ logger.info("vLLM version %s", vllm.__version__) -class LlamaHandler(BaseHandler, ABC): +class CustomHandler(BaseHandler): """ - Transformers handler class for sequence, token classification and question answering. + Custom Handler for integrating vLLM """ def __init__(self): - super(LlamaHandler, self).__init__() + super().__init__() self.max_new_tokens = None self.tokenizer = None self.initialized = False def initialize(self, ctx: Context): - """In this initialize function, the HF large model is loaded and - partitioned using DeepSpeed. + """In this initialize function, the model is loaded Args: ctx (context): It is a JSON Object containing information pertaining to the model artifacts parameters. @@ -34,7 +32,7 @@ def initialize(self, ctx: Context): self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"]) model_name = ctx.model_yaml_config["handler"]["model_name"] model_path = ctx.model_yaml_config["handler"]["model_path"] - tp_size = ctx.model_yaml_config["handler"]["tensor_parallel_size"] + tp_size = ctx.model_yaml_config["torchrun"]["nproc-per-node"] seed = int(ctx.model_yaml_config["handler"]["manual_seed"]) torch.manual_seed(seed) @@ -45,7 +43,7 @@ def initialize(self, ctx: Context): def preprocess(self, requests): """ - Basic text preprocessing, based on the user's choice of application mode. + Pre-processing of prompts being sent to TorchServe Args: requests (list): A list of dictionaries with a "data" or "body" field, each containing the input text to be processed. @@ -54,21 +52,19 @@ def preprocess(self, requests): attention masks. """ input_texts = [data.get("data") or data.get("body") for data in requests] - # return torch.as_tensor(input_texts, device=self.device) input_texts = [ input_text.decode("utf-8") - for input_text in input_texts if isinstance(input_text, (bytes, bytearray)) + else input_text + for input_text in input_texts ] return input_texts def inference(self, input_batch): """ - Predicts the class (or classes) of the received text using the serialized transformers - checkpoint. + Generates the model response for the given prompt Args: - input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch - of attention masks, as returned by the preprocess function. + input_batch : List of input text prompts as returned by the preprocess function. Returns: list: A list of strings with the predicted values for each input text in the batch. """ @@ -80,10 +76,11 @@ def inference(self, input_batch): return outputs def postprocess(self, inference_output): - """Post Process Function converts the predicted response into Torchserve readable format. + """Post Process Function returns the text response from the vLLM output. Args: - inference_output (list): It contains the predicted response of the input text. + inference_output (list): It contains the response of vLLM Returns: (list): Returns a list of the Predictions and Explanations. """ - return [inference_output[0].outputs[0].text] + + return [inf_output.outputs[0].text for inf_output in inference_output] diff --git a/examples/large_models/vllm/mistral/model-config.yaml b/examples/large_models/vllm/mistral/model-config.yaml index dbd73251b5..9ec8f6ec46 100644 --- a/examples/large_models/vllm/mistral/model-config.yaml +++ b/examples/large_models/vllm/mistral/model-config.yaml @@ -1,6 +1,7 @@ # TorchServe frontend parameters minWorkers: 1 maxWorkers: 1 +batchSize: 2 maxBatchDelay: 100 responseTimeout: 1200 deviceType: "gpu" @@ -15,4 +16,3 @@ handler: model_path: "/home/ubuntu/serve/examples/large_models/vllm/mistral/model/models--mistralai--Mistral-7B-v0.1/snapshots/5e9c98b96d071dce59368012254c55b0ec6f8658" max_new_tokens: 100 manual_seed: 40 - tensor_parallel_size : 4