In [1]:
import torch
import numpy as np
import tensorrt as trt
import onnx

from transformers import BertTokenizer, BertForSequenceClassification

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
task_name = "mnli" # cola, mnli, qnli, qqp
batch_size = 32
max_length = 128
onnx_filename = 'bert-base.onnx'
tensorrt_file_name = 'bert-base.plan'

In [3]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
saved_path = f'../ignore/task/bert-base_{task_name}.pt'


task = {
    "qnli":{
        "num_labels": 2,
        "test_dataset_name": "validation",
        "tokenize": lambda data:tokenizer(data['question'], data['sentence'], truncation=True, max_length=max_length, padding='max_length')
    },
    "mnli":{
        "num_labels": 3,
        "test_dataset_name": "validation_matched",
        "tokenize": lambda data:tokenizer(data['premise'], data['hypothesis'], truncation=True, max_length=max_length, padding='max_length')
    },
    "qqp":{
        "num_labels": 2,
        "test_dataset_name": "validation",
        "tokenize": lambda data:tokenizer(data['question1'], data['question2'], truncation=True, max_length=max_length, padding='max_length')
    },
    "cola":{
        "num_labels": 2,
        "test_dataset_name": "validation",
        "tokenize": lambda data:tokenizer(data['sentence'], truncation=True, max_length=max_length, padding='max_length')
    }
}
task = task[task_name]
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=task["num_labels"])

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
model = torch.load(saved_path)
model.cuda()
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

convert model into onnx

In [5]:
input_ids = torch.ones((batch_size, max_length), dtype=torch.long).cuda()
attention_mask = torch.ones((batch_size, max_length), dtype=torch.long).cuda()
token_type_ids = torch.ones((batch_size, max_length), dtype=torch.long).cuda()
 
torch.onnx.export(
    model,
    (input_ids, attention_mask, token_type_ids),
    onnx_filename,
    input_names=['input_ids', 'attention_mask', 'token_type_ids'],
    output_names=['outputs'],
    export_params=True
)

verbose: False, log level: Level.ERROR



convert onnx into tensorRT

In [6]:
onnx_model = onnx.load(onnx_filename)

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(EXPLICIT_BATCH)
config = builder.create_builder_config()
parser = trt.OnnxParser(network, TRT_LOGGER)

config.max_workspace_size = (1 << 30)
# config.in

 
with open(onnx_filename, 'rb') as model:
    if not parser.parse(model.read()):
        for error in range(parser.num_errors):
            print (parser.get_error(error))
 
engine = builder.build_engine(network, config)
buf = engine.serialize()
with open(tensorrt_file_name, 'wb') as f:
    f.write(buf)


[09/13/2023-12:02:11] [TRT] [W] parsers/onnx/onnx2trt_utils.cpp:368: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.


  engine = builder.build_engine(network, config)
