-
Notifications
You must be signed in to change notification settings - Fork 132
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
System Info
- CPU architecture: x86_64
- CPU/Host memory size: 64 GB (AWS g5.4xlarge)
- GPU properties:
- GPU name: NVIDIA A10G
- GPU memory size: 24 GB
- Libraries: TensorRT-LLM
- TensorRT-LLM version: v0.9 (tag)
- Container: nvidia/cuda:12.1.0-devel-ubuntu22.04
- NVIDIA driver version: 12.2
- OS: Ubuntu
- Additional information:
- Instance type: AWS g5.4xlarge
- Model being quantized: Llama-2-7b-chat-hf-instruct
- Quantization format: W4A16_AWQ
- Max sequence length used: 4096
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Deploy the quantised model with the Triton server in the following config:
config.txt
Used the following script to do the async requests:
import os
import queue
from functools import partial
from typing import List, Dict, Any, Optional
import numpy as np
import tritonclient.grpc as grpcclient
from transformers import AutoTokenizer
from tritonclient.utils import InferenceServerException, np_to_triton_dtype
def prepare_tensor(name, input):
t = grpcclient.InferInput(name, input.shape,
np_to_triton_dtype(input.dtype))
t.set_data_from_numpy(input)
return t
def prepare_outputs(output_names):
outputs = []
for output_name in output_names:
outputs.append(grpcclient.InferRequestedOutput(output_name))
return outputs
class UserData:
def __init__(self):
self._completed_requests = queue.Queue()
def callback(idx, user_data, result, error):
if error:
user_data._completed_requests.put((idx, error))
else:
user_data._completed_requests.put((idx, result))
class TensorRTLLMClient:
def __init__(self, url: str, model_name: str, tokenizer_path: str):
self.url = url
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side='left', trust_remote_code=True)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.client = grpcclient.InferenceServerClient(
url=self.url,
# ssl=True
)
self.end_id_data = np.array([[
self.tokenizer.encode(
self.tokenizer.eos_token,
add_special_tokens=False)[0]
]], dtype=np.int32)
self.lora_cache = {}
def prepare_inputs(self, input_ids_data, input_lengths_data, request_output_len_data,
beam_width_data, temperature_data, repetition_penalty_data,
frequency_penalty_data, lora_task_id_data,
lora_weights_data, lora_config_data,
# top_k_data,
top_p_data,
):
inputs = [
prepare_tensor("input_ids", input_ids_data),
prepare_tensor("input_lengths", input_lengths_data),
prepare_tensor("request_output_len", request_output_len_data),
prepare_tensor("beam_width", beam_width_data),
prepare_tensor("temperature", temperature_data),
prepare_tensor("end_id", self.end_id_data),
prepare_tensor("pad_id", self.end_id_data),
# prepare_tensor("runtime_top_k", top_k_data),
prepare_tensor("runtime_top_p", top_p_data),
]
if lora_task_id_data is not None:
inputs += [prepare_tensor("lora_task_id", lora_task_id_data)]
if lora_weights_data is not None:
inputs += [
prepare_tensor("lora_weights", lora_weights_data),
prepare_tensor("lora_config", lora_config_data),
]
if repetition_penalty_data is not None:
inputs += [
prepare_tensor("repetition_penalty", repetition_penalty_data),
]
if frequency_penalty_data is not None:
inputs += [
prepare_tensor("frequency_penalty", frequency_penalty_data),
]
return inputs
def load_lora(self, lora_path: str, lora_task_id: Optional[int] = None):
if lora_path == "":
return None, None, None
lora_weights_data = None
lora_config_data = None
if (lora_path != ""):
lora_weights_data = np.load(
os.path.join(lora_path, "model.lora_weights.npy"))
try:
lora_config_data = np.load(
os.path.join(lora_path, "model.lora_config.npy"))
except Exception:
lora_config_data = np.load(
os.path.join(lora_path, "model.lora_keys.npy"))
lora_task_id_data = None
if lora_task_id is not None:
lora_task_id_data = np.array([[lora_task_id]], dtype=np.uint64)
return lora_weights_data, lora_config_data, lora_task_id_data
def generate(self,
prompts: List[str],
lora_paths: List[str],
request_output_len: int = 16,
beam_width: int = 1,
top_k: int = 1,
top_p: float = 0.6,
temperature: float = 0.8,
repetition_penalty: float = 1.1,
frequency_penalty: float = None,
) -> List[str]:
assert len(prompts) == len(lora_paths), "Number of prompts and lora paths should be the same"
user_data = UserData()
for idx, (prompt, lora_path) in enumerate(zip(prompts, lora_paths)):
if lora_path in self.lora_cache:
lora_weights_data = None
lora_config_data = None
lora_task_id_data = self.lora_cache[lora_path]
else:
hashed_lora_task_id: int = hash(lora_path)
lora_weights_data, lora_config_data, lora_task_id_data = self.load_lora(lora_path, lora_task_id=hashed_lora_task_id)
self.lora_cache[lora_path] = lora_task_id_data
input_ids = [self.tokenizer.encode(prompt)]
input_ids_data = np.array(input_ids, dtype=np.int32)
input_lengths = [[len(ii)] for ii in input_ids]
input_lengths_data = np.array(input_lengths, dtype=np.int32)
request_output_len = [[request_output_len]]
request_output_len_data = np.array(request_output_len, dtype=np.int32)
beam_width = [[beam_width]]
beam_width_data = np.array(beam_width, dtype=np.int32)
# top_k = [[top_k]]
# top_k_data = np.array(top_k, dtype=np.int32)
top_p = [[top_p]]
top_p_data = np.array(top_p, dtype=np.float32)
temperature = [[temperature]]
temperature_data = np.array(temperature, dtype=np.float32)
repetition_penalty_data = None
if repetition_penalty is not None:
repetition_penalty = [[repetition_penalty]]
repetition_penalty_data = np.array(repetition_penalty,
dtype=np.float32)
frequency_penalty_data = None
if frequency_penalty is not None:
frequency_penalty = [[frequency_penalty]]
frequency_penalty_data = np.array(frequency_penalty, dtype=np.float32)
inputs = self.prepare_inputs(
input_ids_data, input_lengths_data, request_output_len_data,
beam_width_data, temperature_data, repetition_penalty_data,
frequency_penalty_data, lora_task_id_data, lora_weights_data,
lora_config_data,
# top_k_data,
top_p_data
)
outputs = prepare_outputs(["output_ids"])
self.client.async_infer(
model_name=self.model_name,
inputs=inputs,
outputs=outputs,
request_id=str(idx),
callback=partial(callback, idx, user_data)
)
print("Sent request number", idx)
# Parse the response when type(infer_future) == grpcclient.InferResult
generated_texts = [None] * len(prompts)
expected_responses = len(prompts)
while True:
try:
idx, result = user_data._completed_requests.get(block=True, timeout=20)
if type(result) == grpcclient.InferResult:
output_ids = result.as_numpy("output_ids")
generated_text = self.tokenizer.decode(output_ids[0][0], skip_special_tokens=True)
generated_texts[idx] = generated_text
print("Received response number", idx)
count_non_none = sum([1 for gt in generated_texts if gt is not None])
if count_non_none == expected_responses:
break
elif type(result) == InferenceServerException:
print("Error in response number", idx)
print(result.status())
print(result.message())
else:
print("Unknown response type in response number", idx)
break
except queue.Empty:
print("All responses received")
break
return generated_textsInference script
client = TensorRTLLMClient(
url="localhost:8001",
model_name="tensorrt_llm",
tokenizer_path="/path/to/tokenizer"
)
return_value = client.generate(
prompts=[
"Hello, ",
"Hi, "
],
lora_paths=[
"/path/to/lora/adapter1",
"/path/to/lora/adapter2"
],
request_output_len=512
)Expected behavior
The generate method should process the prompts using the specified LoRA adapters and return generated texts without errors.
The backend will handle the batching in-flight.
actual behavior
The method fails with an error related to the shape of the 'repetition_penalty' input. The error message is:
Error in response number 1
StatusCode.INVALID_ARGUMENT
[request id: 1] unexpected shape for input 'repetition_penalty' for model 'tensorrt_llm'. Expected [-1,1], got [1,1,1,1]. NOTE: Setting a non-zero max_batch_size in the model config requires a batch dimension to be prepended to each input shape. If you want to specify the full shape including the batch dim in your input dims config, try setting max_batch_size to zero. See the model configuration docs for more info on max_batch_size.additional notes
- The client successfully sends the requests (as indicated by the "Sent request number" messages), but encounters an error when processing the responses.
- Only one of the two requests (number 0) seems to have been processed successfully before the error occurred.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working