Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Sep 28, 2021
2 parents 6d20feb + 89ccbf6 commit 7e34851
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 1 deletion.
158 changes: 158 additions & 0 deletions examples/FasterTransformer_HuggingFace_Bert/Bert_FT_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# This script is adopted from the run-glue example of Nvidia-FasterTransformer,https://github.com/NVIDIA/FasterTransformer/blob/main/sample/pytorch/run_glue.py

import argparse
import logging
import os
import random
import timeit

import numpy as np
import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from tqdm import tqdm, trange

from transformers import (
BertConfig,
BertTokenizer,
)
from utils.modeling_bert import BertForSequenceClassification, BertForQuestionAnswering
from transformers import glue_compute_metrics as compute_metrics
from transformers import glue_convert_examples_to_features as convert_examples_to_features
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors


logger = logging.getLogger(__name__)


def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

def main():
parser = argparse.ArgumentParser()

parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name",
)

parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--tokenizer_name",
default="",
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3",
)
parser.add_argument(
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument("--mode", default= "sequence_classification", help=" Set the model for sequence classification or question answering")
parser.add_argument(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.",
)

parser.add_argument(
"--batch_size", default=8, type=int, help="Batch size for tracing.",
)

parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
# parser.add_arument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")

parser.add_argument("--model_type", type=str, help="ori, ths, thsext")
parser.add_argument("--data_type", type=str, help="fp32, fp16")
parser.add_argument('--ths_path', type=str, default='./lib/libpyt_fastertransformer.so',
help='path of the pyt_fastertransformer dynamic lib file')
parser.add_argument('--remove_padding', action='store_false',
help='Remove the padding of sentences of encoder.')
parser.add_argument('--allow_gemm_test', action='store_false',
help='per-channel quantization.')

args = parser.parse_args()

if torch.cuda.is_available():
device = torch.device("cuda")
args.device = device

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if args.device else logging.WARN,
)

# Set seed
set_seed(args)

tokenizer = BertTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
)

logger.info("Parameters %s", args)

checkpoints = [args.model_name_or_path]
for checkpoint in checkpoints:
use_ths = args.model_type.startswith('ths')
if args.mode == "sequence_classification":
model = BertForSequenceClassification.from_pretrained(checkpoint, torchscript=use_ths)
elif args.mode == "question_answering":
model = BertForQuestionAnswering.from_pretrained(checkpoint, torchscript=use_ths)
model.to(args.device)

if args.data_type == 'fp16':
logger.info("Use fp16")
model.half()
if args.model_type == 'thsext':
logger.info("Use custom BERT encoder for TorchScript")
from utils.encoder import EncoderWeights, CustomEncoder
weights = EncoderWeights(
model.config.num_hidden_layers, model.config.hidden_size,
torch.load(os.path.join(checkpoint, 'pytorch_model.bin'), map_location='cpu'))
weights.to_cuda()
if args.data_type == 'fp16':
weights.to_half()
enc = CustomEncoder(model.config.num_hidden_layers,
model.config.num_attention_heads,
model.config.hidden_size//model.config.num_attention_heads,
weights,
remove_padding=args.remove_padding,
allow_gemm_test=(args.allow_gemm_test),
path=os.path.abspath(args.ths_path))
enc_ = torch.jit.script(enc)
model.replace_encoder(enc_)
if use_ths:
logger.info("Use TorchScript mode")
fake_input_id = torch.LongTensor(args.batch_size, args.max_seq_length)
fake_input_id.fill_(1)
fake_input_id = fake_input_id.to(args.device)
fake_mask = torch.ones(args.batch_size, args.max_seq_length).to(args.device)
fake_type_id = fake_input_id.clone().detach()
if args.data_type == 'fp16':
fake_mask = fake_mask.half()
model.eval()
with torch.no_grad():
print("********** input id and mask sizes ******",fake_input_id.size(),fake_mask.size() )
model_ = torch.jit.trace(model, (fake_input_id, fake_mask))
model = model_
torch.jit.save(model,"traced_model.pt")

if __name__ == "__main__":
main()
148 changes: 148 additions & 0 deletions examples/FasterTransformer_HuggingFace_Bert/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
## Faster Transformer

Batch inferencing with Transformers faces two challenges

- Large batch sizes suffer from higher latency and small or medium-sized batches this will become kernel latency launch bound.
- Padding wastes a lot of compute, (batchsize, seq_length) requires to pad the sequence to (batchsize, max_length) where difference between avg_length and max_length results in a considerable waste of computation, increasing the batch size worsen this situation.

