-
Notifications
You must be signed in to change notification settings - Fork 823
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
313 additions
and
1 deletion.
There are no files selected for viewing
158 changes: 158 additions & 0 deletions
158
examples/FasterTransformer_HuggingFace_Bert/Bert_FT_trace.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# This script is adopted from the run-glue example of Nvidia-FasterTransformer,https://github.com/NVIDIA/FasterTransformer/blob/main/sample/pytorch/run_glue.py | ||
|
||
import argparse | ||
import logging | ||
import os | ||
import random | ||
import timeit | ||
|
||
import numpy as np | ||
import torch | ||
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset | ||
from tqdm import tqdm, trange | ||
|
||
from transformers import ( | ||
BertConfig, | ||
BertTokenizer, | ||
) | ||
from utils.modeling_bert import BertForSequenceClassification, BertForQuestionAnswering | ||
from transformers import glue_compute_metrics as compute_metrics | ||
from transformers import glue_convert_examples_to_features as convert_examples_to_features | ||
from transformers import glue_output_modes as output_modes | ||
from transformers import glue_processors as processors | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def set_seed(args): | ||
random.seed(args.seed) | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--model_name_or_path", | ||
default=None, | ||
type=str, | ||
required=True, | ||
help="Path to pre-trained model or shortcut name", | ||
) | ||
|
||
parser.add_argument( | ||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name", | ||
) | ||
parser.add_argument( | ||
"--tokenizer_name", | ||
default="", | ||
type=str, | ||
help="Pretrained tokenizer name or path if not the same as model_name", | ||
) | ||
parser.add_argument( | ||
"--cache_dir", | ||
default="", | ||
type=str, | ||
help="Where do you want to store the pre-trained models downloaded from s3", | ||
) | ||
parser.add_argument( | ||
"--max_seq_length", | ||
default=128, | ||
type=int, | ||
help="The maximum total input sequence length after tokenization. Sequences longer " | ||
"than this will be truncated, sequences shorter will be padded.", | ||
) | ||
parser.add_argument("--mode", default= "sequence_classification", help=" Set the model for sequence classification or question answering") | ||
parser.add_argument( | ||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.", | ||
) | ||
|
||
parser.add_argument( | ||
"--batch_size", default=8, type=int, help="Batch size for tracing.", | ||
) | ||
|
||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") | ||
# parser.add_arument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") | ||
|
||
parser.add_argument("--model_type", type=str, help="ori, ths, thsext") | ||
parser.add_argument("--data_type", type=str, help="fp32, fp16") | ||
parser.add_argument('--ths_path', type=str, default='./lib/libpyt_fastertransformer.so', | ||
help='path of the pyt_fastertransformer dynamic lib file') | ||
parser.add_argument('--remove_padding', action='store_false', | ||
help='Remove the padding of sentences of encoder.') | ||
parser.add_argument('--allow_gemm_test', action='store_false', | ||
help='per-channel quantization.') | ||
|
||
args = parser.parse_args() | ||
|
||
if torch.cuda.is_available(): | ||
device = torch.device("cuda") | ||
args.device = device | ||
|
||
# Setup logging | ||
logging.basicConfig( | ||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||
datefmt="%m/%d/%Y %H:%M:%S", | ||
level=logging.INFO if args.device else logging.WARN, | ||
) | ||
|
||
# Set seed | ||
set_seed(args) | ||
|
||
tokenizer = BertTokenizer.from_pretrained( | ||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, | ||
do_lower_case=args.do_lower_case, | ||
cache_dir=args.cache_dir if args.cache_dir else None, | ||
) | ||
|
||
logger.info("Parameters %s", args) | ||
|
||
checkpoints = [args.model_name_or_path] | ||
for checkpoint in checkpoints: | ||
use_ths = args.model_type.startswith('ths') | ||
if args.mode == "sequence_classification": | ||
model = BertForSequenceClassification.from_pretrained(checkpoint, torchscript=use_ths) | ||
elif args.mode == "question_answering": | ||
model = BertForQuestionAnswering.from_pretrained(checkpoint, torchscript=use_ths) | ||
model.to(args.device) | ||
|
||
if args.data_type == 'fp16': | ||
logger.info("Use fp16") | ||
model.half() | ||
if args.model_type == 'thsext': | ||
logger.info("Use custom BERT encoder for TorchScript") | ||
from utils.encoder import EncoderWeights, CustomEncoder | ||
weights = EncoderWeights( | ||
model.config.num_hidden_layers, model.config.hidden_size, | ||
torch.load(os.path.join(checkpoint, 'pytorch_model.bin'), map_location='cpu')) | ||
weights.to_cuda() | ||
if args.data_type == 'fp16': | ||
weights.to_half() | ||
enc = CustomEncoder(model.config.num_hidden_layers, | ||
model.config.num_attention_heads, | ||
model.config.hidden_size//model.config.num_attention_heads, | ||
weights, | ||
remove_padding=args.remove_padding, | ||
allow_gemm_test=(args.allow_gemm_test), | ||
path=os.path.abspath(args.ths_path)) | ||
enc_ = torch.jit.script(enc) | ||
model.replace_encoder(enc_) | ||
if use_ths: | ||
logger.info("Use TorchScript mode") | ||
fake_input_id = torch.LongTensor(args.batch_size, args.max_seq_length) | ||
fake_input_id.fill_(1) | ||
fake_input_id = fake_input_id.to(args.device) | ||
fake_mask = torch.ones(args.batch_size, args.max_seq_length).to(args.device) | ||
fake_type_id = fake_input_id.clone().detach() | ||
if args.data_type == 'fp16': | ||
fake_mask = fake_mask.half() | ||
model.eval() | ||
with torch.no_grad(): | ||
print("********** input id and mask sizes ******",fake_input_id.size(),fake_mask.size() ) | ||
model_ = torch.jit.trace(model, (fake_input_id, fake_mask)) | ||
model = model_ | ||
torch.jit.save(model,"traced_model.pt") | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
## Faster Transformer | ||
|
||
Batch inferencing with Transformers faces two challenges | ||
|
||
- Large batch sizes suffer from higher latency and small or medium-sized batches this will become kernel latency launch bound. | ||
- Padding wastes a lot of compute, (batchsize, seq_length) requires to pad the sequence to (batchsize, max_length) where difference between avg_length and max_length results in a considerable waste of computation, increasing the batch size worsen this situation. | ||
|
||
[Faster Transformers](https://github.com/NVIDIA/FasterTransformer/blob/main/sample/pytorch/run_glue.py) (FT) from Nvidia along with [Efficient Transformers](https://github.com/bytedance/effective_transformer) (EFFT) that is built on top of FT address the above two challenges, by fusing the CUDA kernels and dynamically removing padding during computations. The current implementation from [Faster Transformers](https://github.com/NVIDIA/FasterTransformer/blob/main/sample/pytorch/run_glue.py) support BERT like encoder and decoder layers. In this example, we show how to get a Torchsctipted (traced) EFFT variant of Bert models from HuggingFace (HF) for sequence classification and question answering and serve it. | ||
|
||
|
||
### How to get a Torchsctipted (Traced) EFFT of HF Bert model and serving it | ||
|
||
**Requirements** | ||
|
||
Running Faster Transformer at this point is recommended through [NVIDIA docker and NGC container](https://github.com/NVIDIA/FasterTransformer#requirements), also it requires [Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or [Turing](https://www.nvidia.com/en-us/geforce/turing/) or [Ampere](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/) based GPU. For this example we have used a **g4dn.2xlarge** EC2 instance that has a T4 GPU. | ||
|
||
**Setup the a GPU machine that meets the requirements and connect to it**. | ||
|
||
```bash | ||
### Sign up for NGC https://ngc.nvidia.com and get API key### | ||
docker login nvcr.io | ||
Username: $oauthtoken | ||
Password: API key | ||
|
||
docker pull nvcr.io/nvidia/pytorch:20.12-py3 | ||
|
||
nvidia-docker run -ti --gpus all --rm nvcr.io/nvidia/pytorch:20.12-py3 bash | ||
|
||
git clone https://github.com/NVIDIA/FasterTransformer.git | ||
|
||
cd FasterTransformer | ||
|
||
mkdir -p build | ||
|
||
cd build | ||
|
||
cmake -DSM=75 -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON .. # -DSM = 70 for V100 gpu ------- 60 (P40) or 61 (P4) or 70 (V100) or 75(T4) or 80 (A100), | ||
|
||
make | ||
|
||
pip install transformers==2.5.1 | ||
|
||
cd /workspace | ||
|
||
# clone Torchserve to access examples | ||
git clone https://github.com/pytorch/serve.git | ||
|
||
# install torchserve | ||
cd serve | ||
|
||
python ts_scripts/install_dependencies.py --cuda=cu102 | ||
|
||
pip install torchserve torch-model-archiver torch-workflow-archiver | ||
|
||
cp /examples/FasterTransformer_HuggingFace_Bert/Bert_FT_trace.py /workspace/FasterTransformer/build/pytorch | ||
|
||
|
||
``` | ||
|
||
Now we are ready to make the Torchscripted file, as mentioned at the beginning two models are supported Bert for sequence classification and question answering. To do this step we need the download the model weights. We do this the same way we do in [HuggingFace example](https://github.com/pytorch/serve/tree/master/examples/Huggingface_Transformers). | ||
|
||
#### Sequence classification EFFT Traced model and serving | ||
|
||
```bash | ||
# Sequence classification | ||
python ../Huggingface_Transformers/Download_Transformer_models.py | ||
|
||
# This will downlaod the model weights in ../Huggingface_Transformers/Transfomer_model directory | ||
|
||
cd /workspace/FasterTransformer/build/ | ||
|
||
# This will generate the Traced model "traced_model.pt" | ||
# --data_type can be fp16 or fp32 | ||
python pytorch/Bert_FT_trace.py --mode sequence_classification --model_name_or_path "/workspace//serve/examples/Huggingface_Transformers/Transformer_model" --tokenizer_name "bert-base-uncased" --batch_size 1 --data_type fp16 --model_type thsext | ||
|
||
cd - | ||
|
||
# make sure to change the ../Huggingface_Transformers/setup_config.json "save_mode":"torchscript" and "FasterTransformer":true | ||
|
||
# change the ../Huggingface_Transformers/setup_config.json | ||
{ | ||
"model_name":"bert-base-uncased", | ||
"mode":"question_answering", | ||
"do_lower_case":true, | ||
"num_labels":"0", | ||
"save_mode":"pretrained", | ||
"max_length":"128", | ||
"captum_explanation":false, | ||
"embedding_name": "bert", | ||
"FasterTransformer":true | ||
} | ||
|
||
torch-model-archiver --model-name BERTSeqClassification --version 1.0 --serialized-file /workspace/FasterTransformer/build/traced_model.pt --handler ../Huggingface_Transformers/Transformer_handler_generalized.py --extra-files "../Huggingface_Transformers/setup_config.json,../Huggingface_Transformers/Seq_classification_artifacts/index_to_name.json,/workspace/FasterTransformer/build/lib/libpyt_fastertransformer.so" | ||
|
||
mkdir model_store | ||
|
||
mv BERTSeqClassification.mar model_store/ | ||
|
||
torchserve --start --model-store model_store --models my_tc=BERTSeqClassification.mar --ncs | ||
|
||
curl -X POST http://127.0.0.1:8080/predictions/my_tc -T ../Huggingface_Transformers/Seq_classification_artifacts/sample_text_captum_input.txt | ||
|
||
``` | ||
|
||
#### Question answering EFFT Traced model and serving | ||
|
||
```bash | ||
# Question answering | ||
|
||
# change the ../Huggingface_Transformers/setup_config.json | ||
{ | ||
"model_name":"bert-base-uncased", | ||
"mode":"question_answering", | ||
"do_lower_case":true, | ||
"num_labels":"0", | ||
"save_mode":"pretrained", | ||
"max_length":"128", | ||
"captum_explanation":false, | ||
"embedding_name": "bert", | ||
"FasterTransformer":true | ||
} | ||
python ../Huggingface_Transformers/Download_Transformer_models.py | ||
|
||
# This will downlaod the model weights in ../Huggingface_Transformers/Transfomer_model directory | ||
|
||
cd /workspace/FasterTransformer/build/ | ||
|
||
# This will generate the Traced model "traced_model.pt" | ||
# --data_type can be fp16 or fp32 | ||
python pytorch/Bert_FT_trace.py --mode question_answering --model_name_or_path "/workspace//serve/examples/Huggingface_Transformers/Transformer_model" --tokenizer_name "bert-base-uncased" --batch_size 1 --data_type fp16 --model_type thsext | ||
|
||
cd - | ||
|
||
# make sure to change the ../Huggingface_Transformers/setup_config.json "save_mode":"torchscript" | ||
|
||
torch-model-archiver --model-name BERTQA --version 1.0 --serialized-file /workspace/FasterTransformer/build/traced_model.pt --handler ../Huggingface_Transformers/Transformer_handler_generalized.py --extra-files "../Huggingface_Transformers/setup_config.json,/workspace/FasterTransformer/build/lib/libpyt_fastertransformer.so" | ||
|
||
mkdir model_store | ||
|
||
mv BERTQA.mar model_store/ | ||
|
||
torchserve --start --model-store model_store --models my_tc=BERTQA.mar --ncs | ||
|
||
curl -X POST http://127.0.0.1:8080/predictions/my_tc -T ../Huggingface_Transformers/QA_artifacts/sample_text_captum_input.txt | ||
|
||
``` | ||
|
||
#### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters