Copyright (c) Microsoft Corporation. All rights reserved.  
Licensed under the MIT License.

# Inference PyTorch GPT-Neo Model with ONNX Runtime on CPU

In this tutorial, you'll be introduced to how to load a GPT2 model from PyTorch, convert it to ONNX, and inference it using ONNX Runtime using IO Binding. Note that past state is used to get better performance.

## Prerequisites ##

If you have Jupyter Notebook, you may directly run this notebook. We will use pip to install or upgrade [PyTorch](https://pytorch.org/), [OnnxRuntime](https://microsoft.github.io/onnxruntime/) and other required packages.

Otherwise, you can setup a new environment. First, we install [AnaConda](https://www.anaconda.com/distribution/). Then open an AnaConda prompt window and run the following commands:

```console
conda create -n cpu_env python=3.8
conda activate cpu_env
conda install jupyter
jupyter notebook
```
The last command will launch Jupyter Notebook and we can open this notebook in browser to continue.

In [1]:
import os
import numpy as np
import torch
import onnxruntime as ort

In [2]:
!pip install coloredlogs

^C
Traceback (most recent call last):
  File "/usr/bin/pip", line 11, in <module>
    load_entry_point('pip==20.0.2', 'console_scripts', 'pip')()
  File "/usr/lib/python3/dist-packages/pkg_resources/__init__.py", line 490, in load_entry_point
    return get_distribution(dist).load_entry_point(group, name)
  File "/usr/lib/python3/dist-packages/pkg_resources/__init__.py", line 2854, in load_entry_point
    return ep.load()
  File "/usr/lib/python3/dist-packages/pkg_resources/__init__.py", line 2445, in load
    return self.resolve()
  File "/usr/lib/python3/dist-packages/pkg_resources/__init__.py", line 2451, in resolve
    module = __import__(self.module_name, fromlist=['__name__'], level=0)
  File "/usr/lib/python3/dist-packages/pip/_internal/cli/main.py", line 10, in <module>
    from pip._internal.cli.autocompletion import autocomplete
  File "/usr/lib/python3/dist-packages/pip/_internal/cli/autocompletion.py", line 9, in <module>
    from pip._internal.cli.main_parser import create

## Convert GPT-Neo model from PyTorch to ONNX ##

We have a script [convert_to_onnx.py](https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/convert_to_onnx.py) that could help you to convert GPT2 with past state to ONNX. 

The script accepts a pretrained model name or path of a checkpoint directory as input, and converts the model to ONNX. It also verifies that the ONNX model could generate same input as the pytorch model. The usage is like 
```
python -m onnxruntime.transformers.convert_to_onnx -m model_name_or_path --output gpt2.onnx -o -p fp32|fp16|int8
```
The -p option can be used to choose the precision: fp32 (float32), fp16 (mixed precision) or int8 (quantization). The -o option will generate optimized model, which is required for fp16 or int8.

Here we use a pretrained model as example:

In [3]:
num_attention_heads = 12
hidden_size = 768
num_layer = 12
model_name_or_path = "EleutherAI/gpt-neo-125M"
device = torch.device("cpu")
ort_session = ort.InferenceSession("models/onnx/gpt-neo-default-past-lm.onnx")
for input in ort_session.get_inputs():
    print(input)
print("Outputs:")
for output in ort_session.get_outputs():
    print(output)

NodeArg(name='input_ids', type='tensor(int64)', shape=['batch', 'sequence'])
NodeArg(name='past_key_values.0.key', type='tensor(float)', shape=['batch', 12, 'past_sequence', 64])
NodeArg(name='past_key_values.0.value', type='tensor(float)', shape=['batch', 12, 'past_sequence', 64])
NodeArg(name='past_key_values.1.key', type='tensor(float)', shape=['batch', 12, 'past_sequence', 64])
NodeArg(name='past_key_values.1.value', type='tensor(float)', shape=['batch', 12, 'past_sequence', 64])
NodeArg(name='past_key_values.2.key', type='tensor(float)', shape=['batch', 12, 'past_sequence', 64])
NodeArg(name='past_key_values.2.value', type='tensor(float)', shape=['batch', 12, 'past_sequence', 64])
NodeArg(name='past_key_values.3.key', type='tensor(float)', shape=['batch', 12, 'past_sequence', 64])
NodeArg(name='past_key_values.3.value', type='tensor(float)', shape=['batch', 12, 'past_sequence', 64])
NodeArg(name='past_key_values.4.key', type='tensor(float)', shape=['batch', 12, 'past_sequence', 64

## PyTorch Inference using Huggingface Transformers##

In the following, we will use an example input to get the output from PyTorch for comparison purpose.
For the first inference, there is no any past state. We can prepare empty state for input.

In [10]:
from transformers import AutoTokenizer

EXAMPLE_Text = ['Oh no what']

def get_tokenizer(model_name_or_path):
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    tokenizer.padding_side = "right"
    tokenizer.pad_token = tokenizer.eos_token
    #okenizer.add_special_tokens({'pad_token': '[PAD]'})
    return tokenizer

def get_example_inputs(prompt_text=EXAMPLE_Text):    
    tokenizer = get_tokenizer(model_name_or_path)
    encodings_dict = tokenizer.batch_encode_plus(prompt_text, padding=True)
    print(encodings_dict)
        
    input_ids = torch.tensor(encodings_dict['input_ids'], dtype=torch.int64)
    attention_mask = torch.tensor(encodings_dict['attention_mask'], dtype=torch.float32)
    position_ids = (attention_mask.long().cumsum(-1) - 1)
    position_ids.masked_fill_(position_ids < 0, 0)

    #Empty Past State for generating first word
    empty_past = []
    batch_size = input_ids.size(0)
    past_shape = [batch_size, 12, 0, 64]
    for i in range(num_layer * 2):
        empty_past.append(torch.empty(past_shape).type(torch.float32).to(device))
       
    return input_ids, attention_mask, position_ids, empty_past

input_ids, attention_mask, position_ids, empty_past = get_example_inputs()

{'input_ids': [[5812, 645, 644]], 'attention_mask': [[1, 1, 1]]}


In [11]:
inputs = {}
shape_name_mapping = {
    'sequence': input_ids.size(1),
    'batch': 1,
    'past_sequence': 0,
    'past_sequence + sequence': input_ids.size(1)
}
type_name_mapping = {
    'tensor(int64)': np.int64,
    'tensor(float)': np.float32
}
def map_shape(x):
    if type(x) is str:
        return shape_name_mapping[x]
    return x

def test_generation(tokenizer, num_tokens_to_produce = 30):
    eos_token_id = tokenizer.eos_token_id
    input_ids, attention_mask, position_ids, past = get_example_inputs(["In a world where"])
    batch_size = input_ids.size(0)
    has_eos = torch.zeros(batch_size, dtype=torch.bool)
    all_token_ids = input_ids.clone()

    for step in range(num_tokens_to_produce):
        inputs = {
            
        }
        for input in ort_session.get_inputs():
            processed_shape = list(map(map_shape, input.shape))
            inputs[input.name] = np.zeros(processed_shape, dtype = type_name_mapping[input.type])
        
        attention_mask = attention_mask.numpy()
        input_ids = input_ids.numpy()
        inputs['input_ids'] = input_ids
        for i in range(0, num_layer):
            inputs[f'past_key_values.{i}.key'] = np.ascontiguousarray(past[i * 2])
            inputs[f'past_key_values.{i}.value'] = np.ascontiguousarray(past[(i * 2) + 1])
        inputs['attention_mask'] = attention_mask
        print(inputs)
        break
        outputs = ort_session.run(None, inputs)
        next_token_logits = outputs[0][:, -1, :]
        next_token_logits = torch.Tensor(next_token_logits)
        # Greedy approach is used here. You can easily extend it to use beam search and sampling to pick next tokens.
        next_tokens = torch.argmax(next_token_logits, dim=-1)
        has_eos = has_eos | (next_tokens == eos_token_id)
        tokens_to_add = next_tokens.masked_fill(has_eos, eos_token_id)
        all_token_ids = torch.cat([all_token_ids, tokens_to_add.unsqueeze(-1)], dim=-1)

        # Update input_ids, attention_mask, position_ids and past
        input_ids = tokens_to_add.clone().detach().reshape([batch_size, 1]).to(device)    
        position_ids = (position_ids[:,-1] + 1).reshape(batch_size,1)
        attention_mask = torch.Tensor(attention_mask)
        attention_mask = torch.ones([batch_size, 1]).type_as(attention_mask)

        past = []
        for i in range(num_layer * 2):
            past_i = outputs[i + 1]
            past.append(past_i)
        if torch.all(has_eos):
            break

    for i, output in enumerate(all_token_ids):
        print("------------")
        print(tokenizer.decode(output, skip_special_tokens=True))

tokenizer = get_tokenizer(model_name_or_path)
test_generation(tokenizer)

KeyboardInterrupt: 

## ONNX Runtime Inference ##

We can use ONNX Runtime to inference. The inputs are dictionary with name and numpy array as value, and the output is list of numpy array. Note that both input and output are in CPU. When you run the inference in GPU, it will involve data copy between CPU and GPU for input and output.

Let's create an inference session for ONNX Runtime given the exported ONNX model, and see the output.

In [None]:
import onnxruntime
import numpy

input_ids, attention_mask, position_ids, empty_past = get_example_inputs()

onnx_model_path = "gpt2.onnx"
session = onnxruntime.InferenceSession(onnx_model_path)
ort_inputs = {'input_ids': numpy.ascontiguousarray(input_ids.cpu().numpy()),
              'attention_mask' : numpy.ascontiguousarray(attention_mask.cpu().numpy()),
              'position_ids': numpy.ascontiguousarray(position_ids.cpu().numpy())
             }
for i, past_i in enumerate(empty_past):
    ort_inputs[f'past_{i}'] = numpy.ascontiguousarray(past_i.cpu().numpy())
ort_outputs = session.run(None, ort_inputs)

We can compare the outputs from PyTorch and ONNX Runtime. Logits are very close (max difference is 1E-4).

In [None]:
logits_masked_diff = (torch_output[0] - ort_outputs[0]) * attention_mask.unsqueeze(2)
max_logits_diff = logits_masked_diff.abs().max()
print("max logits diff (ignored padding)", max_logits_diff)

## ONNX Runtime Inference with IO Binding ##

To avoid data copy for input and output, ONNX Runtime also supports IO Binding. User could provide some buffer for input and outputs. For GPU inference, the buffer can be in GPU to reduce memory copy between CPU and GPU. This is helpful for high performance inference in GPU. For GPT-2, IO Binding might help the performance when batch size or (past) sequence length is large.

In [None]:
def inference_with_io_binding(session, config, input_ids, position_ids, attention_mask, past):
    output_shapes = Gpt2Helper.get_output_shapes(batch_size=input_ids.size(0),
                                                 past_sequence_length=past[0].size(3),
                                                 sequence_length=input_ids.size(1),
                                                 config=config)
    output_buffers = Gpt2Helper.get_output_buffers(output_shapes, device)

    io_binding = Gpt2Helper.prepare_io_binding(session, input_ids, position_ids, attention_mask, past,
                                               output_buffers, output_shapes)
    session.run_with_iobinding(io_binding)

    outputs = Gpt2Helper.get_outputs_from_io_binding_buffer(session, output_buffers, output_shapes,
                                                            return_numpy=False)
    return outputs

We can see that the result is exactly same with/without IO Binding:

In [None]:
input_ids, attention_mask, position_ids, empty_past = get_example_inputs()
outputs = inference_with_io_binding(session, config, input_ids, position_ids, attention_mask, empty_past)
for i in range(len(outputs)):
    assert torch.eq(outputs[i], torch.from_numpy(ort_outputs[i])).all()
print("IO Binding result is good")

## Batch Text Generation ##

Here is an example for text generation using ONNX Runtime or PyTorch. For ONNX Runtime, IO Binding is used for better performance.

In [None]:
def test_generation(tokenizer, input_text, ort_session=None, num_tokens_to_produce = 30):
    use_onnxruntime = (ort_session is not None)
    print("Text generation using", "OnnxRuntime" if use_onnxruntime else "PyTorch", "...")
    eos_token_id = tokenizer.eos_token_id
    
    input_ids, attention_mask, position_ids, past = get_example_inputs(input_text)
    batch_size = input_ids.size(0)

    has_eos = torch.zeros(batch_size, dtype=torch.bool)

    all_token_ids = input_ids.clone()

    for step in range(num_tokens_to_produce):
        if ort_session is not None:
            outputs = inference_with_io_binding(ort_session, config, input_ids, position_ids, attention_mask, past)
        else:
            outputs = torch_model(input_ids, attention_mask=attention_mask, position_ids=position_ids, past=past)  

        next_token_logits = outputs[0][:, -1, :]
        # Greedy approach is used here. You can easily extend it to use beam search and sampling to pick next tokens.
        next_tokens = torch.argmax(next_token_logits, dim=-1)

        has_eos = has_eos | (next_tokens == eos_token_id)
        tokens_to_add = next_tokens.masked_fill(has_eos, eos_token_id)
        all_token_ids = torch.cat([all_token_ids, tokens_to_add.unsqueeze(-1)], dim=-1)

        # Update input_ids, attention_mask, position_ids and past
        input_ids = tokens_to_add.clone().detach().reshape([batch_size, 1]).to(device)    
        position_ids = (position_ids[:,-1] + 1).reshape(batch_size,1)
        attention_mask = torch.cat([attention_mask, torch.ones([batch_size, 1]).type_as(attention_mask)], 1).to(device)    

        past = []
        if not use_onnxruntime:
            past = list(outputs[1]) # past in torch output is tuple
        else:
            for i in range(num_layer):
                past_i = torch.from_numpy(outputs[i + 1]) if isinstance(outputs[i + 1], numpy.ndarray) else outputs[i + 1].clone().detach()
                past.append(past_i.to(device))

        if torch.all(has_eos):
            break

    for i, output in enumerate(all_token_ids):
        print("------------")
        print(tokenizer.decode(output, skip_special_tokens=True))

In [None]:
tokenizer = get_tokenizer(model_name_or_path, cache_dir)
input_text = EXAMPLE_Text
test_generation(tokenizer, input_text, ort_session=session)

Next, we use PyTorch to run again and we can see that the result is exactly same.

In [None]:
test_generation(tokenizer, input_text)

## Int8 Quantization ##
Next, we will apply dynamic quantization to the model. We optimize the model before quantization to get better performance.

Note that text generation result from fp32 and int8 models could be quite different. User shall evaluate the precision metric for your application for both fp32 and int8 models. If the quality of int8 model result is acceptable, you will be glad to find that it is faster than fp32 model in inference. 

Note that you can leverage [quantization aware training (QAT)](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/) for accuracy improvement if needed.

In [None]:
from onnxruntime.transformers.quantize_helper import QuantizeHelper

optimized_fp32_model_path = "gpt2_fp32.onnx"
quantized_int8_model_path = "gpt2_int8.onnx"
Gpt2Helper.optimize_onnx("gpt2.onnx", optimized_fp32_model_path, False, model.config.num_attention_heads, model.config.hidden_size)
QuantizeHelper.quantize_onnx_model(optimized_fp32_model_path, quantized_int8_model_path)

In [None]:
session_int8 = onnxruntime.InferenceSession(quantized_int8_model_path)
input_text = ['bert model optimization']
test_generation(tokenizer, input_text, ort_session=session_int8, num_tokens_to_produce=14)

## Benchmark ##
There is a tool benchmark_gpt2.py, which can be used to measure the performance of GPT-2 by PyTorch, ONNX Runtime without/with IO Binding.

In [None]:
!{sys.executable} -m onnxruntime.transformers.benchmark_gpt2 -m gpt2 -o

In [None]:
!{sys.executable} -m onnxruntime.transformers.benchmark_gpt2 -m gpt2 -o --precision int8

We can see that quantized model has significant speed up (close to 2x).

### Test Environment ###
The following is the hardware of the test machine, and software version:

In [None]:
!{sys.executable} -m onnxruntime.transformers.machine_info --silent