Skip to content

Commit

Permalink
TRTLLM new API support (NVIDIA#9003)
Browse files Browse the repository at this point in the history
* Add trtllm checkpoint

* Change model config

* fix no query_group

* Using build API

* Change export to new API

* Update generate API

* Fix runtime config

* Fix for llama

* Fix for ptuning

* Fix TP issue

* Change TP rank for building weight dict

* Add lora config

* add prompt embedding table config

* Fix PP isue

* PP layers fix

* Fix no prompt task ids

* Add bos for Gemma

* Add multi block mode

* Embedding and layernorm for PP

* MPI multiprocess support for multinode

* Only output text on first rank

* Change to ModelRunnerCpp

* Add falcon

* Add rotary_pct default value

* Falcon fix

* Add MOE config

* Fix MOE weight dict

* Clean code

* Add rotary_base

* Fix MOE config

* Fix falcon new architecture

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix Gemma 7B

* Add rotary_scaling

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com>

---------

Signed-off-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com>
Co-authored-by: abharwani <abharwani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Onur Yilmaz <35306097+oyilmaz-nvidia@users.noreply.github.com>
Co-authored-by: oyilmaz-nvidia <oyilmaz-nvidia@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
6 people committed May 13, 2024
1 parent 18cce97 commit 5c9c15c
Show file tree
Hide file tree
Showing 8 changed files with 468 additions and 200 deletions.
138 changes: 75 additions & 63 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
from nemo.export.tarutils import TarPath, unpack_tarball
from nemo.export.trt_llm.model_config_trt import model_config_to_tensorrt_llm
from nemo.export.trt_llm.nemo.nemo_ckpt_convert import build_tokenizer
from nemo.export.trt_llm.nemo_utils import get_tokenzier, nemo_llm_model_to_model_config, nemo_llm_to_model_config
from nemo.export.trt_llm.nemo_utils import get_tokenzier, nemo_llm_model_to_model_config, nemo_to_trtllm_config
from nemo.export.trt_llm.qnemo import qnemo_to_tensorrt_llm
from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer
from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine
from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load, load_refit
from nemo.export.trt_llm.utils import is_nemo_file

Expand Down Expand Up @@ -115,6 +116,7 @@ def export(
max_output_token: int = 256,
max_batch_size: int = 8,
max_prompt_embedding_table_size=None,
use_parallel_embedding: bool = False,
use_inflight_batching: bool = False,
enable_context_fmha: bool = True,
paged_kv_cache: bool = False,
Expand Down Expand Up @@ -188,65 +190,70 @@ def export(

self.model = None

tmp_dir = tempfile.TemporaryDirectory()
nemo_export_dir = Path(tmp_dir.name)
if tensorrt_llm.mpi_rank() == 0:
tmp_dir = tempfile.TemporaryDirectory()
nemo_export_dir = Path(tmp_dir.name)

if nemo_checkpoint_path.endswith("qnemo"):
if os.path.isdir(nemo_checkpoint_path):
nemo_export_dir = nemo_checkpoint_path
if nemo_checkpoint_path.endswith("qnemo"):
if os.path.isdir(nemo_checkpoint_path):
nemo_export_dir = nemo_checkpoint_path
else:
unpack_tarball(nemo_checkpoint_path, tmp_dir.name)
nemo_checkpoint_path = tmp_dir.name
self.tokenizer = get_nmt_tokenizer(nemo_checkpoint_path)

qnemo_to_tensorrt_llm(
nemo_checkpoint_path=nemo_checkpoint_path,
engine_dir=self.model_dir,
max_input_len=max_input_token,
max_output_len=max_output_token,
max_batch_size=max_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
lora_target_modules=lora_target_modules,
)
else:
unpack_tarball(nemo_checkpoint_path, tmp_dir.name)
nemo_checkpoint_path = tmp_dir.name
self.tokenizer = get_nmt_tokenizer(nemo_checkpoint_path)

qnemo_to_tensorrt_llm(
nemo_checkpoint_path=nemo_checkpoint_path,
engine_dir=self.model_dir,
max_input_len=max_input_token,
max_output_len=max_output_token,
max_batch_size=max_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
lora_target_modules=lora_target_modules,
)
else:
model_configs, self.tokenizer = nemo_llm_to_model_config(
in_file=nemo_checkpoint_path,
decoder_type=model_type,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
nemo_export_dir=nemo_export_dir,
save_nemo_model_config=save_nemo_model_config,
)
weights_dicts, model_configs, self.tokenizer = nemo_to_trtllm_config(
in_file=nemo_checkpoint_path,
decoder_type=model_type,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
use_parallel_embedding=use_parallel_embedding,
nemo_export_dir=nemo_export_dir,
save_nemo_model_config=save_nemo_model_config,
)

model_config_to_tensorrt_llm(
model_configs,
self.model_dir,
world_size=tensor_parallel_size * pipeline_parallel_size,
max_input_len=max_input_token,
max_output_len=max_output_token,
max_batch_size=max_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
use_inflight_batching=use_inflight_batching,
paged_kv_cache=paged_kv_cache,
enable_context_fmha=enable_context_fmha,
enable_multi_block_mode=enable_multi_block_mode,
use_lora_plugin=use_lora_plugin,
lora_target_modules=lora_target_modules,
max_lora_rank=max_lora_rank,
)
for weight_dict, model_config in zip(weights_dicts, model_configs):
build_and_save_engine(
max_input_len=max_input_token,
max_output_len=max_output_token,
max_batch_size=max_batch_size,
model_config=model_config,
model_weights=weight_dict,
model_dir=self.model_dir,
model_type=model_type,
lora_ckpt_list=self.lora_ckpt_list,
use_lora_plugin=use_lora_plugin,
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
enable_multi_block_mode=enable_multi_block_mode,
)

tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model")
if os.path.exists(tokenizer_path):
shutil.copy(tokenizer_path, self.model_dir)
else:
self.tokenizer.save_pretrained(os.path.join(self.model_dir, 'huggingface_tokenizer'))
tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model")
if os.path.exists(tokenizer_path):
shutil.copy(tokenizer_path, self.model_dir)
else:
self.tokenizer.save_pretrained(os.path.join(self.model_dir, 'huggingface_tokenizer'))

nemo_model_config = os.path.join(nemo_export_dir, "model_config.yaml")
if os.path.exists(nemo_model_config):
shutil.copy(nemo_model_config, self.model_dir)

nemo_model_config = os.path.join(nemo_export_dir, "model_config.yaml")
if os.path.exists(nemo_model_config):
shutil.copy(nemo_model_config, self.model_dir)
tmp_dir.cleanup()

tmp_dir.cleanup()
if tensorrt_llm.mpi_world_size() > 1:
tensorrt_llm.mpi_barrier()

if load_model:
self._load()
Expand Down Expand Up @@ -279,7 +286,9 @@ def build(

# Build or refit TRT-LLM engine from a nemo model.
model_configs = nemo_llm_model_to_model_config(
nemo_model=nemo_model, decoder_type=model_type, nemo_model_config=nemo_model_config,
nemo_model=nemo_model,
decoder_type=model_type,
nemo_model_config=nemo_model_config,
)

model_config_to_tensorrt_llm(
Expand All @@ -298,7 +307,9 @@ def build(
)

def refit(
self, nemo_model, nemo_model_config,
self,
nemo_model,
nemo_model_config,
):
assert self.use_refit, "TRT-LLM model must be built() with refit=True"

Expand Down Expand Up @@ -329,7 +340,6 @@ def forward(
output_log_probs: bool = False,
**sampling_kwargs,
):

"""
Exports nemo checkpoints to TensorRT-LLM.
Expand Down Expand Up @@ -394,7 +404,7 @@ def forward(
), "Task: {0} doesn't exist in the task list.".format(task_ids[i])
input_task_ids.append(self.task_ids[task_ids[i]])
if not streaming:
if torch.distributed.is_initialized():
if torch.distributed.is_initialized() or tensorrt_llm.mpi_world_size() > 1:
multiprocessed_env = True
else:
multiprocessed_env = False
Expand Down Expand Up @@ -478,7 +488,7 @@ def get_hidden_size(self):
if self.config is None:
return None
else:
return self.config["builder_config"]["hidden_size"]
return self.config["pretrained_config"]["hidden_size"]

@property
def get_triton_input(self):
Expand Down Expand Up @@ -665,7 +675,9 @@ def _get_prompt_embedding_table_ckpt(self, prompt_embeddings_checkpoint_path):
return weights.cpu().detach()

def _get_prompt_embedding_table(
self, prompt_embeddings_table=None, prompt_embeddings_checkpoint_path=None,
self,
prompt_embeddings_table=None,
prompt_embeddings_checkpoint_path=None,
):
if prompt_embeddings_table is not None and prompt_embeddings_checkpoint_path is not None:
LOGGER.warning(
Expand Down Expand Up @@ -694,15 +706,15 @@ def _get_prompt_embedding_table(
raise TypeError(prompt_embeddings_checkpoint_path + " is not a nemo file.")
prompt_embeddings_table = self._get_prompt_embedding_table_ckpt(prompt_embeddings_checkpoint_path)

dtype = self.config['builder_config']['precision']
dtype = self.config['pretrained_config']['dtype']
prompt_embeddings_table = prompt_embeddings_table.to(
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)
).cuda()

if prompt_embeddings_table.size(dim=1) != self.config["builder_config"]["hidden_size"]:
if prompt_embeddings_table.size(dim=1) != self.config["pretrained_config"]["hidden_size"]:
raise Exception(
"Hidden dimension of the model is {0} and does not match with the dimension of the prompt table.".format(
self.config["builder_config"]["hidden_size"]
self.config["pretrained_config"]["hidden_size"]
)
)

Expand Down
8 changes: 8 additions & 0 deletions nemo/export/trt_llm/decoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@
DECODER_GEMMA: GemmaDecoderLayerConfigBuilder,
}

DECODER_MODEL_TYPE = {
DECODER_GPT2: 'GPTForCausalLM',
DECODER_GPTNEXT: 'GPTForCausalLM',
DECODER_LLAMA: 'LLaMAForCausalLM',
DECODER_GEMMA: 'GemmaForCausalLM',
DECODER_FALCON: 'FalconForCausalLM',
}


def build_decoder_layer_config(layer, decoder: str, dtype=trt.float16, rank=0, tensor_parallel=1):
"""Builds the decoder layer config with the input torch module."""
Expand Down
Loading

0 comments on commit 5c9c15c

Please sign in to comment.