[Faster Transformers](https://github.com/NVIDIA/FasterTransformer/blob/main/sample/pytorch/run_glue.py) (FT) from Nvidia along with [Efficient Transformers](https://github.com/bytedance/effective_transformer) (EFFT) that is built on top of FT address the above two challenges, by fusing the CUDA kernels and dynamically removing padding during computations. The current implementation from [Faster Transformers](https://github.com/NVIDIA/FasterTransformer/blob/main/sample/pytorch/run_glue.py) support BERT like encoder and decoder layers. In this example, we show how to get a Torchsctipted (traced) EFFT variant of Bert models from HuggingFace (HF) for sequence classification and question answering and serve it.


### How to get a Torchsctipted (Traced) EFFT of HF Bert model and serving it

**Requirements**

Running Faster Transformer at this point is recommended through [NVIDIA docker and NGC container](https://github.com/NVIDIA/FasterTransformer#requirements), also it requires [Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or [Turing](https://www.nvidia.com/en-us/geforce/turing/) or [Ampere](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/) based GPU. For this example we have used a **g4dn.2xlarge** EC2 instance that has a T4 GPU.

**Setup the a GPU machine that meets the requirements and connect to it**.

```bash
### Sign up for NGC https://ngc.nvidia.com and get API key###
docker login nvcr.io
Username: $oauthtoken
Password: API key

docker pull nvcr.io/nvidia/pytorch:20.12-py3

nvidia-docker run -ti --gpus all --rm nvcr.io/nvidia/pytorch:20.12-py3 bash

git clone https://github.com/NVIDIA/FasterTransformer.git

cd FasterTransformer

mkdir -p build

cd build

cmake -DSM=75 -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON .. # -DSM = 70 for V100 gpu ------- 60 (P40) or 61 (P4) or 70 (V100) or 75(T4) or 80 (A100),

make

pip install transformers==2.5.1

cd /workspace

# clone Torchserve to access examples
git clone https://github.com/pytorch/serve.git

# install torchserve
cd serve

python ts_scripts/install_dependencies.py --cuda=cu102

pip install torchserve torch-model-archiver torch-workflow-archiver

cp /examples/FasterTransformer_HuggingFace_Bert/Bert_FT_trace.py /workspace/FasterTransformer/build/pytorch


```

Now we are ready to make the Torchscripted file, as mentioned at the beginning two models are supported Bert for sequence classification and question answering. To do this step we need the download the model weights. We do this the same way we do in [HuggingFace example](https://github.com/pytorch/serve/tree/master/examples/Huggingface_Transformers).

#### Sequence classification EFFT Traced model and serving

```bash
# Sequence classification
python ../Huggingface_Transformers/Download_Transformer_models.py

# This will downlaod the model weights in ../Huggingface_Transformers/Transfomer_model directory

cd /workspace/FasterTransformer/build/

# This will generate the Traced model "traced_model.pt"
# --data_type can be fp16 or fp32
python pytorch/Bert_FT_trace.py --mode sequence_classification --model_name_or_path "/workspace//serve/examples/Huggingface_Transformers/Transformer_model" --tokenizer_name "bert-base-uncased" --batch_size 1 --data_type fp16 --model_type thsext

cd -

# make sure to change the ../Huggingface_Transformers/setup_config.json "save_mode":"torchscript" and "FasterTransformer":true

# change the ../Huggingface_Transformers/setup_config.json
{
"model_name":"bert-base-uncased",
"mode":"question_answering",
"do_lower_case":true,
"num_labels":"0",
"save_mode":"pretrained",
"max_length":"128",
"captum_explanation":false,
"embedding_name": "bert",
"FasterTransformer":true
}

torch-model-archiver --model-name BERTSeqClassification --version 1.0 --serialized-file /workspace/FasterTransformer/build/traced_model.pt --handler ../Huggingface_Transformers/Transformer_handler_generalized.py --extra-files "../Huggingface_Transformers/setup_config.json,../Huggingface_Transformers/Seq_classification_artifacts/index_to_name.json,/workspace/FasterTransformer/build/lib/libpyt_fastertransformer.so"

mkdir model_store

mv BERTSeqClassification.mar model_store/

torchserve --start --model-store model_store --models my_tc=BERTSeqClassification.mar --ncs

curl -X POST http://127.0.0.1:8080/predictions/my_tc -T ../Huggingface_Transformers/Seq_classification_artifacts/sample_text_captum_input.txt

```

#### Question answering EFFT Traced model and serving

```bash
# Question answering

# change the ../Huggingface_Transformers/setup_config.json
{
"model_name":"bert-base-uncased",
"mode":"question_answering",
"do_lower_case":true,
"num_labels":"0",
"save_mode":"pretrained",
"max_length":"128",
"captum_explanation":false,
"embedding_name": "bert",
"FasterTransformer":true
}
python ../Huggingface_Transformers/Download_Transformer_models.py

# This will downlaod the model weights in ../Huggingface_Transformers/Transfomer_model directory

cd /workspace/FasterTransformer/build/

# This will generate the Traced model "traced_model.pt"
# --data_type can be fp16 or fp32
python pytorch/Bert_FT_trace.py --mode question_answering --model_name_or_path "/workspace//serve/examples/Huggingface_Transformers/Transformer_model" --tokenizer_name "bert-base-uncased" --batch_size 1 --data_type fp16 --model_type thsext

cd -

# make sure to change the ../Huggingface_Transformers/setup_config.json "save_mode":"torchscript"

torch-model-archiver --model-name BERTQA --version 1.0 --serialized-file /workspace/FasterTransformer/build/traced_model.pt --handler ../Huggingface_Transformers/Transformer_handler_generalized.py --extra-files "../Huggingface_Transformers/setup_config.json,/workspace/FasterTransformer/build/lib/libpyt_fastertransformer.so"

mkdir model_store

mv BERTQA.mar model_store/

torchserve --start --model-store model_store --models my_tc=BERTQA.mar --ncs

curl -X POST http://127.0.0.1:8080/predictions/my_tc -T ../Huggingface_Transformers/QA_artifacts/sample_text_captum_input.txt

```

####
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def initialize(self, ctx):
else:
logger.warning("Missing the setup_config.json file.")

# Loading the shared object of compiled Faster Transformer Library if Faster Transformer is set
if self.setup_config["FasterTransformer"]:
faster_transformer_complied_path = os.path.join(model_dir, "libpyt_fastertransformer.so")
torch.classes.load_library(faster_transformer_complied_path)

# Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode
# further setup config can be added.
if self.setup_config["save_mode"] == "torchscript":
Expand Down
3 changes: 2 additions & 1 deletion examples/Huggingface_Transformers/setup_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
"save_mode":"pretrained",
"max_length":"150",
"captum_explanation":true,
"embedding_name": "bert"
"embedding_name": "bert",
"FasterTransformer":false
}

0 comments on commit 7e34851

Please sign in to comment.