In [2]:
import torch
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
import onnxruntime as ort
from onnxruntime_tools import optimizer
import argparse
import pandas as pd
import numpy as np
# from transformers.convert_graph_to_onnx import convert
import os
import time
import torch.nn.functional as F
import onnx

In [6]:

def preprocess(tokenizer, text):
    max_seq_length = 128
    tokens = tokenizer.tokenize(text)
    tokens.insert(0, "[CLS]")
    tokens.append("[SEP]")
    segment_ids = []
    for i in range(len(tokens)):
        segment_ids.append(0)
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_ids)
    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)

    input_ids = torch.tensor([input_ids], dtype=torch.long)
    input_mask = torch.tensor([input_mask], dtype=torch.long)
    segment_ids = torch.tensor([segment_ids], dtype=torch.long)

    return input_ids, input_mask, segment_ids



def convert_bert_to_onnx(text, model_dir, onnx_model_path):
    config = BertConfig.from_pretrained(model_dir)
    tokenizer = BertTokenizer.from_pretrained(model_dir)
    model = BertForSequenceClassification.from_pretrained(model_dir, config=config)
    model.to("cpu")
    input_ids, input_mask, segment_ids = preprocess(tokenizer, text)

    dynamic_axes = {
        'input_id': {0:'1',1:'128'},
        'sequence_id': {0:'1',1:'128'},
        'input_mask': {0:'1',1:'128'},
        'output': {0:'1'},
    }
    print('starting export')
    torch.onnx.export(model, (input_ids, input_mask, segment_ids), onnx_model_path,
                      input_names=["input_ids", "input_mask", "segment_ids"],
                      output_names=["output"], opset_version=10, do_constant_folding=True, dynamic_axes=dynamic_axes, verbose=False)

    print("SST model convert to onnx format successfully")


In [8]:
args ={ 
'eval_data_path': 'data/jun3_10Klabels/data_binary_pos_neg_balanced/val_is_hired_1mo.csv', 
'model_dir': '/scratch/da2734/twitter/jobs/training_binary/simple_transformers_manu_bertbase/is_hired_1mo/', 
'onnx_model_path': '/scratch/da2734/twitter/jobs/dhaval_test.onnx'
}

# args['model_dir']
args['model_dir'], args['onnx_model_path']

('/scratch/da2734/twitter/jobs/training_binary/simple_transformers_manu_bertbase/is_hired_1mo/',
 '/scratch/da2734/twitter/jobs/dhaval_test.onnx')

In [9]:
convert_bert_to_onnx('tick tock', args['model_dir'], args['onnx_model_path'])


starting export
SST model convert to onnx format successfully
